├── FPL-plus.png
├── PyMIC
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── docs
│ ├── .debug.yml
│ ├── Documentation
│ │ └── README.md
│ ├── Gemfile
│ ├── Makefile
│ ├── README.md
│ ├── _config.yml
│ ├── index.md
│ ├── make.bat
│ └── source
│ │ ├── api.rst
│ │ ├── conf.py
│ │ ├── index.rst
│ │ ├── installation.rst
│ │ ├── modules.rst
│ │ ├── pymic.io.rst
│ │ ├── pymic.layer.rst
│ │ ├── pymic.loss.cls.rst
│ │ ├── pymic.loss.rst
│ │ ├── pymic.loss.seg.rst
│ │ ├── pymic.net.cls.rst
│ │ ├── pymic.net.net2d.rst
│ │ ├── pymic.net.net3d.rst
│ │ ├── pymic.net.rst
│ │ ├── pymic.net_run.rst
│ │ ├── pymic.net_run_nll.rst
│ │ ├── pymic.net_run_ssl.rst
│ │ ├── pymic.net_run_wsl.rst
│ │ ├── pymic.rst
│ │ ├── pymic.transform.rst
│ │ ├── pymic.util.rst
│ │ ├── usage.fsl.rst
│ │ ├── usage.nll.rst
│ │ ├── usage.quickstart.rst
│ │ ├── usage.rst
│ │ ├── usage.ssl.rst
│ │ └── usage.wsl.rst
├── pymic
│ ├── __init__.py
│ ├── io
│ │ ├── __init__.py
│ │ ├── h5_dataset.py
│ │ ├── image_read_write.py
│ │ └── nifty_dataset.py
│ ├── layer
│ │ ├── __init__.py
│ │ ├── activation.py
│ │ ├── convolution.py
│ │ ├── deconvolution.py
│ │ └── space2channel.py
│ ├── loss
│ │ ├── __init__.py
│ │ ├── cls
│ │ │ ├── __init__.py
│ │ │ ├── basic.py
│ │ │ └── util.py
│ │ ├── loss_dict_cls.py
│ │ ├── loss_dict_seg.py
│ │ └── seg
│ │ │ ├── __init__.py
│ │ │ ├── abstract.py
│ │ │ ├── ce.py
│ │ │ ├── combined.py
│ │ │ ├── deep_sup.py
│ │ │ ├── dice.py
│ │ │ ├── exp_log.py
│ │ │ ├── gatedcrf.py
│ │ │ ├── mse.py
│ │ │ ├── mumford_shah.py
│ │ │ ├── slsr.py
│ │ │ ├── ssl.py
│ │ │ └── util.py
│ ├── net
│ │ ├── __init__.py
│ │ ├── cls
│ │ │ ├── __init__.py
│ │ │ └── torch_pretrained_net.py
│ │ ├── net3d
│ │ │ ├── __init__.py
│ │ │ ├── scse3d.py
│ │ │ ├── unet2d5.py
│ │ │ ├── unet2d5_dsbn.py
│ │ │ ├── unet3d.py
│ │ │ └── unet3d_scse.py
│ │ ├── net_dict_cls.py
│ │ └── net_dict_seg.py
│ ├── net_run
│ │ ├── __init__.py
│ │ ├── agent_abstract.py
│ │ ├── agent_cls.py
│ │ ├── agent_seg.py
│ │ ├── get_optimizer.py
│ │ ├── infer_func.py
│ │ └── net_run.py
│ ├── net_run_dsbn
│ │ ├── __init__.py
│ │ ├── agent_abstract.py
│ │ ├── agent_cls.py
│ │ ├── agent_seg.py
│ │ ├── dsbn.py
│ │ ├── get_optimizer.py
│ │ ├── infer_func.py
│ │ └── net_run.py
│ ├── net_run_nll
│ │ ├── __init__.py
│ │ ├── nll_clslsr.py
│ │ ├── nll_co_teaching.py
│ │ ├── nll_dast.py
│ │ ├── nll_main.py
│ │ └── nll_trinet.py
│ ├── net_run_ssl
│ │ ├── __init__.py
│ │ ├── ssl_abstract.py
│ │ ├── ssl_cct.py
│ │ ├── ssl_cps.py
│ │ ├── ssl_em.py
│ │ ├── ssl_main.py
│ │ ├── ssl_mt.py
│ │ ├── ssl_uamt.py
│ │ └── ssl_urpc.py
│ ├── net_run_wsl
│ │ ├── __init__.py
│ │ ├── wsl_abstract.py
│ │ ├── wsl_dmpls.py
│ │ ├── wsl_em.py
│ │ ├── wsl_gatedcrf.py
│ │ ├── wsl_main.py
│ │ ├── wsl_mumford_shah.py
│ │ ├── wsl_tv.py
│ │ └── wsl_ustm.py
│ ├── transform
│ │ ├── __init__.py
│ │ ├── abstract_transform.py
│ │ ├── crop.py
│ │ ├── flip.py
│ │ ├── intensity.py
│ │ ├── label_convert.py
│ │ ├── normalize.py
│ │ ├── pad.py
│ │ ├── rescale.py
│ │ ├── rotate.py
│ │ ├── threshold.py
│ │ └── trans_dict.py
│ └── util
│ │ ├── __init__.py
│ │ ├── evaluation_cls.py
│ │ ├── evaluation_seg.py
│ │ ├── evaluation_seg_train.py
│ │ ├── general.py
│ │ ├── image_process.py
│ │ ├── make_noise.py
│ │ ├── model_operate.py
│ │ ├── parse_config.py
│ │ ├── post_process.py
│ │ ├── preprocess.py
│ │ └── ramps.py
├── pyproject.toml
├── requirements.txt
└── setup.py
├── README.md
├── __init__.py
├── config_dual
├── data_vs
│ ├── test_hrT2.csv
│ ├── train_ceT1.csv
│ ├── train_ceT1_like.csv
│ ├── train_hrT2-ceT1_cyc.csv
│ ├── train_hrT2.csv
│ ├── train_hrT2_like.csv
│ ├── train_hrT2_pair.csv
│ ├── train_vs_t1s_wi+wp.csv
│ ├── valid_hrT2.csv
│ ├── vs_t1s_S.cfg
│ ├── vs_t1s_g.cfg
│ ├── vs_t1s_g_fake.cfg
│ └── vs_t1s_weights.cfg
└── evaluation.cfg
├── data
├── get image_weight.py
├── get_pixel_weight.py
├── preprocess_bst.py
├── preprocess_mmwhs.py
├── preprocess_vs.py
└── write_csv.py
├── dataset
├── ceT1_train
│ ├── img
│ │ └── vs_gk_99_t1.nii.gz
│ └── lab
│ │ └── vs_gk_99_t1.nii.gz
├── fake_data
│ ├── ceT1-hrT2-ceT1_ac
│ │ └── vs_gk_99_t1.nii.gz
│ ├── ceT1-hrT2-ceT1_cc
│ │ └── vs_gk_99_t1.nii.gz
│ ├── ceT1-hrT2_auxcyc
│ │ └── vs_gk_99_t1.nii.gz
│ ├── ceT1-hrT2_cyc
│ │ └── vs_gk_99_t1.nii.gz
│ └── hrT2-ceT1_train_cyc
│ │ └── vs_gk_98_t2.nii.gz
├── hrT2_test
│ ├── vs_gk_9_t2.nii.gz
│ └── vs_gk_9_t2_seg.nii.gz
├── hrT2_train
│ ├── img
│ │ └── vs_gk_98_t2.nii.gz
│ └── lab
│ │ └── vs_gk_98_t2.nii.gz
├── hrT2_valid
│ ├── vs_gk_95_t2.nii.gz
│ └── vs_gk_95_t2_seg.nii.gz
└── weight
│ └── cyc121_vst1s-gan.npy
├── docs
└── image_info.odt
├── merge_pixelw.py
└── run.sh
/FPL-plus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/FPL-plus.png
--------------------------------------------------------------------------------
/PyMIC/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | build/*
3 | dist/*
4 | *egg*/*
5 | *stop*
6 | files.txt
7 |
8 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
9 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks
10 |
11 | ### JupyterNotebooks ###
12 | # gitignore template for Jupyter Notebooks
13 | # website: http://jupyter.org/
14 |
15 | .ipynb_checkpoints
16 | */.ipynb_checkpoints/*
17 |
18 | # IPython
19 | profile_default/
20 | ipython_config.py
21 |
22 | # Remove previous ipynb_checkpoints
23 | # git rm -r .ipynb_checkpoints/
24 |
25 | ### Python ###
26 | # Byte-compiled / optimized / DLL files
27 | __pycache__/
28 | *.py[cod]
29 | *$py.class
30 |
31 | # C extensions
32 | *.so
33 |
34 | # Distribution / packaging
35 | .Python
36 | build/
37 | develop-eggs/
38 | dist/
39 | downloads/
40 | eggs/
41 | .eggs/
42 | lib/
43 | lib64/
44 | parts/
45 | sdist/
46 | var/
47 | wheels/
48 | share/python-wheels/
49 | *.egg-info/
50 | .installed.cfg
51 | *.egg
52 | MANIFEST
53 |
54 | # PyInstaller
55 | # Usually these files are written by a python script from a template
56 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
57 | *.manifest
58 | *.spec
59 |
60 | # Installer logs
61 | pip-log.txt
62 | pip-delete-this-directory.txt
63 |
64 | # Unit test / coverage reports
65 | htmlcov/
66 | .tox/
67 | .nox/
68 | .coverage
69 | .coverage.*
70 | .cache
71 | nosetests.xml
72 | coverage.xml
73 | *.cover
74 | *.py,cover
75 | .hypothesis/
76 | .pytest_cache/
77 | cover/
78 |
79 | # Translations
80 | *.mo
81 | *.pot
82 |
83 | # Django stuff:
84 | *.log
85 | local_settings.py
86 | db.sqlite3
87 | db.sqlite3-journal
88 |
89 | # Flask stuff:
90 | instance/
91 | .webassets-cache
92 |
93 | # Scrapy stuff:
94 | .scrapy
95 |
96 | # Sphinx documentation
97 | docs/_build/
98 |
99 | # PyBuilder
100 | .pybuilder/
101 | target/
102 |
103 | # Jupyter Notebook
104 |
105 | # IPython
106 |
107 | # pyenv
108 | # For a library or package, you might want to ignore these files since the code is
109 | # intended to run in multiple environments; otherwise, check them in:
110 | # .python-version
111 |
112 | # pipenv
113 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
114 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
115 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
116 | # install all needed dependencies.
117 | #Pipfile.lock
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
163 |
--------------------------------------------------------------------------------
/PyMIC/README.md:
--------------------------------------------------------------------------------
1 | # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing
2 |
3 | PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations.
4 |
5 | Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper:
6 |
7 | * G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022).
8 | [PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation.][arxiv2022] arXiv, 2208.09350.
9 |
10 | [arxiv2022]:http://arxiv.org/abs/2208.09350
11 |
12 | BibTeX entry:
13 |
14 | @article{Wang2022pymic,
15 | author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang},
16 | title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}},
17 | year = {2022},
18 | url = {http://arxiv.org/abs/2208.09350},
19 | journal = {arXiv},
20 | volume = {2208.09350},
21 | pages = {1-10},
22 | }
23 |
24 | # Features
25 | PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions:
26 | * Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning.
27 | * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC.
28 | * Easy-to-use I/O interface to read and write different 2D and 3D images.
29 | * Various data pre-processing/transformation methods before sending a tensor into a network.
30 | * Implementation of typical neural networks for medical image segmentation.
31 | * Re-useable training and testing pipeline that can be transferred to different tasks.
32 | * Evaluation metrics for quantitative evaluation of your methods.
33 |
34 | # Usage
35 | ## Requirement
36 | * [Pytorch][torch_link] version >=1.0.1
37 | * [TensorboardX][tbx_link] to visualize training performance
38 | * Some common python packages such as Numpy, Pandas, SimpleITK
39 | * See `requirements.txt` for details.
40 |
41 | [torch_link]:https://pytorch.org/
42 | [tbx_link]:https://github.com/lanpa/tensorboardX
43 |
44 | ## Installation
45 | Run the following command to install the latest released version of PyMIC:
46 |
47 | ```bash
48 | pip install PYMIC
49 | ```
50 | To install a specific version of PYMIC such as 0.3.0, run:
51 |
52 | ```bash
53 | pip install PYMIC==0.3.0
54 | ```
55 | Alternatively, you can download the source code for the latest version. Run the following command to compile and install:
56 |
57 | ```bash
58 | python setup.py install
59 | ```
60 |
61 | ## How to start
62 | * [PyMIC_examples][exp_link] shows some examples of starting to use PyMIC.
63 | * [PyMIC_doc][docs_link] provides documentation of this project.
64 |
65 | [docs_link]:https://pymic.readthedocs.io/en/latest/
66 | [exp_link]:https://github.com/HiLab-git/PyMIC_examples
67 |
68 | ## Projects based on PyMIC
69 | Using PyMIC, it becomes easy to develop deep learning models for different projects, such as the following:
70 |
71 | 1, [MyoPS][myops] Winner of the MICCAI 2020 myocardial pathology segmentation (MyoPS) Challenge.
72 |
73 | 2, [COPLE-Net][coplenet] (TMI 2020), COVID-19 Pneumonia Segmentation from CT images.
74 |
75 | 3, [Head-Neck-GTV][hn_gtv] (NeuroComputing 2020) Nasopharyngeal Carcinoma (NPC) GTV segmentation from Head and Neck CT images.
76 |
77 | 4, [UGIR][ugir] (MICCAI 2020) Uncertainty-guided interactive refinement for medical image segmentation.
78 |
79 | [myops]: https://github.com/HiLab-git/MyoPS2020
80 | [coplenet]:https://github.com/HiLab-git/COPLE-Net
81 | [hn_gtv]: https://github.com/HiLab-git/Head-Neck-GTV
82 | [ugir]: https://github.com/HiLab-git/UGIR
83 |
84 |
--------------------------------------------------------------------------------
/PyMIC/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/__init__.py
--------------------------------------------------------------------------------
/PyMIC/docs/.debug.yml:
--------------------------------------------------------------------------------
1 | remote_theme: false
2 |
3 | theme: jekyll-rtd-theme
--------------------------------------------------------------------------------
/PyMIC/docs/Documentation/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | sort: 2
3 | ---
4 | # Readme for Documentation
--------------------------------------------------------------------------------
/PyMIC/docs/Gemfile:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/PyMIC/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/PyMIC/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Welcome to PyMIC
2 |
3 | PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations.
4 |
5 | Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images.
6 |
7 | ### Features
8 | PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions:
9 | * Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning.
10 | * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC.
11 | * Easy-to-use I/O interface to read and write different 2D and 3D images.
12 | * Various data pre-processing/transformation methods before sending a tensor into a network.
13 | * Implementation of typical neural networks for medical image segmentation.
14 | * Re-useable training and testing pipeline that can be transferred to different tasks.
15 | * Evaluation metrics for quantitative evaluation of your methods.
16 |
17 | ### Installation
18 | Run the following command to install the current released version of PyMIC:
19 |
20 | ```bash
21 | pip install PYMIC
22 | ```
23 |
24 | Alternatively, you can download the source code for the latest version. Run the following command to compile and install:
25 |
26 | ```bash
27 | python setup.py install
28 | ```
29 |
30 | ### Quick Start
31 | [PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples:
32 |
33 | [examples]: https://github.com/HiLab-git/PyMIC_examples
34 |
35 |
--------------------------------------------------------------------------------
/PyMIC/docs/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-leap-day
--------------------------------------------------------------------------------
/PyMIC/docs/index.md:
--------------------------------------------------------------------------------
1 | ## Welcome to PyMIC
2 |
3 | PyMIC is a Pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations.
4 |
5 | Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images.
6 |
7 | ### Features
8 | PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions:
9 | * Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning.
10 | * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC.
11 | * Easy-to-use I/O interface to read and write different 2D and 3D images.
12 | * Various data pre-processing/transformation methods before sending a tensor into a network.
13 | * Implementation of typical neural networks for medical image segmentation.
14 | * Re-useable training and testing pipeline that can be transferred to different tasks.
15 | * Evaluation metrics for quantitative evaluation of your methods.
16 |
17 | ### Installation
18 | Run the following command to install the current released version of PyMIC:
19 |
20 | ```bash
21 | pip install PYMIC
22 | ```
23 |
24 | Alternatively, you can download the source code for the latest version. Run the following command to compile and install:
25 |
26 | ```bash
27 | python setup.py install
28 | ```
29 |
30 | ### Quick Start
31 | [PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples:
32 |
33 | [examples]: https://github.com/HiLab-git/PyMIC_examples
34 |
35 |
--------------------------------------------------------------------------------
/PyMIC/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/api.rst:
--------------------------------------------------------------------------------
1 | API
2 | ===
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | pymic.io
8 | pymic.layer
9 | pymic.loss
10 | pymic.net
11 | pymic.net_run
12 | pymic.net_run_nll
13 | pymic.net_run_ssl
14 | pymic.net_run_wsl
15 | pymic.transform
16 | pymic.util
--------------------------------------------------------------------------------
/PyMIC/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 |
3 | # -- Project information
4 | import os
5 | import sys
6 | sys.path.insert(0, os.path.abspath('./../..'))
7 |
8 | project = 'PyMIC'
9 | copyright = '2021, HiLab'
10 | author = 'HiLab'
11 |
12 | release = '0.1'
13 | version = '0.1.0'
14 |
15 | # -- General configuration
16 |
17 | extensions = [
18 | 'sphinx.ext.duration',
19 | 'sphinx.ext.doctest',
20 | 'sphinx.ext.autodoc',
21 | 'sphinx.ext.autosummary',
22 | 'sphinx.ext.intersphinx',
23 | ]
24 |
25 | intersphinx_mapping = {
26 | 'python': ('https://docs.python.org/3/', None),
27 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None),
28 | }
29 | intersphinx_disabled_domains = ['std']
30 |
31 | templates_path = ['_templates']
32 |
33 | # -- Options for HTML output
34 |
35 | html_theme = 'sphinx_rtd_theme'
36 |
37 | # -- Options for EPUB output
38 | epub_show_urls = 'footnote'
39 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to PyMIC's documentation!
2 | ===================================
3 |
4 | PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient
5 | deep learning. PyMIC is developed to support learning with imperfect labels, including
6 | semi-supervised and weakly supervised learning, and learning with noisy annotations.
7 |
8 | Check out the :doc:`installation` section for install PyMIC, and go to the :doc:`usage`
9 | section for understanding modules for the segmentation pipeline designed in PyMIC.
10 | Please follow `PyMIC_examples `_
11 | to quickly start with using PyMIC.
12 |
13 | .. note::
14 |
15 | This project is under active development. It will be updated later.
16 |
17 |
18 | .. toctree::
19 | :maxdepth: 2
20 | :caption: Getting Started
21 |
22 | installation
23 | usage
24 | api
25 |
26 | Citation
27 | --------
28 |
29 | If you use PyMIC for your research, please acknowledge it accordingly by citing our paper:
30 |
31 | `G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022).
32 | PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation.
33 | arXiv, 2208.09350. `_
34 |
35 |
36 | BibTeX entry:
37 |
38 | .. code-block:: none
39 |
40 | @article{Wang2022pymic,
41 | author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang},
42 | title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}},
43 | year = {2022},
44 | url = {http://arxiv.org/abs/2208.09350},
45 | journal = {arXiv},
46 | volume = {2208.09350},
47 | pages = {1-10},
48 | }
49 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | Install PyMIC using pip (e.g., within a `Python virtual environment `_):
5 |
6 | .. code-block:: bash
7 |
8 | pip install PYMIC
9 |
10 |
11 | Alternatively, you can download or clone the code from `GitHub `_ and install PyMIC by
12 |
13 | .. code-block:: bash
14 |
15 | git clone https://github.com/HiLab-git/PyMIC
16 | cd PyMIC
17 | python setup.py install
18 |
19 |
20 | PyMIC requires Python 3.6 (or higher) and depends on the following packages:
21 |
22 | - `pandas `_
23 | - `h5py `_
24 | - `NumPy `_
25 | - `scikit-image `_
26 | - `SciPy `_
27 | - `SimpleITK `_
28 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/modules.rst:
--------------------------------------------------------------------------------
1 | pymic
2 | =====
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | pymic
8 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.io.rst:
--------------------------------------------------------------------------------
1 | pymic.io package
2 | ================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.io.h5\_dataset module
8 | ---------------------------
9 |
10 | .. automodule:: pymic.io.h5_dataset
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.io.image\_read\_write module
16 | ----------------------------------
17 |
18 | .. automodule:: pymic.io.image_read_write
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.io.nifty\_dataset module
24 | ------------------------------
25 |
26 | .. automodule:: pymic.io.nifty_dataset
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | Module contents
32 | ---------------
33 |
34 | .. automodule:: pymic.io
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.layer.rst:
--------------------------------------------------------------------------------
1 | pymic.layer package
2 | ===================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.layer.activation module
8 | -----------------------------
9 |
10 | .. automodule:: pymic.layer.activation
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.layer.convolution module
16 | ------------------------------
17 |
18 | .. automodule:: pymic.layer.convolution
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.layer.deconvolution module
24 | --------------------------------
25 |
26 | .. automodule:: pymic.layer.deconvolution
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.layer.space2channel module
32 | --------------------------------
33 |
34 | .. automodule:: pymic.layer.space2channel
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | Module contents
40 | ---------------
41 |
42 | .. automodule:: pymic.layer
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.loss.cls.rst:
--------------------------------------------------------------------------------
1 | pymic.loss.cls package
2 | ======================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.loss.cls.basic module
8 | ------------------------
9 |
10 | .. automodule:: pymic.loss.cls.basic
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.loss.cls.util module
16 | --------------------------
17 |
18 | .. automodule:: pymic.loss.cls.util
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | Module contents
24 | ---------------
25 |
26 | .. automodule:: pymic.loss.cls
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.loss.rst:
--------------------------------------------------------------------------------
1 | pymic.loss package
2 | ==================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | pymic.loss.cls
11 | pymic.loss.seg
12 |
13 | Submodules
14 | ----------
15 |
16 | pymic.loss.loss\_dict\_cls module
17 | ---------------------------------
18 |
19 | .. automodule:: pymic.loss.loss_dict_cls
20 | :members:
21 | :undoc-members:
22 | :show-inheritance:
23 |
24 | pymic.loss.loss\_dict\_seg module
25 | ---------------------------------
26 |
27 | .. automodule:: pymic.loss.loss_dict_seg
28 | :members:
29 | :undoc-members:
30 | :show-inheritance:
31 |
32 | Module contents
33 | ---------------
34 |
35 | .. automodule:: pymic.loss
36 | :members:
37 | :undoc-members:
38 | :show-inheritance:
39 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.loss.seg.rst:
--------------------------------------------------------------------------------
1 | pymic.loss.seg package
2 | ======================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.loss.seg.abstract module
8 | ------------------------
9 |
10 | .. automodule:: pymic.loss.seg.abstract
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.loss.seg.ce module
16 | ------------------------
17 |
18 | .. automodule:: pymic.loss.seg.ce
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.loss.seg.combined module
24 | ------------------------------
25 |
26 | .. automodule:: pymic.loss.seg.combined
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.loss.seg.deep\_sup module
32 | -------------------------------
33 |
34 | .. automodule:: pymic.loss.seg.deep_sup
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | pymic.loss.seg.dice module
40 | --------------------------
41 |
42 | .. automodule:: pymic.loss.seg.dice
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | pymic.loss.seg.exp\_log module
48 | ------------------------------
49 |
50 | .. automodule:: pymic.loss.seg.exp_log
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | pymic.loss.seg.gatedcrf module
56 | ------------------------------
57 |
58 | .. automodule:: pymic.loss.seg.gatedcrf
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | pymic.loss.seg.mse module
64 | -------------------------
65 |
66 | .. automodule:: pymic.loss.seg.mse
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | pymic.loss.seg.mumford\_shah module
72 | -----------------------------------
73 |
74 | .. automodule:: pymic.loss.seg.mumford_shah
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
79 | pymic.loss.seg.slsr module
80 | --------------------------
81 |
82 | .. automodule:: pymic.loss.seg.slsr
83 | :members:
84 | :undoc-members:
85 | :show-inheritance:
86 |
87 | pymic.loss.seg.ssl module
88 | -------------------------
89 |
90 | .. automodule:: pymic.loss.seg.ssl
91 | :members:
92 | :undoc-members:
93 | :show-inheritance:
94 |
95 | pymic.loss.seg.util module
96 | --------------------------
97 |
98 | .. automodule:: pymic.loss.seg.util
99 | :members:
100 | :undoc-members:
101 | :show-inheritance:
102 |
103 | Module contents
104 | ---------------
105 |
106 | .. automodule:: pymic.loss.seg
107 | :members:
108 | :undoc-members:
109 | :show-inheritance:
110 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.net.cls.rst:
--------------------------------------------------------------------------------
1 | pymic.net.cls package
2 | =====================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.net.cls.torch\_pretrained\_net module
8 | -------------------------------------------
9 |
10 | .. automodule:: pymic.net.cls.torch_pretrained_net
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | Module contents
16 | ---------------
17 |
18 | .. automodule:: pymic.net.cls
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.net.net2d.rst:
--------------------------------------------------------------------------------
1 | pymic.net.net2d package
2 | =======================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.net.net2d.cople\_net module
8 | ---------------------------------
9 |
10 | .. automodule:: pymic.net.net2d.cople_net
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.net.net2d.scse2d module
16 | -----------------------------
17 |
18 | .. automodule:: pymic.net.net2d.scse2d
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.net.net2d.unet2d module
24 | -----------------------------
25 |
26 | .. automodule:: pymic.net.net2d.unet2d
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.net.net2d.unet2d\_attention module
32 | ----------------------------------------
33 |
34 | .. automodule:: pymic.net.net2d.unet2d_attention
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | pymic.net.net2d.unet2d\_cct module
40 | ----------------------------------
41 |
42 | .. automodule:: pymic.net.net2d.unet2d_cct
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | pymic.net.net2d.unet2d\_dual\_branch module
48 | -------------------------------------------
49 |
50 | .. automodule:: pymic.net.net2d.unet2d_dual_branch
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | pymic.net.net2d.unet2d\_nest module
56 | -----------------------------------
57 |
58 | .. automodule:: pymic.net.net2d.unet2d_nest
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | pymic.net.net2d.unet2d\_scse module
64 | -----------------------------------
65 |
66 | .. automodule:: pymic.net.net2d.unet2d_scse
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | pymic.net.net2d.unet2d\_urpc module
72 | -----------------------------------
73 |
74 | .. automodule:: pymic.net.net2d.unet2d_urpc
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
79 | Module contents
80 | ---------------
81 |
82 | .. automodule:: pymic.net.net2d
83 | :members:
84 | :undoc-members:
85 | :show-inheritance:
86 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.net.net3d.rst:
--------------------------------------------------------------------------------
1 | pymic.net.net3d package
2 | =======================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.net.net3d.scse3d module
8 | -----------------------------
9 |
10 | .. automodule:: pymic.net.net3d.scse3d
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.net.net3d.unet2d5 module
16 | ------------------------------
17 |
18 | .. automodule:: pymic.net.net3d.unet2d5
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.net.net3d.unet3d module
24 | -----------------------------
25 |
26 | .. automodule:: pymic.net.net3d.unet3d
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.net.net3d.unet3d\_scse module
32 | -----------------------------------
33 |
34 | .. automodule:: pymic.net.net3d.unet3d_scse
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | Module contents
40 | ---------------
41 |
42 | .. automodule:: pymic.net.net3d
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.net.rst:
--------------------------------------------------------------------------------
1 | pymic.net package
2 | =================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | pymic.net.cls
11 | pymic.net.net2d
12 | pymic.net.net3d
13 |
14 | Submodules
15 | ----------
16 |
17 | pymic.net.net\_dict\_cls module
18 | -------------------------------
19 |
20 | .. automodule:: pymic.net.net_dict_cls
21 | :members:
22 | :undoc-members:
23 | :show-inheritance:
24 |
25 | pymic.net.net\_dict\_seg module
26 | -------------------------------
27 |
28 | .. automodule:: pymic.net.net_dict_seg
29 | :members:
30 | :undoc-members:
31 | :show-inheritance:
32 |
33 | Module contents
34 | ---------------
35 |
36 | .. automodule:: pymic.net
37 | :members:
38 | :undoc-members:
39 | :show-inheritance:
40 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.net_run.rst:
--------------------------------------------------------------------------------
1 | pymic.net\_run package
2 | ======================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.net\_run.agent\_abstract module
8 | -------------------------------------
9 |
10 | .. automodule:: pymic.net_run.agent_abstract
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.net\_run.agent\_cls module
16 | --------------------------------
17 |
18 | .. automodule:: pymic.net_run.agent_cls
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.net\_run.agent\_seg module
24 | --------------------------------
25 |
26 | .. automodule:: pymic.net_run.agent_seg
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.net\_run.get\_optimizer module
32 | ------------------------------------
33 |
34 | .. automodule:: pymic.net_run.get_optimizer
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | pymic.net\_run.infer\_func module
40 | ---------------------------------
41 |
42 | .. automodule:: pymic.net_run.infer_func
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | pymic.net\_run.net\_run module
48 | ------------------------------
49 |
50 | .. automodule:: pymic.net_run.net_run
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | Module contents
56 | ---------------
57 |
58 | .. automodule:: pymic.net_run
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.net_run_nll.rst:
--------------------------------------------------------------------------------
1 | pymic.net\_run\_nll package
2 | ===========================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.net\_run\_nll.nll\_clslsr module
8 | --------------------------------------
9 |
10 | .. automodule:: pymic.net_run_nll.nll_clslsr
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.net\_run\_nll.nll\_co\_teaching module
16 | --------------------------------------------
17 |
18 | .. automodule:: pymic.net_run_nll.nll_co_teaching
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.net\_run\_nll.nll\_dast module
24 | ------------------------------------
25 |
26 | .. automodule:: pymic.net_run_nll.nll_dast
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.net\_run\_nll.nll\_main module
32 | ------------------------------------
33 |
34 | .. automodule:: pymic.net_run_nll.nll_main
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | pymic.net\_run\_nll.nll\_trinet module
40 | --------------------------------------
41 |
42 | .. automodule:: pymic.net_run_nll.nll_trinet
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | Module contents
48 | ---------------
49 |
50 | .. automodule:: pymic.net_run_nll
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.net_run_ssl.rst:
--------------------------------------------------------------------------------
1 | pymic.net\_run\_ssl package
2 | ===========================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.net\_run\_ssl.ssl\_abstract module
8 | ----------------------------------------
9 |
10 | .. automodule:: pymic.net_run_ssl.ssl_abstract
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.net\_run\_ssl.ssl\_cct module
16 | -----------------------------------
17 |
18 | .. automodule:: pymic.net_run_ssl.ssl_cct
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.net\_run\_ssl.ssl\_cps module
24 | -----------------------------------
25 |
26 | .. automodule:: pymic.net_run_ssl.ssl_cps
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.net\_run\_ssl.ssl\_em module
32 | ----------------------------------
33 |
34 | .. automodule:: pymic.net_run_ssl.ssl_em
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | pymic.net\_run\_ssl.ssl\_main module
40 | ------------------------------------
41 |
42 | .. automodule:: pymic.net_run_ssl.ssl_main
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | pymic.net\_run\_ssl.ssl\_mt module
48 | ----------------------------------
49 |
50 | .. automodule:: pymic.net_run_ssl.ssl_mt
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | pymic.net\_run\_ssl.ssl\_uamt module
56 | ------------------------------------
57 |
58 | .. automodule:: pymic.net_run_ssl.ssl_uamt
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | pymic.net\_run\_ssl.ssl\_urpc module
64 | ------------------------------------
65 |
66 | .. automodule:: pymic.net_run_ssl.ssl_urpc
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | Module contents
72 | ---------------
73 |
74 | .. automodule:: pymic.net_run_ssl
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.net_run_wsl.rst:
--------------------------------------------------------------------------------
1 | pymic.net\_run\_wsl package
2 | ===========================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.net\_run\_wsl.wsl\_abstract module
8 | ----------------------------------------
9 |
10 | .. automodule:: pymic.net_run_wsl.wsl_abstract
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.net\_run\_wsl.wsl\_dmpls module
16 | -------------------------------------
17 |
18 | .. automodule:: pymic.net_run_wsl.wsl_dmpls
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.net\_run\_wsl.wsl\_em module
24 | ----------------------------------
25 |
26 | .. automodule:: pymic.net_run_wsl.wsl_em
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.net\_run\_wsl.wsl\_gatedcrf module
32 | ----------------------------------------
33 |
34 | .. automodule:: pymic.net_run_wsl.wsl_gatedcrf
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | pymic.net\_run\_wsl.wsl\_main module
40 | ------------------------------------
41 |
42 | .. automodule:: pymic.net_run_wsl.wsl_main
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | pymic.net\_run\_wsl.wsl\_mumford\_shah module
48 | ---------------------------------------------
49 |
50 | .. automodule:: pymic.net_run_wsl.wsl_mumford_shah
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | pymic.net\_run\_wsl.wsl\_tv module
56 | ----------------------------------
57 |
58 | .. automodule:: pymic.net_run_wsl.wsl_tv
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | pymic.net\_run\_wsl.wsl\_ustm module
64 | ------------------------------------
65 |
66 | .. automodule:: pymic.net_run_wsl.wsl_ustm
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | Module contents
72 | ---------------
73 |
74 | .. automodule:: pymic.net_run_wsl
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.rst:
--------------------------------------------------------------------------------
1 | pymic package
2 | =============
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | pymic.io
11 | pymic.layer
12 | pymic.loss
13 | pymic.net
14 | pymic.net_run
15 | pymic.net_run_nll
16 | pymic.net_run_ssl
17 | pymic.net_run_wsl
18 | pymic.transform
19 | pymic.util
20 |
21 | Module contents
22 | ---------------
23 |
24 | .. automodule:: pymic
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.transform.rst:
--------------------------------------------------------------------------------
1 | pymic.transform package
2 | =======================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.transform.abstract\_transform module
8 | ------------------------------------------
9 |
10 | .. automodule:: pymic.transform.abstract_transform
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.transform.crop module
16 | ---------------------------
17 |
18 | .. automodule:: pymic.transform.crop
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.transform.flip module
24 | ---------------------------
25 |
26 | .. automodule:: pymic.transform.flip
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.transform.intensity module
32 | --------------------------------
33 |
34 | .. automodule:: pymic.transform.intensity
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | pymic.transform.label\_convert module
40 | -------------------------------------
41 |
42 | .. automodule:: pymic.transform.label_convert
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | pymic.transform.normalize module
48 | --------------------------------
49 |
50 | .. automodule:: pymic.transform.normalize
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | pymic.transform.pad module
56 | --------------------------
57 |
58 | .. automodule:: pymic.transform.pad
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | pymic.transform.rescale module
64 | ------------------------------
65 |
66 | .. automodule:: pymic.transform.rescale
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | pymic.transform.rotate module
72 | -----------------------------
73 |
74 | .. automodule:: pymic.transform.rotate
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
79 | pymic.transform.threshold module
80 | --------------------------------
81 |
82 | .. automodule:: pymic.transform.threshold
83 | :members:
84 | :undoc-members:
85 | :show-inheritance:
86 |
87 | pymic.transform.trans\_dict module
88 | ----------------------------------
89 |
90 | .. automodule:: pymic.transform.trans_dict
91 | :members:
92 | :undoc-members:
93 | :show-inheritance:
94 |
95 | Module contents
96 | ---------------
97 |
98 | .. automodule:: pymic.transform
99 | :members:
100 | :undoc-members:
101 | :show-inheritance:
102 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/pymic.util.rst:
--------------------------------------------------------------------------------
1 | pymic.util package
2 | ==================
3 |
4 | Submodules
5 | ----------
6 |
7 | pymic.util.evaluation\_cls module
8 | ---------------------------------
9 |
10 | .. automodule:: pymic.util.evaluation_cls
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | pymic.util.evaluation\_seg module
16 | ---------------------------------
17 |
18 | .. automodule:: pymic.util.evaluation_seg
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | pymic.util.general module
24 | -------------------------
25 |
26 | .. automodule:: pymic.util.general
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | pymic.util.image\_process module
32 | --------------------------------
33 |
34 | .. automodule:: pymic.util.image_process
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | pymic.util.parse\_config module
40 | -------------------------------
41 |
42 | .. automodule:: pymic.util.parse_config
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | pymic.util.post\_process module
48 | -------------------------------
49 |
50 | .. automodule:: pymic.util.post_process
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | pymic.util.preprocess module
56 | ----------------------------
57 |
58 | .. automodule:: pymic.util.preprocess
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | pymic.util.ramps module
64 | -----------------------
65 |
66 | .. automodule:: pymic.util.ramps
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | Module contents
72 | ---------------
73 |
74 | .. automodule:: pymic.util
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/usage.nll.rst:
--------------------------------------------------------------------------------
1 | .. _noisy_label_learning:
2 |
3 | Noisy Label Learning
4 | ====================
5 |
6 | pymic_nll
7 | ---------
8 |
9 | :mod:`pymic_nll` is the command for using built-in NLL methods for training.
10 | Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the
11 | stage and configuration file, respectively. The training and testing commands are:
12 |
13 | .. code-block:: bash
14 |
15 | pymic_nll train myconfig_nll.cfg
16 | pymic_nll test myconfig_nll.cfg
17 |
18 | .. tip::
19 |
20 | If the NLL method only involves one network, either ``pymic_nll`` or ``pymic_run``
21 | can be used for inference. Their difference only exists in the training stage.
22 |
23 | .. note::
24 |
25 | Some NLL methods only use noise-robust loss functions without complex
26 | training process, and just combining the standard :mod:`SegmentationAgent` with such
27 | loss function works for training. ``pymic_run`` instead of ``pymic_nll`` should
28 | be used for these methods.
29 |
30 |
31 | NLL Configurations
32 | ------------------
33 |
34 | In the configuration file for ``pymic_nll``, in addition to those used in standard fully
35 | supervised learning, there is a ``noisy_label_learning`` section that is specifically designed
36 | for NLL methods. In that section, users need to specify the ``nll_method`` and configurations
37 | related to the NLL method. For example, the correspoinding configuration for CoTeaching is:
38 |
39 | .. code-block:: none
40 |
41 | [dataset]
42 | ...
43 |
44 | [network]
45 | ...
46 |
47 | [training]
48 | ...
49 |
50 | [noisy_label_learning]
51 | nll_method = CoTeaching
52 | co_teaching_select_ratio = 0.8
53 | rampup_start = 1000
54 | rampup_end = 8000
55 |
56 | [testing]
57 | ...
58 |
59 | .. note::
60 |
61 | The configuration items vary with different NLL methods. Please refer to the API
62 | of each built-in NLL method for details of the correspoinding configuration.
63 |
64 | Built-in NLL Methods
65 | --------------------
66 |
67 | Some NLL methods only use noise-robust loss functions. They are used with ``pymic_run``
68 | for training. Just set ``loss_type`` to one of them in the configuration file, similarly
69 | to the fully supervised learning.
70 |
71 | * ``GCELoss``: (`NeurIPS 2018 `_)
72 | Generalized cross entropy loss.
73 |
74 | * ``MAELoss``: (`AAAI 2017 `_)
75 | Mean Absolute Error loss.
76 |
77 | * ``NRDiceLoss``: (`TMI 2020 `_)
78 | Noise-robust Dice loss.
79 |
80 | The other NLL methods are implemented in child classes of
81 | :mod:`pymic.net_run_nll.nll_abstract.NLLSegAgent`, and they are:
82 |
83 | * ``CLSLSR``: (`MICCAI 2020 `_)
84 | Confident learning with spatial label smoothing regularization.
85 |
86 | * ``CoTeaching``: (`NeurIPS 2018 `_)
87 | Co-teaching between two networks for learning from noisy labels.
88 |
89 | * ``TriNet``: (`MICCAI 2020 `_)
90 | Tri-network combined with sample selection.
91 |
92 | * ``DAST``: (`JBHI 2022 `_)
93 | Divergence-aware selective training.
94 |
95 | Customized NLL Methods
96 | ----------------------
97 |
98 | PyMIC alo supports customizing NLL methods by inheriting the :mod:`NLLSegAgent` class.
99 | You may only need to rewrite the :mod:`training()` method and reuse most part of the
100 | existing pipeline, such as data loading, validation and inference methods. For example:
101 |
102 | .. code-block:: none
103 |
104 | from pymic.net_run_nll.nll_abstract import NLLSegAgent
105 |
106 | class MyNLLMethod(NLLSegAgent):
107 | def __init__(self, config, stage = 'train'):
108 | super(MyNLLMethod, self).__init__(config, stage)
109 | ...
110 |
111 | def training(self):
112 | ...
113 |
114 | agent = MyNLLMethod(config, stage)
115 | agent.run()
116 |
117 | You may need to check the source code of built-in NLL methods to be more familar with
118 | how to implement your own NLL method.
119 |
120 | In addition, if you want to design a new noise-robust loss fucntion,
121 | just follow :doc:`usage.fsl` to impelement and use the customized loss.
--------------------------------------------------------------------------------
/PyMIC/docs/source/usage.quickstart.rst:
--------------------------------------------------------------------------------
1 | .. _quickstart:
2 |
3 | Quick Start
4 | ===========
5 |
6 |
7 | Train and Test
8 | --------------
9 |
10 | PyMIC accepts a configuration file for runing. For example, to train a network
11 | for segmentation with full supervision, run the fullowing command:
12 |
13 | .. code-block:: bash
14 |
15 | pymic_run train myconfig.cfg
16 |
17 | After training, run the following command for testing:
18 |
19 | .. code-block:: bash
20 |
21 | pymic_run test myconfig.cfg
22 |
23 | .. tip::
24 |
25 | We provide several examples in `PyMIC_examples. `_
26 | Please run these examples to quickly start with using PyMIC.
27 |
28 |
29 | .. _configuration:
30 |
31 | Configuration File
32 | ------------------
33 |
34 | PyMIC uses configuration files to specify the setting and parameters of a deep
35 | learning pipeline, so that users can reuse the code and minimize their workload.
36 | Users can use configuration files to config almost all the componets involved,
37 | such as dataset, network structure, loss function, optimizer, learning rate
38 | scheduler and post processing methods, etc.
39 |
40 | .. note::
41 |
42 | Genreally, the configuration file have four sections: ``dataset``, ``network``,
43 | ``training`` and ``testing``.
44 |
45 | The following is an example configuration
46 | file used for segmentation of lung from radiograph, which can be find in
47 | `PyMIC_examples/segmentation/JSRT. `_
48 |
49 | .. code-block:: none
50 |
51 | [dataset]
52 | # tensor type (float or double)
53 | tensor_type = float
54 | task_type = seg
55 | root_dir = ../../PyMIC_data/JSRT
56 | train_csv = config/jsrt_train.csv
57 | valid_csv = config/jsrt_valid.csv
58 | test_csv = config/jsrt_test.csv
59 | train_batch_size = 4
60 |
61 | # data transforms
62 | train_transform = [NormalizeWithMeanStd, RandomCrop, LabelConvert, LabelToProbability]
63 | valid_transform = [NormalizeWithMeanStd, LabelConvert, LabelToProbability]
64 | test_transform = [NormalizeWithMeanStd]
65 |
66 | NormalizeWithMeanStd_channels = [0]
67 | RandomCrop_output_size = [240, 240]
68 |
69 | LabelConvert_source_list = [0, 255]
70 | LabelConvert_target_list = [0, 1]
71 |
72 | [network]
73 | net_type = UNet2D
74 | # Parameters for UNet2D
75 | class_num = 2
76 | in_chns = 1
77 | feature_chns = [16, 32, 64, 128, 256]
78 | dropout = [0, 0, 0.3, 0.4, 0.5]
79 | bilinear = False
80 | deep_supervise= False
81 |
82 | [training]
83 | # list of gpus
84 | gpus = [0]
85 | loss_type = DiceLoss
86 |
87 | # for optimizers
88 | optimizer = Adam
89 | learning_rate = 1e-3
90 | momentum = 0.9
91 | weight_decay = 1e-5
92 |
93 | # for lr scheduler (MultiStepLR)
94 | lr_scheduler = MultiStepLR
95 | lr_gamma = 0.5
96 | lr_milestones = [2000, 4000, 6000]
97 |
98 | ckpt_save_dir = model/unet_dice_loss
99 | ckpt_prefix = unet
100 |
101 | # start iter
102 | iter_start = 0
103 | iter_max = 8000
104 | iter_valid = 200
105 | iter_save = 8000
106 |
107 | [testing]
108 | # list of gpus
109 | gpus = [0]
110 | # checkpoint mode can be [0-latest, 1-best, 2-specified]
111 | ckpt_mode = 0
112 | output_dir = result
113 |
114 | # convert the label of prediction output
115 | label_source = [0, 1]
116 | label_target = [0, 255]
117 |
118 |
119 | Evaluation
120 | ----------
121 |
122 | To evaluate a model's prediction results compared with the ground truth,
123 | use the ``pymic_eval_seg`` and ``pymic_eval_cls`` commands for segmentation
124 | and classfication tasks, respectively. Both of them accept a configuration
125 | file to specify the evaluation metrics, predicted results, ground truth and
126 | other information.
127 |
128 | For example, for segmentation tasks, run:
129 |
130 | .. code-block:: none
131 |
132 | pymic_eval_seg evaluation.cfg
133 |
134 | The configuration file is like (an example from ``PYMIC_examples/seg_ssl/ACDC``):
135 |
136 | .. code-block:: none
137 |
138 | [evaluation]
139 | metric = dice
140 | label_list = [1,2,3]
141 | organ_name = heart
142 |
143 | ground_truth_folder_root = ../../PyMIC_data/ACDC/preprocess
144 | segmentation_folder_root = result/unet2d_em
145 | evaluation_image_pair = config/data/image_test_gt_seg.csv
146 |
147 | See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required.
148 |
149 | For classification tasks, run:
150 |
151 | .. code-block:: none
152 |
153 | pymic_eval_cls evaluation.cfg
154 |
155 | The configuration file is like (an example from ``PYMIC_examples/classification/CHNCXR``):
156 |
157 | .. code-block:: none
158 |
159 | [evaluation]
160 | metric_list = [accuracy, auc]
161 | ground_truth_csv = config/cxr_test.csv
162 | predict_csv = result/resnet18.csv
163 | predict_prob_csv = result/resnet18_prob.csv
164 |
165 | See :mod:`pymic.util.evaluation_cls.main` for details of the configuration required.
166 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/usage.rst:
--------------------------------------------------------------------------------
1 | Usage
2 | =====
3 |
4 | This usage gives details of how to use PyMIC.
5 | Beginners can easily start with training a deep learning
6 | model with configure files. When you are more familar with
7 | the PyMIC pipeline, you can define your customized modules
8 | and reuse the remaining parts of the pipeline, with minimal
9 | workload.
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 |
14 | usage.quickstart
15 | usage.fsl
16 | usage.ssl
17 | usage.wsl
18 | usage.nll
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/PyMIC/docs/source/usage.ssl.rst:
--------------------------------------------------------------------------------
1 | .. _semi_supervised_learning:
2 |
3 | Semi-Supervised Learning
4 | =========================
5 |
6 | pymic_ssl
7 | ---------
8 |
9 | :mod:`pymic_ssl` is the command for using built-in semi-supervised methods for training.
10 | Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the
11 | stage and configuration file, respectively. The training and testing commands are:
12 |
13 | .. code-block:: bash
14 |
15 | pymic_ssl train myconfig_ssl.cfg
16 | pymic_ssl test myconfig_ssl.cfg
17 |
18 | .. tip::
19 |
20 | If the SSL method only involves one network, either ``pymic_ssl`` or ``pymic_run``
21 | can be used for inference. Their difference only exists in the training stage.
22 |
23 | SSL Configurations
24 | ------------------
25 |
26 | In the configuration file for ``pymic_ssl``, in addition to those used in fully
27 | supervised learning, there are some items specificalized for semi-supervised learning.
28 |
29 | Users should provide values for the following items in ``dataset`` section of
30 | the configuration file:
31 |
32 | * ``train_csv_unlab`` (string): the csv file for unlabeled dataset.
33 | Note that ``train_csv`` is only used for labeled dataset.
34 |
35 | * ``train_batch_size_unlab`` (int): the batch size for unlabeled dataset.
36 | Note that ``train_batch_size`` means the batch size for the labeled dataset.
37 |
38 | * ``train_transform_unlab`` (list): a list of transforms used for unlabeled data.
39 |
40 |
41 | The following is an example of the ``dataset`` section for semi-supervised learning:
42 |
43 | .. code-block:: none
44 |
45 | ...
46 | root_dir =../../PyMIC_data/ACDC/preprocess/
47 | train_csv = config/data/image_train_r10_lab.csv
48 | train_csv_unlab = config/data/image_train_r10_unlab.csv
49 | valid_csv = config/data/image_valid.csv
50 | test_csv = config/data/image_test.csv
51 |
52 | train_batch_size = 4
53 | train_batch_size_unlab = 4
54 |
55 | # data transforms
56 | train_transform = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise, LabelToProbability]
57 | train_transform_unlab = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise]
58 | valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability]
59 | test_transform = [NormalizeWithMeanStd, Pad]
60 | ...
61 |
62 | In addition, there is a ``semi_supervised_learning`` section that is specifically designed
63 | for SSL methods. In that section, users need to specify the ``ssl_method`` and configurations
64 | related to the SSL method. For example, the correspoinding configuration for CPS is:
65 |
66 | .. code-block:: none
67 |
68 | ...
69 | [semi_supervised_learning]
70 | ssl_method = CPS
71 | regularize_w = 0.1
72 | rampup_start = 1000
73 | rampup_end = 20000
74 | ...
75 |
76 | .. note::
77 |
78 | The configuration items vary with different SSL methods. Please refer to the API
79 | of each built-in SSL method for details of the correspoinding configuration.
80 |
81 | Built-in SSL Methods
82 | --------------------
83 |
84 | :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for
85 | semi-supervised learning. The built-in SSL methods are child classes of :mod:`SSLSegAgent`.
86 | The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`,
87 | and they are:
88 |
89 | * ``EntropyMinimization``: (`NeurIPS 2005 `_)
90 | Using entorpy minimization to regularize unannotated samples.
91 |
92 | * ``MeanTeacher``: (`NeurIPS 2017 `_) Use self-ensembling mean teacher to supervise the student model on
93 | unannotated samples.
94 |
95 | * ``UAMT``: (`MICCAI 2019 `_) Uncertainty aware mean teacher.
96 |
97 | * ``CCT``: (`CVPR 2020 `_) Cross-consistency training.
98 |
99 | * ``CPS``: (`CVPR 2021 `_) Cross-pseudo supervision.
100 |
101 | * ``URPC``: (`MIA 2022 `_) Uncertainty rectified pyramid consistency.
102 |
103 | Customized SSL Methods
104 | ----------------------
105 |
106 | PyMIC alo supports customizing SSL methods by inheriting the :mod:`SSLSegAgent` class.
107 | You may only need to rewrite the :mod:`training()` method and reuse most part of the
108 | existing pipeline, such as data loading, validation and inference methods. For example:
109 |
110 | .. code-block:: none
111 |
112 | from pymic.net_run_ssl.ssl_abstract import SSLSegAgent
113 |
114 | class MySSLMethod(SSLSegAgent):
115 | def __init__(self, config, stage = 'train'):
116 | super(MySSLMethod, self).__init__(config, stage)
117 | ...
118 |
119 | def training(self):
120 | ...
121 |
122 | agent = MySSLMethod(config, stage)
123 | agent.run()
124 |
125 | You may need to check the source code of built-in SSL methods to be more familar with
126 | how to implement your own SSL method.
--------------------------------------------------------------------------------
/PyMIC/pymic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/io/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/io/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/io/h5_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from logging import root
4 | import os
5 | import torch
6 | import random
7 | import h5py
8 | import pandas as pd
9 | from torch.utils.data import Dataset
10 | from torch.utils.data.sampler import Sampler
11 |
12 | class H5DataSet(Dataset):
13 | """
14 | Dataset for loading images stored in h5 format. It generates
15 | 4D tensors with dimention order [C, D, H, W] for 3D images, and
16 | 3D tensors with dimention order [C, H, W] for 2D images
17 |
18 | Args:
19 | root_dir (str): thr root dir of images. \n
20 | sample_list_name (str): a file name for sample list. \n
21 | tranform (list): A list of transform objects applied on a sample.
22 | """
23 | def __init__(self, root_dir, sample_list_name, transform = None):
24 | self.root_dir = root_dir
25 | self.transform = transform
26 | with open(sample_list_name, 'r') as f:
27 | lines = f.readlines()
28 | self.sample_list = [item.replace('\n', '') for item in lines]
29 |
30 | def __len__(self):
31 | return len(self.sample_list)
32 |
33 | def __getitem__(self, idx):
34 | sample_name = self.sample_list[idx]
35 | h5f = h5py.File(self.root_dir + '/' + sample_name, 'r')
36 | image = h5f['image'][:]
37 | label = h5f['label'][:]
38 | sample = {'image': image, 'label': label}
39 | if self.transform:
40 | sample = self.transform(sample)
41 | return sample
42 |
43 | class TwoStreamBatchSampler(Sampler):
44 | """Iterate two sets of indices
45 |
46 | An 'epoch' is one iteration through the primary indices.
47 | During the epoch, the secondary indices are iterated through
48 | as many times as needed.
49 | """
50 |
51 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
52 | self.primary_indices = primary_indices
53 | self.secondary_indices = secondary_indices
54 | self.secondary_batch_size = secondary_batch_size
55 | self.primary_batch_size = batch_size - secondary_batch_size
56 |
57 | assert len(self.primary_indices) >= self.primary_batch_size > 0
58 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0
59 |
60 | def __iter__(self):
61 | primary_iter = iterate_once(self.primary_indices)
62 | secondary_iter = iterate_eternally(self.secondary_indices)
63 | return (
64 | primary_batch + secondary_batch
65 | for (primary_batch, secondary_batch)
66 | in zip(grouper(primary_iter, self.primary_batch_size),
67 | grouper(secondary_iter, self.secondary_batch_size))
68 | )
69 |
70 | def __len__(self):
71 | return len(self.primary_indices) // self.primary_batch_size
72 |
73 |
74 | def iterate_once(iterable):
75 | return np.random.permutation(iterable)
76 |
77 |
78 | def iterate_eternally(indices):
79 | def infinite_shuffles():
80 | while True:
81 | yield np.random.permutation(indices)
82 | return itertools.chain.from_iterable(infinite_shuffles())
83 |
84 |
85 | def grouper(iterable, n):
86 | "Collect data into fixed-length chunks or blocks"
87 | # grouper('ABCDEFG', 3) --> ABC DEF"
88 | args = [iter(iterable)] * n
89 | return zip(*args)
90 |
91 |
92 | if __name__ == "__main__":
93 | root_dir = "/home/guotai/disk2t/projects/semi_supervise/SSL4MIS/data/ACDC/data/slices"
94 | file_name = "/home/guotai/disk2t/projects/semi_supervise/slices.txt"
95 | dataset = H5DataSet(root_dir, file_name)
96 | train_loader = torch.utils.data.DataLoader(dataset,
97 | batch_size = 4, shuffle=True, num_workers= 1)
98 | for sample in train_loader:
99 | image = sample['image']
100 | label = sample['label']
101 | print(image.shape, label.shape)
102 | print(image.min(), image.max(), label.max())
--------------------------------------------------------------------------------
/PyMIC/pymic/layer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/layer/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/layer/activation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | def get_acti_func(acti_func, params):
9 | acti_func = acti_func.lower()
10 | if(acti_func == 'relu'):
11 | inplace = params.get('relu_inplace', False)
12 | return nn.ReLU(inplace)
13 |
14 | elif(acti_func == 'leakyrelu'):
15 | slope = params.get('leakyrelu_negative_slope', 1e-2)
16 | inplace = params.get('leakyrelu_inplace', False)
17 | return nn.LeakyReLU(slope, inplace)
18 |
19 | elif(acti_func == 'prelu'):
20 | num_params = params.get('prelu_num_parameters', 1)
21 | init_value = params.get('prelu_init', 0.25)
22 | return nn.PReLU(num_params, init_value)
23 |
24 | elif(acti_func == 'rrelu'):
25 | lower = params.get('rrelu_lower', 1.0 /8)
26 | upper = params.get('rrelu_upper', 1.0 /3)
27 | inplace = params.get('rrelu_inplace', False)
28 | return nn.RReLU(lower, upper, inplace)
29 |
30 | elif(acti_func == 'elu'):
31 | alpha = params.get('elu_alpha', 1.0)
32 | inplace = params.get('elu_inplace', False)
33 | return nn.ELU(alpha, inplace)
34 |
35 | elif(acti_func == 'celu'):
36 | alpha = params.get('celu_alpha', 1.0)
37 | inplace = params.get('celu_inplace', False)
38 | return nn.CELU(alpha, inplace)
39 |
40 | elif(acti_func == 'selu'):
41 | inplace = params.get('selu_inplace', False)
42 | return nn.SELU(inplace)
43 |
44 | elif(acti_func == 'glu'):
45 | dim = params.get('glu_dim', -1)
46 | return nn.GLU(dim)
47 |
48 | elif(acti_func == 'sigmoid'):
49 | return nn.Sigmoid()
50 |
51 | elif(acti_func == 'logsigmoid'):
52 | return nn.LogSigmoid()
53 |
54 | elif(acti_func == 'tanh'):
55 | return nn.Tanh()
56 |
57 | elif(acti_func == 'hardtanh'):
58 | min_val = params.get('hardtanh_min_val', -1.0)
59 | max_val = params.get('hardtanh_max_val', 1.0)
60 | inplace = params.get('hardtanh_inplace', False)
61 | return nn.Hardtanh(min_val, max_val, inplace)
62 |
63 | elif(acti_func == 'softplus'):
64 | beta = params.get('softplus_beta', 1.0)
65 | threshold = params.get('softplus_threshold', 20)
66 | return nn.Softplus(beta, threshold)
67 |
68 | elif(acti_func == 'softshrink'):
69 | lambd = params.get('softshrink_lambda', 0.5)
70 | return nn.Softshrink(lambd)
71 |
72 | elif(acti_func == 'softsign'):
73 | return nn.Softsign()
74 |
75 | else:
76 | raise ValueError("Not implemented: {0:}".format(acti_func))
--------------------------------------------------------------------------------
/PyMIC/pymic/layer/space2channel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import SimpleITK as sitk
8 |
9 | class SpaceToChannel3D(nn.Module):
10 | """
11 | Space to channel transform for 3D input."""
12 | def __init__(self):
13 | super(SpaceToChannel3D, self).__init__()
14 |
15 | def forward(self, x):
16 | # only 3D images (5D tensor is support)
17 | input_shape = list(x.shape)
18 | assert(len(input_shape) == 5)
19 | [B,C, D, H, W] = input_shape
20 | assert((D % 2 == 0) and (H % 2 == 0) and (W % 2 == 0))
21 | halfD = int(D/2)
22 | halfH = int(H/2)
23 | halfW = int(W/2)
24 | # split along D axis
25 | x1 = x.contiguous().view([B, C, halfD, 2, H, W])
26 | # permute to [B, C, 2, halfD, H, W]
27 | x2 = x1.permute(0, 1, 3, 2, 4, 5)
28 | # view as [B, 2*C, halfD, H, W] and [B, C*2, halfD, halfH, 2, W]
29 | x3 = x2.contiguous().view([B, C*2, halfD, halfH, 2, W])
30 | # permute to [B, C*2, 2, halfD, halfH, W]
31 | x4 = x3.permute(0, 1, 4, 2, 3, 5)
32 | # view as [B, C*4, halfD, halfH, W] and [B, C*4, halfD, halfH, halfW, 2]
33 | x5 = x4.contiguous().view([B, C*4, halfD, halfH, halfW, 2])
34 | # permute to [B, C*4, 2, halfD, halfH, halfW]
35 | x6 = x5.permute(0, 1, 5, 2, 3, 4)
36 | x7 = x6.contiguous().view([B, C*8, halfD, halfH, halfW])
37 | return x7
38 |
39 | class ChannelToSpace3D(nn.Module):
40 | """
41 | Channel to space transform for 3D input."""
42 | def __init__(self):
43 | super(ChannelToSpace3D, self).__init__()
44 |
45 | def forward(self, x):
46 | # only 3D images (5D tensor is support)
47 | input_shape = list(x.shape)
48 | assert(len(input_shape) == 5)
49 | [B,C, D, H, W] = input_shape
50 | assert(C % 8 == 0)
51 | Cd8 = int(C/8)
52 | Cd4 = 2 * Cd8
53 | Cd2 = 2 * Cd4
54 | x6 = x.contiguous().view([B, Cd2, 2, D, H, W])
55 | # permute to [B, Cd4, D, H, W, 2]
56 | x5 = x6.permute(0, 1, 3, 4, 5, 2)
57 | x4 = x5.contiguous().view([B, Cd4, 2, D, H, 2*W])
58 | # permute to [B, Cd2, D, H, 2, 2*W]
59 | x3 = x4.permute(0, 1, 3, 4, 2, 5)
60 | x2 = x3.contiguous().view([B, Cd8, 2, D, 2* H, 2*W])
61 | x1 = x2.permute(0, 1, 3, 2, 4, 5)
62 | x0 = x1.contiguous().view([B, Cd8, 2*D, 2* H, 2*W])
63 | return x0
64 |
65 |
66 | if __name__ == "__main__":
67 | s2c = SpaceToChannel3D()
68 | s2c = s2c.double()
69 |
70 | c2s = ChannelToSpace3D()
71 | c2s = c2s.double()
72 |
73 | img_name = "/home/disk2t/data/brats/BraTS2018_Training/HGG/Brats18_2013_2_1/Brats18_2013_2_1_flair.nii.gz"
74 | img_obj = sitk.ReadImage(img_name)
75 | img_data = sitk.GetArrayFromImage(img_obj)
76 | img_data = img_data[:-1]
77 | print(img_data.shape)
78 | x = img_data.reshape([1, 1] + list(img_data.shape))
79 | # x = np.random.rand(4, 4, 96, 96, 96)
80 | xt = torch.from_numpy(x)
81 | xt = torch.tensor(xt)
82 |
83 | y = s2c(xt)
84 | z = c2s(y)
85 | y = y.detach().numpy()[0]
86 | print(y.shape)
87 | for i in range(8):
88 | sub_img = sitk.GetImageFromArray(y[i])
89 | # sub_img.CopyInformation(img_obj)
90 | save_name = "../../temp/{0:}.nii.gz".format(i)
91 | sitk.WriteImage(sub_img, save_name)
92 | z = z.detach().numpy()[0]
93 | print(z.shape)
94 | rec_img = sitk.GetImageFromArray(z[0])
95 | save_name = "../../temp/rec.nii.gz"
96 | sitk.WriteImage(rec_img, save_name)
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/loss/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/cls/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/loss/cls/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/cls/basic.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | class AbstractClassificationLoss(nn.Module):
8 | """
9 | Abstract Classification Loss.
10 | """
11 | def __init__(self, params = None):
12 | super(AbstractClassificationLoss, self).__init__()
13 |
14 | def forward(self, loss_input_dict):
15 | """
16 | The arguments should be written in the `loss_input_dict` dictionary, and it has the
17 | following fields.
18 |
19 | :param prediction: A prediction with shape of [N, C] where C is the class number.
20 | :param ground_truth: The corresponding ground truth, with shape of [N, 1].
21 |
22 | Note that `prediction` is the digit output of a network, before using softmax.
23 | """
24 | pass
25 |
26 | class CrossEntropyLoss(AbstractClassificationLoss):
27 | """
28 | Standard Softmax-based CE loss.
29 | """
30 | def __init__(self, params = None):
31 | super(CrossEntropyLoss, self).__init__(params)
32 | self.ce_loss = nn.CrossEntropyLoss()
33 |
34 | def forward(self, loss_input_dict):
35 | predict = loss_input_dict['prediction']
36 | labels = loss_input_dict['ground_truth']
37 | loss = self.ce_loss(predict, labels)
38 | return loss
39 |
40 | class SigmoidCELoss(AbstractClassificationLoss):
41 | """
42 | Sigmoid-based CE loss.
43 | """
44 | def __init__(self, params = None):
45 | super(SigmoidCELoss, self).__init__(params)
46 |
47 | def forward(self, loss_input_dict):
48 | predict = loss_input_dict['prediction']
49 | labels = loss_input_dict['ground_truth']
50 | predict = nn.Sigmoid()(predict) * 0.999 + 5e-4
51 | loss = - labels * torch.log(predict) - (1 - labels) * torch.log( 1 - predict)
52 | loss = loss.mean()
53 | return loss
54 |
55 | class L1Loss(AbstractClassificationLoss):
56 | """
57 | L1 (MAE) loss for classification
58 | """
59 | def __init__(self, params = None):
60 | super(L1Loss, self).__init__(params)
61 | self.l1_loss = nn.L1Loss()
62 |
63 | def forward(self, loss_input_dict):
64 | predict = loss_input_dict['prediction']
65 | labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
66 | softmax = nn.Softmax(dim = 1)
67 | predict = softmax(predict)
68 | num_class = list(predict.size())[1]
69 | data_type = 'float' if(predict.dtype is torch.float32) else 'double'
70 | soft_y = get_soft_label(labels, num_class, data_type)
71 | loss = self.l1_loss(predict, soft_y)
72 | return loss
73 |
74 | class MSELoss(AbstractClassificationLoss):
75 | """
76 | Mean Square Error loss for classification.
77 | """
78 | def __init__(self, params = None):
79 | super(MSELoss, self).__init__(params)
80 | self.mse_loss = nn.MSELoss()
81 |
82 | def forward(self, loss_input_dict):
83 | predict = loss_input_dict['prediction']
84 | labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
85 | softmax = nn.Softmax(dim = 1)
86 | predict = softmax(predict)
87 | num_class = list(predict.size())[1]
88 | data_type = 'float' if(predict.dtype is torch.float32) else 'double'
89 | soft_y = get_soft_label(labels, num_class, data_type)
90 | loss = self.mse_loss(predict, soft_y)
91 | return loss
92 |
93 | class NLLLoss(AbstractClassificationLoss):
94 | """
95 | The negative log likelihood loss for classification.
96 | """
97 | def __init__(self, params = None):
98 | super(NLLLoss, self).__init__(params)
99 | self.nll_loss = nn.NLLLoss()
100 |
101 | def forward(self, loss_input_dict):
102 | predict = loss_input_dict['prediction']
103 | labels = loss_input_dict['ground_truth']
104 | logsoft = nn.LogSoftmax(dim = 1)
105 | predict = logsoft(predict)
106 | loss = self.nll_loss(predict, labels)
107 | return loss
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/cls/util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 | def get_soft_label(input_tensor, num_class, data_type = 'float'):
9 | """
10 | Convert a label tensor to one-hot soft label.
11 |
12 | :param input_tensor: Tensor with shape of [B, 1].
13 | :param output_tensor: Tensor with shape of [B, num_class].
14 | :param num_class: (int) Class number.
15 | :param data_type: (str) `float` or `double`.
16 | """
17 | tensor_list = []
18 | for i in range(num_class):
19 | temp_prob = input_tensor == i*torch.ones_like(input_tensor)
20 | tensor_list.append(temp_prob)
21 | output_tensor = torch.cat(tensor_list, dim = 1)
22 | if(data_type == 'float'):
23 | output_tensor = output_tensor.float()
24 | elif(data_type == 'double'):
25 | output_tensor = output_tensor.double()
26 | else:
27 | raise ValueError("data type can only be float and double: {0:}".format(data_type))
28 |
29 | return output_tensor
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/loss_dict_cls.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Built-in loss functions for classification.
4 |
5 | * CrossEntropyLoss :mod:`pymic.loss.cls.basic.CrossEntropyLoss`
6 | * SigmoidCELoss :mod:`pymic.loss.cls.basic.SigmoidCELoss`
7 | * L1Loss :mod:`pymic.loss.cls.basic.L1Loss`
8 | * MSELoss :mod:`pymic.loss.cls.basic.MSELoss`
9 | * NLLLoss :mod:`pymic.loss.cls.basic.NLLLoss`
10 |
11 | """
12 | from __future__ import print_function, division
13 | from pymic.loss.cls.basic import *
14 |
15 | PyMICClsLossDict = {"CrossEntropyLoss": CrossEntropyLoss,
16 | "SigmoidCELoss": SigmoidCELoss,
17 | "L1Loss": L1Loss,
18 | "MSELoss": MSELoss,
19 | "NLLLoss": NLLLoss}
20 |
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/loss_dict_seg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Built-in loss functions for segmentation.
4 | The following are for fully supervised learning, or learnig from noisy labels:
5 |
6 | * CrossEntropyLoss :mod:`pymic.loss.seg.ce.CrossEntropyLoss`
7 | * GeneralizedCELoss :mod:`pymic.loss.seg.ce.GeneralizedCELoss`
8 | * DiceLoss :mod:`pymic.loss.seg.dice.DiceLoss`
9 | * FocalDiceLoss :mod:`pymic.loss.seg.dice.FocalDiceLoss`
10 | * NoiseRobustDiceLoss :mod:`pymic.loss.seg.dice.NoiseRobustDiceLoss`
11 | * ExpLogLoss :mod:`pymic.loss.seg.exp_log.ExpLogLoss`
12 | * MAELoss :mod:`pymic.loss.seg.mse.MAELoss`
13 | * MSELoss :mod:`pymic.loss.seg.mse.MSELoss`
14 | * SLSRLoss :mod:`pymic.loss.seg.slsr.SLSRLoss`
15 |
16 | The following are for semi-supervised or weakly supervised learning:
17 |
18 | * EntropyLoss :mod:`pymic.loss.seg.ssl.EntropyLoss`
19 | * GatedCRFLoss: :mod:`pymic.loss.seg.gatedcrf.GatedCRFLoss`
20 | * MumfordShahLoss :mod:`pymic.loss.seg.mumford_shah.MumfordShahLoss`
21 | * TotalVariationLoss :mod:`pymic.loss.seg.ssl.TotalVariationLoss`
22 | """
23 | from __future__ import print_function, division
24 | import torch.nn as nn
25 | from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss
26 | from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss
27 | from pymic.loss.seg.exp_log import ExpLogLoss
28 | from pymic.loss.seg.mse import MSELoss, MAELoss
29 | from pymic.loss.seg.slsr import SLSRLoss
30 |
31 | SegLossDict = {
32 | 'CrossEntropyLoss': CrossEntropyLoss,
33 | 'GeneralizedCELoss': GeneralizedCELoss,
34 | 'DiceLoss': DiceLoss,
35 | 'FocalDiceLoss': FocalDiceLoss,
36 | 'NoiseRobustDiceLoss': NoiseRobustDiceLoss,
37 | 'ExpLogLoss': ExpLogLoss,
38 | 'MAELoss': MAELoss,
39 | 'MSELoss': MSELoss,
40 | 'SLSRLoss': SLSRLoss
41 | }
42 |
43 |
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/loss/seg/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/abstract.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | class AbstractSegLoss(nn.Module):
8 | """
9 | Abstract class for loss function of segmentation tasks.
10 | The parameters should be written in the `params` dictionary, and it has the
11 | following fields:
12 |
13 | :param `loss_softmax`: (optional, bool)
14 | Apply softmax to the prediction of network or not. Default is True.
15 | """
16 | def __init__(self, params = None):
17 | super(AbstractSegLoss, self).__init__()
18 | if(params is None):
19 | self.softmax = True
20 | else:
21 | self.softmax = params.get('loss_softmax', True)
22 |
23 | def forward(self, loss_input_dict):
24 | """
25 | Forward pass for calculating the loss.
26 | The arguments should be written in the `loss_input_dict` dictionary,
27 | and it has the following fields:
28 |
29 | :param `prediction`: (tensor) Prediction of a network, with the
30 | shape of [N, C, D, H, W] or [N, C, H, W].
31 | :param `ground_truth`: (tensor) Ground truth, with the
32 | shape of [N, C, D, H, W] or [N, C, H, W].
33 | :param `pixel_weight`: (optional) Pixel-wise weight map, with the
34 | shape of [N, 1, D, H, W] or [N, 1, H, W]. Default is None.
35 | :return: Loss function value.
36 | """
37 | pass
38 |
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/ce.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import torch.nn as nn
6 | from pymic.loss.seg.abstract import AbstractSegLoss
7 | from pymic.loss.seg.util import reshape_tensor_to_2D
8 |
9 | class CrossEntropyLoss(AbstractSegLoss):
10 | """
11 | Cross entropy loss for segmentation tasks.
12 |
13 | The parameters should be written in the `params` dictionary, and it has the
14 | following fields:
15 |
16 | :param `loss_softmax`: (optional, bool)
17 | Apply softmax to the prediction of network or not. Default is True.
18 | """
19 | def __init__(self, params = None):
20 | super(CrossEntropyLoss, self).__init__(params)
21 |
22 |
23 | def forward(self, loss_input_dict):
24 | predict = loss_input_dict['prediction']
25 | soft_y = loss_input_dict['ground_truth']
26 | pix_w = loss_input_dict.get('pixel_weight', None)
27 |
28 | if(isinstance(predict, (list, tuple))):
29 | predict = predict[0]
30 | if(self.softmax):
31 | predict = nn.Softmax(dim = 1)(predict)
32 | predict = reshape_tensor_to_2D(predict)
33 | soft_y = reshape_tensor_to_2D(soft_y)
34 |
35 | # for numeric stability
36 | predict = predict * 0.999 + 5e-4
37 | ce = - soft_y* torch.log(predict)
38 | ce = torch.sum(ce, dim = 1) # shape is [N]
39 | if(pix_w is None):
40 | ce = torch.mean(ce)
41 | else:
42 | pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w))
43 | ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5)
44 | return ce
45 |
46 | class GeneralizedCELoss(AbstractSegLoss):
47 | """
48 | Generalized cross entropy loss to deal with noisy labels.
49 |
50 | * Reference: Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks
51 | with Noisy Labels, NeurIPS 2018.
52 |
53 | The parameters should be written in the `params` dictionary, and it has the
54 | following fields:
55 |
56 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not.
57 | :param `loss_gce_q`: (float): hyper-parameter in the range of (0, 1).
58 | :param `loss_with_pixel_weight`: (optional, bool): Use pixel weighting or not.
59 | :param `loss_class_weight`: (optional, list or none): If not none, a list of weight for each class.
60 |
61 | """
62 | def __init__(self, params):
63 | super(GeneralizedCELoss, self).__init__(params)
64 | self.q = params.get('loss_gce_q', 0.5)
65 | self.enable_pix_weight = params.get('loss_with_pixel_weight', False)
66 | self.cls_weight = params.get('loss_class_weight', None)
67 |
68 | def forward(self, loss_input_dict):
69 | predict = loss_input_dict['prediction']
70 | soft_y = loss_input_dict['ground_truth']
71 |
72 | if(isinstance(predict, (list, tuple))):
73 | predict = predict[0]
74 | if(self.softmax):
75 | predict = nn.Softmax(dim = 1)(predict)
76 | predict = reshape_tensor_to_2D(predict)
77 | soft_y = reshape_tensor_to_2D(soft_y)
78 | gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y
79 |
80 | if(self.cls_weight is not None):
81 | gce = torch.sum(gce * self.cls_w, dim = 1)
82 | else:
83 | gce = torch.sum(gce, dim = 1)
84 |
85 | if(self.enable_pix_weight):
86 | pix_w = loss_input_dict.get('pixel_weight', None)
87 | if(pix_w is None):
88 | raise ValueError("Pixel weight is enabled but not defined")
89 | pix_w = reshape_tensor_to_2D(pix_w)
90 | gce = torch.sum(gce * pix_w) / torch.sum(pix_w)
91 | else:
92 | gce = torch.mean(gce)
93 | return gce
94 |
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/combined.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import torch.nn as nn
6 | from pymic.loss.seg.abstract import AbstractSegLoss
7 |
8 | class CombinedLoss(AbstractSegLoss):
9 | '''
10 | A combination of a list of loss functions.
11 | Parameters should be saved in the `params` dictionary.
12 |
13 | :param `loss_softmax`: (optional, bool)
14 | Apply softmax to the prediction of network or not. Default is True.
15 | :param `loss_type`: (list) A list of loss function name.
16 | :param `loss_weight`: (list) A list of weights for each loss fucntion.
17 | :param loss_dict: (dictionary) A dictionary of avaiable loss functions.
18 |
19 | '''
20 | def __init__(self, params, loss_dict):
21 | super(CombinedLoss, self).__init__(params)
22 | loss_names = params['loss_type']
23 | self.loss_weight = params['loss_weight']
24 | assert (len(loss_names) == len(self.loss_weight))
25 | self.loss_list = []
26 | for loss_name in loss_names:
27 | if(loss_name in loss_dict):
28 | one_loss = loss_dict[loss_name](params)
29 | self.loss_list.append(one_loss)
30 | else:
31 | raise ValueError("{0:} is not defined, or has not been added to the \
32 | loss dictionary".format(loss_name))
33 |
34 | def forward(self, loss_input_dict):
35 | loss_value = 0.0
36 | for i in range(len(self.loss_list)):
37 | loss_value += self.loss_weight[i]*self.loss_list[i](loss_input_dict)
38 | print('222',self.loss_list,'38loss')
39 | return loss_value
40 |
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/deep_sup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch.nn as nn
5 | from pymic.loss.seg.abstract import AbstractSegLoss
6 |
7 | class DeepSuperviseLoss(AbstractSegLoss):
8 | '''
9 | Combine deep supervision with a basic loss function.
10 | Arguments should be provided in the `params` dictionary, and it has the
11 | following fields:
12 |
13 | :param `loss_softmax`: (optional, bool)
14 | Apply softmax to the prediction of network or not. Default is True.
15 | :param `deep_suervise_weight`: (list) A list of weight for each deep supervision scale. \n
16 | :param `base_loss`: (nn.Module) The basic function used for each scale.
17 |
18 | '''
19 | def __init__(self, params):
20 | super(DeepSuperviseLoss, self).__init__(params)
21 | self.deep_sup_weight = params.get('deep_suervise_weight', None)
22 | self.base_loss = params['base_loss']
23 |
24 | def forward(self, loss_input_dict):
25 | predict = loss_input_dict['prediction']
26 | if(not isinstance(predict, (list,tuple))):
27 | raise ValueError("""For deep supervision, the prediction should
28 | be a list or a tuple""")
29 | predict_num = len(predict)
30 | if(self.deep_sup_weight is None):
31 | self.deep_sup_weight = [1.0] * predict_num
32 | else:
33 | assert(predict_num == len(self.deep_sup_weight))
34 | loss_sum, weight_sum = 0.0, 0.0
35 | for i in range(predict_num):
36 | loss_input_dict['prediction'] = predict[i]
37 | temp_loss = self.base_loss(loss_input_dict)
38 | loss_sum += temp_loss * self.deep_sup_weight[i]
39 | weight_sum += self.deep_sup_weight[i]
40 | loss = loss_sum/weight_sum
41 | return loss
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/exp_log.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | from __future__ import print_function, division
4 |
5 | import torch
6 | import torch.nn as nn
7 | from pymic.loss.seg.abstract import AbstractSegLoss
8 | from pymic.loss.seg.util import reshape_tensor_to_2D, get_classwise_dice
9 |
10 | class ExpLogLoss(AbstractSegLoss):
11 | """
12 | The exponential logarithmic loss in this paper:
13 |
14 | * K. Wong et al.: 3D Segmentation with Exponential Logarithmic Loss for Highly
15 | Unbalanced Object Sizes. `MICCAI 2018. `_
16 |
17 | The arguments should be written in the `params` dictionary, and it has the
18 | following fields:
19 |
20 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not.
21 | :param `ExpLogLoss_w_dice`: (float) Weight of ExpLog Dice loss in the range of [0, 1].
22 | :param `ExpLogLoss_gamma`: (float) Hyper-parameter gamma.
23 | """
24 | def __init__(self, params):
25 | super(ExpLogLoss, self).__init__(params)
26 | self.w_dice = params['ExpLogLoss_w_dice'.lower()]
27 | self.gamma = params['ExpLogLoss_gamma'.lower()]
28 |
29 | def forward(self, loss_input_dict):
30 | predict = loss_input_dict['prediction']
31 | soft_y = loss_input_dict['ground_truth']
32 |
33 | if(isinstance(predict, (list, tuple))):
34 | predict = predict[0]
35 | if(self.softmax):
36 | predict = nn.Softmax(dim = 1)(predict)
37 | predict = reshape_tensor_to_2D(predict)
38 | soft_y = reshape_tensor_to_2D(soft_y)
39 |
40 | dice_score = get_classwise_dice(predict, soft_y)
41 | dice_score = 0.005 + dice_score * 0.99
42 | exp_dice = -torch.log(dice_score)
43 | exp_dice = torch.pow(exp_dice, self.gamma)
44 | exp_dice = torch.mean(exp_dice)
45 |
46 | predict= 0.005 + predict * 0.99
47 | wc = torch.mean(soft_y, dim = 0)
48 | wc = 1.0 / (wc + 0.1)
49 | wc = torch.pow(wc, 0.5)
50 | ce = - torch.log(predict)
51 | exp_ce = wc * torch.pow(ce, self.gamma)
52 | exp_ce = torch.sum(soft_y * exp_ce, dim = 1)
53 | exp_ce = torch.mean(exp_ce)
54 |
55 | loss = exp_dice * self.w_dice + exp_ce * (1.0 - self.w_dice)
56 | return loss
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/mse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pymic.loss.seg.abstract import AbstractSegLoss
4 |
5 | class MSELoss(AbstractSegLoss):
6 | """
7 | Mean Sequare Loss for segmentation tasks.
8 | The parameters should be written in the `params` dictionary, and it has the
9 | following fields:
10 |
11 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not.
12 | """
13 | def __init__(self, params = None):
14 | super(MSELoss, self).__init__(params)
15 |
16 | def forward(self, loss_input_dict):
17 | predict = loss_input_dict['prediction']
18 | soft_y = loss_input_dict['ground_truth']
19 |
20 | if(isinstance(predict, (list, tuple))):
21 | predict = predict[0]
22 | if(self.softmax):
23 | predict = nn.Softmax(dim = 1)(predict)
24 | mse = torch.square(predict - soft_y)
25 | mse = torch.mean(mse)
26 | return mse
27 |
28 |
29 | class MAELoss(AbstractSegLoss):
30 | """
31 | Mean Absolute Loss for segmentation tasks.
32 | The arguments should be written in the `params` dictionary, and it has the
33 | following fields:
34 |
35 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not.
36 | """
37 | def __init__(self, params = None):
38 | super(MAELoss, self).__init__(params)
39 |
40 | def forward(self, loss_input_dict):
41 | predict = loss_input_dict['prediction']
42 | soft_y = loss_input_dict['ground_truth']
43 |
44 | if(isinstance(predict, (list, tuple))):
45 | predict = predict[0]
46 | if(self.softmax):
47 | predict = nn.Softmax(dim = 1)(predict)
48 | mae = torch.abs(predict - soft_y)
49 | mae = torch.mean(mae)
50 | return mae
51 |
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/mumford_shah.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | class MumfordShahLoss(nn.Module):
8 | """
9 | Implementation of Mumford Shah Loss for weakly supervised learning.
10 |
11 | * Reference: Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional
12 | for Image Segmentation With Deep Learning. IEEE TIP, 2019.
13 |
14 | The oringial implementation is availabel at `Github.
15 | `_
16 | Currently only 2D version is supported.
17 |
18 | The parameters should be written in the `params` dictionary, and it has the
19 | following fields:
20 |
21 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not.
22 | :param `MumfordShahLoss_penalty`: (optional, str) `l1` or `l2`. Default is `l1`.
23 | :param `MumfordShahLoss_lambda`: (optional, float) Hyper-parameter lambda, default is 1.0.
24 | """
25 | def __init__(self, params = None):
26 | super(MumfordShahLoss, self).__init__()
27 | if(params is None):
28 | params = {}
29 | self.penalty = params.get('MumfordShahLoss_penalty', "l1")
30 | self.grad_w = params.get('MumfordShahLoss_lambda', 1.0)
31 |
32 | def get_levelset_loss(self, output, target):
33 | """
34 | Get the level set loss value.
35 |
36 | :param `output`: (tensor) softmax output of a network.
37 | :param `target`: (tensor) the input image.
38 | :return: the level set loss.
39 | """
40 | outshape = output.shape
41 | tarshape = target.shape
42 | loss = 0.0
43 | for ich in range(tarshape[1]):
44 | target_ = torch.unsqueeze(target[:,ich], 1)
45 | target_ = target_.expand(tarshape[0], outshape[1], tarshape[2], tarshape[3])
46 | pcentroid = torch.sum(target_ * output, (2,3))/torch.sum(output, (2,3))
47 | pcentroid = pcentroid.view(tarshape[0], outshape[1], 1, 1)
48 | plevel = target_ - pcentroid.expand(tarshape[0], outshape[1], tarshape[2], tarshape[3])
49 | pLoss = plevel * plevel * output
50 | loss += torch.sum(pLoss)
51 | return loss
52 |
53 | def get_gradient_loss(self, pred, penalty = "l2"):
54 | dH = torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :])
55 | dW = torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1])
56 | if penalty == "l2":
57 | dH = dH * dH
58 | dW = dW * dW
59 | loss = torch.sum(dH) + torch.sum(dW)
60 | return loss
61 |
62 | def forward(self, loss_input_dict):
63 | """
64 | Forward pass for calculating the loss.
65 | The arguments should be written in the `loss_input_dict` dictionary,
66 | and it has the following fields:
67 |
68 | :param `prediction`: (tensor) Prediction of a network, with the
69 | shape of [N, C, D, H, W] or [N, C, H, W].
70 | :param `image`: (tensor) Image, with the
71 | shape of [N, C, D, H, W] or [N, C, H, W].
72 |
73 | :return: Loss function value.
74 | """
75 | predict = loss_input_dict['prediction']
76 | image = loss_input_dict['image']
77 | if(isinstance(predict, (list, tuple))):
78 | predict = predict[0]
79 | if(self.softmax):
80 | predict = nn.Softmax(dim = 1)(predict)
81 |
82 | pred_shape = list(predict.shape)
83 | if(len(pred_shape) == 5):
84 | [N, C, D, H, W] = pred_shape
85 | new_shape = [N*D, C, H, W]
86 | predict = torch.transpose(predict, 1, 2)
87 | predict = torch.reshape(predict, new_shape)
88 | [N, C, D, H, W] = list(image.shape)
89 | new_shape = [N*D, C, H, W]
90 | image = torch.transpose(image, 1, 2)
91 | image = torch.reshape(image, new_shape)
92 | loss0 = self.get_levelset_loss(predict, image)
93 | loss1 = self.get_gradient_loss(predict, self.penalty)
94 | loss = loss0 + self.grad_w * loss1
95 | return loss/torch.numel(predict)
96 |
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/slsr.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import torch
6 | import torch.nn as nn
7 | from pymic.loss.seg.abstract import AbstractSegLoss
8 | from pymic.loss.seg.util import reshape_tensor_to_2D
9 |
10 | class SLSRLoss(AbstractSegLoss):
11 | """
12 | Spatial Label Smoothing Regularization (SLSR) loss for learning from
13 | noisy annotatins. This loss requires pixel weighting, please make sure
14 | that a `pixel_weight` field is provided for the csv file of the training images.
15 |
16 | The pixel wight here is actually the confidence mask, i.e., if the value is one,
17 | it means the label of corresponding pixel is noisy and should be smoothed.
18 |
19 | * Reference: Minqing Zhang, Jiantao Gao et al.: Characterizing Label Errors: Confident Learning for Noisy-Labeled Image
20 | Segmentation, `MICCAI 2020. `_
21 |
22 | The arguments should be written in the `params` dictionary, and it has the
23 | following fields:
24 |
25 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not.
26 | :param `slsrloss_epsilon`: (optional, float) Hyper-parameter epsilon. Default is 0.25.
27 | """
28 | def __init__(self, params = None):
29 | super(SLSRLoss, self).__init__(params)
30 | if(params is None):
31 | params = {}
32 | self.epsilon = params.get('slsrloss_epsilon', 0.25)
33 |
34 | def forward(self, loss_input_dict):
35 | predict = loss_input_dict['prediction']
36 | soft_y = loss_input_dict['ground_truth']
37 | pix_w = loss_input_dict.get('pixel_weight', None)
38 |
39 | if(isinstance(predict, (list, tuple))):
40 | predict = predict[0]
41 | if(self.softmax):
42 | predict = nn.Softmax(dim = 1)(predict)
43 | predict = reshape_tensor_to_2D(predict)
44 | soft_y = reshape_tensor_to_2D(soft_y)
45 | if(pix_w is not None):
46 | pix_w = reshape_tensor_to_2D(pix_w > 0).float()
47 | # smooth labels for pixels in the unconfident mask
48 | smooth_y = (soft_y - 0.5) * (0.5 - self.epsilon) / 0.5 + 0.5
49 | smooth_y = pix_w * smooth_y + (1 - pix_w) * soft_y
50 | else:
51 | smooth_y = soft_y
52 |
53 | # for numeric stability
54 | predict = predict * 0.999 + 5e-4
55 | ce = - smooth_y* torch.log(predict)
56 | ce = torch.sum(ce, dim = 1) # shape is [N]
57 | ce = torch.mean(ce)
58 | return ce
59 |
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/ssl.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | from __future__ import print_function, division
4 |
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | from pymic.loss.seg.util import reshape_tensor_to_2D
9 |
10 | class EntropyLoss(nn.Module):
11 | """
12 | Entropy Minimization for segmentation tasks.
13 | The parameters should be written in the `params` dictionary, and it has the
14 | following fields:
15 |
16 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not.
17 | """
18 | def __init__(self, params = None):
19 | super(EntropyLoss, self).__init__(params)
20 |
21 | def forward(self, loss_input_dict):
22 | """
23 | Forward pass for calculating the loss.
24 | The arguments should be written in the `loss_input_dict` dictionary,
25 | and it has the following fields:
26 |
27 | :param `prediction`: (tensor) Prediction of a network, with the
28 | shape of [N, C, D, H, W] or [N, C, H, W].
29 |
30 | :return: Loss function value.
31 | """
32 | predict = loss_input_dict['prediction']
33 |
34 | if(isinstance(predict, (list, tuple))):
35 | predict = predict[0]
36 | if(self.softmax):
37 | predict = nn.Softmax(dim = 1)(predict)
38 |
39 | # for numeric stability
40 | predict = predict * 0.999 + 5e-4
41 | C = list(predict.shape)[1]
42 | entropy = torch.sum(-predict*torch.log(predict), dim=1) / np.log(C)
43 | avg_ent = torch.mean(entropy)
44 | return avg_ent
45 |
46 | class TotalVariationLoss(nn.Module):
47 | """
48 | Total Variation Loss for segmentation tasks.
49 | The parameters should be written in the `params` dictionary, and it has the
50 | following fields:
51 |
52 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not.
53 | """
54 | def __init__(self, params = None):
55 | super(TotalVariationLoss, self).__init__(params)
56 |
57 | def forward(self, loss_input_dict):
58 | """
59 | Forward pass for calculating the loss.
60 | The arguments should be written in the `loss_input_dict` dictionary,
61 | and it has the following fields:
62 |
63 | :param `prediction`: (tensor) Prediction of a network, with the
64 | shape of [N, C, D, H, W] or [N, C, H, W].
65 |
66 | :return: Loss function value.
67 | """
68 | predict = loss_input_dict['prediction']
69 |
70 | if(isinstance(predict, (list, tuple))):
71 | predict = predict[0]
72 | if(self.softmax):
73 | predict = nn.Softmax(dim = 1)(predict)
74 |
75 | # for numeric stability
76 | predict = predict * 0.999 + 5e-4
77 | dim = list(predict.shape)[2:]
78 | if(dim == 2):
79 | pred_min = -1 * nn.functional.max_pool2d(-1*predict, (3, 3), 1, 1)
80 | pred_max = nn.functional.max_pool2d(pred_min, (3, 3), 1, 1)
81 | else:
82 | pred_min = -1 * nn.functional.max_pool3d(-1*predict, (3, 3, 3), 1, 1)
83 | pred_max = nn.functional.max_pool3d(pred_min, (3, 3, 3), 1, 1)
84 | contour = torch.relu(pred_max - pred_min)
85 | length = torch.mean(contour)
86 | return length
--------------------------------------------------------------------------------
/PyMIC/pymic/loss/seg/util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 | def get_soft_label(input_tensor, num_class, data_type = 'float'):
9 | """
10 | Convert a label tensor to one-hot label for segmentation tasks.
11 |
12 | :param `input_tensor`: (tensor) Tensor with shae [B, 1, D, H, W] or [B, 1, H, W].
13 | :param `num_class`: (int) The class number.
14 | :param `data_type`: (optional, str) Type of data, `float` (default) or `double`.
15 |
16 | :return: A tensor with shape [B, num_class, D, H, W] or [B, num_class, H, W]
17 | """
18 |
19 | shape = input_tensor.shape
20 | if len(shape) == 5:
21 | output_tensor = torch.nn.functional.one_hot(input_tensor[:, 0], num_classes = num_class).permute(0, 4, 1, 2, 3)
22 | elif len(shape) == 4:
23 | output_tensor = torch.nn.functional.one_hot(input_tensor[:, 0], num_classes = num_class).permute(0, 3, 1, 2)
24 | else:
25 | raise ValueError("dimention of data can only be 4 or 5: {0:}".format(len(shape)))
26 |
27 | if(data_type == 'float'):
28 | output_tensor = output_tensor.float()
29 | elif(data_type == 'double'):
30 | output_tensor = output_tensor.double()
31 | else:
32 | raise ValueError("data type can only be float and double: {0:}".format(data_type))
33 |
34 | return output_tensor
35 |
36 | def reshape_tensor_to_2D(x):
37 | """
38 | Reshape input tensor of shape [N, C, D, H, W] or [N, C, H, W] to [voxel_n, C]
39 | """
40 | tensor_dim = len(x.size())
41 | num_class = list(x.size())[1]
42 | if(tensor_dim == 5):
43 | x_perm = x.permute(0, 2, 3, 4, 1)
44 | elif(tensor_dim == 4):
45 | x_perm = x.permute(0, 2, 3, 1)
46 | else:
47 | raise ValueError("{0:}D tensor not supported".format(tensor_dim))
48 |
49 | y = torch.reshape(x_perm, (-1, num_class))
50 | return y
51 | def dice_weight_loss(predict,target):
52 | target = target.float()
53 | predict = predict
54 | smooth = 1e-4
55 | intersect = torch.sum(predict*target)
56 | dice = (2 * intersect + smooth)/(torch.sum(target)+torch.sum(predict*predict)+smooth)
57 | loss = 1.0 - dice
58 | return loss
59 | def reshape_prediction_and_ground_truth(predict, soft_y):
60 | """
61 | Reshape input variables two 2D.
62 |
63 | :param predict: (tensor) A tensor of shape [N, C, D, H, W] or [N, C, H, W].
64 | :param soft_y: (tensor) A tensor of shape [N, C, D, H, W] or [N, C, H, W].
65 |
66 | :return: Two output tensors with shape [voxel_n, C] that correspond to the two inputs.
67 | """
68 | # print(predict.shape, soft_y.shape,'61')
69 | tensor_dim = len(predict.size())
70 | num_class = list(predict.size())[1]
71 | if(tensor_dim == 5):
72 | soft_y = soft_y.permute(0, 2, 3, 4, 1)
73 | predict = predict.permute(0, 2, 3, 4, 1)
74 | elif(tensor_dim == 4):
75 | soft_y = soft_y.permute(0, 2, 3, 1)
76 | predict = predict.permute(0, 2, 3, 1)
77 | else:
78 | raise ValueError("{0:}D tensor not supported".format(tensor_dim))
79 |
80 | predict = torch.reshape(predict, (-1, num_class))
81 | soft_y = torch.reshape(soft_y, (-1, num_class))
82 |
83 | return predict, soft_y
84 |
85 | def get_classwise_dice(predict, soft_y, pix_w = None):
86 | """
87 | Get dice scores for each class in predict (after softmax) and soft_y.
88 |
89 | :param predict: (tensor) Prediction of a segmentation network after softmax.
90 | :param soft_y: (tensor) The one-hot segmentation ground truth.
91 | :param pix_w: (optional, tensor) The pixel weight map. Default is None.
92 |
93 | :return: Dice score for each class.
94 | """
95 |
96 | if(pix_w is None):
97 | y_vol = torch.sum(soft_y, dim = 0)
98 | p_vol = torch.sum(predict, dim = 0)
99 | intersect = torch.sum(soft_y * predict, dim = 0)
100 | # print(y_vol,p_vol,('93'))
101 | else:
102 | y_vol = torch.sum(soft_y * pix_w, dim = 0)
103 | p_vol = torch.sum(predict * pix_w, dim = 0)
104 | intersect = torch.sum(soft_y * predict * pix_w, dim = 0)
105 | # print(pix_w.shape,pix_w.max(),pix_w.min(),('97'))
106 | dice_score = (2.0 * intersect + 1e-5)/ (y_vol + p_vol + 1e-5)
107 | return dice_score
108 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/net/cls/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net/cls/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/net/net3d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net/net3d/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/net/net3d/scse3d.py:
--------------------------------------------------------------------------------
1 | """
2 | 3D implementation of: \n
3 | 1. Channel Squeeze and Excitation \n
4 | 2. Spatial Squeeze and Excitation \n
5 | 3. Concurrent Spatial and Channel Squeeze & Excitation
6 |
7 | Oringinal file is on `Github.
8 | `_
9 | """
10 | from __future__ import print_function, division
11 |
12 | from enum import Enum
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 | class ChannelSELayer3D(nn.Module):
18 | """
19 | 3D implementation of Squeeze-and-Excitation (SE) block.
20 |
21 | * Reference: Jie Hu, Li Shen, Gang Sun: Squeeze-and-Excitation Networks.
22 | `CVPR 2018. `_
23 |
24 | :param num_channels: Number of input channels
25 | :param reduction_ratio: By how much should the num_channels should be reduced
26 | """
27 | def __init__(self, num_channels, reduction_ratio=2):
28 | super(ChannelSELayer3D, self).__init__()
29 | num_channels_reduced = num_channels // reduction_ratio
30 | self.reduction_ratio = reduction_ratio
31 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
32 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
33 | self.relu = nn.ReLU()
34 | self.sigmoid = nn.Sigmoid()
35 |
36 | def forward(self, input_tensor):
37 | """
38 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
39 | :return: output tensor
40 | """
41 |
42 | batch_size, num_channels, D, H, W = input_tensor.size()
43 | # Average along each channel
44 | squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)
45 | # channel excitation
46 | fc_out_1 = self.relu(self.fc1(squeeze_tensor))
47 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
48 |
49 | a, b = squeeze_tensor.size()
50 | output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1, 1))
51 |
52 | return output_tensor
53 |
54 | class SpatialSELayer3D(nn.Module):
55 | """
56 | 3D Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in:
57 |
58 | * Reference: Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in
59 | Fully Convolutional Networks, MICCAI 2018.
60 |
61 | :param num_channels: Number of input channels
62 | """
63 | def __init__(self, num_channels):
64 | super(SpatialSELayer3D, self).__init__()
65 | self.conv = nn.Conv3d(num_channels, 1, 1)
66 | self.sigmoid = nn.Sigmoid()
67 |
68 | def forward(self, input_tensor, weights=None):
69 | """
70 | :param weights: weights for few shot learning
71 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
72 | :return: output_tensor
73 | """
74 | # spatial squeeze
75 | batch_size, channel, D, H, W = input_tensor.size()
76 | if weights is not None:
77 | weights = torch.mean(weights, dim=0)
78 | weights = weights.view(1, channel, 1, 1)
79 | out = F.conv3d(input_tensor, weights)
80 | else:
81 | out = self.conv(input_tensor)
82 | squeeze_tensor = self.sigmoid(out)
83 |
84 | # spatial excitation
85 | # print(input_tensor.size(), squeeze_tensor.size())
86 | squeeze_tensor = squeeze_tensor.view(batch_size, 1, D, H, W)
87 | output_tensor = torch.mul(input_tensor, squeeze_tensor)
88 |
89 | #output_tensor = torch.mul(input_tensor, squeeze_tensor)
90 | return output_tensor
91 |
92 |
93 | class ChannelSpatialSELayer3D(nn.Module):
94 | """
95 | 3D Re-implementation of concurrent spatial and channel squeeze & excitation.
96 |
97 | * Reference: Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in
98 | Fully Convolutional Networks, MICCAI 2018.
99 |
100 | :param num_channels: Number of input channels
101 | :param reduction_ratio: By how much should the num_channels should be reduced
102 | """
103 | def __init__(self, num_channels, reduction_ratio=2):
104 | super(ChannelSpatialSELayer3D, self).__init__()
105 | self.cSE = ChannelSELayer3D(num_channels, reduction_ratio)
106 | self.sSE = SpatialSELayer3D(num_channels)
107 |
108 | def forward(self, input_tensor):
109 | """
110 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
111 | :return: output_tensor
112 | """
113 | output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
114 | return output_tensor
115 |
116 | class SELayer(Enum):
117 | """
118 | Enum restricting the type of SE Blockes available. So that type checking can be adding when adding these blockes to
119 | a neural network::
120 | if self.se_block_type == se.SELayer.CSE.value:
121 | self.SELayer = se.ChannelSpatialSELayer(params['num_filters'])
122 | elif self.se_block_type == se.SELayer.SSE.value:
123 | self.SELayer = se.SpatialSELayer(params['num_filters'])
124 | elif self.se_block_type == se.SELayer.CSSE.value:
125 | self.SELayer = se.ChannelSpatialSELayer(params['num_filters'])
126 | """
127 | NONE = 'NONE'
128 | CSE = 'CSE'
129 | SSE = 'SSE'
130 | CSSE = 'CSSE'
--------------------------------------------------------------------------------
/PyMIC/pymic/net/net_dict_cls.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Built-in networks for classification.
4 |
5 | * resnet18 :mod:`pymic.net.cls.torch_pretrained_net.ResNet18`
6 | * vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16`
7 | * mobilenetv2 :mod:`pymic.net.cls.torch_pretrained_net.MobileNetV2`
8 | """
9 |
10 | from __future__ import print_function, division
11 | from pymic.net.cls.torch_pretrained_net import *
12 |
13 | TorchClsNetDict = {
14 | 'resnet18': ResNet18,
15 | 'vgg16': VGG16,
16 | 'mobilenetv2':MobileNetV2
17 | }
18 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net/net_dict_seg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Built-in networks for segmentation.
4 |
5 | * UNet2D :mod:`pymic.net.net2d.unet2d.UNet2D`
6 | * UNet2D_DualBranch :mod:`pymic.net.net2d.unet2d_dual_branch.UNet2D_DualBranch`
7 | * UNet2D_URPC :mod:`pymic.net.net2d.unet2d_urpc.UNet2D_URPC`
8 | * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT`
9 | * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE`
10 | * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D`
11 | * NestedUNet2D :mod:`pymic.net.net2d.unet2d_nest.NestedUNet2D`
12 | * COPLENet :mod:`pymic.net.net2d.cople_net.COPLENet`
13 | * UNet2D5 :mod:`pymic.net.net3d.unet2d5.UNet2D5`
14 | * UNet3D :mod:`pymic.net.net3d.unet3d.UNet3D`
15 | * UNet3D_ScSE :mod:`pymic.net.net3d.unet3d_scse.UNet3D_ScSE`
16 | """
17 | from __future__ import print_function, division
18 | from pymic.net.net2d.unet2d import UNet2D
19 | # from pymic.net.net2d.unet2d_dsbn import UNet2D_dsbn
20 | # from pymic.net.net2d.unet2d_dsbn_copy import UNet2D_dsbn
21 | from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch
22 | from pymic.net.net2d.unet2d_urpc import UNet2D_URPC
23 | from pymic.net.net2d.unet2d_cct import UNet2D_CCT
24 | from pymic.net.net2d.cople_net import COPLENet
25 | from pymic.net.net2d.unet2d_attention import AttentionUNet2D
26 | from pymic.net.net2d.unet2d_nest import NestedUNet2D
27 | from pymic.net.net2d.unet2d_scse import UNet2D_ScSE
28 | from pymic.net.net3d.unet2d5 import UNet2D5
29 | from pymic.net.net3d.unet3d import UNet3D
30 | from pymic.net.net3d.unet2d5_dsbn import UNet2D5_dsbn,Dis
31 | from pymic.net.net3d.unet3d_scse import UNet3D_ScSE
32 |
33 | SegNetDict = {
34 | 'UNet2D': UNet2D,
35 | 'UNet2D_DualBranch': UNet2D_DualBranch,
36 | 'Dis': Dis,
37 | 'UNet2D_URPC': UNet2D_URPC,
38 | 'UNet2D_CCT': UNet2D_CCT,
39 | 'COPLENet': COPLENet,
40 | 'AttentionUNet2D': AttentionUNet2D,
41 | 'NestedUNet2D': NestedUNet2D,
42 | 'UNet2D_ScSE': UNet2D_ScSE,
43 | 'UNet2D5': UNet2D5,
44 | 'UNet2D5_dsbn': UNet2D5_dsbn,
45 | 'UNet3D': UNet3D,
46 | 'UNet3D_ScSE': UNet3D_ScSE
47 | }
48 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run/get_optimizer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | from torch import optim
6 | from torch.optim import lr_scheduler
7 | from pymic.util.general import keyword_match
8 |
9 | def get_optimizer(name, net_params, optim_params):
10 | lr = optim_params['learning_rate']
11 | momentum = optim_params['momentum']
12 | weight_decay = optim_params['weight_decay']
13 | if(keyword_match(name, "SGD")):
14 | return optim.SGD(net_params, lr,
15 | momentum = momentum, weight_decay = weight_decay)
16 | elif(keyword_match(name, "Adam")):
17 | return optim.Adam(net_params, lr, weight_decay = weight_decay)
18 | elif(keyword_match(name, "SparseAdam")):
19 | return optim.SparseAdam(net_params, lr)
20 | elif(keyword_match(name, "Adadelta")):
21 | return optim.Adadelta(net_params, lr, weight_decay = weight_decay)
22 | elif(keyword_match(name, "Adagrad")):
23 | return optim.Adagrad(net_params, lr, weight_decay = weight_decay)
24 | elif(keyword_match(name, "Adamax")):
25 | return optim.Adamax(net_params, lr, weight_decay = weight_decay)
26 | elif(keyword_match(name, "ASGD")):
27 | return optim.ASGD(net_params, lr, weight_decay = weight_decay)
28 | elif(keyword_match(name, "LBFGS")):
29 | return optim.LBFGS(net_params, lr)
30 | elif(keyword_match(name, "RMSprop")):
31 | return optim.RMSprop(net_params, lr, momentum = momentum,
32 | weight_decay = weight_decay)
33 | elif(keyword_match(name, "Rprop")):
34 | return optim.Rprop(net_params, lr)
35 | else:
36 | raise ValueError("unsupported optimizer {0:}".format(name))
37 |
38 |
39 | def get_lr_scheduler(optimizer, sched_params):
40 | name = sched_params["lr_scheduler"]
41 | if(name is None):
42 | return None
43 | lr_gamma = sched_params["lr_gamma"]
44 | if(keyword_match(name, "ReduceLROnPlateau")):
45 | patience_it = sched_params["ReduceLROnPlateau_patience".lower()]
46 | val_it = sched_params["iter_valid"]
47 | patience = patience_it / val_it
48 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
49 | mode = "max", factor=lr_gamma, patience = patience)
50 | elif(keyword_match(name, "MultiStepLR")):
51 | lr_milestones = sched_params["lr_milestones"]
52 | last_iter = sched_params["last_iter"]
53 | scheduler = lr_scheduler.MultiStepLR(optimizer,
54 | lr_milestones, lr_gamma, last_iter)
55 | else:
56 | raise ValueError("unsupported lr scheduler {0:}".format(name))
57 | return scheduler
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run/net_run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import logging
4 | import os
5 | import sys
6 | from pymic.util.parse_config import *
7 | from pymic.net_run.agent_cls import ClassificationAgent
8 | from pymic.net_run.agent_seg import SegmentationAgent
9 |
10 | def main():
11 | """
12 | The main function for running a network for training or inference.
13 | """
14 | if(len(sys.argv) < 3):
15 | print('Number of arguments should be 3. e.g.')
16 | print(' pymic_run train config.cfg')
17 | exit()
18 | stage = str(sys.argv[1])
19 | cfg_file = str(sys.argv[2])
20 | config = parse_config(cfg_file)
21 | config = synchronize_config(config)
22 | log_dir = config['training']['ckpt_save_dir']
23 | if(not os.path.exists(log_dir)):
24 | os.makedirs(log_dir, exist_ok=True)
25 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
26 | format='%(message)s')
27 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
28 | logging_config(config)
29 | task = config['dataset']['task_type']
30 | assert task in ['cls', 'cls_nexcl', 'seg']
31 | if(task == 'cls' or task == 'cls_nexcl'):
32 | agent = ClassificationAgent(config, stage)
33 | else:
34 | agent = SegmentationAgent(config, stage)
35 | agent.run()
36 |
37 | if __name__ == "__main__":
38 | main()
39 |
40 |
41 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_dsbn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run_dsbn/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_dsbn/dsbn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class _DomainSpecificBatchNorm(nn.Module):
5 | _version = 2
6 |
7 | def __init__(self, num_features, num_domains):
8 | super(_DomainSpecificBatchNorm, self).__init__()
9 | self.bns = nn.ModuleList(
10 | [nn.BatchNorm2d(num_features) for _ in range(num_domains)])
11 |
12 | def reset_running_stats(self):
13 | for bn in self.bns:
14 | bn.reset_running_stats()
15 |
16 | def reset_parameters(self):
17 | for bn in self.bns:
18 | bn.reset_parameters()
19 |
20 | def _check_input_dim(self, input):
21 | raise NotImplementedError
22 |
23 | def forward(self, x, domain_label):
24 | self._check_input_dim(x)
25 | bn = self.bns[domain_label[0]]
26 | return bn(x), domain_label
27 |
28 |
29 | class DomainSpecificBatchNorm2d(_DomainSpecificBatchNorm):
30 | def _check_input_dim(self, input):
31 | if input.dim() != 4:
32 | raise ValueError('expected 4D input (got {}D input)'
33 | .format(input.dim()))
34 |
35 | class _DomainSpecificBatchNorm3d(nn.Module):
36 | _version = 2
37 |
38 | def __init__(self, num_features, num_domains):
39 | super(_DomainSpecificBatchNorm3d, self).__init__()
40 | self.bns = nn.ModuleList(
41 | [nn.BatchNorm3d(num_features) for _ in range(num_domains)])
42 |
43 | def reset_running_stats(self):
44 | for bn in self.bns:
45 | bn.reset_running_stats()
46 |
47 | def reset_parameters(self):
48 | for bn in self.bns:
49 | bn.reset_parameters()
50 |
51 | def _check_input_dim(self, input):
52 | raise NotImplementedError
53 |
54 | def forward(self, x, domain_label):
55 | self._check_input_dim(x)
56 | bn = self.bns[domain_label[0]]
57 | return bn(x), domain_label
58 |
59 |
60 | class DomainSpecificBatchNorm3d(_DomainSpecificBatchNorm3d):
61 | def _check_input_dim(self, input):
62 | if input.dim() != 5:
63 | raise ValueError('expected 5D input (got {}D input)'
64 | .format(input.dim()))
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_dsbn/get_optimizer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | from torch import optim
6 | from torch.optim import lr_scheduler
7 | from pymic.util.general import keyword_match
8 |
9 | def get_optimizer(name, net_params, optim_params):
10 | lr = optim_params['learning_rate']
11 | momentum = optim_params['momentum']
12 | weight_decay = optim_params['weight_decay']
13 | if(keyword_match(name, "SGD")):
14 | return optim.SGD(net_params, lr,
15 | momentum = momentum, weight_decay = weight_decay)
16 | elif(keyword_match(name, "Adam")):
17 | return optim.Adam(net_params, lr, weight_decay = weight_decay)
18 | elif(keyword_match(name, "SparseAdam")):
19 | return optim.SparseAdam(net_params, lr)
20 | elif(keyword_match(name, "Adadelta")):
21 | return optim.Adadelta(net_params, lr, weight_decay = weight_decay)
22 | elif(keyword_match(name, "Adagrad")):
23 | return optim.Adagrad(net_params, lr, weight_decay = weight_decay)
24 | elif(keyword_match(name, "Adamax")):
25 | return optim.Adamax(net_params, lr, weight_decay = weight_decay)
26 | elif(keyword_match(name, "ASGD")):
27 | return optim.ASGD(net_params, lr, weight_decay = weight_decay)
28 | elif(keyword_match(name, "LBFGS")):
29 | return optim.LBFGS(net_params, lr)
30 | elif(keyword_match(name, "RMSprop")):
31 | return optim.RMSprop(net_params, lr, momentum = momentum,
32 | weight_decay = weight_decay)
33 | elif(keyword_match(name, "Rprop")):
34 | return optim.Rprop(net_params, lr)
35 | else:
36 | raise ValueError("unsupported optimizer {0:}".format(name))
37 |
38 |
39 | def get_lr_scheduler(optimizer, sched_params):
40 | name = sched_params["lr_scheduler"]
41 | if(name is None):
42 | return None
43 | lr_gamma = sched_params["lr_gamma"]
44 | if(keyword_match(name, "ReduceLROnPlateau")):
45 | patience_it = sched_params["ReduceLROnPlateau_patience".lower()]
46 | val_it = sched_params["iter_valid"]
47 | patience = patience_it / val_it
48 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
49 | mode = "max", factor=lr_gamma, patience = patience)
50 | elif(keyword_match(name, "MultiStepLR")):
51 | lr_milestones = sched_params["lr_milestones"]
52 | last_iter = sched_params["last_iter"]
53 | scheduler = lr_scheduler.MultiStepLR(optimizer,
54 | lr_milestones, lr_gamma, last_iter)
55 | else:
56 | raise ValueError("unsupported lr scheduler {0:}".format(name))
57 | return scheduler
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_dsbn/net_run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import logging
4 | import os
5 | import sys
6 | from pymic.util.parse_config import *
7 | from pymic.net_run_dsbn.agent_cls import ClassificationAgent
8 | from pymic.net_run_dsbn.agent_seg import SegmentationAgent
9 | from pymic.util.evaluation_seg_train import eva_main
10 |
11 | def main():
12 | """
13 | The main function for running a network for training or inference.
14 | """
15 | if(len(sys.argv) < 3):
16 | print('Number of arguments should be 3. e.g.')
17 | print(' pymic_run train config.cfg')
18 | exit()
19 | stage = str(sys.argv[1])
20 | cfg_file = str(sys.argv[2])
21 | config = parse_config(cfg_file)
22 | config = synchronize_config(config)
23 | log_dir = config['training']['ckpt_save_dir']
24 | if(not os.path.exists(log_dir)):
25 | os.makedirs(log_dir, exist_ok=True)
26 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
27 | format='%(message)s')
28 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
29 | logging_config(config)
30 | task = config['dataset']['task_type']
31 | assert task in ['cls', 'cls_nexcl', 'seg']
32 | if(task == 'cls' or task == 'cls_nexcl'):
33 | agent = ClassificationAgent(config, stage)
34 | else:
35 | agent = SegmentationAgent(config, stage)
36 | agent.run()
37 | if stage != 'test':
38 | agent2 = SegmentationAgent(config, 'test')
39 | agent2.run()
40 | eva_main(config)
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
45 |
46 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_nll/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run_nll/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_nll/nll_main.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | from __future__ import print_function, division
4 | import logging
5 | import os
6 | import sys
7 | from pymic.util.parse_config import *
8 | from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching
9 | from pymic.net_run_nll.nll_trinet import NLLTriNet
10 | from pymic.net_run_nll.nll_dast import NLLDAST
11 |
12 | NLLMethodDict = {'CoTeaching': NLLCoTeaching,
13 | "TriNet": NLLTriNet,
14 | "DAST": NLLDAST}
15 |
16 | def main():
17 | """
18 | The main function for noisy label learning methods.
19 | """
20 | if(len(sys.argv) < 3):
21 | print('Number of arguments should be 3. e.g.')
22 | print(' pymic_nll train config.cfg')
23 | exit()
24 | stage = str(sys.argv[1])
25 | cfg_file = str(sys.argv[2])
26 | config = parse_config(cfg_file)
27 | config = synchronize_config(config)
28 | log_dir = config['training']['ckpt_save_dir']
29 | if(not os.path.exists(log_dir)):
30 | os.mkdir(log_dir)
31 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
32 | format='%(message)s')
33 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
34 | logging_config(config)
35 | nll_method = config['noisy_label_learning']['nll_method']
36 | agent = NLLMethodDict[nll_method](config, stage)
37 | agent.run()
38 |
39 | if __name__ == "__main__":
40 | main()
41 |
42 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_ssl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run_ssl/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_ssl/ssl_em.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import logging
4 | import numpy as np
5 | import torch
6 | from torch.optim import lr_scheduler
7 | from pymic.loss.seg.util import get_soft_label
8 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth
9 | from pymic.loss.seg.util import get_classwise_dice
10 | from pymic.loss.seg.ssl import EntropyLoss
11 | from pymic.net_run_ssl.ssl_abstract import SSLSegAgent
12 | from pymic.transform.trans_dict import TransformDict
13 | from pymic.util.ramps import get_rampup_ratio
14 |
15 | class SSLEntropyMinimization(SSLSegAgent):
16 | """
17 | Using Entropy Minimization for semi-supervised segmentation.
18 |
19 | * Reference: Yves Grandvalet and Yoshua Bengio:
20 | Semi-supervised Learningby Entropy Minimization.
21 | `NeurIPS, 2005. `_
22 |
23 | :param config: (dict) A dictionary containing the configuration.
24 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`.
25 |
26 | .. note::
27 |
28 | In the configuration dictionary, in addition to the four sections (`dataset`,
29 | `network`, `training` and `inference`) used in fully supervised learning, an
30 | extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details.
31 | """
32 | def __init__(self, config, stage = 'train'):
33 | super(SSLEntropyMinimization, self).__init__(config, stage)
34 | self.transform_dict = TransformDict
35 | self.train_set_unlab = None
36 |
37 | def training(self):
38 | class_num = self.config['network']['class_num']
39 | iter_valid = self.config['training']['iter_valid']
40 | ssl_cfg = self.config['semi_supervised_learning']
41 | iter_max = self.config['training']['iter_max']
42 | rampup_start = ssl_cfg.get('rampup_start', 0)
43 | rampup_end = ssl_cfg.get('rampup_end', iter_max)
44 | train_loss = 0
45 | train_loss_sup = 0
46 | train_loss_reg = 0
47 | train_dice_list = []
48 | self.net.train()
49 | for it in range(iter_valid):
50 | try:
51 | data_lab = next(self.trainIter)
52 | except StopIteration:
53 | self.trainIter = iter(self.train_loader)
54 | data_lab = next(self.trainIter)
55 | try:
56 | data_unlab = next(self.trainIter_unlab)
57 | except StopIteration:
58 | self.trainIter_unlab = iter(self.train_loader_unlab)
59 | data_unlab = next(self.trainIter_unlab)
60 |
61 | # get the inputs
62 | x0 = self.convert_tensor_type(data_lab['image'])
63 | y0 = self.convert_tensor_type(data_lab['label_prob'])
64 | x1 = self.convert_tensor_type(data_unlab['image'])
65 | inputs = torch.cat([x0, x1], dim = 0)
66 | inputs, y0 = inputs.to(self.device), y0.to(self.device)
67 |
68 | # zero the parameter gradients
69 | self.optimizer.zero_grad()
70 |
71 | # forward + backward + optimize
72 | outputs = self.net(inputs)
73 | n0 = list(x0.shape)[0]
74 | p0 = outputs[:n0]
75 | loss_sup = self.get_loss_value(data_lab, p0, y0)
76 | loss_dict = {"prediction":outputs, 'softmax':True}
77 | loss_reg = EntropyLoss()(loss_dict)
78 |
79 | rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid")
80 | regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio
81 |
82 | loss = loss_sup + regular_w*loss_reg
83 | # if (self.config['training']['use'])
84 | loss.backward()
85 | self.optimizer.step()
86 | if(self.scheduler is not None and \
87 | not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
88 | self.scheduler.step()
89 |
90 | train_loss = train_loss + loss.item()
91 | train_loss_sup = train_loss_sup + loss_sup.item()
92 | train_loss_reg = train_loss_reg + loss_reg.item()
93 | # get dice evaluation for each class in annotated images
94 | if(isinstance(p0, tuple) or isinstance(p0, list)):
95 | p0 = p0[0]
96 | p0_argmax = torch.argmax(p0, dim = 1, keepdim = True)
97 | p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type)
98 | p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0)
99 | dice_list = get_classwise_dice(p0_soft, y0)
100 | train_dice_list.append(dice_list.cpu().numpy())
101 | train_avg_loss = train_loss / iter_valid
102 | train_avg_loss_sup = train_loss_sup / iter_valid
103 | train_avg_loss_reg = train_loss_reg / iter_valid
104 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
105 | train_avg_dice = train_cls_dice.mean()
106 |
107 | train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
108 | 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
109 | 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
110 | return train_scalers
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_ssl/ssl_main.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | from __future__ import print_function, division
4 | import logging
5 | import os
6 | import sys
7 | from pymic.util.parse_config import *
8 | from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization
9 | from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher
10 | from pymic.net_run_ssl.ssl_uamt import SSLUncertaintyAwareMeanTeacher
11 | from pymic.net_run_ssl.ssl_cct import SSLCCT
12 | from pymic.net_run_ssl.ssl_cps import SSLCPS
13 | from pymic.net_run_ssl.ssl_urpc import SSLURPC
14 |
15 |
16 | SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization,
17 | 'MeanTeacher': SSLMeanTeacher,
18 | 'UAMT': SSLUncertaintyAwareMeanTeacher,
19 | 'CCT': SSLCCT,
20 | 'CPS': SSLCPS,
21 | 'URPC': SSLURPC}
22 |
23 | def main():
24 | """
25 | Main function for running a semi-supervised method.
26 | """
27 | if(len(sys.argv) < 3):
28 | print('Number of arguments should be 3. e.g.')
29 | print(' pymic_ssl train config.cfg')
30 | exit()
31 | stage = str(sys.argv[1])
32 | cfg_file = str(sys.argv[2])
33 | config = parse_config(cfg_file)
34 | config = synchronize_config(config)
35 | log_dir = config['training']['ckpt_save_dir']
36 | if(not os.path.exists(log_dir)):
37 | os.mkdir(log_dir)
38 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
39 | format='%(message)s')
40 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
41 | logging_config(config)
42 | ssl_method = config['semi_supervised_learning']['ssl_method']
43 | agent = SSLMethodDict[ssl_method](config, stage)
44 | agent.run()
45 |
46 | if __name__ == "__main__":
47 | main()
48 |
49 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_wsl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run_wsl/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_wsl/wsl_abstract.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import logging
4 | from pymic.net_run.agent_seg import SegmentationAgent
5 |
6 | class WSLSegAgent(SegmentationAgent):
7 | """
8 | Abstract agent for weakly supervised segmentation.
9 |
10 | :param config: (dict) A dictionary containing the configuration.
11 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`.
12 |
13 | .. note::
14 |
15 | In the configuration dictionary, in addition to the four sections (`dataset`,
16 | `network`, `training` and `inference`) used in fully supervised learning, an
17 | extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details.
18 | """
19 | def __init__(self, config, stage = 'train'):
20 | super(WSLSegAgent, self).__init__(config, stage)
21 |
22 | def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
23 | loss_scalar ={'train':train_scalars['loss'],
24 | 'valid':valid_scalars['loss']}
25 | loss_sup_scalar = {'train':train_scalars['loss_sup']}
26 | loss_upsup_scalar = {'train':train_scalars['loss_reg']}
27 | dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']}
28 | self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
29 | self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it)
30 | self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it)
31 | self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it)
32 | self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it)
33 | self.summ_writer.add_scalars('dice', dice_scalar, glob_it)
34 | class_num = self.config['network']['class_num']
35 | for c in range(class_num):
36 | cls_dice_scalar = {'train':train_scalars['class_dice'][c], \
37 | 'valid':valid_scalars['class_dice'][c]}
38 | self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it)
39 | logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format(
40 | train_scalars['loss'], train_scalars['avg_dice']) + "[" + \
41 | ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]")
42 | logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format(
43 | valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \
44 | ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]")
45 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_wsl/wsl_em.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import logging
4 | import numpy as np
5 | import torch
6 | from torch.optim import lr_scheduler
7 | from pymic.loss.seg.util import get_soft_label
8 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth
9 | from pymic.loss.seg.util import get_classwise_dice
10 | from pymic.loss.seg.ssl import EntropyLoss
11 | from pymic.net_run.agent_seg import SegmentationAgent
12 | from pymic.net_run_wsl.wsl_abstract import WSLSegAgent
13 | from pymic.util.ramps import get_rampup_ratio
14 |
15 | class WSLEntropyMinimization(WSLSegAgent):
16 | """
17 | Weakly supervised segmentation based on Entropy Minimization.
18 |
19 | * Reference: Yves Grandvalet and Yoshua Bengio:
20 | Semi-supervised Learningby Entropy Minimization.
21 | `NeurIPS, 2005. `_
22 |
23 | :param config: (dict) A dictionary containing the configuration.
24 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`.
25 |
26 | .. note::
27 |
28 | In the configuration dictionary, in addition to the four sections (`dataset`,
29 | `network`, `training` and `inference`) used in fully supervised learning, an
30 | extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details.
31 | """
32 | def __init__(self, config, stage = 'train'):
33 | super(WSLEntropyMinimization, self).__init__(config, stage)
34 |
35 | def training(self):
36 | class_num = self.config['network']['class_num']
37 | iter_valid = self.config['training']['iter_valid']
38 | wsl_cfg = self.config['weakly_supervised_learning']
39 | iter_max = self.config['training']['iter_max']
40 | rampup_start = wsl_cfg.get('rampup_start', 0)
41 | rampup_end = wsl_cfg.get('rampup_end', iter_max)
42 | train_loss = 0
43 | train_loss_sup = 0
44 | train_loss_reg = 0
45 | train_dice_list = []
46 | self.net.train()
47 | for it in range(iter_valid):
48 | try:
49 | data = next(self.trainIter)
50 | except StopIteration:
51 | self.trainIter = iter(self.train_loader)
52 | data = next(self.trainIter)
53 |
54 | # get the inputs
55 | inputs = self.convert_tensor_type(data['image'])
56 | y = self.convert_tensor_type(data['label_prob'])
57 |
58 | inputs, y = inputs.to(self.device), y.to(self.device)
59 |
60 | # zero the parameter gradients
61 | self.optimizer.zero_grad()
62 |
63 | # forward + backward + optimize
64 | outputs = self.net(inputs)
65 | loss_sup = self.get_loss_value(data, outputs, y)
66 | loss_dict= {"prediction":outputs, 'softmax':True}
67 | loss_reg = EntropyLoss()(loss_dict)
68 |
69 | rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid")
70 | regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio
71 | loss = loss_sup + regular_w*loss_reg
72 |
73 | loss.backward()
74 | self.optimizer.step()
75 | if(self.scheduler is not None and \
76 | not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
77 | self.scheduler.step()
78 |
79 | train_loss = train_loss + loss.item()
80 | train_loss_sup = train_loss_sup + loss_sup.item()
81 | train_loss_reg = train_loss_reg + loss_reg.item()
82 | # get dice evaluation for each class in annotated images
83 | if(isinstance(outputs, tuple) or isinstance(outputs, list)):
84 | outputs = outputs[0]
85 | p_argmax = torch.argmax(outputs, dim = 1, keepdim = True)
86 | p_soft = get_soft_label(p_argmax, class_num, self.tensor_type)
87 | p_soft, y = reshape_prediction_and_ground_truth(p_soft, y)
88 | dice_list = get_classwise_dice(p_soft, y)
89 | train_dice_list.append(dice_list.cpu().numpy())
90 | train_avg_loss = train_loss / iter_valid
91 | train_avg_loss_sup = train_loss_sup / iter_valid
92 | train_avg_loss_reg = train_loss_reg / iter_valid
93 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
94 | train_avg_dice = train_cls_dice.mean()
95 |
96 | train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
97 | 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
98 | 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
99 | return train_scalers
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_wsl/wsl_main.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | from __future__ import print_function, division
4 | import logging
5 | import os
6 | import sys
7 | from pymic.util.parse_config import *
8 | from pymic.net_run_wsl.wsl_em import WSLEntropyMinimization
9 | from pymic.net_run_wsl.wsl_gatedcrf import WSLGatedCRF
10 | from pymic.net_run_wsl.wsl_mumford_shah import WSLMumfordShah
11 | from pymic.net_run_wsl.wsl_tv import WSLTotalVariation
12 | from pymic.net_run_wsl.wsl_ustm import WSLUSTM
13 | from pymic.net_run_wsl.wsl_dmpls import WSLDMPLS
14 |
15 | WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization,
16 | 'GatedCRF': WSLGatedCRF,
17 | 'MumfordShah': WSLMumfordShah,
18 | 'TotalVariation': WSLTotalVariation,
19 | 'USTM': WSLUSTM,
20 | 'DMPLS': WSLDMPLS}
21 |
22 | def main():
23 | """
24 | The main function for training and inference of weakly supervised segmentation.
25 | """
26 | if(len(sys.argv) < 3):
27 | print('Number of arguments should be 3. e.g.')
28 | print(' pymic_wsl train config.cfg')
29 | exit()
30 | stage = str(sys.argv[1])
31 | cfg_file = str(sys.argv[2])
32 | config = parse_config(cfg_file)
33 | config = synchronize_config(config)
34 | log_dir = config['training']['ckpt_save_dir']
35 | if(not os.path.exists(log_dir)):
36 | os.mkdir(log_dir)
37 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
38 | format='%(message)s')
39 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
40 | logging_config(config)
41 | wsl_method = config['weakly_supervised_learning']['wsl_method']
42 | agent = WSLMethodDict[wsl_method](config, stage)
43 | agent.run()
44 |
45 | if __name__ == "__main__":
46 | main()
47 |
48 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_wsl/wsl_mumford_shah.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import logging
4 | import numpy as np
5 | import torch
6 | from torch.optim import lr_scheduler
7 | from pymic.loss.seg.util import get_soft_label
8 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth
9 | from pymic.loss.seg.util import get_classwise_dice
10 | from pymic.loss.seg.mumford_shah import MumfordShahLoss
11 | from pymic.net_run_wsl.wsl_abstract import WSLSegAgent
12 | from pymic.util.ramps import get_rampup_ratio
13 |
14 | class WSLMumfordShah(WSLSegAgent):
15 | """
16 | Weakly supervised learning with Mumford Shah Loss.
17 |
18 | * Reference: Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional
19 | for Image Segmentation With Deep Learning.
20 | `IEEE TIP `_, 2019.
21 |
22 | :param config: (dict) A dictionary containing the configuration.
23 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`.
24 |
25 | .. note::
26 |
27 | In the configuration dictionary, in addition to the four sections (`dataset`,
28 | `network`, `training` and `inference`) used in fully supervised learning, an
29 | extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details.
30 | """
31 | def __init__(self, config, stage = 'train'):
32 | super(WSLMumfordShah, self).__init__(config, stage)
33 |
34 | def training(self):
35 | class_num = self.config['network']['class_num']
36 | iter_valid = self.config['training']['iter_valid']
37 | wsl_cfg = self.config['weakly_supervised_learning']
38 | iter_max = self.config['training']['iter_max']
39 | rampup_start = wsl_cfg.get('rampup_start', 0)
40 | rampup_end = wsl_cfg.get('rampup_end', iter_max)
41 | train_loss = 0
42 | train_loss_sup = 0
43 | train_loss_reg = 0
44 | train_dice_list = []
45 |
46 | reg_loss_calculator = MumfordShahLoss(wsl_cfg)
47 | self.net.train()
48 | for it in range(iter_valid):
49 | try:
50 | data = next(self.trainIter)
51 | except StopIteration:
52 | self.trainIter = iter(self.train_loader)
53 | data = next(self.trainIter)
54 |
55 | # get the inputs
56 | inputs = self.convert_tensor_type(data['image'])
57 | y = self.convert_tensor_type(data['label_prob'])
58 |
59 | inputs, y = inputs.to(self.device), y.to(self.device)
60 |
61 | # zero the parameter gradients
62 | self.optimizer.zero_grad()
63 |
64 | # forward + backward + optimize
65 | outputs = self.net(inputs)
66 | loss_sup = self.get_loss_value(data, outputs, y)
67 | loss_dict = {"prediction":outputs, 'image':inputs}
68 | loss_reg = reg_loss_calculator(loss_dict)
69 |
70 | rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid")
71 | regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio
72 | loss = loss_sup + regular_w*loss_reg
73 | # if (self.config['training']['use'])
74 | loss.backward()
75 | self.optimizer.step()
76 | if(self.scheduler is not None and \
77 | not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
78 | self.scheduler.step()
79 |
80 | train_loss = train_loss + loss.item()
81 | train_loss_sup = train_loss_sup + loss_sup.item()
82 | train_loss_reg = train_loss_reg + loss_reg.item()
83 | # get dice evaluation for each class in annotated images
84 | if(isinstance(outputs, tuple) or isinstance(outputs, list)):
85 | outputs = outputs[0]
86 | p_argmax = torch.argmax(outputs, dim = 1, keepdim = True)
87 | p_soft = get_soft_label(p_argmax, class_num, self.tensor_type)
88 | p_soft, y = reshape_prediction_and_ground_truth(p_soft, y)
89 | dice_list = get_classwise_dice(p_soft, y)
90 | train_dice_list.append(dice_list.cpu().numpy())
91 | train_avg_loss = train_loss / iter_valid
92 | train_avg_loss_sup = train_loss_sup / iter_valid
93 | train_avg_loss_reg = train_loss_reg / iter_valid
94 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
95 | train_avg_dice = train_cls_dice.mean()
96 |
97 | train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
98 | 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
99 | 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
100 | return train_scalers
101 |
--------------------------------------------------------------------------------
/PyMIC/pymic/net_run_wsl/wsl_tv.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import logging
4 | import numpy as np
5 | import torch
6 | from torch.optim import lr_scheduler
7 | from pymic.loss.seg.util import get_soft_label
8 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth
9 | from pymic.loss.seg.util import get_classwise_dice
10 | from pymic.loss.seg.ssl import TotalVariationLoss
11 | from pymic.net_run_wsl.wsl_abstract import WSLSegAgent
12 | from pymic.util.ramps import get_rampup_ratio
13 | from pymic.util.general import keyword_match
14 |
15 | class WSLTotalVariation(WSLSegAgent):
16 | """
17 | Weakly suepervised segmentation with Total Variation regularization.
18 |
19 | :param config: (dict) A dictionary containing the configuration.
20 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`.
21 |
22 | .. note::
23 |
24 | In the configuration dictionary, in addition to the four sections (`dataset`,
25 | `network`, `training` and `inference`) used in fully supervised learning, an
26 | extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details.
27 | """
28 | def __init__(self, config, stage = 'train'):
29 | super(WSLTotalVariation, self).__init__(config, stage)
30 |
31 | def training(self):
32 | class_num = self.config['network']['class_num']
33 | iter_valid = self.config['training']['iter_valid']
34 | wsl_cfg = self.config['weakly_supervised_learning']
35 | iter_max = self.config['training']['iter_max']
36 | rampup_start = wsl_cfg.get('rampup_start', 0)
37 | rampup_end = wsl_cfg.get('rampup_end', iter_max)
38 | train_loss = 0
39 | train_loss_sup = 0
40 | train_loss_reg = 0
41 | train_dice_list = []
42 | self.net.train()
43 | for it in range(iter_valid):
44 | try:
45 | data = next(self.trainIter)
46 | except StopIteration:
47 | self.trainIter = iter(self.train_loader)
48 | data = next(self.trainIter)
49 |
50 | # get the inputs
51 | inputs = self.convert_tensor_type(data['image'])
52 | y = self.convert_tensor_type(data['label_prob'])
53 |
54 | inputs, y = inputs.to(self.device), y.to(self.device)
55 |
56 | # zero the parameter gradients
57 | self.optimizer.zero_grad()
58 |
59 | # forward + backward + optimize
60 | outputs = self.net(inputs)
61 | loss_sup = self.get_loss_value(data, outputs, y)
62 | loss_dict = {"prediction":outputs, 'softmax':True}
63 | loss_reg = TotalVariationLoss()(loss_dict)
64 |
65 | rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid")
66 | regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio
67 | loss = loss_sup + regular_w*loss_reg
68 | # if (self.config['training']['use'])
69 | loss.backward()
70 | self.optimizer.step()
71 | if(self.scheduler is not None and \
72 | not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
73 | self.scheduler.step()
74 |
75 | train_loss = train_loss + loss.item()
76 | train_loss_sup = train_loss_sup + loss_sup.item()
77 | train_loss_reg = train_loss_reg + loss_reg.item()
78 | # get dice evaluation for each class in annotated images
79 | if(isinstance(outputs, tuple) or isinstance(outputs, list)):
80 | outputs = outputs[0]
81 | p_argmax = torch.argmax(outputs, dim = 1, keepdim = True)
82 | p_soft = get_soft_label(p_argmax, class_num, self.tensor_type)
83 | p_soft, y = reshape_prediction_and_ground_truth(p_soft, y)
84 | dice_list = get_classwise_dice(p_soft, y)
85 | train_dice_list.append(dice_list.cpu().numpy())
86 | train_avg_loss = train_loss / iter_valid
87 | train_avg_loss_sup = train_loss_sup / iter_valid
88 | train_avg_loss_reg = train_loss_reg / iter_valid
89 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
90 | train_avg_dice = train_cls_dice.mean()
91 |
92 | train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
93 | 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
94 | 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
95 | return train_scalers
96 |
--------------------------------------------------------------------------------
/PyMIC/pymic/transform/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/transform/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/transform/abstract_transform.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | class AbstractTransform(object):
5 | """
6 | The abstract class for Transform.
7 | """
8 | def __init__(self, params):
9 | self.task = params['Task'.lower()]
10 |
11 | def __call__(self, sample):
12 | """
13 | Forward pass of the transform.
14 |
15 | :arg sample: (dict) A dictionary for the input sample obtained by dataloader.
16 | """
17 | return sample
18 |
19 | def inverse_transform_for_prediction(self, sample):
20 | """
21 | Inverse transform for the sample dictionary.
22 | Especially, it will update sample['predict'] obtained by a network's
23 | prediction based on the inverse transform. This function is only useful for spatial transforms.
24 | """
25 | raise(ValueError("not implemented"))
26 |
--------------------------------------------------------------------------------
/PyMIC/pymic/transform/flip.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import json
6 | import math
7 | import random
8 | import numpy as np
9 | from scipy import ndimage
10 | from pymic.transform.abstract_transform import AbstractTransform
11 | from pymic.util.image_process import *
12 |
13 |
14 | class RandomFlip(AbstractTransform):
15 | """ Random flip the image. The shape is [C, D, H, W] or [C, H, W].
16 |
17 | The arguments should be written in the `params` dictionary, and it has the
18 | following fields:
19 |
20 | :param `RandomFlip_flip_depth`: (bool)
21 | Random flip along depth axis or not, only used for 3D images.
22 | :param `RandomFlip_flip_height`: (bool) Random flip along height axis or not.
23 | :param `RandomFlip_flip_width`: (bool) Random flip along width axis or not.
24 | :param `RandomFlip_inverse`: (optional, bool) Is inverse transform needed for inference.
25 | Default is `True`.
26 | """
27 | def __init__(self, params):
28 | super(RandomFlip, self).__init__(params)
29 | self.flip_depth = params['RandomFlip_flip_depth'.lower()]
30 | self.flip_height = params['RandomFlip_flip_height'.lower()]
31 | self.flip_width = params['RandomFlip_flip_width'.lower()]
32 | self.inverse = params.get('RandomFlip_inverse'.lower(), True)
33 |
34 | def __call__(self, sample):
35 | image = sample['image']
36 | input_shape = image.shape
37 | input_dim = len(input_shape) - 1
38 | flip_axis = []
39 | if(self.flip_width):
40 | if(random.random() > 0.5):
41 | flip_axis.append(-1)
42 | if(self.flip_height):
43 | if(random.random() > 0.5):
44 | flip_axis.append(-2)
45 | if(input_dim == 3 and self.flip_depth):
46 | if(random.random() > 0.5):
47 | flip_axis.append(-3)
48 |
49 | sample['RandomFlip_Param'] = json.dumps(flip_axis)
50 | if(len(flip_axis) > 0):
51 | # use .copy() to avoid negative strides of numpy array
52 | # current pytorch does not support negative strides
53 | image_t = np.flip(image, flip_axis).copy()
54 | sample['image'] = image_t
55 | if('label' in sample and self.task == 'segmentation'):
56 | sample['label'] = np.flip(sample['label'] , flip_axis).copy()
57 | if('pixel_weight' in sample and self.task == 'segmentation'):
58 | sample['pixel_weight'] = np.flip(sample['pixel_weight'] , flip_axis).copy()
59 | if('image1' in sample and self.task == 'segmentation'):
60 | sample['image1'] = np.flip(sample['image1'] , flip_axis).copy()
61 |
62 | return sample
63 |
64 | def inverse_transform_for_prediction(self, sample):
65 | if(isinstance(sample['RandomFlip_Param'], list) or \
66 | isinstance(sample['RandomFlip_Param'], tuple)):
67 | flip_axis = json.loads(sample['RandomFlip_Param'][0])
68 | else:
69 | flip_axis = json.loads(sample['RandomFlip_Param'])
70 | if(len(flip_axis) > 0):
71 | sample['predict'] = np.flip(sample['predict'] , flip_axis).copy()
72 | return sample
--------------------------------------------------------------------------------
/PyMIC/pymic/transform/intensity.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import json
6 | import math
7 | import random
8 | import numpy as np
9 | from scipy import ndimage
10 | from pymic.transform.abstract_transform import AbstractTransform
11 | from pymic.util.image_process import *
12 |
13 |
14 | class GammaCorrection(AbstractTransform):
15 | """
16 | Apply random gamma correction to given channels.
17 |
18 | The arguments should be written in the `params` dictionary, and it has the
19 | following fields:
20 |
21 | :param `GammaCorrection_channels`: (list) A list of int for specifying the channels.
22 | :param `GammaCorrection_gamma_min`: (float) The minimal gamma value.
23 | :param `GammaCorrection_gamma_max`: (float) The maximal gamma value.
24 | :param `GammaCorrection_probability`: (optional, float)
25 | The probability of applying GammaCorrection. Default is 0.5.
26 | :param `GammaCorrection_inverse`: (optional, bool)
27 | Is inverse transform needed for inference. Default is `False`.
28 | """
29 | def __init__(self, params):
30 | super(GammaCorrection, self).__init__(params)
31 | self.channels = params['GammaCorrection_channels'.lower()]
32 | self.gamma_min = params['GammaCorrection_gamma_min'.lower()]
33 | self.gamma_max = params['GammaCorrection_gamma_max'.lower()]
34 | self.prob = params.get('GammaCorrection_probability'.lower(), 0.5)
35 | self.inverse = params.get('GammaCorrection_inverse'.lower(), False)
36 |
37 | def __call__(self, sample):
38 | if(np.random.uniform() > self.prob):
39 | return sample
40 | image= sample['image']
41 | for chn in self.channels:
42 | gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min
43 | img_c = image[chn]
44 | v_min = img_c.min()
45 | v_max = img_c.max()
46 | img_c = (img_c - v_min)/(v_max - v_min)
47 | img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min
48 | image[chn] = img_c
49 |
50 | sample['image'] = image
51 | return sample
52 |
53 | class GaussianNoise(AbstractTransform):
54 | """
55 | Add Gaussian Noise to given channels.
56 |
57 | The arguments should be written in the `params` dictionary, and it has the
58 | following fields:
59 |
60 | :param `GaussianNoise_channels`: (list) A list of int for specifying the channels.
61 | :param `GaussianNoise_mean`: (float) The mean value of noise.
62 | :param `GaussianNoise_std`: (float) The std of noise.
63 | :param `GaussianNoise_probability`: (optional, float)
64 | The probability of applying GaussianNoise. Default is 0.5.
65 | :param `GaussianNoise_inverse`: (optional, bool)
66 | Is inverse transform needed for inference. Default is `False`.
67 | """
68 | def __init__(self, params):
69 | super(GaussianNoise, self).__init__(params)
70 | self.channels = params['GaussianNoise_channels'.lower()]
71 | self.mean = params['GaussianNoise_mean'.lower()]
72 | self.std = params['GaussianNoise_std'.lower()]
73 | self.prob = params.get('GaussianNoise_probability'.lower(), 0.5)
74 | self.inverse = params.get('GaussianNoise_inverse'.lower(), False)
75 |
76 | def __call__(self, sample):
77 | if(np.random.uniform() > self.prob):
78 | return sample
79 | image= sample['image']
80 | for chn in self.channels:
81 | img_c = image[chn]
82 | noise = np.random.normal(self.mean, self.std, img_c.shape)
83 | image[chn] = img_c + noise
84 |
85 | sample['image'] = image
86 | return sample
87 |
88 | class GrayscaleToRGB(AbstractTransform):
89 | """
90 | Convert gray scale images to RGB by copying channels.
91 | """
92 | def __init__(self, params):
93 | super(GrayscaleToRGB, self).__init__(params)
94 | self.inverse = params.get('GrayscaleToRGB_inverse'.lower(), False)
95 |
96 | def __call__(self, sample):
97 | image= sample['image']
98 | assert(image.shape[0] == 1 or image.shape[0] == 3)
99 | if(image.shape[0] == 1):
100 | sample['image'] = np.concatenate([image, image, image])
101 | return sample
--------------------------------------------------------------------------------
/PyMIC/pymic/transform/rotate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 |
4 | import torch
5 | import json
6 | import math
7 | import random
8 | import numpy as np
9 | from scipy import ndimage
10 | from pymic.transform.abstract_transform import AbstractTransform
11 | from pymic.util.image_process import *
12 |
13 |
14 | class RandomRotate(AbstractTransform):
15 | """
16 | Random rotate an image, wiht a shape of [C, D, H, W] or [C, H, W].
17 |
18 | The arguments should be written in the `params` dictionary, and it has the
19 | following fields:
20 |
21 | :param `RandomRotate_angle_range_d`: (list/tuple or None)
22 | Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90).
23 | If None, no rotation along this axis.
24 | :param `RandomRotate_angle_range_h`: (list/tuple or None)
25 | Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90).
26 | If None, no rotation along this axis. Only used for 3D images.
27 | :param `RandomRotate_angle_range_w`: (list/tuple or None)
28 | Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90).
29 | If None, no rotation along this axis. Only used for 3D images.
30 | :param `RandomRotate_inverse`: (optional, bool)
31 | Is inverse transform needed for inference. Default is `True`.
32 | """
33 | def __init__(self, params):
34 | super(RandomRotate, self).__init__(params)
35 | self.angle_range_d = params['RandomRotate_angle_range_d'.lower()]
36 | self.angle_range_h = params['RandomRotate_angle_range_h'.lower()]
37 | self.angle_range_w = params['RandomRotate_angle_range_w'.lower()]
38 | self.inverse = params.get('RandomRotate_inverse'.lower(), True)
39 |
40 | def __apply_transformation(self, image, transform_param_list, order = 1):
41 | """
42 | Apply rotation transformation to an ND image.
43 |
44 | :param image: The input ND image.
45 | :param transform_param_list: (list) A list of roration angle and axes.
46 | :param order: (int) Interpolation order.
47 | """
48 | for angle, axes in transform_param_list:
49 | image = ndimage.rotate(image, angle, axes, reshape = False, order = order)
50 | return image
51 |
52 | def __call__(self, sample):
53 | image = sample['image']
54 | input_shape = image.shape
55 | input_dim = len(input_shape) - 1
56 |
57 | transform_param_list = []
58 | if(self.angle_range_d is not None):
59 | angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1])
60 | transform_param_list.append([angle_d, (-1, -2)])
61 | if(input_dim == 3):
62 | if(self.angle_range_h is not None):
63 | angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1])
64 | transform_param_list.append([angle_h, (-1, -3)])
65 | if(self.angle_range_w is not None):
66 | angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1])
67 | transform_param_list.append([angle_w, (-2, -3)])
68 | assert(len(transform_param_list) > 0)
69 |
70 | sample['RandomRotate_Param'] = json.dumps(transform_param_list)
71 | image_t = self.__apply_transformation(image, transform_param_list, 1)
72 | sample['image'] = image_t
73 | if('label' in sample and self.task == 'segmentation'):
74 | sample['label'] = self.__apply_transformation(sample['label'] ,
75 | transform_param_list, 0)
76 | if('pixel_weight' in sample and self.task == 'segmentation'):
77 | sample['pixel_weight'] = self.__apply_transformation(sample['pixel_weight'] ,
78 | transform_param_list, 1)
79 | return sample
80 |
81 | def inverse_transform_for_prediction(self, sample):
82 | if(isinstance(sample['RandomRotate_Param'], list) or \
83 | isinstance(sample['RandomRotate_Param'], tuple)):
84 | transform_param_list = json.loads(sample['RandomRotate_Param'][0])
85 | else:
86 | transform_param_list = json.loads(sample['RandomRotate_Param'])
87 | transform_param_list.reverse()
88 | for i in range(len(transform_param_list)):
89 | transform_param_list[i][0] = - transform_param_list[i][0]
90 | sample['predict'] = self.__apply_transformation(sample['predict'] ,
91 | transform_param_list, 1)
92 | return sample
--------------------------------------------------------------------------------
/PyMIC/pymic/transform/trans_dict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | The built-in transforms in PyMIC are:
4 |
5 | .. code-block:: none
6 |
7 | 'ChannelWiseThreshold': ChannelWiseThreshold,
8 | 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize,
9 | 'CropWithBoundingBox': CropWithBoundingBox,
10 | 'CenterCrop': CenterCrop,
11 | 'GrayscaleToRGB': GrayscaleToRGB,
12 | 'GammaCorrection': GammaCorrection,
13 | 'GaussianNoise': GaussianNoise,
14 | 'LabelConvert': LabelConvert,
15 | 'LabelConvertNonzero': LabelConvertNonzero,
16 | 'LabelToProbability': LabelToProbability,
17 | 'NormalizeWithMeanStd': NormalizeWithMeanStd,
18 | 'NormalizeWithMinMax': NormalizeWithMinMax,
19 | 'NormalizeWithPercentiles': NormalizeWithPercentiles,
20 | 'PartialLabelToProbability':PartialLabelToProbability,
21 | 'RandomCrop': RandomCrop,
22 | 'RandomResizedCrop': RandomResizedCrop,
23 | 'RandomRescale': RandomRescale,
24 | 'RandomFlip': RandomFlip,
25 | 'RandomRotate': RandomRotate,
26 | 'ReduceLabelDim': ReduceLabelDim,
27 | 'Rescale': Rescale,
28 | 'Pad': Pad.
29 |
30 | """
31 | from __future__ import print_function, division
32 | from pymic.transform.intensity import *
33 | from pymic.transform.flip import RandomFlip
34 | from pymic.transform.pad import Pad
35 | from pymic.transform.rotate import RandomRotate
36 | from pymic.transform.rescale import Rescale, RandomRescale
37 | from pymic.transform.threshold import *
38 | from pymic.transform.normalize import *
39 | from pymic.transform.crop import *
40 | from pymic.transform.label_convert import *
41 |
42 | TransformDict = {
43 | 'ChannelWiseThreshold': ChannelWiseThreshold,
44 | 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize,
45 | 'CropWithBoundingBox': CropWithBoundingBox,
46 | 'CenterCrop': CenterCrop,
47 | 'GrayscaleToRGB': GrayscaleToRGB,
48 | 'GammaCorrection': GammaCorrection,
49 | 'GaussianNoise': GaussianNoise,
50 | 'LabelConvert': LabelConvert,
51 | 'LabelConvertNonzero': LabelConvertNonzero,
52 | 'LabelToProbability': LabelToProbability,
53 | 'NormalizeWithMeanStd': NormalizeWithMeanStd,
54 | 'NormalizeWithMeanStd_dual': NormalizeWithMeanStd_dual,
55 | 'NormalizeWithMinMax': NormalizeWithMinMax,
56 | 'NormalizeWithPercentiles': NormalizeWithPercentiles,
57 | 'PartialLabelToProbability':PartialLabelToProbability,
58 | 'RandomCrop': RandomCrop,
59 | 'RandomResizedCrop': RandomResizedCrop,
60 | 'RandomRescale': RandomRescale,
61 | 'RandomFlip': RandomFlip,
62 | 'RandomRotate': RandomRotate,
63 | 'ReduceLabelDim': ReduceLabelDim,
64 | 'Rescale': Rescale,
65 | 'Pad': Pad,
66 | }
67 |
--------------------------------------------------------------------------------
/PyMIC/pymic/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/util/__init__.py
--------------------------------------------------------------------------------
/PyMIC/pymic/util/general.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function, division
3 | import torch
4 | import numpy as np
5 |
6 | def keyword_match(a,b):
7 | """
8 | Test if two string are the same when converted to lower case.
9 | """
10 | return a.lower() == b.lower()
11 |
12 | def get_one_hot_seg(label, class_num):
13 | """
14 | Convert a segmentation label to one-hot.
15 |
16 | :param label: A tensor with a shape of [N, 1, D, H, W] or [N, 1, H, W]
17 | :param class_num: Class number.
18 |
19 | :return: a one-hot tensor with a shape of [N, C, D, H, W] or [N, C, H, W].
20 | """
21 | size = list(label.size())
22 | if(size[1] != 1):
23 | raise ValueError("The channel should be 1, \
24 | rather than {0:} before one-hot encoding".format(size[1]))
25 | label = label.view(-1)
26 | ones = torch.sparse.torch.eye(class_num).to(label.device)
27 | one_hot = ones.index_select(0, label)
28 | size.append(class_num)
29 | one_hot = one_hot.view(*size)
30 | one_hot = torch.transpose(one_hot, 1, -1)
31 | one_hot = torch.squeeze(one_hot, -1)
32 | return one_hot
--------------------------------------------------------------------------------
/PyMIC/pymic/util/model_operate.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def rename_model_variable(input_file, output_file, input_var_list, output_var_list):
5 | assert(len(input_var_list) == len(output_var_list))
6 | checkpoint = torch.load(input_file)
7 | state_dict = checkpoint['model_state_dict']
8 | for i in range(len(input_var_list)):
9 | input_var = input_var_list[i]
10 | output_var = output_var_list[i]
11 | state_dict[output_var] = state_dict[input_var]
12 | state_dict.pop(input_var)
13 | checkpoint['model_state_dict'] = state_dict
14 | torch.save(checkpoint, output_file)
15 |
16 |
17 | def get_average_model(checkpoint_name1, checkpoint_name2, checkpoint_name3, save_name):
18 | checkpoint1 = torch.load(checkpoint_name1)
19 | state_dict1 = checkpoint1['model_state_dict']
20 |
21 | checkpoint2 = torch.load(checkpoint_name2)
22 | state_dict2 = checkpoint2['model_state_dict']
23 |
24 | checkpoint3 = torch.load(checkpoint_name3)
25 | state_dict3 = checkpoint3['model_state_dict']
26 |
27 | state_dict = {}
28 | for item in state_dict1:
29 | print(item)
30 | state_dict[item] = (state_dict1[item] + state_dict2[item] + state_dict3[item])/3
31 |
32 | save_dict = {'model_state_dict': state_dict}
33 | torch.save(save_dict, save_name)
34 |
35 | if __name__ == "__main__":
36 | input_file = '/home/guotai/disk2t/projects/dlls/training_fetal_brain/exp1/model/unet2dres_bn1_20000.pt'
37 | output_file = '/home/guotai/disk2t/projects/dlls/training_fetal_brain/exp1/model/unet2dres_bn1_20000_rename.pt'
38 | input_var_list = ['conv.weight', 'conv.bias']
39 | output_var_list= ['conv9.weight', 'conv9.bias']
40 | rename_model_variable(input_file, output_file, input_var_list, output_var_list)
--------------------------------------------------------------------------------
/PyMIC/pymic/util/parse_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import absolute_import, print_function
3 |
4 | import configparser
5 | import logging
6 |
7 | def is_int(val_str):
8 | start_digit = 0
9 | if(val_str[0] =='-'):
10 | start_digit = 1
11 | flag = True
12 | for i in range(start_digit, len(val_str)):
13 | if(str(val_str[i]) < '0' or str(val_str[i]) > '9'):
14 | flag = False
15 | break
16 | return flag
17 |
18 | def is_float(val_str):
19 | flag = False
20 | if('.' in val_str and len(val_str.split('.'))==2 and not('./' in val_str)):
21 | if(is_int(val_str.split('.')[0]) and is_int(val_str.split('.')[1])):
22 | flag = True
23 | else:
24 | flag = False
25 | elif('e' in val_str and val_str[0] != 'e' and len(val_str.split('e'))==2):
26 | if(is_int(val_str.split('e')[0]) and is_int(val_str.split('e')[1])):
27 | flag = True
28 | else:
29 | flag = False
30 | else:
31 | flag = False
32 | return flag
33 |
34 | def is_bool(var_str):
35 | if( var_str.lower() =='true' or var_str.lower() == 'false'):
36 | return True
37 | else:
38 | return False
39 |
40 | def parse_bool(var_str):
41 | if(var_str.lower() =='true'):
42 | return True
43 | else:
44 | return False
45 |
46 | def is_list(val_str):
47 | if(val_str[0] == '[' and val_str[-1] == ']'):
48 | return True
49 | else:
50 | return False
51 |
52 | def parse_list(val_str):
53 | sub_str = val_str[1:-1]
54 | splits = sub_str.split(',')
55 | output = []
56 | for item in splits:
57 | item = item.strip()
58 | if(is_int(item)):
59 | output.append(int(item))
60 | elif(is_float(item)):
61 | output.append(float(item))
62 | elif(is_bool(item)):
63 | output.append(parse_bool(item))
64 | elif(item.lower() == 'none'):
65 | output.append(None)
66 | else:
67 | output.append(item)
68 | return output
69 |
70 | def parse_value_from_string(val_str):
71 | # val_str = val_str.encode('ascii','ignore')
72 | if(is_int(val_str)):
73 | val = int(val_str)
74 | elif(is_float(val_str)):
75 | val = float(val_str)
76 | elif(is_list(val_str)):
77 | val = parse_list(val_str)
78 | elif(is_bool(val_str)):
79 | val = parse_bool(val_str)
80 | elif(val_str.lower() == 'none'):
81 | val = None
82 | else:
83 | val = val_str
84 | return val
85 |
86 | def parse_config(filename):
87 | config = configparser.ConfigParser()
88 | config.read(filename)
89 | output = {}
90 | for section in config.sections():
91 | output[section] = {}
92 | for key in config[section]:
93 | val_str = str(config[section][key])
94 | if(len(val_str)>0):
95 | val = parse_value_from_string(val_str)
96 | output[section][key] = val
97 | else:
98 | val = None
99 | print(section, key, val)
100 | return output
101 |
102 | def synchronize_config(config):
103 | data_cfg = config['dataset']
104 | net_cfg = config['network']
105 | # data_cfg["modal_num"] = net_cfg["in_chns"]
106 | data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"]
107 | if "PartialLabelToProbability" in data_cfg['train_transform']:
108 | data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"]
109 | config['dataset'] = data_cfg
110 | config['network'] = net_cfg
111 | return config
112 |
113 | def logging_config(config):
114 | for section in config:
115 | for key in config[section]:
116 | value = config[section][key]
117 | logging.info("{0:} {1:} = {2:}".format(section, key, value))
118 |
119 | if __name__ == "__main__":
120 | print(is_int('555'))
121 | print(is_float('555.10'))
122 | a='[1 ,2 ,3 ]'
123 | print(a)
124 | print(parse_list(a))
125 |
126 |
127 |
--------------------------------------------------------------------------------
/PyMIC/pymic/util/post_process.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import absolute_import, print_function
3 |
4 | import os
5 | import numpy as np
6 | import SimpleITK as sitk
7 | from pymic.util.image_process import get_largest_k_components
8 |
9 | class PostProcess(object):
10 | """
11 | The abastract class for post processing.
12 | """
13 | def __init__(self, params):
14 | self.params = params
15 |
16 | def __call__(self, seg):
17 | return seg
18 |
19 | class PostKeepLargestComponent(PostProcess):
20 | """
21 | Post process by keeping the largest component.
22 |
23 | The arguments should be written in the `params` dictionary, and it has the
24 | following fields:
25 |
26 | :param `KeepLargestComponent_mode`: (int)
27 | `1` means keep the largest component of the union of foreground classes.
28 | `2` means keep the largest component for each foreground class.
29 | """
30 | def __init__(self, params):
31 | super(PostKeepLargestComponent, self).__init__(params)
32 | self.mode = params.get("KeepLargestComponent_mode".lower(), 1)
33 |
34 | def __call__(self, seg):
35 | if(self.mode == 1):
36 | mask = np.asarray(seg > 0, np.uint8)
37 | mask = get_largest_k_components(mask)
38 | seg = seg * mask
39 | elif(self.mode == 2):
40 | class_num = seg.max()
41 | output = np.zeros_like(seg)
42 | for c in range(1, class_num + 1):
43 | seg_c = np.asarray(seg == c, np.uint8)
44 | seg_c = get_largest_k_components(seg_c)
45 | output = output + seg_c * c
46 | return seg
47 |
48 | PostProcessDict = {
49 | 'KeepLargestComponent': PostKeepLargestComponent}
--------------------------------------------------------------------------------
/PyMIC/pymic/util/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import SimpleITK as sitk
4 | from pymic.io.image_read_write import load_image_as_nd_array
5 | from pymic.transform.trans_dict import TransformDict
6 | from pymic.util.parse_config import parse_config
7 |
8 | def get_transform_list(trans_config_file):
9 | """
10 | Create a list of transforms given a configuration file.
11 | """
12 | config = parse_config(trans_config_file)
13 | transform_list = []
14 |
15 | transform_param = config['dataset']
16 | transform_param['task'] = 'segmentation'
17 | transform_names = config['dataset']['transform']
18 | for name in transform_names:
19 | print(name)
20 | if(name not in TransformDict):
21 | raise(ValueError("Undefined transform {0:}".format(name)))
22 | one_transform = TransformDict[name](transform_param)
23 | transform_list.append(one_transform)
24 | return transform_list
25 |
26 | def preprocess_with_transform(transforms, img_in_name, img_out_name,
27 | lab_in_name = None, lab_out_name = None):
28 | """
29 | Using a list of data transforms for preprocessing,
30 | such as image normalization, cropping, etc.
31 | TODO: support multip-modality preprocessing.
32 |
33 | :param transforms: (list) A list of transform objects.
34 | :param img_in_name: (str) Input file name.
35 | :param img_out_name: (str) Output file name.
36 | :param lab_in_name: (optional, str) If None, load the image's
37 | corresponding label for preprocessing as well.
38 | :param lab_out_name: (optional, str) The output label name.
39 | """
40 | image_dict = load_image_as_nd_array(img_in_name)
41 | sample = {'image': np.asarray(image_dict['data_array'], np.float32),
42 | 'origin':image_dict['origin'],
43 | 'spacing': image_dict['spacing'],
44 | 'direction':image_dict['direction']}
45 | if(lab_in_name is not None):
46 | label_dict = load_image_as_nd_array(lab_in_name)
47 | sample['label'] = label_dict['data_array']
48 | for transform in transforms:
49 | sample = transform(sample)
50 |
51 | out_img = sitk.GetImageFromArray(sample['image'][0])
52 | out_img.SetSpacing(sample['spacing'])
53 | out_img.SetOrigin(sample['origin'])
54 | out_img.SetDirection(sample['direction'])
55 | sitk.WriteImage(out_img, img_out_name)
56 | if(lab_in_name is not None and lab_out_name is not None):
57 | out_lab = sitk.GetImageFromArray(sample['label'][0])
58 | out_lab.CopyInformation(out_img)
59 | sitk.WriteImage(out_lab, lab_out_name)
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/PyMIC/pymic/util/ramps.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | """
4 | Functions for ramping hyperparameters up or down.
5 |
6 | Each function takes the current training step or epoch, and the
7 | ramp length (start and end step or epoch), and returns a multiplier between
8 | 0 and 1.
9 | """
10 | from __future__ import print_function, division
11 | import numpy as np
12 |
13 | def get_rampup_ratio(i, start, end, mode = "linear"):
14 | """
15 | Obtain the rampup ratio.
16 |
17 | :param i: (int) The current iteration.
18 | :param start: (int) The start iteration.
19 | :param end: (int) The end itertation.
20 | :param mode: (str) Valid values are {`linear`, `sigmoid`, `cosine`}.
21 | """
22 | i = np.clip(i, start, end)
23 | if(mode == "linear"):
24 | rampup = (i - start) / (end - start)
25 | elif(mode == "sigmoid"):
26 | phase = 1.0 - (i - start) / (end - start)
27 | rampup = float(np.exp(-5.0 * phase * phase))
28 | elif(mode == "cosine"):
29 | phase = 1.0 - (i - start) / (end - start)
30 | rampup = float(.5 * (np.cos(np.pi * phase) + 1))
31 | else:
32 | raise ValueError("Undefined rampup mode {0:}".format(mode))
33 | return rampup
34 |
35 |
36 | def get_rampdown_ratio(i, start, end, mode = "linear"):
37 | """
38 | Obtain the rampdown ratio.
39 |
40 | :param i: (int) The current iteration.
41 | :param start: (int) The start iteration.
42 | :param end: (int) The end itertation.
43 | :param mode: (str) Valid values are {`linear`, `sigmoid`, `cosine`}.
44 | """
45 | i = np.clip(i, start, end)
46 | if(mode == "linear"):
47 | rampdown = 1.0 - (i - start) / (end - start)
48 | elif(mode == "sigmoid"):
49 | phase = (i - start) / (end - start)
50 | rampdown = float(np.exp(-5.0 * phase * phase))
51 | elif(mode == "cosine"):
52 | phase = (i - start) / (end - start)
53 | rampdown = float(.5 * (np.cos(np.pi * phase) + 1))
54 | else:
55 | raise ValueError("Undefined rampup mode {0:}".format(mode))
56 | return rampdown
57 |
58 |
--------------------------------------------------------------------------------
/PyMIC/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["flit_core >=3.2,<4"]
3 | build-backend = "flit_core.buildapi"
4 |
5 | [project]
6 | name = "PyMIC"
7 | authors = [{name = "Graziella", email = "graziella@lumache"}]
8 | dynamic = ["version", "description"]
9 |
--------------------------------------------------------------------------------
/PyMIC/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.1.2
2 | numpy>=1.17.4
3 | pandas>=0.25.3
4 | scikit-image>=0.16.2
5 | scikit-learn>=0.22
6 | scipy>=1.3.3
7 | SimpleITK>=2.0.0
8 | tensorboard>=2.1.0
9 | tensorboardX>=1.9
10 | torch>=1.7.1
11 | torchvision>=0.8.2
12 |
--------------------------------------------------------------------------------
/PyMIC/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import setuptools
3 |
4 | # Get the summary
5 | description = 'An open-source deep learning platform' + \
6 | ' for annotation-efficient medical image computing'
7 |
8 | # Get the long description
9 | with open('README.md', encoding='utf-8') as f:
10 | long_description = f.read()
11 |
12 | setuptools.setup(
13 | name = 'PYMIC',
14 | version = "0.3.0",
15 | author ='PyMIC Consortium',
16 | author_email = 'wguotai@gmail.com',
17 | description = description,
18 | long_description = long_description,
19 | long_description_content_type = 'text/markdown',
20 | url = 'https://github.com/HiLab-git/PyMIC',
21 | license = 'Apache 2.0',
22 | packages = setuptools.find_packages(),
23 | classifiers=[
24 | 'License :: OSI Approved :: Apache Software License',
25 | 'Programming Language :: Python',
26 | 'Programming Language :: Python :: 2',
27 | 'Programming Language :: Python :: 3',
28 | ],
29 | python_requires = '>=3.6',
30 | entry_points = {
31 | 'console_scripts': [
32 | 'pymic_run = pymic.net_run.net_run:main',
33 | 'pymic_ssl = pymic.net_run_ssl.ssl_main:main',
34 | 'pymic_wsl = pymic.net_run_wsl.wsl_main:main',
35 | 'pymic_nll = pymic.net_run_nll.nll_main:main',
36 | 'pymic_eval_cls = pymic.util.evaluation_cls:main',
37 | 'pymic_eval_seg = pymic.util.evaluation_seg:main'
38 | ],
39 | },
40 | )
41 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
13 |
14 |
15 | # FPL+: Filtered Pseudo Label-based Unsupervised Cross-Modality Adaptation for 3D Medical Image Segmentation
16 | by [Jianghao Wu](https://jianghaowu.github.io/), et.al.
17 |
18 | ## Introduction
19 |
20 | This repository is for our IEEE TMI paper **FPL+: Filtered Pseudo Label-based Unsupervised Cross-Modality Adaptation for 3D Medical Image Segmentation**.
21 |
22 |
23 | 
24 |
25 | ## Data Preparation
26 |
27 | ### Dataset
28 | [ Vestibular Schwannoma Segmentation Dataset](https://www.nature.com/articles/s41597-021-01064-w) | [BraTS 2020](https://www.med.upenn.edu/cbica/brats2020/data.html) | [MMWHS](http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/)
29 |
30 | For VS dataset, preprocess original data according to `./data/preprocess_vs.py`.
31 |
32 | ### Cross domian data augmentation
33 | Training [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), and convert source domain data into source domian-like set and target domian-like set, refer the folder `./dataset`.
34 |
35 | ### File Organization
36 | Using `./write_csv.py` to write your data into a `csv` file
37 |
38 | For vs data, ceT1 as the source domain, hrT2 as the target domain, the`csv `file can be seen in `./config_dual/data_vs`:
39 | ```
40 | ├──config_dual/data_vs
41 | ├── [train_ceT1_like.csv]
42 | ├──image,label
43 | ├──./dataset/ceT1/img/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1_seg.nii.gz
44 | ├──./dataset/fake_data/ceT1-hrT2-ceT1_cc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz
45 | ├──./dataset/fake_data/ceT1-hrT2-ceT1_ac/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz
46 | ...
47 | ├── [train_hrT2_like.csv]
48 | ├──image,label
49 | ├──./dataset/fake_data/ceT1-hrT2_cyc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz
50 | ├──./dataset/fake_data/ceT1-hrT2_auxcyc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz
51 | ...
52 | ```
53 |
54 | ## Training and Testing
55 |
56 | ### Train pseudo labels generator and get pseudo label
57 | Write your training config file in `config_dual/vs_t1s_g.cfg`
58 |
59 | ```
60 | export CUDA_VISIBLE_DEVICES=0
61 | export PYTHONPATH=$PYTHONPATH:./PyMIC
62 | ## train pseudo label generator
63 | python ./PyMIC/pymic/net_run_dsbn/net_run.py train ./config_dual/vs_t1s_g.cfg
64 | ## get pseudo label
65 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/vs_t1s_g.cfg
66 | ## get the pseudo label of fake source image
67 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/vs_t1s_g_fake.cfg
68 | ## get image-level weights
69 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/vs_t1s_weights.cfg
70 | ```
71 | Weights are saved on `[testing][fpl_uncertainty_sorted]` and `[testing][fpl_uncertainty_weight]`, run:
72 | ```
73 | python data/get_pixel_weight.py
74 | python data/get image_weight.py
75 | ```
76 | ### Train final segmentor
77 | ```
78 | export CUDA_VISIBLE_DEVICES=0
79 | export PYTHONPATH=$PYTHONPATH:./PyMIC
80 | python ./PyMIC/pymic/net_run_dsbn/net_run.py train ./config_dual/vs_t1s_S.cfg
81 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/vs_t1s_S.cfg
82 | ```
83 |
84 |
85 |
87 |
88 | ## Citation
89 | If you find this project useful for your research, please consider citing:
90 |
91 | ```bibtex
92 | @article{wu2024fpl+,
93 | title={FPL+: Filtered Pseudo Label-based Unsupervised Cross-Modality Adaptation for 3D Medical Image Segmentation},
94 | author={Wu, Jianghao and Guo, Dong and Wang, Guotai and Yue, Qiang and Yu, Huijun and Li, Kang and Zhang, Shaoting},
95 | journal={IEEE Transactions on Medical Imaging},
96 | year={2024},
97 | publisher={IEEE}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/__init__.py
--------------------------------------------------------------------------------
/config_dual/data_vs/test_hrT2.csv:
--------------------------------------------------------------------------------
1 | image,label
2 | ./dataset/hrT2_test/vs_gk_9_t2.nii.gz,./dataset/hrT2_test/vs_gk_9_t2_seg.nii.gz
--------------------------------------------------------------------------------
/config_dual/data_vs/train_ceT1.csv:
--------------------------------------------------------------------------------
1 | image,label
2 | ./dataset/ceT1/img/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1_seg.nii.gz
--------------------------------------------------------------------------------
/config_dual/data_vs/train_ceT1_like.csv:
--------------------------------------------------------------------------------
1 | image,label
2 | ./dataset/ceT1/img/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1_seg.nii.gz
3 | ./dataset/fake_data/ceT1-hrT2-ceT1_cc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz
4 | ./dataset/fake_data/ceT1-hrT2-ceT1_ac/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz
--------------------------------------------------------------------------------
/config_dual/data_vs/train_hrT2-ceT1_cyc.csv:
--------------------------------------------------------------------------------
1 | image,label
2 | ./dataset/fake_data/hrT2-ceT1_train_cyc/vs_gk_98_t2.nii.gz,./dataset/hrT2_train/lab/vs_gk_98_t2_seg.nii.gz
--------------------------------------------------------------------------------
/config_dual/data_vs/train_hrT2.csv:
--------------------------------------------------------------------------------
1 | image,label
2 | ./dataset/hrT2_train/img/vs_gk_98_t2.nii.gz,./dataset/hrT2_train/lab/vs_gk_98_t2_seg.nii.gz
--------------------------------------------------------------------------------
/config_dual/data_vs/train_hrT2_like.csv:
--------------------------------------------------------------------------------
1 | image,label
2 | ./dataset/fake_data/ceT1-hrT2_cyc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz
3 | ./dataset/fake_data/ceT1-hrT2_auxcyc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz
--------------------------------------------------------------------------------
/config_dual/data_vs/train_hrT2_pair.csv:
--------------------------------------------------------------------------------
1 | image,label
2 | ./dataset/hrT2_test/vs_gk_98_t2.nii.gz,./dataset/hrT2_test/vs_gk_98_t2_seg.nii.gz
--------------------------------------------------------------------------------
/config_dual/data_vs/valid_hrT2.csv:
--------------------------------------------------------------------------------
1 | image,label
2 | ./dataset/hrT2_valid/vs_gk_95_t2.nii.gz,./dataset/hrT2_valid/vs_gk_95_t2_seg.nii.gz
--------------------------------------------------------------------------------
/config_dual/data_vs/vs_t1s_S.cfg:
--------------------------------------------------------------------------------
1 | [dataset]
2 | # tensor type (float or double)
3 | tensor_type = float
4 | dsbn = True
5 | task_type = seg
6 | root_dir = /
7 | 1_train_csv = config_dual/data_vs/train_ceT1.csv
8 | 1_valid_csv = config_dual/data_vs/valid_ceT1.csv
9 | 2_train_csv = config_dual/data_vs/train_vs_t1s_wi+wp.csv
10 | 2_valid_csv = config_dual/data_vs/valid_hrT2.csv
11 |
12 | test_csv = config_dual/data_vs/test_hrT2.csv
13 |
14 | train_batch_size = 4
15 |
16 | load_pixelwise_weight = False
17 | # modality number
18 | modal_num = 1
19 |
20 | # data transforms
21 | train_transform = [NormalizeWithMeanStd,Pad,RandomCrop, RandomFlip, LabelToProbability]
22 | valid_transform = [NormalizeWithMeanStd,Pad,LabelToProbability]
23 | test_transform = [NormalizeWithMeanStd,Pad]
24 |
25 | NormalizeWithMeanStd_channels = [0]
26 | NormalizeWithMeanStd_mean = None
27 | NormalizeWithMeanStd_std = None
28 | NormalizeWithMeanStd_mask = False
29 | NormalizeWithMeanStd_random_fill = False
30 | NormalizeWithMeanStd_inverse = False
31 |
32 |
33 | Pad_output_size = [28, 128, 128]
34 | Pad_ceil_mode = False
35 | Pad_inverse = True
36 |
37 | RandomCrop_output_size = [28, 128, 128]
38 | RandomCrop_foreground_focus = True
39 | RandomCrop_foreground_ratio = 0.5
40 | Randomcrop_mask_label = [1, 2]
41 | RandomCrop_inverse = False
42 |
43 | RandomFlip_flip_depth = False
44 | RandomFlip_flip_height = True
45 | RandomFlip_flip_width = True
46 | RandomFlip_inverse = False
47 |
48 | LabelToProbability_class_num = 2
49 | LabelToProbability_inverse = False
50 |
51 | [network]
52 | # this section gives parameters for network
53 | # the keys may be different for different networks
54 |
55 | net_type = UNet2D5_dsbn
56 | num_domains = 2
57 |
58 | # number of class, required for segmentation task
59 | class_num = 2
60 | in_chns = 1
61 | feature_chns = [32, 64, 128, 256, 512]
62 | conv_dims = [2, 2, 3, 3, 3]
63 | dropout = [0.0, 0.0, 0.3, 0.4, 0.5]
64 | bilinear = False
65 | deep_supervise = False
66 | aes = False
67 | [training]
68 | aes = False
69 | aes_para = None
70 | train_fpl_uda = True
71 | dis = False
72 | dis_para = None
73 | val_t1 = False
74 | val_t2 = True
75 | dual = False
76 | # list of gpus
77 | gpus = [0]
78 | loss_type = DiceLoss
79 | DiceLoss_enable_pixel_weight = False
80 | DiceLoss_enable_class_weight = False
81 | loss_class_weight = [1, 1]
82 | # for optimizers
83 | optimizer = Adam
84 | learning_rate = 1e-4
85 | momentum = 0.9
86 | weight_decay = 1e-5
87 |
88 | # for lr schedular (MultiStepLR)
89 | lr_scheduler = MultiStepLR
90 | lr_gamma = 0.5
91 | lr_milestones = [10000,20000,30000,40000]
92 | ckpt_save_dir = ./model_dual/vs_t1s_g
93 | ckpt_save_prefix = dsbn
94 |
95 | # start iter
96 | iter_start = 40000
97 | iter_max = 60000
98 | iter_valid = 500
99 | iter_save = 60000
100 | [testing]
101 | # list of gpus
102 | fpl = False
103 | gpus = [0]
104 | domian_label = 1
105 | ae = None
106 | # checkpoint mode can be [0-latest, 1-best, 2-specified]
107 | ckpt_mode = 1
108 | output_dir = results_dual/
109 | evaluation_mode = True
110 | test_time_dropout = False
111 | # post_process = KeepLargestComponent
112 | # use test time augmentation
113 | tta_mode = 1
114 |
115 | sliding_window_enable = True
116 | sliding_window_size = [28, 128, 128]
117 | sliding_window_stride = [28, 128, 128]
118 | [evaluation]
119 | metric_1 = dice
120 | metric_2 = assd
121 | label_list = [1]
122 | organ_name = tumor
123 |
124 |
125 | ground_truth_folder_root = ./dataset/ceT1_train/lab
126 | segmentation_folder_root = results_dual/vs_t1s_g
127 | test_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv
128 | valid_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv
--------------------------------------------------------------------------------
/config_dual/data_vs/vs_t1s_g.cfg:
--------------------------------------------------------------------------------
1 | [dataset]
2 | # tensor type (float or double)
3 | tensor_type = float
4 | dsbn = True
5 | task_type = seg
6 | root_dir = /
7 | 1_train_csv = config_dual/data_vs/train_ceT1_like.csv
8 | 1_valid_csv = config_dual/data_vs/valid_ceT1.csv
9 | 2_train_csv = config_dual/data_vs/train_hrT2_like.csv
10 | 2_valid_csv = config_dual/data_vs/valid_hrT2.csv
11 |
12 | test_csv = config_dual/data_vs/train_hrT2.csv
13 |
14 | train_batch_size = 4
15 |
16 | load_pixelwise_weight = False
17 | # modality number
18 | modal_num = 1
19 |
20 | # data transforms
21 | train_transform = [NormalizeWithMeanStd,Pad,RandomCrop, RandomFlip, LabelToProbability]
22 | valid_transform = [NormalizeWithMeanStd,Pad,LabelToProbability]
23 | test_transform = [NormalizeWithMeanStd,Pad]
24 |
25 | NormalizeWithMeanStd_channels = [0]
26 | NormalizeWithMeanStd_mean = None
27 | NormalizeWithMeanStd_std = None
28 | NormalizeWithMeanStd_mask = False
29 | NormalizeWithMeanStd_random_fill = False
30 | NormalizeWithMeanStd_inverse = False
31 |
32 |
33 | Pad_output_size = [28, 128, 128]
34 | Pad_ceil_mode = False
35 | Pad_inverse = True
36 |
37 | RandomCrop_output_size = [28, 128, 128]
38 | RandomCrop_foreground_focus = True
39 | RandomCrop_foreground_ratio = 0.5
40 | Randomcrop_mask_label = [1, 2]
41 | RandomCrop_inverse = False
42 |
43 | RandomFlip_flip_depth = False
44 | RandomFlip_flip_height = True
45 | RandomFlip_flip_width = True
46 | RandomFlip_inverse = False
47 |
48 | LabelToProbability_class_num = 2
49 | LabelToProbability_inverse = False
50 |
51 | [network]
52 | # this section gives parameters for network
53 | # the keys may be different for different networks
54 |
55 | net_type = UNet2D5_dsbn
56 | num_domains = 2
57 |
58 | # number of class, required for segmentation task
59 | class_num = 2
60 | in_chns = 1
61 | feature_chns = [32, 64, 128, 256, 512]
62 | conv_dims = [2, 2, 3, 3, 3]
63 | dropout = [0.0, 0.0, 0.3, 0.4, 0.5]
64 | bilinear = False
65 | deep_supervise = False
66 | aes = False
67 | [training]
68 | aes = False
69 | aes_para = None
70 | train_fpl_uda = True
71 | dis = False
72 | dis_para = None
73 | val_t1 = False
74 | val_t2 = True
75 | dual = False
76 | # list of gpus
77 | gpus = [0]
78 | loss_type = DiceLoss
79 | DiceLoss_enable_pixel_weight = False
80 | DiceLoss_enable_class_weight = False
81 | loss_class_weight = [1, 1]
82 | # for optimizers
83 | optimizer = Adam
84 | learning_rate = 1e-4
85 | momentum = 0.9
86 | weight_decay = 1e-5
87 |
88 | # for lr schedular (MultiStepLR)
89 | lr_scheduler = MultiStepLR
90 | lr_gamma = 0.5
91 | lr_milestones = [10000,20000,30000,40000]
92 | ckpt_save_dir = ./model_dual/vs_t1s_g
93 | ckpt_save_prefix = dsbn
94 |
95 | # start iter
96 | iter_start = 0
97 | iter_max = 40000
98 | iter_valid = 500
99 | iter_save = 40000
100 | [testing]
101 | # list of gpus
102 | fpl = False
103 | gpus = [0]
104 | domian_label = 1
105 | ae = None
106 | # checkpoint mode can be [0-latest, 1-best, 2-specified]
107 | ckpt_mode = 1
108 | output_dir = results_dual/
109 | evaluation_mode = True
110 | test_time_dropout = False
111 | # post_process = KeepLargestComponent
112 | # use test time augmentation
113 | tta_mode = 1
114 |
115 | sliding_window_enable = True
116 | sliding_window_size = [28, 128, 128]
117 | sliding_window_stride = [28, 128, 128]
118 | [evaluation]
119 | metric_1 = dice
120 | metric_2 = assd
121 | label_list = [1]
122 | organ_name = tumor
123 |
124 |
125 | ground_truth_folder_root = ./dataset/ceT1_train/lab
126 | segmentation_folder_root = results_dual/vs_t1s_g
127 | test_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv
128 | valid_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv
129 |
130 |
--------------------------------------------------------------------------------
/config_dual/data_vs/vs_t1s_g_fake.cfg:
--------------------------------------------------------------------------------
1 | [dataset]
2 | # tensor type (float or double)
3 | tensor_type = float
4 | dsbn = True
5 | task_type = seg
6 | root_dir = /
7 | 1_train_csv = config_dual/data_vs/train_ceT1_like.csv
8 | 1_valid_csv = config_dual/data_vs/valid_ceT1.csv
9 | 2_train_csv = config_dual/data_vs/train_hrT2_like.csv
10 | 2_valid_csv = config_dual/data_vs/valid_hrT2.csv
11 |
12 | test_csv = config_dual/data_vs/train_hrT2-ceT1_cyc.csv
13 |
14 | train_batch_size = 4
15 |
16 | load_pixelwise_weight = False
17 | # modality number
18 | modal_num = 1
19 |
20 | # data transforms
21 | train_transform = [NormalizeWithMeanStd,Pad,RandomCrop, RandomFlip, LabelToProbability]
22 | valid_transform = [NormalizeWithMeanStd,Pad,LabelToProbability]
23 | test_transform = [NormalizeWithMeanStd,Pad]
24 |
25 | NormalizeWithMeanStd_channels = [0]
26 | NormalizeWithMeanStd_mean = None
27 | NormalizeWithMeanStd_std = None
28 | NormalizeWithMeanStd_mask = False
29 | NormalizeWithMeanStd_random_fill = False
30 | NormalizeWithMeanStd_inverse = False
31 |
32 |
33 | Pad_output_size = [28, 128, 128]
34 | Pad_ceil_mode = False
35 | Pad_inverse = True
36 |
37 | RandomCrop_output_size = [28, 128, 128]
38 | RandomCrop_foreground_focus = True
39 | RandomCrop_foreground_ratio = 0.5
40 | Randomcrop_mask_label = [1, 2]
41 | RandomCrop_inverse = False
42 |
43 | RandomFlip_flip_depth = False
44 | RandomFlip_flip_height = True
45 | RandomFlip_flip_width = True
46 | RandomFlip_inverse = False
47 |
48 | LabelToProbability_class_num = 2
49 | LabelToProbability_inverse = False
50 |
51 | [network]
52 | # this section gives parameters for network
53 | # the keys may be different for different networks
54 |
55 | net_type = UNet2D5_dsbn
56 | num_domains = 2
57 |
58 | # number of class, required for segmentation task
59 | class_num = 2
60 | in_chns = 1
61 | feature_chns = [32, 64, 128, 256, 512]
62 | conv_dims = [2, 2, 3, 3, 3]
63 | dropout = [0.0, 0.0, 0.3, 0.4, 0.5]
64 | bilinear = False
65 | deep_supervise = False
66 | aes = False
67 | [training]
68 | aes = False
69 | aes_para = None
70 | train_fpl_uda = True
71 | dis = False
72 | dis_para = None
73 | val_t1 = False
74 | val_t2 = True
75 | dual = False
76 | # list of gpus
77 | gpus = [0]
78 | loss_type = DiceLoss
79 | DiceLoss_enable_pixel_weight = False
80 | DiceLoss_enable_class_weight = False
81 | loss_class_weight = [1, 1]
82 | # for optimizers
83 | optimizer = Adam
84 | learning_rate = 1e-4
85 | momentum = 0.9
86 | weight_decay = 1e-5
87 |
88 | # for lr schedular (MultiStepLR)
89 | lr_scheduler = MultiStepLR
90 | lr_gamma = 0.5
91 | lr_milestones = [10000,20000,30000,40000]
92 | ckpt_save_dir = ./model_dual/vs_t1s_g
93 | ckpt_save_prefix = dsbn
94 |
95 | # start iter
96 | iter_start = 0
97 | iter_max = 40000
98 | iter_valid = 500
99 | iter_save = 40000
100 | [testing]
101 | # list of gpus
102 | fpl = False
103 | gpus = [0]
104 | domian_label = 0
105 | ae = None
106 | # checkpoint mode can be [0-latest, 1-best, 2-specified]
107 | ckpt_mode = 1
108 | output_dir = results_dual/
109 | evaluation_mode = True
110 | test_time_dropout = False
111 | # post_process = KeepLargestComponent
112 | # use test time augmentation
113 | tta_mode = 1
114 |
115 | sliding_window_enable = True
116 | sliding_window_size = [28, 128, 128]
117 | sliding_window_stride = [28, 128, 128]
118 | [evaluation]
119 | metric_1 = dice
120 | metric_2 = assd
121 | label_list = [1]
122 | organ_name = tumor
123 |
124 |
125 | ground_truth_folder_root = ./dataset/ceT1_train/lab
126 | segmentation_folder_root = results_dual/vs_t1s_g
127 | test_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv
128 | valid_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv
129 |
130 |
--------------------------------------------------------------------------------
/config_dual/data_vs/vs_t1s_weights.cfg:
--------------------------------------------------------------------------------
1 | [dataset]
2 | # tensor type (float or double)
3 | tensor_type = float
4 | dsbn = True
5 | task_type = seg
6 | root_dir = /
7 | 1_train_csv = config_dual/data_vs/train_ceT1_like.csv
8 | 1_valid_csv = config_dual/data_vs/valid_ceT1.csv
9 | 2_train_csv = config_dual/data_vs/train_hrT2_like.csv
10 | 2_valid_csv = config_dual/data_vs/valid_hrT2.csv
11 |
12 | test_csv = config_dual/data_vs/train_hrT2.csv
13 |
14 | train_batch_size = 4
15 |
16 | load_pixelwise_weight = False
17 | # modality number
18 | modal_num = 1
19 |
20 | # data transforms
21 | train_transform = [NormalizeWithMeanStd,Pad,RandomCrop, RandomFlip, LabelToProbability]
22 | valid_transform = [NormalizeWithMeanStd,Pad,LabelToProbability]
23 | test_transform = [NormalizeWithMeanStd,Pad]
24 |
25 | NormalizeWithMeanStd_channels = [0]
26 | NormalizeWithMeanStd_mean = None
27 | NormalizeWithMeanStd_std = None
28 | NormalizeWithMeanStd_mask = False
29 | NormalizeWithMeanStd_random_fill = False
30 | NormalizeWithMeanStd_inverse = False
31 |
32 |
33 | Pad_output_size = [28, 128, 128]
34 | Pad_ceil_mode = False
35 | Pad_inverse = True
36 |
37 | RandomCrop_output_size = [28, 128, 128]
38 | RandomCrop_foreground_focus = True
39 | RandomCrop_foreground_ratio = 0.5
40 | Randomcrop_mask_label = [1, 2]
41 | RandomCrop_inverse = False
42 |
43 | RandomFlip_flip_depth = False
44 | RandomFlip_flip_height = True
45 | RandomFlip_flip_width = True
46 | RandomFlip_inverse = False
47 |
48 | LabelToProbability_class_num = 2
49 | LabelToProbability_inverse = False
50 |
51 | [network]
52 | # this section gives parameters for network
53 | # the keys may be different for different networks
54 |
55 | net_type = UNet2D5_dsbn
56 | num_domains = 2
57 |
58 | # number of class, required for segmentation task
59 | class_num = 2
60 | in_chns = 1
61 | feature_chns = [32, 64, 128, 256, 512]
62 | conv_dims = [2, 2, 3, 3, 3]
63 | dropout = [0.0, 0.0, 0.3, 0.4, 0.5]
64 | bilinear = False
65 | deep_supervise = False
66 | aes = False
67 | [training]
68 | aes = False
69 | aes_para = None
70 | train_fpl_uda = True
71 | dis = False
72 | dis_para = None
73 | val_t1 = False
74 | val_t2 = True
75 | dual = False
76 | # list of gpus
77 | gpus = [0]
78 | loss_type = DiceLoss
79 | DiceLoss_enable_pixel_weight = False
80 | DiceLoss_enable_class_weight = False
81 | loss_class_weight = [1, 1]
82 | # for optimizers
83 | optimizer = Adam
84 | learning_rate = 1e-4
85 | momentum = 0.9
86 | weight_decay = 1e-5
87 |
88 | # for lr schedular (MultiStepLR)
89 | lr_scheduler = MultiStepLR
90 | lr_gamma = 0.5
91 | lr_milestones = [10000,20000,30000,40000]
92 | ckpt_save_dir = ./model_dual/vs_t1s_g
93 | ckpt_save_prefix = dsbn
94 |
95 | # start iter
96 | iter_start = 0
97 | iter_max = 40000
98 | iter_valid = 500
99 | iter_save = 40000
100 | [testing]
101 | # list of gpus
102 | fpl = True
103 | fpl_uncertainty_sorted = dataset/weight/vs_t2s.npy
104 | fpl_uncertainty_weight = dataset/weight/vs_t2s-weight.npy
105 | gpus = [0]
106 | domian_label = 1
107 | ae = None
108 | # checkpoint mode can be [0-latest, 1-best, 2-specified]
109 | ckpt_mode = 1
110 | output_dir = results_dual/
111 | evaluation_mode = True
112 | test_time_dropout = True
113 | # post_process = KeepLargestComponent
114 | # use test time augmentation
115 | tta_mode = 1
116 |
117 | sliding_window_enable = True
118 | sliding_window_size = [28, 128, 128]
119 | sliding_window_stride = [28, 128, 128]
120 | [evaluation]
121 | metric_1 = dice
122 | metric_2 = assd
123 | label_list = [1]
124 | organ_name = tumor
125 |
126 |
127 | ground_truth_folder_root = ./dataset/ceT1_train/lab
128 | segmentation_folder_root = results_dual/vs_t1s_g
129 | test_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv
130 | valid_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv
131 |
132 |
--------------------------------------------------------------------------------
/config_dual/evaluation.cfg:
--------------------------------------------------------------------------------
1 | [evaluation]
2 | # metric = dice
3 | metric = assd
4 | label_list = [1]
5 | organ_name = tumor
6 |
7 | # ground_truth_folder_root = /data2/jianghao/FPL-UDA/vs_data/T1_test/lab/
8 | # segmentation_folder_root = results_dual/m-vs_t2-t1_cyc_i-t1-test
9 | # evaluation_image_pair = config_dual/data_vs/t1_test_pair.csv
10 |
11 | ground_truth_folder_root = /data2/jianghao/FPL-UDA/vs_data/T2_test/lab/
12 | segmentation_folder_root = results_dual/m-vst1s_dsbn_t1+t1-t2-cyc-t1_i-t2-test
13 | evaluation_image_pair = config_dual/data_vs/t2_test_pair.csv
14 |
15 |
--------------------------------------------------------------------------------
/data/get image_weight.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import csv
3 | # all_names_dict = np.load('dataset/weight/vs_t2s.npy',allow_pickle=True)
4 | all_names_dict = np.load('dataset/weight/cyc121_vst1s-gan.npy',allow_pickle=True)
5 | print(len(all_names_dict))
6 | print(all_names_dict)
7 | eva_filenames = []
8 | tra_filenames = []
9 | all_weights = []
10 | for i in range(int(1*len(all_names_dict))):
11 | image_weight = all_names_dict[i][0][0]
12 | if image_weight !=1:
13 | all_weights.append(image_weight)
14 | max = max(all_weights)
15 | min = min(all_weights)
16 | print('max weight value:',max,'; min weight value:',min)
17 | for i in range(int(1*len(all_names_dict))):
18 | sig_dir = all_names_dict[i][1]
19 | img_name_eva = sig_dir.split('/')[-1]#.replace('.nii.gz','_seg.nii.gz')
20 | lab_name_eva = sig_dir.split('/')[-1]
21 | img_name = sig_dir
22 | lab_name = sig_dir.replace('./dataset/hrT2_train/img','./results_dual/vs_t1s_g_i-train_hrT2')
23 | weight_pixel = sig_dir.replace('./dataset/hrT2_train/img','dataset/hrT2_pixel-weight')
24 | image_weight = all_names_dict[i][0][0]
25 | if image_weight>max:
26 | image_weight = max
27 | image_weight = abs((max - image_weight)/(max-min))+0.01
28 | print('image weight:',image_weight)
29 | tra_filenames.append([img_name, lab_name, weight_pixel, image_weight])
30 | tra_output_file = 'config_dual/data_vs/train_vs_t1s_wi+wp.csv'
31 | fields = ['image', 'label' , 'pixel_weight','image_weight']
32 | with open(tra_output_file, mode='w') as csv_file:
33 | csv_writer = csv.writer(csv_file, delimiter=',',
34 | quotechar='"',quoting=csv.QUOTE_MINIMAL)
35 | csv_writer.writerow(fields)
36 | for item in tra_filenames:
37 | csv_writer.writerow(item)
--------------------------------------------------------------------------------
/data/get_pixel_weight.py:
--------------------------------------------------------------------------------
1 | import SimpleITK as sitk
2 | import numpy as np
3 | import os
4 |
5 | pseudo_target_root = './results_dual/vs_t1s_g_i-train_hrT2'
6 | pseudo_fake_source_root = './results_dual/vs_t1s_g_i-train_hrT2-ceT1_cyc'
7 | t2s = os.listdir(pseudo_target_root)
8 | t2_cycs = os.listdir(pseudo_fake_source_root)
9 | t2_names = [item for item in t2s if '.nii.gz' in item]
10 | t2_cyc_names = [item for item in t2_cycs if '.nii.gz' in item]
11 | assert len(t2_names) == len(t2_cyc_names)
12 | for name in t2_names:
13 | t2_full = os.path.join(pseudo_target_root,name)
14 | t2_cyc_full = os.path.join(pseudo_fake_source_root,name)
15 | t2_full = sitk.ReadImage(t2_full)
16 | t2_cyc_full = sitk.ReadImage(t2_cyc_full)
17 | t2_full = sitk.GetArrayFromImage(t2_full)
18 | t2_cyc_full = sitk.GetArrayFromImage(t2_cyc_full)
19 | assert t2_full.shape == t2_cyc_full.shape
20 | # assert t2_cyc_full.max() == t2_full.max()
21 | both_arr = t2_full+t2_cyc_full
22 | both_arr[both_arr > 1] = 1
23 | and_arr = t2_cyc_full*t2_full
24 | sub_arr = both_arr - and_arr
25 | sub_rev = np.ones_like(sub_arr)
26 | sub_rev = sub_rev-sub_arr*0.5
27 | sub_rev = sitk.GetImageFromArray(sub_rev)
28 | sitk.WriteImage(sub_rev,'./dataset/hrT2_pixel-weight/'+name)
29 |
--------------------------------------------------------------------------------
/data/preprocess_bst.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from posixpath import split
4 | import SimpleITK as sitk
5 | import numpy as np
6 | def winadj_mri(array):
7 | v0 = np.percentile(array, 1)
8 | v1 = np.percentile(array, 999)
9 | array[array < v0] = v0
10 | array[array > v1] = v1
11 | v0 = array.min()
12 | v1 = array.max()
13 | array = (array - v0) / (v1 - v0) * 2.0 - 1.0
14 | return array
15 | def crop_depth(img,lab):
16 | D,W,H = img.shape
17 | indices = np.where(lab>0)
18 | d00, d11 = indices[0].min(), indices[0].max()
19 | zero_img = img[max(d00-16,0):min(d11+16,D),:,:]
20 | zero_lab = lab[max(d00-16,0):min(d11+16,D),:,:]
21 | return zero_img,zero_lab
22 |
23 | if __name__ == "__main__":
24 | ref_moda = 'flair'
25 | moda = 'flair'
26 | phase = 'train'
27 | roots = '/your/MICCAI_BraTS2020_TrainingData'
28 | img_path = '/your/bst_data/'+ref_moda+'_'+phase+'/img'
29 | save_img = '/your/bst_data/'+moda+'_'+phase+'/img'
30 | save_lab = '/your/bst_data/'+moda+'_'+phase+'/lab'
31 | if(not os.path.exists(save_lab)):
32 | os.mkdir(save_lab)
33 | names = os.listdir(img_path)
34 |
35 | for i in names:
36 | caase_name = i.split('.nii')[0]
37 | print(i,caase_name)
38 | case_root = os.path.join(roots,caase_name)
39 | img_obj_ = sitk.ReadImage(case_root + '/' + caase_name +'_'+ moda+'.nii.gz')
40 | lab_obj_ = sitk.ReadImage(case_root + '/' + caase_name +'_seg'+'.nii.gz')
41 | lab_obj = sitk.GetArrayFromImage(lab_obj_)
42 | img_obj = sitk.GetArrayFromImage(img_obj_)
43 | lab_obj[lab_obj>0] = 1
44 | img_obj,lab_obj = crop_depth(img_obj,lab_obj)
45 | lab_obj = sitk.GetImageFromArray(lab_obj)
46 | img_obj = sitk.GetImageFromArray(img_obj)
47 | img_save_dir = os.path.join(save_img,caase_name+'.nii.gz')
48 | lab_save_dir = os.path.join(save_lab,caase_name+'.nii.gz')
49 | sitk.WriteImage(img_obj, img_save_dir)
50 | sitk.WriteImage(lab_obj, lab_save_dir)
--------------------------------------------------------------------------------
/data/preprocess_mmwhs.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/data/preprocess_mmwhs.py
--------------------------------------------------------------------------------
/dataset/ceT1_train/img/vs_gk_99_t1.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/ceT1_train/img/vs_gk_99_t1.nii.gz
--------------------------------------------------------------------------------
/dataset/ceT1_train/lab/vs_gk_99_t1.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/ceT1_train/lab/vs_gk_99_t1.nii.gz
--------------------------------------------------------------------------------
/dataset/fake_data/ceT1-hrT2-ceT1_ac/vs_gk_99_t1.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/ceT1-hrT2-ceT1_ac/vs_gk_99_t1.nii.gz
--------------------------------------------------------------------------------
/dataset/fake_data/ceT1-hrT2-ceT1_cc/vs_gk_99_t1.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/ceT1-hrT2-ceT1_cc/vs_gk_99_t1.nii.gz
--------------------------------------------------------------------------------
/dataset/fake_data/ceT1-hrT2_auxcyc/vs_gk_99_t1.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/ceT1-hrT2_auxcyc/vs_gk_99_t1.nii.gz
--------------------------------------------------------------------------------
/dataset/fake_data/ceT1-hrT2_cyc/vs_gk_99_t1.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/ceT1-hrT2_cyc/vs_gk_99_t1.nii.gz
--------------------------------------------------------------------------------
/dataset/fake_data/hrT2-ceT1_train_cyc/vs_gk_98_t2.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/hrT2-ceT1_train_cyc/vs_gk_98_t2.nii.gz
--------------------------------------------------------------------------------
/dataset/hrT2_test/vs_gk_9_t2.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_test/vs_gk_9_t2.nii.gz
--------------------------------------------------------------------------------
/dataset/hrT2_test/vs_gk_9_t2_seg.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_test/vs_gk_9_t2_seg.nii.gz
--------------------------------------------------------------------------------
/dataset/hrT2_train/img/vs_gk_98_t2.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_train/img/vs_gk_98_t2.nii.gz
--------------------------------------------------------------------------------
/dataset/hrT2_train/lab/vs_gk_98_t2.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_train/lab/vs_gk_98_t2.nii.gz
--------------------------------------------------------------------------------
/dataset/hrT2_valid/vs_gk_95_t2.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_valid/vs_gk_95_t2.nii.gz
--------------------------------------------------------------------------------
/dataset/hrT2_valid/vs_gk_95_t2_seg.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_valid/vs_gk_95_t2_seg.nii.gz
--------------------------------------------------------------------------------
/dataset/weight/cyc121_vst1s-gan.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/weight/cyc121_vst1s-gan.npy
--------------------------------------------------------------------------------
/docs/image_info.odt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/docs/image_info.odt
--------------------------------------------------------------------------------
/merge_pixelw.py:
--------------------------------------------------------------------------------
1 | from scipy import ndimage
2 | import SimpleITK as sitk
3 | import numpy as np
4 | import os
5 |
6 | t2_root = '/target/train/data/prediction'
7 | t2_cyc_root = '/target/train/data_cyclegan/prediction'
8 | t2s = os.listdir(t2_root)
9 | t2_cycs = os.listdir(t2_cyc_root)
10 | t2_names = [item for item in t2s if '.nii.gz' in item]
11 | t2_cyc_names = [item for item in t2_cycs if '.nii.gz' in item]
12 | assert len(t2_names) == len(t2_cyc_names)
13 | for name in t2_names:
14 | t2_full = os.path.join(t2_root,name)
15 | t2_cyc_full = os.path.join(t2_cyc_root,name)
16 | t2_full = sitk.ReadImage(t2_full)
17 | t2_cyc_full = sitk.ReadImage(t2_cyc_full)
18 | t2_full = sitk.GetArrayFromImage(t2_full)
19 | t2_cyc_full = sitk.GetArrayFromImage(t2_cyc_full)
20 | assert t2_full.shape == t2_cyc_full.shape
21 | both_arr = t2_full+t2_cyc_full
22 | both_arr[both_arr > 1] = 1
23 | and_arr = t2_cyc_full*t2_full
24 | sub_arr = both_arr - and_arr
25 | print(both_arr.sum(),and_arr.sum(),sub_arr.sum(),sub_arr.max())
26 | sub_rev = np.ones_like(sub_arr)
27 | sub_rev = sub_rev-sub_arr*0.5
28 | sub_rev = sitk.GetImageFromArray(sub_rev)
29 | sitk.WriteImage(sub_rev,'/data2/jianghao/VS/vs_seg2021/script/FPL-UDA/bst_t2s_sub_arr/'+name)
30 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | export PYTHONPATH=$PYTHONPATH:./PyMIC
3 | python ./PyMIC/pymic/net_run_dsbn/net_run.py train ./config_dual/unet2d_dsbn_bst_t2s.cfg
4 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/unet2d_dsbn_bst_t2s.cfg
5 |
6 |
7 |
--------------------------------------------------------------------------------