├── .gitignore ├── .travis.yml ├── README.md ├── docs ├── Makefile ├── conf.py ├── datasets.rst ├── index.rst ├── patterns.rst ├── readers.rst └── samplers.rst ├── examples ├── 01_brain_segmentation_MRI │ ├── architecture.py │ ├── datasets │ │ ├── inference.py │ │ ├── mappings.py │ │ └── training.py │ ├── inference_canvas.py │ ├── readme.md │ ├── segment.py │ ├── train.py │ ├── train_segment.py │ └── utils.py ├── 02_brain_segmentation_NonAdjLoss_MRI │ ├── architecture.py │ ├── datasets │ │ ├── inference.py │ │ ├── mappings.py │ │ └── training.py │ ├── inference_canvas.py │ ├── loss.py │ ├── readme.md │ ├── segment.py │ ├── train.py │ ├── train_segment.py │ └── utils.py └── scripts │ ├── adjacency │ ├── extract_3d_adjacency.py │ └── extract_3d_adjacency_from_seg.py │ ├── preprocessing │ ├── ibsr_v2.py │ ├── mappings.py │ ├── miccai12.py │ └── oasis.py │ └── report │ ├── plot_boxplot_labels_by_metric.py │ ├── plot_boxplot_patients_by_metric.py │ ├── write_label_by_metric.py │ ├── write_model_by_metric.py │ ├── write_patient_by_metric.py │ └── write_top_label_by_metric.py ├── requirements.txt ├── setup.py ├── test ├── evaluation │ ├── __init__.py │ └── test_dice.py └── patterns │ ├── __init__.py │ └── test_patch2d.py └── torchmed ├── __init__.py ├── datasets ├── __init__.py ├── medcombiner.py ├── medfile.py └── medfolder.py ├── patterns ├── __init__.py └── patch.py ├── readers ├── __init__.py ├── cv.py ├── nib.py ├── pil.py ├── reader.py └── simpleitk.py ├── samplers ├── __init__.py ├── mask_sampler.py └── sampler.py └── utils ├── __init__.py ├── augmentation.py ├── file.py ├── logger_plotter.py ├── loss.py ├── metric.py ├── multiproc.py ├── preprocessing.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 3.5 4 | - 3.6 5 | 6 | install: 7 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 8 | - bash miniconda.sh -b -p $HOME/miniconda 9 | - export PATH="$HOME/miniconda/bin:$PATH" 10 | - hash -r 11 | - conda config --set always_yes yes --set changeps1 no 12 | - conda update -q conda 13 | # Useful for debugging any issues with conda 14 | - conda info -a 15 | 16 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION 17 | - source activate test-environment 18 | - conda config --append channels simpleitk 19 | - conda config --prepend channels soumith 20 | 21 | # install python 3.5 for simpleITK 22 | - conda install python=3.5 23 | 24 | - conda env export -n test-environment 25 | - conda install simpleitk=1.0.0 26 | - conda install pytorch torchvision 27 | - pip install codecov 28 | - pip install -r requirements.txt 29 | 30 | # script: 31 | #- coverage run --source=torchmed -m unittest discover -s test/ 32 | 33 | # after_success: 34 | #- codecov --token=6f6b0796-78bb-4e30-a3e8-3ce76af84a1a 35 | 36 | notifications: 37 | slack: pagboat:UVFJS35PcjobkUmztxAygsNG 38 | on_success: change 39 | on_failure: always 40 | 41 | sudo: false 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchMed 2 | 3 | Read and process medical images in PyTorch. 4 | 5 | [![Build Status](https://travis-ci.com/trypag/pytorch-med.svg?token=W7UTQDqNUe21xtLfiqRm&branch=master)](https://travis-ci.com/trypag/pytorch-med) 6 | [![codecov](https://codecov.io/gh/trypag/pytorch-med/branch/master/graph/badge.svg?token=kL3ASEka4B)](https://codecov.io/gh/trypag/pytorch-med) 7 | 8 | --- 9 | 10 | This library is designed as a flexible tool to process various types N dimension images. 11 | Through a set of image **readers** based on famous projects (SimpleITK, NiBabel, OpenCV, Pillow) 12 | you will be able to load your data. Once loaded, specific sub-sampling of the original 13 | data is performed with **patterns** (describing what/how to extract) and **samplers** 14 | (checks where to extract). 15 | 16 | With **readers**, **samplers** and **patterns**, you can compose **datasets** which 17 | a perfectly suited for PyTorch. 18 | 19 | ## Install 20 | 21 | From pip: 22 | 23 | ```bash 24 | pip install torchmed 25 | ``` 26 | 27 | Locally : 28 | 29 | ```bash 30 | python install setup.py 31 | ``` 32 | 33 | ## Usage 34 | 35 | ### Reader 36 | 37 | ```python 38 | >>> import torchmed 39 | 40 | >>> image = torchmed.readers.SitkReader('prepro_im_mni_bc.nii.gz') 41 | >>> label_map = torchmed.readers.SitkReader('prepro_seg_mni.nii.gz') 42 | # gets image data 43 | >>> image_array = image.to_torch() 44 | >>> label_array = label_map.to_torch() 45 | 46 | >>> image_array.size() 47 | torch.Size([182, 218, 182]) 48 | >>> type(image_array) 49 | 50 | >>> label_array[0,0,0] 51 | tensor(0.) 52 | # also available for Numpy 53 | >>> type(image.to_numpy()) 54 | 55 | ``` 56 | 57 | ### Pattern 58 | 59 | Patterns are useful to specify how the data should be extracted from an image. 60 | It is possible to apply several patterns on one or more images. 61 | 62 | ```python 63 | >>> import torchmed 64 | 65 | >>> image = torchmed.readers.SitkReader('prepro_im_mni_bc.nii.gz') 66 | >>> square_patch = torchmed.patterns.SquaredSlidingWindow([182, 7, 182], use_padding=False) 67 | # initialize the pattern with the image properties 68 | >>> square_patch.prepare(image_arr) 69 | 70 | # can_apply checks if a pattern can be applied at a given position 71 | >>> square_patch.can_apply(image_arr, [0,0,0]) 72 | False 73 | >>> square_patch.can_apply(image_arr, [91,4,91]) 74 | True 75 | >>> square_patch.can_apply(image_arr, [91,3,91]) 76 | True 77 | >>> square_patch.can_apply(image_arr, [91,2,91]) 78 | False 79 | >>> square_patch.can_apply(image_arr, [91,154,91]) 80 | True 81 | 82 | # to extract a patch at a correct position 83 | >>> sample = square_patch(image_arr, [91,154,91]) 84 | >>> sample.size() 85 | torch.Size([182, 7, 182]) 86 | ``` 87 | 88 | ### Sampler 89 | 90 | Multi-processed sampler to automatically search for coordinates where sampling 91 | (pattern extraction) is possible. 92 | 93 | ```python 94 | >>> from torchmed.readers import SitkReader 95 | >>> from torchmed.samplers import MaskableSampler 96 | >>> from torchmed.patterns import SquaredSlidingWindow 97 | 98 | # maps a name to each image 99 | >>> file_map = { 100 | ... 'image_ref': SitkReader('prepro_im_mni_bc.nii.gz', 101 | ... torch_type='torch.FloatTensor'), 102 | ... 'target': SitkReader('prepro_seg_mni.nii.gz', 103 | ... torch_type='torch.LongTensor') 104 | ... } 105 | 106 | # sliding window pattern 107 | >>> patch2d = SquaredSlidingWindow(patch_size=[182, 7, 182], use_padding=False) 108 | # specify a pattern for each input image 109 | >>> pattern_mapper = {'input': ('image_ref', patch2d), 110 | ... 'target': ('target', patch2d)} 111 | # muli-processed sampler with offset 112 | >>> sampler = MaskableSampler(pattern_mapper, offset=[91, 1, 91], nb_workers=4) 113 | >>> sampler.build(file_map) 114 | >>> len(sampler) 115 | 212 116 | >>> sample = sampler[0] 117 | >>> type(sample) 118 | 119 | >>> sample[0].size() 120 | torch.Size([3]) 121 | >>> sample[1].size() 122 | torch.Size([182, 7, 182]) 123 | >>> sample[2].size() 124 | torch.Size([182, 7, 182]) 125 | ``` 126 | 127 | ### Dataset 128 | 129 | `MedFile` and `MedFolder` are iterable datasets, returning samples from the input 130 | data. Here is an example of how to build a `MedFolder` from a list of images. 131 | A `MedFolder` takes as input a list of `MedFile`s. 132 | 133 | ```python 134 | import os 135 | from torchmed.datasets import MedFile, MedFolder 136 | 137 | self.train_dataset = MedFolder( 138 | self.generate_medfiles(os.path.join(base_dir, 'train'), nb_workers)) 139 | 140 | def generate_medfiles(self, dir, nb_workers): 141 | # database composed of dirname contained in the allowed_data.txt 142 | database = open(os.path.join(dir, 'allowed_data.txt'), 'r') 143 | patient_list = [line.rstrip('\n') for line in database] 144 | medfiles = [] 145 | 146 | # builds a list of MedFiles, one for each folder 147 | for patient in patient_list: 148 | if patient: 149 | patient_dir = os.path.join(dir, patient) 150 | patient_data = self.build_patient_data_map(patient_dir) 151 | patient_file = MedFile(patient_data, self.build_sampler(nb_workers)) 152 | medfiles.append(patient_file) 153 | 154 | return medfiles 155 | 156 | def build_patient_data_map(self, dir): 157 | # pads each dimension of the image on both sides. 158 | pad_reflect = Pad(((1, 1), (3, 3), (1, 1)), 'reflect') 159 | file_map = { 160 | 'image_ref': SitkReader( 161 | os.path.join(dir, 'prepro_im_mni_bc.nii.gz'), 162 | torch_type='torch.FloatTensor', transform=pad_reflect), 163 | 'target': SitkReader( 164 | os.path.join(dir, 'prepro_seg_mni.nii.gz'), 165 | torch_type='torch.LongTensor', transform=pad_reflect) 166 | } 167 | 168 | return file_map 169 | 170 | def build_sampler(self, nb_workers): 171 | # sliding window of size [184, 7, 184] without padding 172 | patch2d = SquaredSlidingWindow(patch_size=[184, 7, 184], use_padding=False) 173 | # pattern map links image id to a Sampler 174 | pattern_mapper = {'input': ('image_ref', patch2d), 175 | 'target': ('target', patch2d)} 176 | 177 | # add a fixed offset to make patch sampling faster (doesn't look for all positions) 178 | return MaskableSampler(pattern_mapper, offset=[92, 1, 92], 179 | nb_workers=nb_workers) 180 | 181 | ``` 182 | 183 | ### Examples 184 | 185 | See the `datasets` folder of the examples for a more pratical use case. 186 | 187 | #### Credits 188 | 189 | Evaluation metrics are mostly based on MedPy. 190 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('../')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'TorchMed' 23 | copyright = '2019, Pierre-Antoine Ganaye' 24 | author = 'Pierre-Antoine Ganaye' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = 'alpha' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.doctest', 44 | 'sphinx.ext.autosummary', 45 | 'sphinx.ext.intersphinx', 46 | 'sphinx.ext.todo', 47 | 'sphinx.ext.coverage', 48 | 'sphinx.ext.napoleon', 49 | 'sphinx.ext.mathjax', 50 | 'sphinx.ext.viewcode', 51 | 'sphinx.ext.githubpages', 52 | ] 53 | 54 | 55 | # Add any paths that contain templates here, relative to this directory. 56 | templates_path = ['_templates'] 57 | 58 | # The suffix(es) of source filenames. 59 | # You can specify multiple suffix as a list of string: 60 | # 61 | # source_suffix = ['.rst', '.md'] 62 | source_suffix = '.rst' 63 | 64 | # The master toctree document. 65 | master_doc = 'index' 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | # 70 | # This is also used if you do content translation via gettext catalogs. 71 | # Usually you set "language" from the command line for these cases. 72 | language = None 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | # This pattern also affects html_static_path and html_extra_path. 77 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 78 | 79 | # The name of the Pygments (syntax highlighting) style to use. 80 | pygments_style = 'sphinx' 81 | 82 | napoleon_use_ivar = True 83 | 84 | # Disable docstring inheritance 85 | autodoc_inherit_docstrings = False 86 | 87 | # -- Options for HTML output ------------------------------------------------- 88 | 89 | # The theme to use for HTML and HTML Help pages. See the documentation for 90 | # a list of builtin themes. 91 | # 92 | html_theme = "sphinx_rtd_theme" 93 | 94 | # Theme options are theme-specific and customize the look and feel of a theme 95 | # further. For a list of options available for each theme, see the 96 | # documentation. 97 | # 98 | # html_theme_options = {} 99 | 100 | # Add any paths that contain custom static files (such as style sheets) here, 101 | # relative to this directory. They are copied after the builtin static files, 102 | # so a file named "default.css" will overwrite the builtin "default.css". 103 | html_static_path = ['_static'] 104 | 105 | # Custom sidebar templates, must be a dictionary that maps document names 106 | # to template names. 107 | # 108 | # The default sidebars (for documents that don't match any pattern) are 109 | # defined by theme itself. Builtin themes are using these templates by 110 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 111 | # 'searchbox.html']``. 112 | # 113 | # html_sidebars = {} 114 | 115 | 116 | # -- Options for HTMLHelp output --------------------------------------------- 117 | 118 | # Output file base name for HTML help builder. 119 | htmlhelp_basename = 'TorchMeddoc' 120 | 121 | 122 | # -- Options for LaTeX output ------------------------------------------------ 123 | 124 | latex_elements = { 125 | # The paper size ('letterpaper' or 'a4paper'). 126 | # 127 | # 'papersize': 'letterpaper', 128 | 129 | # The font size ('10pt', '11pt' or '12pt'). 130 | # 131 | # 'pointsize': '10pt', 132 | 133 | # Additional stuff for the LaTeX preamble. 134 | # 135 | # 'preamble': '', 136 | 137 | # Latex figure (float) alignment 138 | # 139 | # 'figure_align': 'htbp', 140 | } 141 | 142 | # Grouping the document tree into LaTeX files. List of tuples 143 | # (source start file, target name, title, 144 | # author, documentclass [howto, manual, or own class]). 145 | latex_documents = [ 146 | (master_doc, 'TorchMed.tex', 'TorchMed Documentation', 147 | 'Pierre-Antoine Ganaye', 'manual'), 148 | ] 149 | 150 | 151 | # -- Options for manual page output ------------------------------------------ 152 | 153 | # One entry per manual page. List of tuples 154 | # (source start file, name, description, authors, manual section). 155 | man_pages = [ 156 | (master_doc, 'torchmed', 'TorchMed Documentation', 157 | [author], 1) 158 | ] 159 | 160 | 161 | # -- Options for Texinfo output ---------------------------------------------- 162 | 163 | # Grouping the document tree into Texinfo files. List of tuples 164 | # (source start file, target name, title, author, 165 | # dir menu entry, description, category) 166 | texinfo_documents = [ 167 | (master_doc, 'TorchMed', 'TorchMed Documentation', 168 | author, 'TorchMed', 'One line description of project.', 169 | 'Miscellaneous'), 170 | ] 171 | 172 | 173 | # -- Options for Epub output ------------------------------------------------- 174 | 175 | # Bibliographic Dublin Core info. 176 | epub_title = project 177 | 178 | # The unique identifier of the text. This can be a ISBN number 179 | # or the project homepage. 180 | # 181 | # epub_identifier = '' 182 | 183 | # A unique identification for the text. 184 | # 185 | # epub_uid = '' 186 | 187 | # A list of files that should not be packed into the epub file. 188 | epub_exclude_files = ['search.html'] 189 | 190 | 191 | # -- Extension configuration ------------------------------------------------- 192 | 193 | # -- Options for intersphinx extension --------------------------------------- 194 | 195 | # Example configuration for intersphinx: refer to the Python standard library. 196 | intersphinx_mapping = {'https://docs.python.org/': None} 197 | 198 | # -- Options for todo extension ---------------------------------------------- 199 | 200 | # If true, `todo` and `todoList` produce output, else they produce nothing. 201 | todo_include_todos = True 202 | -------------------------------------------------------------------------------- /docs/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ==================================== 3 | 4 | .. automodule:: torchmed.datasets 5 | :members: 6 | :special-members: __getitem__, __len__ 7 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TorchMed documentation master file, created by 2 | sphinx-quickstart on Mon Jan 21 17:42:25 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to TorchMed's documentation! 7 | ==================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | 14 | .. toctree:: 15 | :maxdepth: 4 16 | :caption: Contents: 17 | 18 | readers 19 | datasets 20 | patterns 21 | samplers 22 | 23 | 24 | 25 | Indices and tables 26 | ================== 27 | 28 | * :ref:`genindex` 29 | * :ref:`modindex` 30 | * :ref:`search` 31 | -------------------------------------------------------------------------------- /docs/patterns.rst: -------------------------------------------------------------------------------- 1 | Patterns 2 | ==================================== 3 | 4 | .. automodule:: torchmed.patterns 5 | :members: 6 | :special-members: __call__ 7 | -------------------------------------------------------------------------------- /docs/readers.rst: -------------------------------------------------------------------------------- 1 | Readers 2 | ==================================== 3 | 4 | .. automodule:: torchmed.readers 5 | :members: 6 | :inherited-members: 7 | -------------------------------------------------------------------------------- /docs/samplers.rst: -------------------------------------------------------------------------------- 1 | Samplers 2 | ==================================== 3 | 4 | .. automodule:: torchmed.samplers 5 | :members: 6 | :inherited-members: 7 | :special-members: __getitem__, __len__ 8 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ModSegNet(nn.Module): 7 | """Inspired from SegNet https://arxiv.org/abs/1511.00561 8 | 9 | Args: 10 | num_classes (int): number of classes to segment 11 | n_init_features (int): number of input features in the fist convolution 12 | drop_rate (float): dropout rate of the last two encoders 13 | filter_config (list of 5 ints): number of output features at each level 14 | """ 15 | def __init__(self, num_classes, n_init_features=7, drop_rate=0.5, 16 | filter_config=(32, 64, 128, 256)): 17 | super(ModSegNet, self).__init__() 18 | 19 | self.encoder1 = _Encoder(n_init_features, filter_config[0]) 20 | self.encoder2 = _Encoder(filter_config[0], filter_config[1]) 21 | self.encoder3 = _Encoder(filter_config[1], filter_config[2]) 22 | self.encoder4 = _Encoder(filter_config[2], filter_config[3], drop_rate) 23 | 24 | self.decoder1 = _Decoder(filter_config[3], filter_config[2]) 25 | self.decoder2 = _Decoder(filter_config[2], filter_config[1]) 26 | self.decoder3 = _Decoder(filter_config[1], filter_config[0]) 27 | self.decoder4 = _Decoder(filter_config[0], filter_config[0]) 28 | 29 | # final classifier (equivalent to a fully connected layer) 30 | self.classifier = nn.Conv2d(filter_config[0], num_classes, 1) 31 | 32 | # init weights 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | nn.init.xavier_normal_(m.weight) 36 | 37 | def forward(self, x): 38 | feat_encoder_1 = self.encoder1(x) 39 | size_2 = feat_encoder_1.size() 40 | feat_encoder_2, ind_2 = F.max_pool2d(feat_encoder_1, 2, 2, 41 | return_indices=True) 42 | 43 | feat_encoder_2 = self.encoder2(feat_encoder_2) 44 | size_3 = feat_encoder_2.size() 45 | feat_encoder_3, ind_3 = F.max_pool2d(feat_encoder_2, 2, 2, 46 | return_indices=True) 47 | 48 | feat_encoder_3 = self.encoder3(feat_encoder_3) 49 | size_4 = feat_encoder_3.size() 50 | feat_encoder_4, ind_4 = F.max_pool2d(feat_encoder_3, 2, 2, 51 | return_indices=True) 52 | 53 | feat_encoder_4 = self.encoder4(feat_encoder_4) 54 | size_5 = feat_encoder_4.size() 55 | feat_encoder_5, ind_5 = F.max_pool2d(feat_encoder_4, 2, 2, 56 | return_indices=True) 57 | 58 | feat_decoder = self.decoder1(feat_encoder_5, feat_encoder_4, ind_5, size_5) 59 | feat_decoder = self.decoder2(feat_decoder, feat_encoder_3, ind_4, size_4) 60 | feat_decoder = self.decoder3(feat_decoder, feat_encoder_2, ind_3, size_3) 61 | feat_decoder = self.decoder4(feat_decoder, feat_encoder_1, ind_2, size_2) 62 | 63 | return F.log_softmax(self.classifier(feat_decoder), dim=1) 64 | 65 | 66 | class _Encoder(nn.Module): 67 | """Encoder layer encodes the features along the contracting path (left side), 68 | drop_rate parameter is used with respect to the paper and the official 69 | caffe model. 70 | 71 | Args: 72 | n_in_feat (int): number of input features 73 | n_out_feat (int): number of output features 74 | drop_rate (float): dropout rate at the end of the block 75 | """ 76 | def __init__(self, n_in_feat, n_out_feat, drop_rate=0): 77 | super(_Encoder, self).__init__() 78 | 79 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 80 | nn.ReLU(inplace=True), 81 | nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1), 82 | nn.ReLU(inplace=True)] 83 | 84 | if drop_rate > 0: 85 | layers += [nn.Dropout(drop_rate)] 86 | 87 | self.features = nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | return self.features(x) 91 | 92 | 93 | class _Decoder(nn.Module): 94 | """Decoder layer decodes the features by performing deconvolutions and 95 | concatenating the resulting features with cropped features from the 96 | corresponding encoder (skip-connections). Encoder features are cropped 97 | because convolution operations does not allow to recover the same 98 | resolution in the expansive path (cf input image size > output 99 | segmentation map size). 100 | 101 | Args: 102 | n_in_feat (int): number of input features 103 | n_out_feat (int): number of output features 104 | """ 105 | def __init__(self, n_in_feat, n_out_feat): 106 | super(_Decoder, self).__init__() 107 | 108 | self.encoder = _Encoder(n_in_feat * 2, n_out_feat) 109 | 110 | def forward(self, x, feat_encoder, indices, size): 111 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size) 112 | feat = torch.cat([unpooled, feat_encoder], 1) 113 | feat = self.encoder(feat) 114 | 115 | return feat 116 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/datasets/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torchmed.datasets import MedFile 5 | from torchmed.samplers import MaskableSampler 6 | from torchmed.patterns import SquaredSlidingWindow 7 | from torchmed.readers import SitkReader 8 | from torchmed.utils.transforms import Pad 9 | 10 | 11 | class MICCAI2012MedFile(object): 12 | def __init__(self, data, nb_workers): 13 | # a data map which specifies the image to read and reading specs 14 | # the transform pads each dimension of the image on both sides. 15 | patient_data = { 16 | 'image_ref': SitkReader( 17 | os.path.join(data, 'prepro_im_mni_bc.nii.gz'), 18 | torch_type='torch.FloatTensor', 19 | transform=Pad(((1, 1), (0, 0), (1, 1)), 'reflect')) 20 | } 21 | # medfile dataset takes a data map, a Sampler and a transform 22 | self.test_data = MedFile(patient_data, self.buid_sampler(nb_workers), 23 | transform=lambda t: t.permute(1, 0, 2)) 24 | 25 | # init all the images before multiprocessing 26 | self.test_data._sampler._coordinates.share_memory_() 27 | for k, v in self.test_data._sampler._data.items(): 28 | v._torch_init() 29 | 30 | def buid_sampler(self, nb_workers): 31 | # sliding window of size [184, 1, 184] without padding 32 | patch2d = SquaredSlidingWindow(patch_size=[184, 1, 184], use_padding=False) 33 | # pattern map links image id to a Sampler 34 | pattern_mapper = {'input': ('image_ref', patch2d)} 35 | 36 | return MaskableSampler(pattern_mapper, nb_workers=nb_workers) 37 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/datasets/mappings.py: -------------------------------------------------------------------------------- 1 | 2 | class Miccai12Mapping(object): 3 | def __init__(self): 4 | self.all_labels = [0, 4, 11, 23, 30, 31, 32, 35, 36, 37, 38, 39, 40, 5 | 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 55, 6 | 56, 57, 58, 59, 60, 61, 62, 63, 64, 66, 69, 71, 72, 7 | 73, 75, 76, 100, 101, 102, 103, 104, 105, 106, 107, 8 | 108, 109, 112, 113, 114, 115, 116, 117, 118, 119, 9 | 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 10 | 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 11 | 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 12 | 152, 153, 154, 155, 156, 157, 160, 161, 162, 163, 13 | 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 14 | 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 15 | 184, 185, 186, 187, 190, 191, 192, 193, 194, 195, 16 | 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 17 | 206, 207] 18 | self.ignore_labels = [1, 2, 3] + \ 19 | [5, 6, 7, 8, 9, 10] + \ 20 | [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] + \ 21 | [24, 25, 26, 27, 28, 29] + \ 22 | [33, 34] + [42, 43] + [53, 54] + \ 23 | [63, 64, 65, 66, 67, 68] + [70, 74] + \ 24 | [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 25 | 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + \ 26 | [110, 111, 126, 127, 130, 131, 158, 159, 188, 189] 27 | self.overall_labels = set(self.all_labels).difference( 28 | set(self.ignore_labels)) 29 | 30 | self.cortical_labels = [x for x in self.overall_labels if x >= 100] 31 | self.non_cortical_labels = \ 32 | [x for x in self.overall_labels if x > 0 and x < 100] 33 | 34 | self.map = {v: k for k, v in enumerate(self.overall_labels)} 35 | self.reversed_map = {k: v for k, v in enumerate(self.overall_labels)} 36 | self.nb_classes = len(self.overall_labels) 37 | 38 | def __getitem__(self, index): 39 | return self.map[index] 40 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/datasets/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import random 4 | import torch 5 | 6 | from torchmed.datasets import MedFile, MedFolder 7 | from torchmed.samplers import MaskableSampler 8 | from torchmed.patterns import SquaredSlidingWindow 9 | from torchmed.readers import SitkReader 10 | from torchmed.utils.transforms import Pad 11 | from torchmed.utils.augmentation import elastic_deformation_2d 12 | 13 | 14 | class MICCAI2012Dataset(object): 15 | def __init__(self, base_dir, nb_workers): 16 | def transform_train(tensor): 17 | return tensor.permute(1, 0, 2) 18 | 19 | def transform_target(tensor): 20 | return tensor.permute(1, 0, 2)[0] 21 | 22 | self.train_dataset = MedFolder( 23 | self.generate_medfiles(os.path.join(base_dir, 'train'), nb_workers), 24 | transform=transform_train, target_transform=transform_target, 25 | paired_transform=self.elastic_transform) 26 | self.validation_dataset = MedFolder( 27 | self.generate_medfiles(os.path.join(base_dir, 'validation'), nb_workers), 28 | transform=transform_train, target_transform=transform_target, 29 | paired_transform=self.elastic_transform) 30 | 31 | # init all the images before multiprocessing 32 | for medfile in self.train_dataset._medfiles: 33 | medfile._sampler._coordinates.share_memory_() 34 | for k, v in medfile._sampler._data.items(): 35 | v._torch_init() 36 | 37 | # init all the images before multiprocessing 38 | for medfile in self.validation_dataset._medfiles: 39 | medfile._sampler._coordinates.share_memory_() 40 | for k, v in medfile._sampler._data.items(): 41 | v._torch_init() 42 | 43 | # read cumulated volume of each labels 44 | df = pd.read_csv(os.path.join(base_dir, 'train/class_log.csv'), sep=';', index_col=0) 45 | self.class_freq = torch.from_numpy(df['volume'].values).float() 46 | 47 | def generate_medfiles(self, dir, nb_workers): 48 | # database composed of dirname contained in the allowed_data.txt 49 | database = open(os.path.join(dir, 'allowed_data.txt'), 'r') 50 | patient_list = [line.rstrip('\n') for line in database] 51 | medfiles = [] 52 | 53 | # builds a list of MedFiles, one for each folder 54 | for patient in patient_list: 55 | if patient: 56 | patient_dir = os.path.join(dir, patient) 57 | patient_data = self.build_patient_data_map(patient_dir) 58 | patient_file = MedFile(patient_data, self.build_sampler(nb_workers)) 59 | medfiles.append(patient_file) 60 | 61 | return medfiles 62 | 63 | def build_patient_data_map(self, dir): 64 | # pads each dimension of the image on both sides. 65 | pad_reflect = Pad(((1, 1), (0, 0), (1, 1)), 'reflect') 66 | file_map = { 67 | 'image_ref': SitkReader( 68 | os.path.join(dir, 'prepro_im_mni_bc.nii.gz'), 69 | torch_type='torch.FloatTensor', transform=pad_reflect), 70 | 'target': SitkReader( 71 | os.path.join(dir, 'prepro_seg_mni.nii.gz'), 72 | torch_type='torch.LongTensor', transform=pad_reflect) 73 | } 74 | 75 | return file_map 76 | 77 | def build_sampler(self, nb_workers): 78 | # sliding window of size [184, 1, 184] without padding 79 | patch2d = SquaredSlidingWindow(patch_size=[184, 1, 184], use_padding=False) 80 | # pattern map links image id to a Sampler 81 | pattern_mapper = {'input': ('image_ref', patch2d), 82 | 'target': ('target', patch2d)} 83 | 84 | # add a fixed offset to make patch sampling faster (doesn't look for all positions) 85 | return MaskableSampler(pattern_mapper, offset=[92, 1, 92], 86 | nb_workers=nb_workers) 87 | 88 | def elastic_transform(self, data, label): 89 | # elastic deformation 90 | if random.random() > 0.4: 91 | data_label = torch.cat([data, label.unsqueeze(0).float()], 0) 92 | data_label = elastic_deformation_2d( 93 | data_label, 94 | data_label.shape[1] * 1.05, # intensity of the deformation 95 | data_label.shape[1] * 0.05, # smoothing of the deformation 96 | 0, # order of bspline interp 97 | mode='nearest') # border mode 98 | 99 | data = data_label[0].unsqueeze(0) 100 | label = data_label[1].long() 101 | 102 | return data, label 103 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/inference_canvas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import pandas as pd 5 | import collections 6 | 7 | from torchmed.utils.metric import multiclass, dice, hd, assd, precision 8 | from torchmed.readers import SitkReader 9 | 10 | 11 | class InferenceCanvas(object): 12 | def __init__(self, args, inference_fn, data_fn, model): 13 | self.args = args 14 | self.dataset_fn = data_fn 15 | self.inference_fn = inference_fn 16 | self.model = model 17 | 18 | self.file_names = {'image': 'prepro_im_mni_bc.nii.gz', 19 | 'label': 'prepro_seg_mni.nii.gz', 20 | 'segmentation': 'segmentation.nii.gz'} 21 | self.metrics = {'dice': dice, 22 | 'hausdorff': hd, 23 | 'mean_surface_distance': assd, 24 | 'precision': precision 25 | } 26 | 27 | def __call__(self): 28 | print("=> started segmentation script") 29 | # if output dir does not exists, create it 30 | if not os.path.isdir(self.args.output): 31 | os.makedirs(self.args.output) 32 | 33 | print("=> loading the architecture") 34 | model = self.model.cuda() 35 | 36 | print("=> loading trained model at {}".format(self.args.model)) 37 | # load model parameters 38 | torch.backends.cudnn.benchmark = True 39 | checkpoint = torch.load(self.args.model) 40 | try: 41 | model.load_state_dict(checkpoint['state_dict']) 42 | except RuntimeError as e: 43 | model = torch.nn.DataParallel(model).cuda() 44 | model.load_state_dict(checkpoint['state_dict']) 45 | model = model.module 46 | 47 | print("=> segmentation output at {}".format(self.args.output)) 48 | test_times = [] 49 | if os.path.isfile(os.path.join(self.args.data, 'allowed_data.txt')): 50 | allowed_data_file = open( 51 | os.path.join(self.args.data, 'allowed_data.txt'), 'r') 52 | patient_list = [line.rstrip('\n') for line in allowed_data_file] 53 | 54 | for patient in patient_list: 55 | if patient: 56 | patient_dir = os.path.join(self.args.data, patient) 57 | patient_out = os.path.join(self.args.output, patient) 58 | os.makedirs(patient_out) 59 | test_time = self.segment_metric_plot(model, 60 | patient_dir, 61 | patient_out) 62 | test_times.append(test_time) 63 | 64 | else: 65 | test_times.append( 66 | self.segment_metric_plot(model, self.args.data, 67 | self.args.output)) 68 | 69 | # write test time to file 70 | time_file = os.path.join(self.args.output, 'test_time.csv') 71 | time_report = open(time_file, 'a') 72 | time_report.write('image_id;minutes\n') 73 | for time_id in range(0, len(test_times)): 74 | time_report.write('{};{:.5f}\n'.format( 75 | time_id, test_times[time_id])) 76 | time_report.write('{};{:.5f}\n'.format( 77 | 'average', sum(test_times) / len(test_times))) 78 | time_report.flush() 79 | 80 | def segment_metric_plot(self, model, patient_dir, patient_out): 81 | patient_seg = os.path.join(patient_out, 82 | self.file_names['segmentation']) 83 | 84 | # segmentation 85 | test_time = self.segment_one_patient(model, patient_dir, patient_out) 86 | print('-- segmented {} in {:.2f}s'.format(patient_dir, test_time)) 87 | 88 | # if ground truth is available, use metrics 89 | if self.args.wo_metrics: 90 | patient_map = os.path.join(patient_dir, 91 | self.file_names['label']) 92 | 93 | self.save_error_map(patient_dir, patient_out) 94 | 95 | # evaluate metrics 96 | ref_img = SitkReader(patient_map).to_numpy() 97 | seg_img = SitkReader(patient_seg).to_numpy() 98 | results, undec_structs = multiclass(seg_img, ref_img, 99 | self.metrics.values()) 100 | metrics_results = zip(self.metrics.keys(), results) 101 | 102 | m = collections.OrderedDict(sorted(metrics_results, key=lambda x: x[0])) 103 | df = pd.DataFrame.from_dict(m) 104 | df.to_csv(os.path.join(patient_out, 'metrics_report.csv'), ';') 105 | 106 | if len(undec_structs) > 0: 107 | df = pd.DataFrame(undec_structs, columns=["class_id"]) 108 | df.to_csv(os.path.join(patient_out, 'undetected_classes.csv'), ';') 109 | 110 | return test_time 111 | 112 | def segment_one_patient(self, model, data, output): 113 | # Data loading code 114 | medcomp = self.dataset_fn(data, self.args.batch_size).test_data 115 | loader = torch.utils.data.DataLoader(medcomp, 116 | batch_size=self.args.batch_size, 117 | shuffle=False, 118 | num_workers=5, 119 | pin_memory=True) 120 | 121 | lab = SitkReader(os.path.join(data, self.file_names['image']), 122 | torch_type='torch.LongTensor') 123 | lab_array = lab.to_numpy() 124 | lab_array.fill(0) 125 | 126 | start_time = time.time() 127 | probability_maps = self.inference_fn(model, loader, lab_array) 128 | end_time = time.time() 129 | 130 | # save label map 131 | lab.to_image_file(os.path.join(output, self.file_names['segmentation'])) 132 | 133 | """ 134 | generation of probability maps, use only if you want to visualize 135 | probabilities for a given label, otherwise it will generate all 136 | the probabilities maps 137 | """ 138 | if len(probability_maps) > 0: 139 | os.makedirs(os.path.join(output, "probability_maps")) 140 | 141 | img = SitkReader(os.path.join(data, self.file_names['image'])) 142 | img_array = img.to_numpy() 143 | for map_id in range(0, len(probability_maps)): 144 | prob_file = os.path.join(output, 145 | "probability_maps/label_{}.img".format(map_id)) 146 | img_array.fill(0) 147 | img_array[...] = probability_maps[map_id] 148 | img.to_image_file(prob_file) 149 | 150 | return (end_time - start_time) / 60 151 | 152 | def save_error_map(self, data, output): 153 | lab = SitkReader(os.path.join(data, self.file_names['label'])) 154 | seg = SitkReader(os.path.join(output, self.file_names['segmentation'])) 155 | lab_array = lab.to_numpy() 156 | seg_array = seg.to_numpy() 157 | 158 | seg_array[seg_array == lab_array] = 0 159 | 160 | # save label map 161 | seg.to_image_file(os.path.join(output, 'error_map.img')) 162 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/readme.md: -------------------------------------------------------------------------------- 1 | ## Code organization 2 | 3 | - datasets : converts input data into iterable datasets for training and inference. 4 | - `training.py` : build a dataset for inference. 5 | - `inference.py` : build the datasets for training and validation. 6 | - `mappings.py` : mapping of image labels. 7 | - `architecture.py` : architecture of the CNN. 8 | - `inference_canvas.py` : inference and metrics evaluation for the test dataset. 9 | - `segment.py` : script for segmenting images based on a model. 10 | - `train_segment.py` : training + segmentation of test dataset. 11 | - `train.py` : training script. 12 | 13 | The output folder will contain : 14 | 15 | - figures plotting the various segmentation metrics and losses. 16 | - logs of metrics and losses for each iteration on train and validation. 17 | - `checkpoint.pth.tar` : last epoch model's parameters. 18 | - `model_best_dice.pth.tar` : best performing model's parameters. 19 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from architecture import ModSegNet 5 | from inference_canvas import InferenceCanvas 6 | from datasets.mappings import Miccai12Mapping 7 | from datasets.inference import MICCAI2012MedFile 8 | 9 | 10 | parser = argparse.ArgumentParser( 11 | description='PyTorch Automatic Segmentation (inference mode)') 12 | parser.add_argument('data', metavar='DIR', 13 | help='path to dataset') 14 | parser.add_argument('model', metavar='MODEL', 15 | help='path to a trained model') 16 | parser.add_argument('output', metavar='OUTPUT', 17 | help='path to the output segmentation map') 18 | 19 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 20 | help='number of data loading workers (default: 4)') 21 | parser.add_argument('-b', '--batch-size', default=16, type=int, 22 | metavar='N', help='mini-batch size (default: 16)') 23 | parser.add_argument('--wo-metrics', action='store_false', 24 | help='whether to use metrics (dice, assd) or not') 25 | 26 | 27 | def main(): 28 | global args 29 | args = parser.parse_args() 30 | model = ModSegNet(num_classes=Miccai12Mapping().nb_classes, 31 | n_init_features=1).cuda() 32 | inference_canvas = InferenceCanvas(args, infer_segmentation_map, 33 | MICCAI2012MedFile, model) 34 | inference_canvas() 35 | 36 | 37 | def infer_segmentation_map(model, data_loader, label_map): 38 | probability_maps = [] 39 | 40 | with torch.no_grad(): 41 | for position, input in data_loader: 42 | output = model(input.cuda()) 43 | _, predicted = output.data.max(1) 44 | 45 | # for each element of the batch 46 | for i in range(0, predicted.size(0)): 47 | y = position[i][1] 48 | label_map[:, y, :] = predicted[i].cpu().numpy()[1:-1, 1:-1] 49 | 50 | return probability_maps 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/train_segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import random 5 | import string 6 | import time 7 | 8 | 9 | parser = argparse.ArgumentParser( 10 | description='PyTorch Automatic Segmentation Training and Inference') 11 | parser.add_argument('data', metavar='DATA_DIR', help='path to the data dir') 12 | parser.add_argument('output_dir', default='', metavar='OUTPUT_DIR', 13 | help='path to the output directory (default: current dir)') 14 | 15 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 16 | help='number of data loading workers (default: 4)') 17 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 18 | help='number of total epochs to run (default: 200)') 19 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 20 | help='manual epoch number (useful on restarts)') 21 | parser.add_argument('-b', '--batch-size', default=64, type=int, 22 | metavar='N', help='mini-batch size (default: 64)') 23 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 24 | metavar='LR', help='initial learning rate (default: 0.01)') 25 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 26 | help='momentum') 27 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 28 | metavar='W', help='weight decay (default: 1e-4)') 29 | parser.add_argument('--print-freq', '-p', default=10, type=int, 30 | metavar='N', help='print frequency (default: 10)') 31 | parser.add_argument('--resume', default=None, type=str, metavar='PATH', 32 | help='path to latest checkpoint (default: none)') 33 | 34 | 35 | def main(): 36 | args = parser.parse_args() 37 | 38 | code_dir = os.path.dirname(os.path.realpath(__file__)) 39 | train_script = os.path.join(code_dir, 'train.py') 40 | inference_script = os.path.join(code_dir, 'segment.py') 41 | report_script_dir = os.path.join(code_dir, '../scripts/report/') 42 | report_scripts = ['plot_boxplot_labels_by_metric', 43 | 'plot_boxplot_patients_by_metric', 44 | 'write_patient_by_metric', 'write_label_by_metric'] 45 | 46 | exp_id = os.path.basename(code_dir) 47 | now = datetime.datetime.now() 48 | dir_name = '{}_{}_{}@{}-{}_{}_{}'.format(now.year, now.month, now.day, 49 | now.hour, now.minute, 50 | exp_id, id_generator(2)) 51 | output_dir = os.path.join(args.output_dir, dir_name) 52 | os.makedirs(output_dir) 53 | 54 | ##### 55 | # 56 | # Training 57 | # 58 | ##### 59 | print('--> The output folder is {}'.format(output_dir)) 60 | print('--> Started train script') 61 | ret = os.system('python -u {} {} {} -j {} -b {} --epochs {} --lr {} --exp-id {} --resume {}'.format( 62 | train_script, 63 | args.data, 64 | output_dir, 65 | args.workers, 66 | args.batch_size, 67 | args.epochs, 68 | args.lr, 69 | exp_id, 70 | args.resume 71 | )) 72 | 73 | # in case training ended with an error 74 | if os.WEXITSTATUS(ret) != 0: 75 | return -1 76 | 77 | print('Sleeping for 5 seconds before segmentation. (read/write sync)') 78 | time.sleep(5) 79 | 80 | ##### 81 | # 82 | # Segmentation 83 | # 84 | ##### 85 | output_dir_seg = os.path.join(output_dir, 'segmentations') 86 | os.makedirs(output_dir_seg) 87 | 88 | segment_command = 'python -u {} {} {} {}'.format( 89 | inference_script, 90 | os.path.join(args.data, 'test'), 91 | os.path.join(output_dir, 'model_best_dice.pth.tar'), 92 | output_dir_seg 93 | ) 94 | 95 | print(segment_command) 96 | os.system(segment_command) 97 | 98 | ##### 99 | # 100 | # Reporting 101 | # 102 | ##### 103 | output_dir_report = os.path.join(output_dir, 'reports') 104 | for r_script in report_scripts: 105 | report_command = 'python -u {} {} {}'.format( 106 | os.path.join(report_script_dir, r_script + '.py'), 107 | output_dir_seg, 108 | output_dir_report 109 | ) 110 | os.system(report_command) 111 | time.sleep(1) 112 | 113 | 114 | def id_generator(size=6, chars=string.ascii_uppercase + string.digits): 115 | return ''.join(random.choice(chars) for _ in range(size)) 116 | 117 | 118 | if __name__ == '__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /examples/01_brain_segmentation_MRI/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | import os 4 | import shutil 5 | import torch 6 | 7 | from torchmed.utils.metric import dice as dc 8 | from torchmed.utils.metric import jaccard, multiclass 9 | 10 | 11 | def write_config(model, args, train_size, val_size): 12 | num_params = 0 13 | for m in model.modules(): 14 | if isinstance(m, torch.nn.Conv2d): 15 | num_params += m.weight.data.numel() 16 | 17 | configfile = os.path.join(args.output_dir, 'config.txt') 18 | cfg_f = open(configfile, "a") 19 | cfg_f.write('\ntraining with {} patches\n' 20 | 'validating with {} patches\n' 21 | .format(train_size * args.batch_size, 22 | val_size * args.batch_size)) 23 | cfg_f.write(('project: {}\n' + 24 | 'number of workers: {}\n' + 25 | 'number of epochs: {}\n' + 26 | 'starting epoch: {}\n' + 27 | 'batch size: {}\n' + 28 | 'learning rate: {:.6f}\n' + 29 | 'momentum: {:.5f}\n' + 30 | 'weight-decay: {:.5f}\n' + 31 | 'number of parameters: {}\n') 32 | .format(args.exp_id, args.workers, 33 | args.epochs, args.start_epoch, 34 | args.batch_size, args.lr, 35 | args.momentum, args.weight_decay, num_params) 36 | ) 37 | cfg_f.write('\nstarted training at {}\n'.format(datetime.datetime.now())) 38 | cfg_f.flush() 39 | 40 | 41 | def write_end_config(args, elapsed_time): 42 | configfile = os.path.join(args.output_dir, 'config.txt') 43 | cfg_f = open(configfile, "a") 44 | cfg_f.write('stopped training at {}\n'.format(datetime.datetime.now())) 45 | cfg_f.write('elapsed time : {:.2f} hours or {:.2f} days.' 46 | .format((elapsed_time) / (60 * 60), 47 | (elapsed_time) / (60 * 60 * 24))) 48 | cfg_f.flush() 49 | 50 | 51 | def update_figures(log_plot): 52 | # plot avg train loss_meter 53 | log_plot.add_line('cross_entropy', 'average_train.csv', 'epoch', 'cross_entropy_loss', "#1f77b4") 54 | log_plot.add_line('dice', 'average_train.csv', 'epoch', 'dice_loss', "#ff7f0e") 55 | log_plot.plot('losses_train.png', 'epoch', 'loss') 56 | 57 | # plot avg validation loss_meter 58 | log_plot.add_line('cross_entropy', 'average_validation.csv', 'epoch', 'cross_entropy_loss', "#1f77b4") 59 | log_plot.add_line('dice', 'average_validation.csv', 'epoch', 'dice_loss', "#ff7f0e") 60 | log_plot.plot('losses_validation.png', 'epoch', 'loss') 61 | 62 | # plot learning rate 63 | log_plot.add_line('learning_rate', 'learning_rate.csv', 64 | 'epoch', 'lr', '#1f77b4') 65 | log_plot.plot('learning_rate.png', 'epoch', 'learning rate') 66 | 67 | # plot dice 68 | log_plot.add_line('train', 'average_train.csv', 'epoch', 'dice_metric', '#1f77b4') 69 | log_plot.add_line('validation', 'average_validation.csv', 70 | 'epoch', 'dice_metric', '#ff7f0e') 71 | log_plot.plot('average_dice.png', 'epoch', 'dice', max_y=1) 72 | 73 | # plot iou 74 | log_plot.add_line('train', 'average_train.csv', 'epoch', 'iou_metric', '#1f77b4') 75 | log_plot.add_line('validation', 'average_validation.csv', 76 | 'epoch', 'iou_metric', '#ff7f0e') 77 | log_plot.plot('average_iou.png', 'epoch', 'iou', max_y=1) 78 | 79 | 80 | def save_checkpoint(state, is_best, output_dir): 81 | filename = os.path.join(output_dir, 'checkpoint.pth.tar') 82 | bestfile = os.path.join(output_dir, 'best_log.txt') 83 | torch.save(state, filename) 84 | if is_best: 85 | bestfile_f = open(bestfile, "a") 86 | bestfile_f.write('epoch:{:>5d} dice:{:>7.4f} IoU:{:>7.4f}\n'.format( 87 | state['epoch'], state['dice_metric'], state['iou_metric'])) 88 | bestfile_f.flush() 89 | shutil.copyfile(filename, 90 | os.path.join(output_dir, 'model_best_dice.pth.tar')) 91 | 92 | 93 | def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, 94 | max_iter=100, power=0.9): 95 | """Polynomial decay of learning rate 96 | :param init_lr is base learning rate 97 | :param iter is a current iteration 98 | :param lr_decay_iter how frequently decay occurs, default is 1 99 | :param max_iter is number of maximum iterations 100 | :param power is a polymomial power 101 | 102 | """ 103 | if iter % lr_decay_iter or iter > max_iter: 104 | return optimizer 105 | 106 | lr = init_lr * (1 - iter / max_iter)**power 107 | for param_group in optimizer.param_groups: 108 | param_group['lr'] = lr 109 | 110 | return lr 111 | 112 | 113 | def eval_metrics(segmentation, reference): 114 | results, undec_labels = multiclass(segmentation, reference, [dc, jaccard]) 115 | return list(map(lambda l: sum(l.values()) / len(l), results)) 116 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ModSegNet(nn.Module): 7 | """Inspired from SegNet 8 | 9 | Args: 10 | num_classes (int): number of classes to segment 11 | n_init_features (int): number of input features in the fist convolution 12 | drop_rate (float): dropout rate of the last two encoders 13 | filter_config (list of 5 ints): number of output features at each level 14 | """ 15 | def __init__(self, num_classes, n_init_features=7, drop_rate=0.5, 16 | filter_config=(32, 64, 128, 256)): 17 | super(ModSegNet, self).__init__() 18 | 19 | self.encoder1 = _Encoder(n_init_features, filter_config[0]) 20 | self.encoder2 = _Encoder(filter_config[0], filter_config[1]) 21 | self.encoder3 = _Encoder(filter_config[1], filter_config[2]) 22 | self.encoder4 = _Encoder(filter_config[2], filter_config[3], drop_rate) 23 | 24 | self.decoder1 = _Decoder(filter_config[3], filter_config[2]) 25 | self.decoder2 = _Decoder(filter_config[2], filter_config[1]) 26 | self.decoder3 = _Decoder(filter_config[1], filter_config[0]) 27 | self.decoder4 = _Decoder(filter_config[0], filter_config[0]) 28 | 29 | # final classifier (equivalent to a fully connected layer) 30 | self.classifier = nn.Conv2d(filter_config[0], num_classes, 1) 31 | 32 | # init weights 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | nn.init.xavier_normal_(m.weight) 36 | 37 | def forward(self, x): 38 | feat_encoder_1 = self.encoder1(x) 39 | size_2 = feat_encoder_1.size() 40 | feat_encoder_2, ind_2 = F.max_pool2d(feat_encoder_1, 2, 2, 41 | return_indices=True) 42 | 43 | feat_encoder_2 = self.encoder2(feat_encoder_2) 44 | size_3 = feat_encoder_2.size() 45 | feat_encoder_3, ind_3 = F.max_pool2d(feat_encoder_2, 2, 2, 46 | return_indices=True) 47 | 48 | feat_encoder_3 = self.encoder3(feat_encoder_3) 49 | size_4 = feat_encoder_3.size() 50 | feat_encoder_4, ind_4 = F.max_pool2d(feat_encoder_3, 2, 2, 51 | return_indices=True) 52 | 53 | feat_encoder_4 = self.encoder4(feat_encoder_4) 54 | size_5 = feat_encoder_4.size() 55 | feat_encoder_5, ind_5 = F.max_pool2d(feat_encoder_4, 2, 2, 56 | return_indices=True) 57 | 58 | feat_decoder = self.decoder1(feat_encoder_5, feat_encoder_4, ind_5, size_5) 59 | feat_decoder = self.decoder2(feat_decoder, feat_encoder_3, ind_4, size_4) 60 | feat_decoder = self.decoder3(feat_decoder, feat_encoder_2, ind_3, size_3) 61 | feat_decoder = self.decoder4(feat_decoder, feat_encoder_1, ind_2, size_2) 62 | 63 | return F.log_softmax(self.classifier(feat_decoder), dim=1) 64 | 65 | 66 | class _Encoder(nn.Module): 67 | """Encoder layer encodes the features along the contracting path (left side), 68 | drop_rate parameter is used with respect to the paper and the official 69 | caffe model. 70 | 71 | Args: 72 | n_in_feat (int): number of input features 73 | n_out_feat (int): number of output features 74 | drop_rate (float): dropout rate at the end of the block 75 | """ 76 | def __init__(self, n_in_feat, n_out_feat, drop_rate=0): 77 | super(_Encoder, self).__init__() 78 | 79 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 80 | nn.ReLU(inplace=True), 81 | nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1), 82 | nn.ReLU(inplace=True)] 83 | 84 | if drop_rate > 0: 85 | layers += [nn.Dropout(drop_rate)] 86 | 87 | self.features = nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | return self.features(x) 91 | 92 | 93 | class _Decoder(nn.Module): 94 | """Decoder layer decodes the features by performing deconvolutions and 95 | concatenating the resulting features with cropped features from the 96 | corresponding encoder (skip-connections). Encoder features are cropped 97 | because convolution operations does not allow to recover the same 98 | resolution in the expansive path (cf input image size > output 99 | segmentation map size). 100 | 101 | Args: 102 | n_in_feat (int): number of input features 103 | n_out_feat (int): number of output features 104 | """ 105 | def __init__(self, n_in_feat, n_out_feat): 106 | super(_Decoder, self).__init__() 107 | 108 | self.encoder = _Encoder(n_in_feat * 2, n_out_feat) 109 | 110 | def forward(self, x, feat_encoder, indices, size): 111 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size) 112 | feat = torch.cat([unpooled, feat_encoder], 1) 113 | feat = self.encoder(feat) 114 | 115 | return feat 116 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/datasets/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torchmed.datasets import MedFile 5 | from torchmed.samplers import MaskableSampler 6 | from torchmed.patterns import SquaredSlidingWindow 7 | from torchmed.readers import SitkReader 8 | from torchmed.utils.transforms import Pad 9 | 10 | 11 | class MICCAI2012MedFile(object): 12 | def __init__(self, data, nb_workers): 13 | # a data map which specifies the image to read and reading specs 14 | # the transform pads each dimension of the image on both sides. 15 | patient_data = { 16 | 'image_ref': SitkReader( 17 | os.path.join(data, 'prepro_im_mni_bc.nii.gz'), 18 | torch_type='torch.FloatTensor', 19 | transform=Pad(((1, 1), (3, 3), (1, 1)), 'reflect')) 20 | } 21 | # medfile dataset takes a data map, a Sampler and a transform 22 | self.test_data = MedFile(patient_data, self.buid_sampler(nb_workers), 23 | transform=lambda t: t.permute(1, 0, 2)) 24 | 25 | # init all the images before multiprocessing 26 | self.test_data._sampler._coordinates.share_memory_() 27 | for k, v in self.test_data._sampler._data.items(): 28 | v._torch_init() 29 | 30 | def buid_sampler(self, nb_workers): 31 | # sliding window of size [184, 7, 184] without padding 32 | patch2d = SquaredSlidingWindow(patch_size=[184, 7, 184], use_padding=False) 33 | # pattern map links image id to a Sampler 34 | pattern_mapper = {'input': ('image_ref', patch2d)} 35 | 36 | return MaskableSampler(pattern_mapper, nb_workers=nb_workers) 37 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/datasets/mappings.py: -------------------------------------------------------------------------------- 1 | 2 | class Miccai12Mapping(object): 3 | def __init__(self): 4 | self.all_labels = [0, 4, 11, 23, 30, 31, 32, 35, 36, 37, 38, 39, 40, 5 | 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 55, 6 | 56, 57, 58, 59, 60, 61, 62, 63, 64, 66, 69, 71, 72, 7 | 73, 75, 76, 100, 101, 102, 103, 104, 105, 106, 107, 8 | 108, 109, 112, 113, 114, 115, 116, 117, 118, 119, 9 | 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 10 | 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 11 | 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 12 | 152, 153, 154, 155, 156, 157, 160, 161, 162, 163, 13 | 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 14 | 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 15 | 184, 185, 186, 187, 190, 191, 192, 193, 194, 195, 16 | 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 17 | 206, 207] 18 | self.ignore_labels = [1, 2, 3] + \ 19 | [5, 6, 7, 8, 9, 10] + \ 20 | [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] + \ 21 | [24, 25, 26, 27, 28, 29] + \ 22 | [33, 34] + [42, 43] + [53, 54] + \ 23 | [63, 64, 65, 66, 67, 68] + [70, 74] + \ 24 | [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 25 | 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + \ 26 | [110, 111, 126, 127, 130, 131, 158, 159, 188, 189] 27 | self.overall_labels = set(self.all_labels).difference( 28 | set(self.ignore_labels)) 29 | 30 | self.cortical_labels = [x for x in self.overall_labels if x >= 100] 31 | self.non_cortical_labels = \ 32 | [x for x in self.overall_labels if x > 0 and x < 100] 33 | 34 | self.map = {v: k for k, v in enumerate(self.overall_labels)} 35 | self.reversed_map = {k: v for k, v in enumerate(self.overall_labels)} 36 | self.nb_classes = len(self.overall_labels) 37 | 38 | def __getitem__(self, index): 39 | return self.map[index] 40 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/datasets/training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import random 5 | import torch 6 | 7 | from torchmed.datasets import MedFile, MedFolder 8 | from torchmed.samplers import MaskableSampler 9 | from torchmed.patterns import SquaredSlidingWindow 10 | from torchmed.readers import SitkReader 11 | from torchmed.utils.transforms import Pad 12 | from torchmed.utils.augmentation import elastic_deformation_2d 13 | import torchmed.utils.transforms as transforms 14 | 15 | 16 | class MICCAI2012Dataset(object): 17 | def __init__(self, base_dir, nb_workers): 18 | self.train_dataset = MedFolder( 19 | generate_medfiles(os.path.join(base_dir, 'train'), nb_workers), 20 | transform=transform_train, target_transform=transform_target, 21 | paired_transform=elastic_transform) 22 | self.validation_dataset = MedFolder( 23 | generate_medfiles(os.path.join(base_dir, 'validation'), nb_workers), 24 | transform=transform_train, target_transform=transform_target, 25 | paired_transform=elastic_transform) 26 | 27 | # init all the images before multiprocessing 28 | for medfile in self.train_dataset._medfiles: 29 | medfile._sampler._coordinates.share_memory_() 30 | for k, v in medfile._sampler._data.items(): 31 | v._torch_init() 32 | 33 | # init all the images before multiprocessing 34 | for medfile in self.validation_dataset._medfiles: 35 | medfile._sampler._coordinates.share_memory_() 36 | for k, v in medfile._sampler._data.items(): 37 | v._torch_init() 38 | 39 | # read cumulated volume of each labels 40 | df = pd.read_csv(os.path.join(base_dir, 'train/class_log.csv'), sep=';', index_col=0) 41 | self.class_freq = torch.from_numpy(df['volume'].values).float() 42 | 43 | # read ground truth adjacency matrix 44 | adjacency_mat_path = os.path.join(base_dir, 'train/graph.csv') 45 | self.adjacency_mat = torch.from_numpy(np.loadtxt(adjacency_mat_path, delimiter=';')) 46 | 47 | 48 | class SemiDataset(object): 49 | def __init__(self, base_dir, nb_workers): 50 | def transform_semi(tensor): 51 | return tensor.permute(1, 0, 2) 52 | 53 | self.train_dataset = MedFolder( 54 | generate_medfiles(base_dir, nb_workers, with_target=False), 55 | transform=transform_semi) 56 | 57 | # init all the images before multiprocessing 58 | for medfile in self.train_dataset._medfiles: 59 | medfile._sampler._coordinates.share_memory_() 60 | for k, v in medfile._sampler._data.items(): 61 | v._torch_init() 62 | 63 | 64 | def build_patient_data_map(dir, with_target): 65 | # pads each dimension of the image on both sides. 66 | pad_reflect = Pad(((1, 1), (3, 3), (1, 1)), 'reflect') 67 | file_map = { 68 | 'image_ref': SitkReader( 69 | os.path.join(dir, 'prepro_im_mni_bc.nii.gz'), 70 | torch_type='torch.FloatTensor', transform=pad_reflect) 71 | } 72 | if with_target: 73 | file_map['target'] = SitkReader( 74 | os.path.join(dir, 'prepro_seg_mni.nii.gz'), 75 | torch_type='torch.LongTensor', transform=pad_reflect) 76 | 77 | return file_map 78 | 79 | 80 | def build_sampler(nb_workers, with_target): 81 | # sliding window of size [184, 7, 184] without padding 82 | patch2d = SquaredSlidingWindow(patch_size=[184, 7, 184], use_padding=False) 83 | # pattern map links image id to a Sampler 84 | pattern_mapper = {'input': ('image_ref', patch2d)} 85 | if with_target: 86 | pattern_mapper['target'] = ('target', patch2d) 87 | 88 | # add a fixed offset to make patch sampling faster (doesn't look for all positions) 89 | return MaskableSampler(pattern_mapper, offset=[92, 1, 92], 90 | nb_workers=nb_workers) 91 | 92 | 93 | def elastic_transform(data, label): 94 | # elastic deformation 95 | if random.random() > 0.4: 96 | data_label = torch.cat([data, label.unsqueeze(0).float()], 0) 97 | data_label = elastic_deformation_2d( 98 | data_label, 99 | data_label.shape[1] * 1.05, # intensity of the deformation 100 | data_label.shape[1] * 0.05, # smoothing of the deformation 101 | 0, # order of bspline interp 102 | mode='nearest') # border mode 103 | 104 | data = data_label[0:7] 105 | label = data_label[7].long() 106 | 107 | return data, label 108 | 109 | 110 | def generate_medfiles(dir, nb_workers, data_map_fn=build_patient_data_map, 111 | sampler_fn=build_sampler, with_target=True): 112 | # database composed of dirname contained in the allowed_data.txt 113 | database = open(os.path.join(dir, 'allowed_data.txt'), 'r') 114 | patient_list = [line.rstrip('\n') for line in database] 115 | medfiles = [] 116 | 117 | # builds a list of MedFiles, one for each folder 118 | for patient in patient_list: 119 | if patient: 120 | patient_dir = os.path.join(dir, patient) 121 | patient_data = data_map_fn(patient_dir, with_target) 122 | patient_file = MedFile(patient_data, sampler_fn(nb_workers, with_target)) 123 | medfiles.append(patient_file) 124 | 125 | return medfiles 126 | 127 | 128 | def transform_train(tensor): 129 | return tensor.permute(1, 0, 2) 130 | 131 | 132 | def transform_target(tensor): 133 | return tensor.permute(1, 0, 2)[3] 134 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/inference_canvas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import pandas as pd 5 | import collections 6 | 7 | from torchmed.utils.metric import multiclass, dice, hd, assd, precision 8 | from torchmed.readers import SitkReader 9 | 10 | 11 | class InferenceCanvas(object): 12 | def __init__(self, args, inference_fn, data_fn, model): 13 | self.args = args 14 | self.dataset_fn = data_fn 15 | self.inference_fn = inference_fn 16 | self.model = model 17 | 18 | self.file_names = {'image': 'prepro_im_mni_bc.nii.gz', 19 | 'label': 'prepro_seg_mni.nii.gz', 20 | 'segmentation': 'segmentation.nii.gz'} 21 | self.metrics = {'dice': dice, 22 | 'hausdorff': hd, 23 | 'mean_surface_distance': assd, 24 | 'precision': precision 25 | } 26 | 27 | def __call__(self): 28 | print("=> started segmentation script") 29 | # if output dir does not exists, create it 30 | if not os.path.isdir(self.args.output): 31 | os.makedirs(self.args.output) 32 | 33 | print("=> loading the architecture") 34 | model = self.model.cuda() 35 | 36 | print("=> loading trained model at {}".format(self.args.model)) 37 | # load model parameters 38 | torch.backends.cudnn.benchmark = True 39 | checkpoint = torch.load(self.args.model) 40 | try: 41 | model.load_state_dict(checkpoint['state_dict']) 42 | except RuntimeError as e: 43 | model = torch.nn.DataParallel(model).cuda() 44 | model.load_state_dict(checkpoint['state_dict']) 45 | model = model.module 46 | 47 | print("=> segmentation output at {}".format(self.args.output)) 48 | test_times = [] 49 | if os.path.isfile(os.path.join(self.args.data, 'allowed_data.txt')): 50 | allowed_data_file = open( 51 | os.path.join(self.args.data, 'allowed_data.txt'), 'r') 52 | patient_list = [line.rstrip('\n') for line in allowed_data_file] 53 | 54 | for patient in patient_list: 55 | if patient: 56 | patient_dir = os.path.join(self.args.data, patient) 57 | patient_out = os.path.join(self.args.output, patient) 58 | os.makedirs(patient_out) 59 | test_time = self.segment_metric_plot(model, 60 | patient_dir, 61 | patient_out) 62 | test_times.append(test_time) 63 | 64 | else: 65 | test_times.append( 66 | self.segment_metric_plot(model, self.args.data, 67 | self.args.output)) 68 | 69 | # write test time to file 70 | time_file = os.path.join(self.args.output, 'test_time.csv') 71 | time_report = open(time_file, 'a') 72 | time_report.write('image_id;minutes\n') 73 | for time_id in range(0, len(test_times)): 74 | time_report.write('{};{:.5f}\n'.format( 75 | time_id, test_times[time_id])) 76 | time_report.write('{};{:.5f}\n'.format( 77 | 'average', sum(test_times) / len(test_times))) 78 | time_report.flush() 79 | 80 | def segment_metric_plot(self, model, patient_dir, patient_out): 81 | patient_seg = os.path.join(patient_out, 82 | self.file_names['segmentation']) 83 | 84 | # segmentation 85 | test_time = self.segment_one_patient(model, patient_dir, patient_out) 86 | print('-- segmented {} in {:.2f}s'.format(patient_dir, test_time)) 87 | 88 | # if ground truth is available, use metrics 89 | if self.args.wo_metrics: 90 | patient_map = os.path.join(patient_dir, 91 | self.file_names['label']) 92 | 93 | self.save_error_map(patient_dir, patient_out) 94 | 95 | # evaluate metrics 96 | ref_img = SitkReader(patient_map).to_numpy() 97 | seg_img = SitkReader(patient_seg).to_numpy() 98 | results, undec_structs = multiclass(seg_img, ref_img, 99 | self.metrics.values()) 100 | metrics_results = zip(self.metrics.keys(), results) 101 | 102 | m = collections.OrderedDict(sorted(metrics_results, key=lambda x: x[0])) 103 | df = pd.DataFrame.from_dict(m) 104 | df.to_csv(os.path.join(patient_out, 'metrics_report.csv'), ';') 105 | 106 | if len(undec_structs) > 0: 107 | df = pd.DataFrame(undec_structs, columns=["class_id"]) 108 | df.to_csv(os.path.join(patient_out, 'undetected_classes.csv'), ';') 109 | 110 | return test_time 111 | 112 | def segment_one_patient(self, model, data, output): 113 | # Data loading code 114 | medcomp = self.dataset_fn(data, self.args.batch_size).test_data 115 | loader = torch.utils.data.DataLoader(medcomp, 116 | batch_size=self.args.batch_size, 117 | shuffle=False, 118 | num_workers=5, 119 | pin_memory=True) 120 | 121 | lab = SitkReader(os.path.join(data, self.file_names['image']), 122 | torch_type='torch.LongTensor') 123 | lab_array = lab.to_numpy() 124 | lab_array.fill(0) 125 | 126 | start_time = time.time() 127 | probability_maps = self.inference_fn(model, loader, lab_array) 128 | end_time = time.time() 129 | 130 | # save label map 131 | lab.to_image_file(os.path.join(output, self.file_names['segmentation'])) 132 | 133 | """ 134 | generation of probability maps, use only if you want to visualize 135 | probabilities for a given label, otherwise it will generate all 136 | the probabilities maps 137 | """ 138 | if len(probability_maps) > 0: 139 | os.makedirs(os.path.join(output, "probability_maps")) 140 | 141 | img = SitkReader(os.path.join(data, self.file_names['image'])) 142 | img_array = img.to_numpy() 143 | for map_id in range(0, len(probability_maps)): 144 | prob_file = os.path.join(output, 145 | "probability_maps/label_{}.img".format(map_id)) 146 | img_array.fill(0) 147 | img_array[...] = probability_maps[map_id] 148 | img.to_image_file(prob_file) 149 | 150 | return (end_time - start_time) / 60 151 | 152 | def save_error_map(self, data, output): 153 | lab = SitkReader(os.path.join(data, self.file_names['label'])) 154 | seg = SitkReader(os.path.join(output, self.file_names['segmentation'])) 155 | lab_array = lab.to_numpy() 156 | seg_array = seg.to_numpy() 157 | 158 | seg_array[seg_array == lab_array] = 0 159 | 160 | # save label map 161 | seg.to_image_file(os.path.join(output, 'error_map.img')) 162 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AdjacencyEstimator(torch.nn.Module): 5 | """Estimates the adjacency graph of labels based on probability maps. 6 | 7 | Parameters 8 | ---------- 9 | nb_labels : int 10 | number of structures segmented. 11 | 12 | """ 13 | def __init__(self, nb_labels): 14 | super(AdjacencyEstimator, self).__init__() 15 | 16 | # constant 2D convolution, needs constant weights and no gradient 17 | # apply the same convolution filter to all labels 18 | layer = torch.nn.Conv2d(in_channels=nb_labels, out_channels=nb_labels, 19 | kernel_size=3, stride=1, padding=0, 20 | bias=False, groups=nb_labels) 21 | layer.weight.data.fill_(0) 22 | 23 | canvas = torch.Tensor(3, 3).fill_(1) 24 | # fill 3x3 filters with ones 25 | for i in range(0, nb_labels): 26 | layer.weight.data[i, 0, :, :] = canvas 27 | 28 | # exlude parameters from the subgraph 29 | for param in layer.parameters(): 30 | param.requires_grad = False 31 | 32 | self._conv_layer = layer 33 | # replicate padding to recover the same resolution after convolution 34 | self._pad_layer = torch.nn.ZeroPad2d(1) 35 | 36 | def forward(self, image): 37 | # padding of tensor of size batch x k x W x H 38 | p_tild = self._pad_layer(image) 39 | # apply constant convolution and normalize by size of kernel 40 | p_tild = self._conv_layer(p_tild) / 9 41 | 42 | # normalization factor 43 | norm_factor = image.size()[0] * image.size()[2] * image.size()[3] 44 | 45 | # old deprecated formulation 46 | # graph = torch.Tensor(image.size(1), image.size(1)).cuda() 47 | # graph.fill_(0) 48 | # for i in range(image.size(1)): 49 | # p_tild_i = image[:, i, :, :].unsqueeze(1) 50 | # # element product of image.exp() and p_tild. Sum over batch, W, H 51 | # graph[:, i] = (image * p_tild_i).sum(0).sum(1).sum(1) 52 | 53 | # torch v1.0 einstein notation replaces following loop 54 | return torch.einsum('nihw,njhw->ij', image, p_tild) / norm_factor 55 | 56 | 57 | class LambdaControl: 58 | """Utility to optimize the value of the lambda parameter, which is the 59 | weighting term of the NonAdjLoss. 60 | 61 | Parameters 62 | ---------- 63 | graph_gt : :class:`torch.Tensor` 64 | ground truth adjacency map. 65 | tuning_epoch : int 66 | index of the epoch at which we stop increasing `lambda_`. 67 | For example : for a total training time of 200 epochs, tuning_epoch 68 | could be 180. 69 | To disable this tuning step, set the parameter to 70 | the total number of epoch (ex: 200 in the previous example). 71 | 72 | Attributes 73 | ---------- 74 | nb_conn_ab_max : int 75 | Description of attribute `nb_conn_ab_max`. 76 | lambda_ : float 77 | weighting term of the NonAdjLoss. 78 | lambda_factor : float 79 | factor used to increase `lambda_` at each call to :func:`update`. 80 | train_nan_count : int 81 | counter of the number of nan values caught so far. 82 | last_good_epoch : int 83 | Index of the last epoch with a good seg metric. 84 | Good = decrease should not be more than 0.01 compared to `good_seg_metric_value`. 85 | good_seg_metric_value : float 86 | value of the segmentation metric at epoch 0. This is used as a reference 87 | point in order to monitor if the segmentation quality decreases during 88 | training. This is ultimately used to control `lambda_`. 89 | update_counter : int 90 | number of iterations without updating `lambda_`. `lambda_` is updated 91 | when `update_counter` reaches 5. 92 | graph_gt : :class:`torch.Tensor` 93 | ground truth adjacency map. 94 | tuning_epoch : int 95 | index of the epoch at which we stop increasing `lambda_`. 96 | For example : for a total training time of 200 epochs, tuning_epoch 97 | could be 180. 98 | To disable this tuning step, set the parameter to 99 | the total number of epoch (ex: 200 in the previous example). 100 | 101 | """ 102 | def __init__(self, graph_gt, tuning_epoch): 103 | self.graph_gt = graph_gt 104 | self.nb_conn_ab_max = (graph_gt > 0).sum() 105 | self.lambda_ = 1 106 | self.lambda_factor = 1.3 107 | self.train_nan_count = 0 108 | self.last_good_epoch = 0 109 | self.good_seg_metric_value = None 110 | self.update_counter = 0 111 | self.tuning_epoch = tuning_epoch 112 | 113 | def get_config(self): 114 | if self.good_seg_metric_value is None: 115 | return (self.graph_gt, self.nb_conn_ab_max, self.lambda_, False) 116 | else: 117 | return (self.graph_gt, self.nb_conn_ab_max, self.lambda_, True) 118 | 119 | def update(self, epoch, avg_seg_loss_train, avg_nonadjloss_train, 120 | avg_seg_loss_val, has_train_nan): 121 | """Update the lambda parameter according to metrics from the previous epoch. 122 | 123 | Parameters 124 | ---------- 125 | epoch : int 126 | index of the epoch. 127 | avg_seg_loss_train : float 128 | average training segmentation loss per image per epoch. 129 | avg_nonadjloss_train : float 130 | average training NonAdjLoss per image per epoch. 131 | avg_seg_loss_val : float 132 | average validation segmentation loss per image per epoch. 133 | has_train_nan : bool 134 | flag indicating if nan values have appeared during last training epoch. 135 | 136 | Returns 137 | ------- 138 | int 139 | 0 if training should continue. 140 | -1 or -2 to interrupt. 141 | 142 | """ 143 | # Init lambda after first epoch 144 | if epoch == 0: 145 | self.lambda_ = 0.3 * (avg_seg_loss_train / avg_nonadjloss_train) 146 | self.good_seg_metric_value = avg_seg_loss_val 147 | elif epoch == self.tuning_epoch: 148 | self.update_counter = 4 149 | 150 | # if no issue was detected, check dice and update 151 | if has_train_nan is False and self.train_nan_count < 3: 152 | # automatic update of lambda 153 | if epoch < self.tuning_epoch: 154 | if self.good_seg_metric_value - avg_seg_loss_val >= 0.02: 155 | print('--High decrease in dice detected, no lambda update for 5 epochs') 156 | self.update_counter = 0 157 | elif self.good_seg_metric_value - avg_seg_loss_val >= 0.01: 158 | print('--Small decrease in dice detected') 159 | 160 | elif epoch > 0: 161 | if self.update_counter >= 4: 162 | print('--Increase lambda') 163 | self.lambda_ *= self.lambda_factor 164 | self.update_counter = 0 165 | else: 166 | self.update_counter += 1 167 | 168 | self.train_nan_count = 0 169 | else: 170 | dice_diff = self.good_seg_metric_value - avg_seg_loss_val 171 | if dice_diff > 0: 172 | if self.update_counter >= 4: 173 | print('--Decrease lambda to improve dice {:.4f}'.format(dice_diff)) 174 | self.lambda_ *= 0.9 175 | self.update_counter = 0 176 | else: 177 | print('--Dice is not good enough {:.4f}'.format(dice_diff)) 178 | self.update_counter += 1 179 | else: 180 | self.update_counter = 0 181 | 182 | if avg_seg_loss_val >= self.good_seg_metric_value - 0.01: 183 | print('--Logging as good epoch') 184 | self.last_good_epoch = epoch 185 | 186 | # if an error was detected decrease factor 187 | elif has_train_nan is True and self.train_nan_count < 3: 188 | print('--Decreasing lambda factor') 189 | self.lambda_ *= 0.9 190 | self.lambda_factor *= 0.98 191 | self.update_counter = 0 192 | self.train_nan_count += 1 193 | 194 | if epoch - self.last_good_epoch >= 5 and (epoch - self.last_good_epoch) % 5 == 0: 195 | if epoch - self.last_good_epoch >= 10: 196 | print('--Insufficient dice for 15 epochs, reboot') 197 | print('--Decreasing lambda factor') 198 | self.lambda_ *= 0.9 199 | return -2 200 | else: 201 | print('--Decreasing lambda because of deacreasing dice') 202 | self.lambda_ *= 0.95 203 | 204 | self.lambda_factor *= 0.98 205 | self.update_counter = 0 206 | 207 | if self.lambda_factor < 1: 208 | print('lambda_factor is < 1, ending.') 209 | return -1 210 | elif self.train_nan_count >= 3: 211 | print('--Too many errors, restart from a previous epoch') 212 | return -2 213 | else: 214 | return 0 215 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/readme.md: -------------------------------------------------------------------------------- 1 | ## Code organization 2 | 3 | - datasets : converts input data into iterable datasets for training and inference. 4 | - `training.py` : build a dataset for inference. 5 | - `inference.py` : build the datasets for training and validation. 6 | - `mappings.py` : mapping of image labels. 7 | - `architecture.py` : architecture of the CNN. 8 | - `inference_canvas.py` : inference and metrics evaluation for the test dataset. 9 | - `segment.py` : script for segmenting images based on a model. 10 | - `train_segment.py` : training + segmentation of test dataset. 11 | - `train.py` : training script. 12 | 13 | The output folder will contain : 14 | 15 | - figures plotting the various segmentation metrics and losses. 16 | - logs of metrics and losses for each iteration on train and validation. 17 | - `checkpoint.pth.tar` : last epoch model's parameters. 18 | - `model_best_dice.pth.tar` : best performing model's parameters. 19 | 20 | ## About NonAdjLoss 21 | 22 | [Semi-supervised Learning for Segmentation Under Semantic 23 | Constraint](https://link.springer.com/chapter/10.1007/978-3-030-00931-1_68), 24 | Pierre-Antoine Ganaye, Michaël Sdika, Hugues Benoit-Cattin. MICCAI 2018 25 | 26 | Image segmentation based on convolutional neural networks is proving to be a 27 | powerful and efficient solution for medical applications. However, the lack of 28 | annotated data, presence of artifacts and variability in appearance can still 29 | result in inconsistencies during the inference. We choose to take advantage of 30 | the invariant nature of anatomical structures, by enforcing a semantic constraint 31 | to improve the robustness of the segmentation. The proposed solution is applied 32 | on a brain structures segmentation task, where the output of the network is 33 | constrained to satisfy a known adjacency graph of the brain regions. This 34 | criteria is introduced during the training through an original penalization 35 | loss named NonAdjLoss. With the help of a new metric, we show that the 36 | proposed approach significantly reduces abnormalities produced during the 37 | segmentation. Additionally, we demonstrate that our framework can be used 38 | in a semi-supervised way, opening a path to better generalization to unseen data. 39 | 40 | The implementation of the NonAdjLoss is contained in : 41 | 42 | - `../scripts/adjacency/extract_3d_adjacency.py` : extracts and sum adjacencies from ground truth 43 | segmentation maps. 44 | - `loss.py` : adjacency estimator and tuning of the lambda weighting parameter. 45 | 46 | ## Pre-trained models 47 | 48 | You can download the trained models [here](https://drive.google.com/drive/folders/1OWlQlzhjOgl1GuaRKibRyyo_ce-CtSln?usp=sharing) 49 | for the 2D baseline, 2D + NonAdjLoss and 2D + NonAdjLoss with semi-supervision. 50 | 51 | ### How to 52 | 53 | In practice this non-adjacency penalization is applied on a pre-trained model, 54 | because it's simpler to optimize a model that's already good at segmenting structures. 55 | Then enforcing this penalization can be seen as some kind of fine-tuning, where 56 | first you train your cnn for segmentation and finally fine-tune it to produce 57 | segmentations that are also correct with respect to the adjacency matrix. This 58 | matrix can be extracted from the ground truth label maps, or given by hand. 59 | 60 | In this example the adjacency matrix should be named `graph.csv` and located 61 | in the train directory. Then your dataset should look something like this : 62 | 63 | ```bash 64 | (pytorch-v1.0) [ganaye@iv-ms-593 miccai]$ ll 65 | total 12 66 | drwxrwxr-x 22 ganaye creatis 4096 Jan 26 2018 test 67 | drwxrwxr-x 14 ganaye creatis 4096 Feb 2 18:41 train 68 | drwxrwxr-x 7 ganaye creatis 4096 Jan 15 2018 validation 69 | 70 | (pytorch-v1.0) [ganaye@iv-ms-593 miccai]$ ll train/ 71 | total 516 72 | drwxrwxr-x 2 ganaye creatis 4096 Jan 23 11:09 1000 73 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1001 74 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1002 75 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1006 76 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1007 77 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1008 78 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1009 79 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1010 80 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1011 81 | drwxrwxr-x 2 ganaye creatis 4096 Jan 12 2018 1012 82 | -rw-rw-r-- 1 ganaye creatis 5 Jan 15 2018 allowed_data.txt 83 | -rw-rw-r-- 1 ganaye creatis 1288 Jan 15 2018 class_log.csv 84 | -rw-rw-r-- 1 ganaye creatis 455625 Jan 15 2018 graph.csv 85 | -rw-rw-r-- 1 ganaye creatis 3030 Jan 15 2018 graph.png 86 | -rw-rw-r-- 1 ganaye creatis 71 Jan 15 2018 stats_log.txt 87 | 88 | (pytorch-v1.0) [ganaye@iv-ms-593 miccai]$ ll train/1000 89 | total 41864 90 | -rw-rw-r-- 1 ganaye creatis 16340661 Jan 15 2018 im_mni_bc.nii.gz 91 | -rw-rw-r-- 1 ganaye creatis 8589615 Jan 15 2018 im_mni.nii.gz 92 | -rw-rw-r-- 1 ganaye creatis 125 Jan 15 2018 mni_aff_transf.c3dmat 93 | -rw-rw-r-- 1 ganaye creatis 190 Jan 15 2018 mni_aff_transf.mat 94 | -rw-rw-r-- 1 ganaye creatis 16761108 Jan 15 2018 prepro_im_mni_bc.nii.gz 95 | -rw-rw-r-- 1 ganaye creatis 344875 Jan 15 2018 prepro_seg_mni.nii.gz 96 | -rw-rw-r-- 1 ganaye creatis 344193 Jan 15 2018 seg_mni.nii.gz 97 | 98 | ``` 99 | 100 | ### General advice 101 | 102 | Before training your first network with this penalization, we advise you to 103 | compare the ground truth adjacencies to the ones produced by your baseline model. 104 | In this way, you will have a first impression of how your solution performs and 105 | if not, how much improvements you can expect from this work. 106 | 107 | How to extract an adjacency matrix from an image, this example is part 108 | of `extract_3d_adjacency.py` : 109 | 110 | ```python 111 | # adjacency matrix with one additional dimension for discarded label (-1) 112 | img_adj = torch.FloatTensor(args.nb_labels + 1, args.nb_labels + 1).zero_() 113 | # reads a segmentation map 114 | label = SitkReader(train_patient + '/prepro_seg_mni.nii.gz') 115 | 116 | # image array from the reader 117 | label_array = label.to_torch().long() 118 | # re-label discarded label (-1) by the last positive integer 119 | label_array[label_array == -1] = args.nb_labels 120 | 121 | # extract adjacency matrix from the image and fill in the matrix 122 | image2graph3d_patch(label_array, img_adj.numpy(), args.nb_labels, args.n_size) 123 | # discard last positive label (discarded label) 124 | img_adj = img_adj[:-1, :-1] 125 | ``` 126 | 127 | ### Training 128 | 129 | To start training with a pre-trained model, use the following command : 130 | 131 | ```bash 132 | python train_segment.py /mnt/hdd/datasets/processed/miccai /mnt/hdd/datasets/models/final/ -j 4 -b 2 --lr 0.001 --resume ~/model_best_dice.pth.tar 133 | ``` 134 | 135 | For more options : 136 | 137 | ```bash 138 | (pytorch-v1.0) [ganaye@iv-ms-593 02_segmentation_NonAdjLoss]$ python train_segment.py --help 139 | usage: train_segment.py [-h] [-j N] [--epochs N] [--start-epoch N] [-b N] 140 | [--lr LR] [--momentum M] [--weight-decay W] 141 | [--print-freq N] [--resume PATH] 142 | DATA_DIR OUTPUT_DIR 143 | 144 | PyTorch Automatic Segmentation Training and Inference 145 | 146 | positional arguments: 147 | DATA_DIR path to the data dir 148 | OUTPUT_DIR path to the output directory (default: current dir) 149 | 150 | optional arguments: 151 | -h, --help show this help message and exit 152 | -j N, --workers N number of data loading workers (default: 4) 153 | --epochs N number of total epochs to run (default: 200) 154 | --start-epoch N manual epoch number (useful on restarts) 155 | -b N, --batch-size N mini-batch size (default: 64) 156 | --lr LR, --learning-rate LR 157 | initial learning rate (default: 0.01) 158 | --momentum M momentum 159 | --weight-decay W, --wd W 160 | weight decay (default: 1e-4) 161 | --print-freq N, -p N print frequency (default: 10) 162 | --resume PATH path to latest checkpoint (default: none) 163 | ``` 164 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from architecture import ModSegNet 5 | from inference_canvas import InferenceCanvas 6 | from datasets.mappings import Miccai12Mapping 7 | from datasets.inference import MICCAI2012MedFile 8 | 9 | 10 | parser = argparse.ArgumentParser( 11 | description='PyTorch Automatic Segmentation (inference mode)') 12 | parser.add_argument('data', metavar='DIR', 13 | help='path to dataset') 14 | parser.add_argument('model', metavar='MODEL', 15 | help='path to a trained model') 16 | parser.add_argument('output', metavar='OUTPUT', 17 | help='path to the output segmentation map') 18 | 19 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 20 | help='number of data loading workers (default: 4)') 21 | parser.add_argument('-b', '--batch-size', default=16, type=int, 22 | metavar='N', help='mini-batch size (default: 16)') 23 | parser.add_argument('--wo-metrics', action='store_false', 24 | help='whether to use metrics (dice, assd) or not') 25 | 26 | 27 | def main(): 28 | global args 29 | args = parser.parse_args() 30 | model = ModSegNet(num_classes=Miccai12Mapping().nb_classes, 31 | n_init_features=7).cuda() 32 | inference_canvas = InferenceCanvas(args, infer_segmentation_map, 33 | MICCAI2012MedFile, model) 34 | inference_canvas() 35 | 36 | 37 | def infer_segmentation_map(model, data_loader, label_map): 38 | probability_maps = [] 39 | 40 | with torch.no_grad(): 41 | for position, input in data_loader: 42 | output = model(input.cuda()) 43 | _, predicted = output.data.max(1) 44 | 45 | # for each element of the batch 46 | for i in range(0, predicted.size(0)): 47 | y = position[i][1] 48 | label_map[:, y - 3, :] = predicted[i].cpu().numpy()[1:-1, 1:-1] 49 | 50 | return probability_maps 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/train_segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import random 5 | import string 6 | import time 7 | 8 | 9 | parser = argparse.ArgumentParser( 10 | description='PyTorch Automatic Segmentation Training and Inference') 11 | parser.add_argument('data', metavar='DATA_DIR', help='path to the data dir') 12 | parser.add_argument('output_dir', default='', metavar='OUTPUT_DIR', 13 | help='path to the output directory (default: current dir)') 14 | 15 | parser.add_argument('--semi-data-dir', metavar='SEMI_DATA_DIR', default=None, 16 | help='path to the dataset for semi-supervision') 17 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 18 | help='number of data loading workers (default: 4)') 19 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 20 | help='number of total epochs to run (default: 200)') 21 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 22 | help='manual epoch number (useful on restarts)') 23 | parser.add_argument('-b', '--batch-size', default=64, type=int, 24 | metavar='N', help='mini-batch size (default: 64)') 25 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 26 | metavar='LR', help='initial learning rate (default: 0.01)') 27 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 28 | help='momentum') 29 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 30 | metavar='W', help='weight decay (default: 1e-4)') 31 | parser.add_argument('--print-freq', '-p', default=10, type=int, 32 | metavar='N', help='print frequency (default: 10)') 33 | parser.add_argument('--resume', default=None, type=str, metavar='PATH', 34 | help='path to latest checkpoint (default: none)') 35 | 36 | 37 | def main(): 38 | args = parser.parse_args() 39 | 40 | code_dir = os.path.dirname(os.path.realpath(__file__)) 41 | train_script = os.path.join(code_dir, 'train.py') 42 | inference_script = os.path.join(code_dir, 'segment.py') 43 | report_script_dir = os.path.join(code_dir, '../scripts/report/') 44 | report_scripts = ['plot_boxplot_labels_by_metric', 45 | 'plot_boxplot_patients_by_metric', 46 | 'write_patient_by_metric', 'write_label_by_metric'] 47 | 48 | exp_id = os.path.basename(code_dir) 49 | now = datetime.datetime.now() 50 | dir_name = '{}_{}_{}@{}-{}_{}_{}'.format(now.year, now.month, now.day, 51 | now.hour, now.minute, 52 | exp_id, id_generator(2)) 53 | output_dir = os.path.join(args.output_dir, dir_name) 54 | os.makedirs(output_dir) 55 | 56 | ##### 57 | # 58 | # Training 59 | # 60 | ##### 61 | print('--> The output folder is {}'.format(output_dir)) 62 | print('--> Started train script') 63 | ret = os.system('python -u {} {} {} -j {} -b {} --epochs {} --lr {}' 64 | ' --exp-id {} --resume {} --semi-data-dir {}'.format( 65 | train_script, 66 | args.data, 67 | output_dir, 68 | args.workers, 69 | args.batch_size, 70 | args.epochs, 71 | args.lr, 72 | exp_id, 73 | args.resume, 74 | args.semi_data_dir 75 | )) 76 | 77 | # in case training ended with an error 78 | if os.WEXITSTATUS(ret) != 0: 79 | return -1 80 | 81 | print('Sleeping for 5 seconds before segmentation. (read/write sync)') 82 | time.sleep(5) 83 | 84 | ##### 85 | # 86 | # Segmentation 87 | # 88 | ##### 89 | output_dir_seg = os.path.join(output_dir, 'segmentations') 90 | os.makedirs(output_dir_seg) 91 | 92 | segment_command = 'python -u {} {} {} {}'.format( 93 | inference_script, 94 | os.path.join(args.data, 'test'), 95 | os.path.join(output_dir, 'model_best_dice.pth.tar'), 96 | output_dir_seg 97 | ) 98 | 99 | print(segment_command) 100 | os.system(segment_command) 101 | 102 | ##### 103 | # 104 | # Reporting 105 | # 106 | ##### 107 | output_dir_report = os.path.join(output_dir, 'reports') 108 | for r_script in report_scripts: 109 | report_command = 'python -u {} {} {}'.format( 110 | os.path.join(report_script_dir, r_script + '.py'), 111 | output_dir_seg, 112 | output_dir_report 113 | ) 114 | os.system(report_command) 115 | time.sleep(1) 116 | 117 | 118 | def id_generator(size=6, chars=string.ascii_uppercase + string.digits): 119 | return ''.join(random.choice(chars) for _ in range(size)) 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /examples/02_brain_segmentation_NonAdjLoss_MRI/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | import os 4 | import pandas 5 | import shutil 6 | import torch 7 | 8 | from torchmed.utils.metric import dice as dc 9 | from torchmed.utils.metric import jaccard, multiclass 10 | 11 | 12 | def write_config(model, args, train_size, val_size): 13 | num_params = 0 14 | for m in model.modules(): 15 | if isinstance(m, torch.nn.Conv2d): 16 | num_params += m.weight.data.numel() 17 | 18 | configfile = os.path.join(args.output_dir, 'config.txt') 19 | cfg_f = open(configfile, "a") 20 | cfg_f.write('\ntraining with {} patches\n' 21 | 'validating with {} patches\n' 22 | .format(train_size * args.batch_size, 23 | val_size * args.batch_size)) 24 | cfg_f.write(('project: {}\n' + 25 | 'number of workers: {}\n' + 26 | 'number of epochs: {}\n' + 27 | 'starting epoch: {}\n' + 28 | 'batch size: {}\n' + 29 | 'learning rate: {:.6f}\n' + 30 | 'momentum: {:.5f}\n' + 31 | 'weight-decay: {:.5f}\n' + 32 | 'number of parameters: {}\n') 33 | .format(args.exp_id, args.workers, 34 | args.epochs, args.start_epoch, 35 | args.batch_size, args.lr, 36 | args.momentum, args.weight_decay, num_params) 37 | ) 38 | cfg_f.write('\nstarted training at {}\n'.format(datetime.datetime.now())) 39 | cfg_f.flush() 40 | 41 | 42 | def write_end_config(args, elapsed_time): 43 | configfile = os.path.join(args.output_dir, 'config.txt') 44 | cfg_f = open(configfile, "a") 45 | cfg_f.write('stopped training at {}\n'.format(datetime.datetime.now())) 46 | cfg_f.write('elapsed time : {:.2f} hours or {:.2f} days.' 47 | .format((elapsed_time) / (60 * 60), 48 | (elapsed_time) / (60 * 60 * 24))) 49 | cfg_f.flush() 50 | 51 | 52 | def update_figures(log_plot): 53 | # plot avg train loss_meter 54 | log_plot.add_line('cross_entropy', 'average_train.csv', 'epoch', 'cross_entropy_loss', "#1f77b4") 55 | log_plot.add_line('dice', 'average_train.csv', 'epoch', 'dice_loss', "#ff7f0e") 56 | log_plot.plot('losses_train.png', 'epoch', 'loss') 57 | 58 | # plot avg validation loss_meter 59 | log_plot.add_line('cross_entropy', 'average_validation.csv', 'epoch', 'cross_entropy_loss', "#1f77b4") 60 | log_plot.add_line('dice', 'average_validation.csv', 'epoch', 'dice_loss', "#ff7f0e") 61 | log_plot.plot('losses_validation.png', 'epoch', 'loss') 62 | 63 | # plot learning rate 64 | log_plot.add_line('learning_rates', 'learning_rates.csv', 65 | 'epoch', 'lr', '#1f77b4') 66 | log_plot.plot('learning_rate.png', 'epoch', 'learning rate') 67 | 68 | # plot dice 69 | log_plot.add_line('train', 'average_train.csv', 'epoch', 'dice_metric', '#1f77b4') 70 | log_plot.add_line('validation', 'average_validation.csv', 71 | 'epoch', 'dice_metric', '#ff7f0e') 72 | log_plot.plot('average_dice.png', 'epoch', 'dice', max_y=1) 73 | 74 | # plot iou 75 | log_plot.add_line('train', 'average_train.csv', 'epoch', 'iou_metric', '#1f77b4') 76 | log_plot.add_line('validation', 'average_validation.csv', 77 | 'epoch', 'iou_metric', '#ff7f0e') 78 | log_plot.plot('average_iou.png', 'epoch', 'iou', max_y=1) 79 | 80 | # plot nonadjloss 81 | log_plot.add_line('nonadjloss', 'average_validation.csv', 82 | 'epoch', 'nonadj_loss', '#ff7f0e') 83 | log_plot.plot('average_nonadjloss.png', 'epoch', 'NonAdjLoss') 84 | 85 | 86 | def save_checkpoint(state, epoch, output_dir): 87 | filename = os.path.join(output_dir, 'checkpoint_' + str(epoch) + '.pth.tar') 88 | bestfile = os.path.join(output_dir, 'best_log.txt') 89 | torch.save(state, filename) 90 | 91 | bestfile_f = open(bestfile, "a") 92 | bestfile_f.write('epoch:{:>5d} dice:{:>7.4f} IoU:{:>7.4f} NonAdjLoss:{:>7.4e}\n'.format( 93 | state['epoch'], state['dice_metric'], state['iou_metric'], state['nonadjloss'])) 94 | bestfile_f.flush() 95 | 96 | 97 | def load_checkpoint(model, filename): 98 | checkpoint = torch.load(filename) 99 | # curr_epoch = checkpoint['epoch'] + 1 100 | model.load_state_dict(checkpoint['state_dict']) 101 | return torch.optim.SGD(model.parameters(), args.lr, 102 | momentum=args.momentum, 103 | weight_decay=args.weight_decay) 104 | 105 | 106 | def find_best_model(model_path): 107 | df = pandas.read_csv(os.path.join(model_path, 'logs/average_validation.csv'), 108 | sep=';') 109 | 110 | # reference dice of first iteration without graph 111 | ref_dice = df.iloc[0]['dice_metric'] 112 | best_dices = df[df['dice_metric'] >= (ref_dice - 0.005)] 113 | best_epoch = best_dices.loc[best_dices['nonadj_loss'].idxmin()] 114 | epoch_id = int(best_epoch['epoch']) 115 | 116 | log = open(os.path.join(model_path, 'log_best_graph.txt'), "w") 117 | log.write('iteration: {} dice: {:.4f} nonAdjLoss: {:.4e}\n' 118 | .format(epoch_id, best_epoch['dice_metric'], best_epoch['nonadj_loss'])) 119 | log.flush() 120 | shutil.copy(os.path.join(model_path, ('checkpoint_' + str(epoch_id) + '.pth.tar')), 121 | os.path.join(model_path, 'model_best_dice.pth.tar')) 122 | 123 | 124 | def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, 125 | max_iter=100, power=0.9): 126 | """Polynomial decay of learning rate 127 | :param init_lr is base learning rate 128 | :param iter is a current iteration 129 | :param lr_decay_iter how frequently decay occurs, default is 1 130 | :param max_iter is number of maximum iterations 131 | :param power is a polymomial power 132 | 133 | """ 134 | if iter % lr_decay_iter or iter > max_iter: 135 | return optimizer 136 | 137 | lr = init_lr * (1 - iter / max_iter)**power 138 | for param_group in optimizer.param_groups: 139 | param_group['lr'] = lr 140 | 141 | return lr 142 | 143 | 144 | def eval_metrics(segmentation, reference): 145 | results, undec_labels = multiclass(segmentation, reference, [dc, jaccard]) 146 | return list(map(lambda l: sum(l.values()) / len(l), results)) 147 | -------------------------------------------------------------------------------- /examples/scripts/adjacency/extract_3d_adjacency.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import matplotlib.pyplot 4 | import numpy as np 5 | import os 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchmed.readers import SitkReader 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('data_dir', metavar='DIR_DATA', 12 | help='path to the image dataset') 13 | parser.add_argument('output_dir', default='', metavar='OUTPUT_DIR', 14 | help='path to the output directory (default: current dir)') 15 | parser.add_argument('nb_labels', default=21, metavar='N_LABELS', type=int, 16 | help='number of labels in the dataset') 17 | parser.add_argument('--n-size', default=1, type=int, metavar='SIZE', 18 | help='size of the neighborhood') 19 | parser.add_argument('--discard_label', default=None, type=int, metavar='INT', 20 | help='label to discard') 21 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 22 | help='number of data loading workers (default: 4)') 23 | 24 | 25 | def main(): 26 | global args, nb_classes 27 | args = parser.parse_args() 28 | 29 | nb_classes = args.nb_labels 30 | image_dataset = ImageDataset(args.data_dir) 31 | 32 | print("=> building the train dataset") 33 | data_loader = torch.utils.data.DataLoader( 34 | image_dataset, 35 | batch_size=1, 36 | shuffle=False, 37 | num_workers=args.workers, 38 | pin_memory=True) 39 | 40 | adj_mat = extract_from_data(data_loader, args.discard_label) 41 | save2png(adj_mat, os.path.join(args.output_dir, 'adjacency_n_' + str(args.n_size))) 42 | 43 | 44 | def extract_from_data(data_loader, discard_label): 45 | nonAdjLoss_arr = torch.FloatTensor(nb_classes, nb_classes).zero_().cuda() 46 | if discard_label is not None: 47 | adjacencyLayer = AdjacencyEstimator(nb_classes + 1, args.n_size * 2 + 1).train().cuda() 48 | else: 49 | adjacencyLayer = AdjacencyEstimator(nb_classes, args.n_size * 2 + 1).train().cuda() 50 | 51 | for i, (p, target) in enumerate(data_loader): 52 | print(str(i) + ' / ' + str(len(data_loader)), target.size(), p) 53 | if discard_label is not None: 54 | target[target == discard_label] = nb_classes 55 | target_gpu = target.cuda() 56 | 57 | if discard_label is not None: 58 | nonAdjLoss_arr += adjacencyLayer(target_gpu)[:-1, :-1] 59 | else: 60 | nonAdjLoss_arr += adjacencyLayer(target_gpu) 61 | 62 | return nonAdjLoss_arr.cpu() 63 | 64 | 65 | class ImageDataset(Dataset): 66 | def __init__(self, base_dir): 67 | database = open(os.path.join(base_dir, 'allowed_data.txt'), 'r') 68 | patient_list = [line.rstrip('\n') for line in database] 69 | self.medfiles = [] 70 | 71 | for patient in patient_list: 72 | if patient: 73 | patient_dir = os.path.join(base_dir, patient) 74 | r = SitkReader(os.path.join(patient_dir, 'prepro_seg.nii.gz'), 75 | torch_type='torch.LongTensor') 76 | self.medfiles.append((patient, r)) 77 | 78 | def __len__(self): 79 | return len(self.medfiles) 80 | 81 | def __getitem__(self, idx): 82 | return (self.medfiles[idx][0], self.medfiles[idx][1].to_torch()) 83 | 84 | 85 | class AdjacencyEstimator(torch.nn.Module): 86 | """Estimates the adjacency graph of labels based on probability maps. 87 | 88 | Parameters 89 | ---------- 90 | nb_labels : int 91 | number of structures segmented. 92 | 93 | """ 94 | def __init__(self, nb_labels, kernel_size=3): 95 | super(AdjacencyEstimator, self).__init__() 96 | 97 | # constant 3D convolution, needs constant weights and no gradient 98 | # apply the same convolution filter to all labels 99 | layer = torch.nn.Conv3d(in_channels=nb_labels, out_channels=nb_labels, 100 | kernel_size=kernel_size, stride=1, padding=0, 101 | bias=False, groups=nb_labels) 102 | layer.weight.data.fill_(0) 103 | 104 | canvas = torch.Tensor(kernel_size, kernel_size, kernel_size).fill_(1) 105 | # fill filters with ones 106 | for i in range(0, nb_labels): 107 | layer.weight.data[i, 0] = canvas 108 | 109 | # exlude parameters from the subgraph 110 | for param in layer.parameters(): 111 | param.requires_grad = False 112 | 113 | self._conv_layer = layer 114 | # replicate padding to recover the same resolution after convolution 115 | self._pad_layer = torch.nn.ReplicationPad3d((kernel_size - 1) // 2) 116 | 117 | def forward(self, target): 118 | target_size = list(target.size()) 119 | target_size.insert(1, self._conv_layer.in_channels) 120 | one_hot_size = torch.Size(target_size) 121 | 122 | image = torch.FloatTensor(one_hot_size).zero_().cuda() 123 | image.scatter_(1, target.unsqueeze(1), 1) 124 | 125 | # padding of tensor of size batch x k x W x H 126 | p_tild = self._pad_layer(image) 127 | # apply constant convolution and normalize by size of kernel 128 | p_tild = self._conv_layer(p_tild) 129 | 130 | return torch.einsum('nidhw,njdhw->ij', image, p_tild) 131 | 132 | 133 | def save2png(image, output): 134 | # save to file 135 | np.savetxt(output + '.csv', image.float().numpy(), delimiter=';') 136 | 137 | # save binary image 138 | brain = (image > 0).float().numpy() 139 | brain = cv2.cvtColor(brain, cv2.COLOR_GRAY2BGR) 140 | brain = cv2.normalize(brain, brain, 0, 1, cv2.NORM_MINMAX) 141 | matplotlib.pyplot.imsave(output + '.png', brain) 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /examples/scripts/adjacency/extract_3d_adjacency_from_seg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | 5 | import cv2 6 | import matplotlib.pyplot 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torchmed.readers import SitkReader 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('data_dir', metavar='DIR_DATA', 14 | help='path to the segmentation dataset') 15 | parser.add_argument('output_dir', default='', metavar='OUTPUT_DIR', 16 | help='path to the output directory (default: current dir)') 17 | parser.add_argument('graph', metavar='FILE', help='path to GT graph') 18 | parser.add_argument('--n-size', default=1, type=int, metavar='SIZE', 19 | help='size of the neighborhood') 20 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 21 | help='number of data loading workers (default: 4)') 22 | 23 | 24 | def main(): 25 | global args, nb_classes 26 | args = parser.parse_args() 27 | 28 | if not os.path.exists(args.output_dir): 29 | os.makedirs(args.output_dir) 30 | 31 | image_dataset = ImageDataset(args.data_dir) 32 | graph_master = np.loadtxt(args.graph, delimiter=';') 33 | nb_classes = graph_master.shape[0] 34 | 35 | print("=> building the train dataset") 36 | data_loader = torch.utils.data.DataLoader( 37 | image_dataset, 38 | batch_size=1, 39 | shuffle=False, 40 | num_workers=args.workers, 41 | pin_memory=True) 42 | 43 | with torch.no_grad(): 44 | mat = extract_from_data(data_loader, graph_master) 45 | save2png(mat, os.path.join(args.output_dir, 'adjacency_n_' + str(args.n_size))) 46 | 47 | 48 | def extract_from_data(data_loader, graph_master): 49 | adjacencyLayer = AdjacencyEstimator(nb_classes, args.n_size * 2 + 1).train().cuda() 50 | spatialAdjacencyLayer = ContourEstimator(nb_classes, args.n_size * 2 + 1).train().cuda() 51 | adjacencyLayer.eval() 52 | spatialAdjacencyLayer.eval() 53 | 54 | log = open(os.path.join(args.output_dir, 'log.txt'), "w") 55 | log.write('id;unique;cumulated;m1;m2\n') 56 | m1_list = [] 57 | m2_list = [] 58 | 59 | for i, (p, target, p_name) in enumerate(data_loader): 60 | print(str(i) + ' / ' + str(len(data_loader)), target.size(), p) 61 | 62 | target_gpu = target.cuda() 63 | nonAdjLoss_arr = adjacencyLayer(target_gpu) 64 | contour_size = spatialAdjacencyLayer(target_gpu).cpu().numpy()[0] 65 | contour_size = (contour_size == 0).sum() 66 | 67 | nonAdjLoss_arr = nonAdjLoss_arr.cpu() 68 | nonAdjLoss_arr[torch.from_numpy(graph_master) >= 1] = 0 69 | 70 | m1 = (nonAdjLoss_arr > 0).sum().item() / (graph_master == 0).sum() 71 | m2 = 0 72 | if contour_size > 0: 73 | m2 = nonAdjLoss_arr.sum().item() / contour_size.item() 74 | 75 | m1_list.append(m1) 76 | m2_list.append(m2) 77 | 78 | log.write('{};{};{};{};{}\n'.format(p_name[0], 79 | (nonAdjLoss_arr > 0).sum(), 80 | nonAdjLoss_arr.sum(), m1, m2)) 81 | log.flush() 82 | 83 | log.write('total;{} += {};{} += {}'.format(np.mean(m1_list), 84 | np.std(m1_list), np.mean(m2_list), 85 | np.std(m2_list))) 86 | log.flush() 87 | 88 | return nonAdjLoss_arr.cpu() 89 | 90 | 91 | class AdjacencyEstimator(torch.nn.Module): 92 | """Estimates the adjacency graph of labels based on probability maps. 93 | 94 | Parameters 95 | ---------- 96 | nb_labels : int 97 | number of structures segmented. 98 | 99 | """ 100 | def __init__(self, nb_labels, kernel_size=3): 101 | super(AdjacencyEstimator, self).__init__() 102 | 103 | # constant 3D convolution, needs constant weights and no gradient 104 | # apply the same convolution filter to all labels 105 | layer = torch.nn.Conv3d(in_channels=nb_labels, out_channels=nb_labels, 106 | kernel_size=kernel_size, stride=1, padding=0, 107 | bias=False, groups=nb_labels) 108 | layer.weight.data.fill_(0) 109 | 110 | canvas = torch.Tensor(kernel_size, kernel_size, kernel_size).fill_(1) 111 | # fill filters with ones 112 | for i in range(0, nb_labels): 113 | layer.weight.data[i, 0] = canvas 114 | 115 | # exlude parameters from the subgraph 116 | for param in layer.parameters(): 117 | param.requires_grad = False 118 | 119 | self._conv_layer = layer 120 | # replicate padding to recover the same resolution after convolution 121 | # self._pad_layer = torch.nn.ReplicationPad3d((kernel_size - 1) // 2) 122 | self._pad_layer = torch.nn.ConstantPad3d((kernel_size - 1) // 2, 0) 123 | 124 | def forward(self, target): 125 | target_size = list(target.size()) 126 | target_size.insert(1, self._conv_layer.in_channels) 127 | target_size = torch.Size(target_size) 128 | image = torch.FloatTensor(target_size).zero_().cuda() 129 | image.scatter_(1, target.unsqueeze(1), 1) 130 | 131 | # padding of tensor of size batch x k x W x H 132 | p_tild = self._pad_layer(image) 133 | # apply constant convolution and normalize by size of kernel 134 | p_tild = self._conv_layer(p_tild) 135 | 136 | return torch.einsum('nidhw,njdhw->ij', image, p_tild) 137 | 138 | 139 | class ContourEstimator(torch.nn.Module): 140 | """Estimates the adjacency graph of labels based on probability maps. 141 | 142 | Parameters 143 | ---------- 144 | nb_labels : int 145 | number of structures segmented. 146 | 147 | """ 148 | def __init__(self, nb_labels, kernel_size=3): 149 | super(ContourEstimator, self).__init__() 150 | 151 | # constant 3D convolution, needs constant weights and no gradient 152 | # apply the same convolution filter to all labels 153 | layer = torch.nn.Conv3d(in_channels=nb_labels, out_channels=nb_labels, 154 | kernel_size=kernel_size, stride=1, padding=0, 155 | bias=False, groups=nb_labels) 156 | layer.weight.data.fill_(0) 157 | 158 | canvas = torch.Tensor(kernel_size, kernel_size, kernel_size).fill_(1) 159 | # fill filters with ones 160 | for i in range(0, nb_labels): 161 | layer.weight.data[i, 0] = canvas 162 | 163 | # exlude parameters from the subgraph 164 | for param in layer.parameters(): 165 | param.requires_grad = False 166 | 167 | self._conv_layer = layer 168 | # replicate padding to recover the same resolution after convolution 169 | self._pad_layer = torch.nn.ConstantPad3d((kernel_size - 1) // 2, 0) 170 | 171 | def forward(self, target): 172 | target_size = list(target.size()) 173 | target_size.insert(1, self._conv_layer.in_channels) 174 | target_size = torch.Size(target_size) 175 | image = torch.FloatTensor(target_size).zero_().cuda() 176 | image.scatter_(1, target.unsqueeze(1), 1) 177 | ret = torch.FloatTensor(target.size()).zero_().cuda() 178 | 179 | # padding of tensor of size batch x k x W x H 180 | p_tild = self._pad_layer(image) 181 | # apply constant convolution and normalize by size of kernel 182 | p_tild = self._conv_layer(p_tild) / 27.0 183 | 184 | for i in range(self._conv_layer.in_channels): 185 | for j in range(self._conv_layer.in_channels): 186 | if i != j: 187 | ret[0, :, :, :] += p_tild[0, i] * image[0, j] 188 | return ret 189 | # return torch.einsum('nidhw,njdhw->dhw', image, p_tild) 190 | 191 | 192 | def save2png(image, output): 193 | # save to file 194 | np.savetxt(output + '.csv', image.float().numpy(), delimiter=';') 195 | 196 | # save binary image 197 | brain = (image > 0).float().numpy() 198 | brain = cv2.cvtColor(brain, cv2.COLOR_GRAY2BGR) 199 | brain = cv2.normalize(brain, brain, 0, 1, cv2.NORM_MINMAX) 200 | matplotlib.pyplot.imsave(output + '.png', brain) 201 | 202 | 203 | class ImageDataset(Dataset): 204 | def __init__(self, base_dir): 205 | self.medfiles = [] 206 | 207 | for patient in sorted(os.listdir(base_dir)): 208 | p_dir = os.path.join(base_dir, patient) 209 | if os.path.exists(os.path.join(p_dir, 'segmentation.hdr')): 210 | seg_file = os.path.join(p_dir, 'segmentation.hdr') 211 | elif os.path.exists(os.path.join(p_dir, 'segmentation.nii.gz')): 212 | seg_file = os.path.join(p_dir, 'segmentation.nii.gz') 213 | else: 214 | raise Exception('Segmentation file does not exist (.nii.gz or .hdr)') 215 | 216 | if os.path.isdir(p_dir) and os.path.exists(seg_file): 217 | r = SitkReader(seg_file, torch_type='torch.LongTensor') 218 | self.medfiles.append((patient, r, patient)) 219 | 220 | def __len__(self): 221 | return len(self.medfiles) 222 | 223 | def __getitem__(self, idx): 224 | return (self.medfiles[idx][0], self.medfiles[idx][1].to_torch(), 225 | self.medfiles[idx][2]) 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | -------------------------------------------------------------------------------- /examples/scripts/preprocessing/mappings.py: -------------------------------------------------------------------------------- 1 | 2 | class MiccaiMapping(object): 3 | def __init__(self): 4 | self.all_labels = [0, 4, 11, 23, 30, 31, 32, 35, 36, 37, 38, 39, 40, 5 | 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 55, 6 | 56, 57, 58, 59, 60, 61, 62, 63, 64, 66, 69, 71, 72, 7 | 73, 75, 76, 100, 101, 102, 103, 104, 105, 106, 107, 8 | 108, 109, 112, 113, 114, 115, 116, 117, 118, 119, 9 | 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 10 | 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 11 | 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 12 | 152, 153, 154, 155, 156, 157, 160, 161, 162, 163, 13 | 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 14 | 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 15 | 184, 185, 186, 187, 190, 191, 192, 193, 194, 195, 16 | 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 17 | 206, 207] 18 | self.ignore_labels = [1, 2, 3] + \ 19 | [5, 6, 7, 8, 9, 10] + \ 20 | [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] + \ 21 | [24, 25, 26, 27, 28, 29] + \ 22 | [33, 34] + [42, 43] + [53, 54] + \ 23 | [63, 64, 65, 66, 67, 68] + [70, 74] + \ 24 | [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 25 | 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + \ 26 | [110, 111, 126, 127, 130, 131, 158, 159, 188, 189] 27 | self.overall_labels = set(self.all_labels).difference( 28 | set(self.ignore_labels)) 29 | 30 | self.cortical_labels = [x for x in self.overall_labels if x >= 100] 31 | self.non_cortical_labels = \ 32 | [x for x in self.overall_labels if x > 0 and x < 100] 33 | 34 | self.map = {v: k for k, v in enumerate(self.overall_labels)} 35 | self.reversed_map = {k: v for k, v in enumerate(self.overall_labels)} 36 | self.nb_classes = len(self.overall_labels) 37 | 38 | def __getitem__(self, index): 39 | return self.map[index] 40 | 41 | 42 | class OASISMapping(object): 43 | def __init__(self): 44 | self.avoid_train_patients = [ 45 | '0061', '0080', '0092', '0145', '0150', '0156', '0191', '0202', 46 | '0230', '0236', '0239', '0249', '0285', '0353', '0368' 47 | ] 48 | 49 | self.avoid_test_patients = [ 50 | '0101', '0111', '0117', '0379', '0395', '0101', '0111', '0117', 51 | '0379', '0395', '0091', '0417', '0040', '0282', '0331', '0456', 52 | '0300', '0220', '0113', '0083' 53 | ] 54 | 55 | self.avoid_patients = self.avoid_train_patients + self.avoid_train_patients 56 | 57 | 58 | class IbsrMapping(object): 59 | def __init__(self): 60 | self.all_labels = [0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 61 | 18, 24, 26, 28, 29, 30, 41, 42, 43, 44, 46, 47, 48, 62 | 49, 50, 51, 52, 53, 54, 58, 60, 61, 62, 72] 63 | 64 | # structures : undetermined (29,61), vessel(72,30), 5th ventricule (62) 65 | self.ignore_labels = [29, 61, 72, 30, 62] 66 | self.overall_labels = sorted(list(set(self.all_labels).difference( 67 | set(self.ignore_labels)))) 68 | 69 | self.map = {} 70 | index = 0 71 | for v in self.overall_labels: 72 | if v == 49: 73 | self.map.update({49: self.map[48]}) 74 | else: 75 | self.map.update({v: index}) 76 | index += 1 77 | 78 | self.reversed_map = {k: v for k, v in enumerate(self.overall_labels)} 79 | self.nb_classes = len(self.overall_labels) - 1 80 | 81 | def __getitem__(self, index): 82 | return self.map[index] 83 | -------------------------------------------------------------------------------- /examples/scripts/preprocessing/oasis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import time 5 | import shutil 6 | 7 | from torchmed.utils.multiproc import parallelize_system_calls 8 | from torchmed.readers import SitkReader 9 | 10 | from mappings import OASISMapping 11 | 12 | 13 | parser = argparse.ArgumentParser( 14 | description='Dataset extractor and pre-processing for the OASIS dataset') 15 | parser.add_argument('data', metavar='DIR', help='path to the oasis dataset') 16 | parser.add_argument('output', metavar='DIR', help='path to output dataset') 17 | parser.add_argument('-n', '--nb-workers', default=2, type=int, metavar='N', 18 | help='Number of workers for the parallel execution of Affine Registration') 19 | parser.add_argument('--data-split-ratio', default=0.70, type=float, metavar='N', 20 | help='Pourcentage of train data vs validation data') 21 | 22 | 23 | def main(): 24 | args = parser.parse_args() 25 | mapping = OASISMapping() 26 | 27 | print('\n' + 28 | '\n', 29 | ('##### Automatic Segmentation of brain MRI #####') + '\n', 30 | ('##### By Pierre-Antoine Ganaye #####') + '\n', 31 | ('##### CREATIS Laboratory #####'), 32 | '\n' * 2, 33 | ('The dataset can be downloaded at https://www.oasis-brains.org/#data') + '\n', 34 | ('This script will preprocess the OASIS-1 T1-w brain MRI dataset') + '\n', 35 | ('-------------------------------------------------------------') + '\n') 36 | 37 | if not os.path.exists(args.output): 38 | os.makedirs(args.output) 39 | 40 | os.makedirs(os.path.join(args.output, 'train')) 41 | os.makedirs(os.path.join(args.output, 'validation')) 42 | 43 | # only work with scans from session 1 "OAS1" 44 | sessions = [session for session in os.listdir(args.data) if session.startswith('OAS1')] 45 | filtered_sessions = [s for s in sessions if s[5:9] not in mapping.avoid_patients] 46 | filtered_sessions = sorted(filtered_sessions) 47 | 48 | # allowed data 49 | command_list = [] 50 | command_list2 = [] 51 | 52 | print(('1/ --> Reading train and validation directories')) 53 | print(('2/ --> Creating corresponding tree hierarchy')) 54 | print(('3/ --> Copying files to destination folders')) 55 | for patient_session in filtered_sessions: 56 | # create destination dirs 57 | output_dir = os.path.join(args.output, patient_session) 58 | if not os.path.exists(output_dir): 59 | os.makedirs(output_dir) 60 | 61 | # get name of the image, pick first scan of each session 62 | session_files = os.path.join(args.data, patient_session) 63 | raw_files = os.path.join(session_files, 'RAW') 64 | scan = [scan for scan in os.listdir(raw_files) if scan.endswith('_mpr-1_anon.img')] 65 | 66 | # write input and destination to file 67 | input_scan = os.path.join(raw_files, scan[0]) 68 | registered_scan = os.path.join(output_dir, 'im_mni.nii.gz') 69 | registered_scan_bc = os.path.join(output_dir, 'im_mni_bc.nii.gz') 70 | aff_trans = os.path.join(output_dir, 'mni_aff_transf.mat') 71 | 72 | # affine registration with flirt 73 | command = ('flirt -searchrx -180 180 -searchry -180 180' 74 | ' -searchrz -180 180 -in {} -ref $FSLDIR/data/standard/MNI152_T1_1mm.nii.gz ' 75 | '-interp trilinear -o {} -omat {}').format( 76 | input_scan, registered_scan, aff_trans 77 | ) 78 | command_list.append(command) 79 | 80 | # bias field correction 81 | command = ('N4BiasFieldCorrection -d 3 -i {}' 82 | ' -o {} -s 4 -b 200 -d 3').format( 83 | registered_scan, registered_scan_bc) 84 | command_list2.append(command) 85 | 86 | print(('4/ --> Affine registration to MNI space and bias field correction')) 87 | print('Can take several hours depending on the number of workers ({})'.format( 88 | args.nb_workers 89 | )) 90 | parallelize_system_calls(args.nb_workers, command_list) 91 | time.sleep(30) 92 | parallelize_system_calls(args.nb_workers // 2, command_list2) 93 | time.sleep(30) 94 | 95 | print('5/ --> Mean centering and reduction') 96 | # estimate mean and variance on the first 50 images 97 | n_mean_estimate = 50 98 | mean, var = (0, 0) 99 | for patient_session in filtered_sessions[0:n_mean_estimate]: 100 | output_dir = os.path.join(args.output, patient_session) 101 | registered_scan_bc = os.path.join(output_dir, 'im_mni_bc.nii.gz') 102 | 103 | brain = SitkReader(registered_scan_bc, torch_type='torch.FloatTensor').to_torch() 104 | mean += brain.mean() 105 | var += brain.var() 106 | 107 | mean = mean / n_mean_estimate 108 | var = var / n_mean_estimate 109 | 110 | for patient_session in filtered_sessions: 111 | output_dir = os.path.join(args.output, patient_session) 112 | registered_scan_bc = os.path.join(output_dir, 'im_mni_bc.nii.gz') 113 | registered_scan_norm = os.path.join(output_dir, 'prepro_im_mni_bc.nii.gz') 114 | 115 | brain = SitkReader(registered_scan_bc, torch_type='torch.FloatTensor') 116 | brain_array = brain.to_torch() 117 | brain_array[...] = (brain_array - mean) / math.sqrt(var) 118 | brain.to_image_file(registered_scan_norm) 119 | 120 | train_dir = os.path.join(args.output, 'train') 121 | val_dir = os.path.join(args.output, 'validation') 122 | 123 | stats_log = open(os.path.join(train_dir, 'stats_log.txt'), "w") 124 | stats_log.write('average mean: {:.10f}\n' 125 | 'average standard deviation: {:.10f}'.format(mean, math.sqrt(var))) 126 | stats_log.flush() 127 | 128 | # split into train validation test 129 | train_patient_number = math.floor(args.data_split_ratio * len(filtered_sessions)) 130 | train_patients = filtered_sessions[:train_patient_number] 131 | validation_patients = filtered_sessions[train_patient_number:] 132 | 133 | allowed_train = open(os.path.join(train_dir, 'allowed_data.txt'), "w") 134 | allowed_val = open(os.path.join(val_dir, 'allowed_data.txt'), "w") 135 | 136 | # move train 137 | for patient_session in train_patients: 138 | output_dir = os.path.join(args.output, patient_session) 139 | dest_dir = os.path.join(train_dir, patient_session) 140 | allowed_train.write(patient_session + '\n') 141 | shutil.move(output_dir, dest_dir) 142 | allowed_train.flush() 143 | 144 | # move validation 145 | for patient_session in validation_patients: 146 | output_dir = os.path.join(args.output, patient_session) 147 | dest_dir = os.path.join(val_dir, patient_session) 148 | allowed_val.write(patient_session + '\n') 149 | shutil.move(output_dir, dest_dir) 150 | allowed_val.flush() 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /examples/scripts/report/plot_boxplot_labels_by_metric.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pandas 5 | import matplotlib as mpl 6 | 7 | # agg backend is used to create plot as a .png file 8 | mpl.use('agg') 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Evaluation of metrics with boxplots') 14 | parser.add_argument('data', metavar='DIR', 15 | help='path to root segmentation dir') 16 | parser.add_argument('output', metavar='DIR', 17 | help='path to output dir') 18 | 19 | 20 | def main(): 21 | global args 22 | args = parser.parse_args() 23 | 24 | if not os.path.isdir(args.output): 25 | os.makedirs(args.output) 26 | 27 | df = None 28 | # iterate over segmentation folders to gather data 29 | for file in os.listdir(args.data): 30 | filepath = os.path.join(args.data, file) 31 | if os.path.isdir(filepath) and not file.startswith('__'): 32 | metric_file = os.path.join(filepath, 'metrics_report.csv') 33 | df = read_and_aggregate(metric_file, df) 34 | 35 | if df is not None: 36 | col_names = pandas.read_csv(metric_file, sep=';', index_col=0).columns.values 37 | 38 | for metric_name in col_names: 39 | fig, ax1 = plt.subplots(nrows=1, ncols=1, 40 | figsize=(50, 10), dpi=100) 41 | plt.subplots_adjust(hspace=0.2) 42 | 43 | plot_boxplot(ax1, 111, df[metric_name], 44 | 'boxplot of the ' + metric_name + ' for each label', 45 | 'labels', 46 | metric_name, 47 | None, None) 48 | 49 | outfile = os.path.join(args.output, 'boxplot_labels_' + metric_name + '.png') 50 | fig.savefig(outfile, bbox_inches='tight') 51 | plt.clf() 52 | 53 | 54 | def read_and_aggregate(filename, df=None): 55 | ''' 56 | Read a csv file and aggregates the result into a pandas dataframe 57 | ''' 58 | 59 | read_df = pandas.read_csv(filename, sep=';', index_col=0) 60 | 61 | if df is None: 62 | return read_df 63 | else: 64 | return pandas.concat([df, read_df], axis=1) 65 | 66 | 67 | def plot_boxplot(ax, d, df, title, x_axis, y_axis, 68 | y_scale=None, y_precision=None): 69 | samples = [] 70 | for index, row in df.iterrows(): 71 | samples.append(row.tolist()) 72 | 73 | # Create the boxplot 74 | bp = ax.boxplot(samples, patch_artist=True) 75 | 76 | # change outline color, fill color and linewidth of the boxes 77 | for box in bp['boxes']: 78 | # change outline color 79 | box.set(color='#7570b3', linewidth=2) 80 | # change fill color 81 | box.set(facecolor='#1b9e77') 82 | 83 | # change color and linewidth of the whiskers 84 | for whisker in bp['whiskers']: 85 | whisker.set(color='#7570b3', linewidth=2) 86 | 87 | # change color and linewidth of the caps 88 | for cap in bp['caps']: 89 | cap.set(color='#7570b3', linewidth=2) 90 | 91 | # change color and linewidth of the medians 92 | for median in bp['medians']: 93 | median.set(color='#b2df8a', linewidth=2) 94 | 95 | # change the style of fliers and their fill 96 | for flier in bp['fliers']: 97 | flier.set(marker='o', color='#e7298a', alpha=0.5) 98 | 99 | # Remove top axes and right axes ticks 100 | ax.get_xaxis().tick_bottom() 101 | ax.get_yaxis().tick_left() 102 | 103 | # set y axis scale 104 | if y_scale is not None: 105 | assert(isinstance(y_scale, list)) 106 | start, end = y_scale 107 | else: 108 | start, end = ax.get_ylim() 109 | start = 0 110 | 111 | ax.set_ylim((start, end + ((end - start) / 20))) 112 | 113 | # increase y axis precision 114 | if y_precision is not None: 115 | ax.yaxis.set_ticks(np.arange(start, end, y_precision)) 116 | 117 | # add grid 118 | ax.grid(True, color='lightgrey', alpha=0.5) 119 | gridlines = ax.get_xgridlines() + ax.get_ygridlines() 120 | 121 | # rotate labels on the x axis 122 | ax.set_xticklabels(range(0, len(samples)), rotation=35) 123 | 124 | # set title 125 | ax.set_title(title) 126 | ax.set_xlabel(x_axis) 127 | ax.set_ylabel(y_axis) 128 | 129 | # add mean value on top on each label column 130 | pos = np.arange(len(samples)) + 1 131 | means = [np.mean(s) for s in samples] 132 | upperLabels = [str(np.round(s, 2)) for s in means] 133 | upperLabels = [l[:4] for l in upperLabels] 134 | weights = ['bold', 'semibold'] 135 | boxColors = ['darkkhaki', 'royalblue'] 136 | top = end + ((end - start) / 20) 137 | for tick, label in zip(range(len(samples)), ax.get_xticklabels()): 138 | k = tick % 2 139 | ax.text(pos[tick], top - (top * 0.05), upperLabels[tick], 140 | horizontalalignment='center', size='x-small', weight=weights[k], 141 | color='#7570b3') 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /examples/scripts/report/plot_boxplot_patients_by_metric.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pandas 5 | import matplotlib as mpl 6 | 7 | # agg backend is used to create plot as a .png file 8 | mpl.use('agg') 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Evaluation of metrics with boxplots') 14 | parser.add_argument('data', metavar='DIR', 15 | help='path to root segmentation dir') 16 | parser.add_argument('output', metavar='DIR', 17 | help='path to output dir') 18 | 19 | 20 | def main(): 21 | global args 22 | args = parser.parse_args() 23 | 24 | if not os.path.isdir(args.output): 25 | os.makedirs(args.output) 26 | 27 | # iterate over segmentation folders to gather data 28 | for file in os.listdir(args.data): 29 | filepath = os.path.join(args.data, file) 30 | if os.path.isdir(filepath) and not file.startswith('__'): 31 | metric_file = os.path.join(filepath, 'metrics_report.csv') 32 | col_names = pandas.read_csv(metric_file, sep=';', index_col=0).columns.values 33 | break 34 | 35 | for metric_name in col_names: 36 | df = None 37 | p_names = [] 38 | # iterate over segmentation folders to gather data 39 | for file in os.listdir(args.data): 40 | filepath = os.path.join(args.data, file) 41 | if os.path.isdir(filepath) and not file.startswith('__'): 42 | p_names.append(file) 43 | metric_file = os.path.join(filepath, 'metrics_report.csv') 44 | patient_m = pandas.read_csv(metric_file, sep=';', index_col=0)[metric_name] 45 | if df is None: 46 | df = patient_m 47 | else: 48 | df = pandas.concat([df, patient_m], axis=1) 49 | df.columns = p_names 50 | 51 | fig, ax1 = plt.subplots(nrows=1, ncols=1, 52 | figsize=(50, 10), dpi=100) 53 | plt.subplots_adjust(hspace=0.2) 54 | 55 | plot_boxplot(ax1, 111, df, 56 | 'boxplot of the ' + metric_name + ' for each patients', 57 | 'patients', 58 | metric_name, 59 | None, None) 60 | 61 | outfile = os.path.join(args.output, 'boxplot_patients_by_' + metric_name + '.png') 62 | fig.savefig(outfile, bbox_inches='tight') 63 | plt.clf() 64 | 65 | 66 | def read_and_aggregate(filename, df=None): 67 | ''' 68 | Read a csv file and aggregates the result into a pandas dataframe 69 | ''' 70 | 71 | read_df = pandas.read_csv(filename, sep=';', index_col=0) 72 | 73 | if df is None: 74 | return read_df 75 | else: 76 | return pandas.concat([df, read_df], axis=1) 77 | 78 | 79 | def plot_boxplot(ax, d, df, title, x_axis, y_axis, 80 | y_scale=None, y_precision=None): 81 | samples = [] 82 | # for index, row in df.iterrows(): 83 | # samples.append(row.tolist()) 84 | for col in df: 85 | samples.append(df[col].tolist()) 86 | 87 | # Create the boxplot 88 | bp = ax.boxplot(samples, patch_artist=True) 89 | 90 | # change outline color, fill color and linewidth of the boxes 91 | for box in bp['boxes']: 92 | # change outline color 93 | box.set(color='#7570b3', linewidth=2) 94 | # change fill color 95 | box.set(facecolor='#1b9e77') 96 | 97 | # change color and linewidth of the whiskers 98 | for whisker in bp['whiskers']: 99 | whisker.set(color='#7570b3', linewidth=2) 100 | 101 | # change color and linewidth of the caps 102 | for cap in bp['caps']: 103 | cap.set(color='#7570b3', linewidth=2) 104 | 105 | # change color and linewidth of the medians 106 | for median in bp['medians']: 107 | median.set(color='#b2df8a', linewidth=2) 108 | 109 | # change the style of fliers and their fill 110 | for flier in bp['fliers']: 111 | flier.set(marker='o', color='#e7298a', alpha=0.5) 112 | 113 | # Remove top axes and right axes ticks 114 | ax.get_xaxis().tick_bottom() 115 | ax.get_yaxis().tick_left() 116 | 117 | # set y axis scale 118 | if y_scale is not None: 119 | assert(isinstance(y_scale, list)) 120 | start, end = y_scale 121 | else: 122 | start, end = ax.get_ylim() 123 | start = 0 124 | 125 | ax.set_ylim((start, end + ((end - start) / 20))) 126 | 127 | # increase y axis precision 128 | if y_precision is not None: 129 | ax.yaxis.set_ticks(np.arange(start, end, y_precision)) 130 | 131 | # add grid 132 | ax.grid(True, color='lightgrey', alpha=0.5) 133 | gridlines = ax.get_xgridlines() + ax.get_ygridlines() 134 | 135 | # rotate labels on the x axis 136 | ax.set_xticklabels(df.columns, rotation=35) 137 | 138 | # set title 139 | ax.set_title(title) 140 | ax.set_xlabel(x_axis) 141 | ax.set_ylabel(y_axis) 142 | 143 | # add mean value on top on each label column 144 | pos = np.arange(len(samples)) + 1 145 | means = [np.mean(s) for s in samples] 146 | upperLabels = [str(np.round(s, 2)) for s in means] 147 | upperLabels = [l[:4] for l in upperLabels] 148 | weights = ['bold', 'semibold'] 149 | boxColors = ['darkkhaki', 'royalblue'] 150 | top = end + ((end - start) / 20) 151 | for tick, label in zip(range(len(samples)), ax.get_xticklabels()): 152 | k = tick % 2 153 | ax.text(pos[tick], top - (top * 0.05), upperLabels[tick], 154 | horizontalalignment='center', size='x-small', weight=weights[k], 155 | color='#7570b3') 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /examples/scripts/report/write_label_by_metric.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pandas 5 | 6 | parser = argparse.ArgumentParser( 7 | description='Make model reports') 8 | parser.add_argument('data', metavar='DIR', 9 | help='path to the segmentation dir') 10 | parser.add_argument('output', metavar='OUT', 11 | help='path to the output dir') 12 | 13 | 14 | def main(): 15 | args = parser.parse_args() 16 | 17 | if not os.path.isdir(args.output): 18 | os.makedirs(args.output) 19 | 20 | df = None 21 | undetected_classes = [] 22 | for seg in os.listdir(args.data): 23 | seg_dir = os.path.join(args.data, seg) 24 | if os.path.isdir(seg_dir): 25 | report_file = os.path.join(seg_dir, 'metrics_report.csv') 26 | undec_classes = os.path.join(seg_dir, 'undetected_classes.csv') 27 | df = read_and_aggregate(report_file, df) 28 | if os.path.isfile(undec_classes): 29 | undetected_classes.append((seg, read_and_aggregate(undec_classes)['class_id'].tolist())) 30 | 31 | # name of each metric 32 | metrics_name = sorted(list(set(df.columns.values))) 33 | log = open(os.path.join(args.output, 'average_metrics.txt'), "w") 34 | 35 | # mean by metric over columns 36 | for metric in metrics_name: 37 | m_df = df[metric].mean(1) 38 | std_df = df[metric].std(1) 39 | # write mean and std of classes to csv 40 | file_name = os.path.join(args.output, metric + '_by_label.csv') 41 | mean_std = [] 42 | for n in range(m_df.shape[0]): 43 | mean_std.append('{:.3f} += {:.2f}'.format(m_df[n], std_df[n])) 44 | pandas.DataFrame(mean_std).to_csv( 45 | file_name, sep=';', header=['mean += std'], 46 | index=True, index_label='label_id') 47 | 48 | # write mean and std of each metric to file 49 | avg = m_df.mean(0) 50 | std = m_df.std(0) 51 | log.write('{}: {:.3f} += {:.3f}\n'.format(metric, avg, std)) 52 | 53 | if len(undetected_classes) > 0: 54 | log.write('\nUndetected Structures:\n') 55 | for seg, classes in undetected_classes: 56 | log.write('{}: {}\n'.format(seg, ', '.join(map(str, classes)))) 57 | log.flush() 58 | 59 | 60 | def read_and_aggregate(filename, df=None): 61 | read_df = pandas.read_csv(filename, sep=';', index_col=0) 62 | 63 | if df is None: 64 | return read_df 65 | else: 66 | return pandas.concat([df, read_df], axis=1) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /examples/scripts/report/write_model_by_metric.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pandas 5 | 6 | parser = argparse.ArgumentParser( 7 | description='Make model reports') 8 | parser.add_argument('data', metavar='DIR', 9 | help='path to the segmentation dir') 10 | parser.add_argument('output', metavar='OUT', 11 | help='path to the output dir') 12 | 13 | 14 | def main(): 15 | args = parser.parse_args() 16 | 17 | if not os.path.isdir(args.output): 18 | os.makedirs(args.output) 19 | 20 | models = os.listdir(args.data) 21 | models_df = [] 22 | # for each model 23 | for model in models: 24 | model_dir = os.path.join(args.data, model) 25 | if os.path.isdir(model_dir): 26 | 27 | # collect class metrics 28 | df = None 29 | undetected_classes = [] 30 | data_dir = os.path.join(model_dir, 'segmentations') 31 | for seg in os.listdir(data_dir): 32 | seg_dir = os.path.join(data_dir, seg) 33 | if os.path.isdir(seg_dir): 34 | report_file = os.path.join(seg_dir, 'metrics_report.csv') 35 | undec_classes = os.path.join(seg_dir, 'undetected_classes.csv') 36 | df = read_and_aggregate(report_file, df) 37 | if os.path.isfile(undec_classes): 38 | undetected_classes.append((seg, read_and_aggregate(undec_classes)['class_id'].tolist())) 39 | models_df.append((model, df, undetected_classes)) 40 | 41 | # name of each metric 42 | metrics_name = sorted(list(set(df.columns.values))) 43 | log = open(os.path.join(args.output, 'log.csv'), "w") 44 | models_df = sorted(models_df, key=lambda k:k[0]) 45 | 46 | # average of metrics for each model 47 | log.write('model;{}'.format(';'.join(metrics_name))) 48 | for name, df, undetected in models_df: 49 | # mean by metric over columns 50 | log.write('\n{}'.format(name)) 51 | for metric in metrics_name: 52 | m_df = df[metric].mean(1) 53 | 54 | # write mean and std of each metric to file 55 | avg = m_df.mean(0) 56 | std = m_df.std(0) 57 | log.write(';{:.3f} += {:.3f}'.format(avg, std)) 58 | log.write('\n') 59 | 60 | for name, df, undetected in models_df: 61 | if len(undetected_classes) > 0: 62 | log.write('\n {} Undetected Structures:\n'.format(name)) 63 | for seg, classes in undetected_classes: 64 | log.write('{}: {}\n'.format(seg, ', '.join(map(str, classes)))) 65 | log.flush() 66 | 67 | 68 | def read_and_aggregate(filename, df=None): 69 | read_df = pandas.read_csv(filename, sep=';', index_col=0) 70 | 71 | if df is None: 72 | return read_df 73 | else: 74 | return pandas.concat([df, read_df], axis=1) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /examples/scripts/report/write_patient_by_metric.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pandas 5 | 6 | parser = argparse.ArgumentParser( 7 | description='Make model reports') 8 | parser.add_argument('data', metavar='DIR', 9 | help='path to the segmentation dir') 10 | parser.add_argument('output', metavar='OUT', 11 | help='path to the output dir') 12 | 13 | 14 | def main(): 15 | args = parser.parse_args() 16 | 17 | if not os.path.isdir(args.output): 18 | os.makedirs(args.output) 19 | 20 | avgs = None 21 | patient_name = [] 22 | patients = sorted(os.listdir(args.data)) 23 | for seg in patients: 24 | seg_dir = os.path.join(args.data, seg) 25 | if os.path.isdir(seg_dir): 26 | patient_name.append(seg) 27 | report_file = os.path.join(seg_dir, 'metrics_report.csv') 28 | m_df = pandas.read_csv(report_file, sep=';', index_col=0) 29 | 30 | if avgs is None: 31 | avgs = m_df.mean(0).values 32 | stds = m_df.std(0).values 33 | else: 34 | avgs = np.vstack((avgs, m_df.mean(0).values)) 35 | stds = np.vstack((stds, m_df.std(0).values)) 36 | 37 | nb_patient, nb_metric = avgs.shape 38 | metrics = [[0] * nb_metric] * nb_patient 39 | for p in range(nb_patient): 40 | for m in range(nb_metric): 41 | p_list = metrics[p] 42 | p_list[m] = '{:.3f} += {:.2f}'.format(avgs[p, m], stds[p, m]) 43 | metrics[p] = p_list.copy() 44 | 45 | df = pandas.DataFrame(metrics) 46 | df.columns = [m_df.columns.values] 47 | df['name'] = patient_name 48 | 49 | # reorder columns to have name in first 50 | cols = df.columns.tolist() 51 | cols = cols[-1:] + cols[:-1] 52 | df = df[cols] 53 | 54 | # name of each metric 55 | report_file = os.path.join(args.output, 'metric_by_patient.csv') 56 | df.to_csv(report_file, sep=";", index=False) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /examples/scripts/report/write_top_label_by_metric.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pandas 5 | 6 | parser = argparse.ArgumentParser( 7 | description='Make model reports') 8 | parser.add_argument('data', metavar='DIR', 9 | help='path to the segmentation dir') 10 | parser.add_argument('output', metavar='OUT', 11 | help='path to the output dir') 12 | parser.add_argument('metric', metavar='METRIC NAME', 13 | help='name of the metric') 14 | parser.add_argument('-k', '--k-top', metavar='N', default=5, type=int, 15 | help='number of element to retain') 16 | parser.add_argument('-t', '--is-top', action='store_false', 17 | help='whether to use k-top or not (k-small)') 18 | 19 | 20 | def main(): 21 | args = parser.parse_args() 22 | 23 | if not os.path.isdir(args.output): 24 | os.makedirs(args.output) 25 | 26 | df = None 27 | undetected_classes = [] 28 | for seg in os.listdir(args.data): 29 | seg_dir = os.path.join(args.data, seg) 30 | if os.path.isdir(seg_dir): 31 | report_file = os.path.join(seg_dir, 'metrics_report.csv') 32 | df = read_and_aggregate(report_file, df) 33 | 34 | # name of each metric 35 | metrics_name = sorted(list(set(df.columns.values))) 36 | 37 | if args.metric in metrics_name: 38 | m_df = df[args.metric].mean(1) 39 | std_df = df[args.metric].std(1) 40 | 41 | idx = np.argsort(m_df) 42 | idx = idx[:args.k_top] if args.is_top else list(reversed(idx))[:args.k_top] 43 | 44 | # write mean and std of classes to csv 45 | f_name = str(args.k_top) + ('_smallest_' if args.is_top else '_highest_') + args.metric 46 | file_name = os.path.join(args.output, f_name + '.csv') 47 | mean_std = [] 48 | for i in idx: 49 | mean_std.append('{:.3f} += {:.2f}'.format(m_df[i], std_df[i])) 50 | o_df = pandas.DataFrame({'mean': mean_std, 'label_id': idx}) 51 | o_df.to_csv( 52 | file_name, sep=';', 53 | index=True, index_label='order') 54 | 55 | 56 | def read_and_aggregate(filename, df=None): 57 | read_df = pandas.read_csv(filename, sep=';', index_col=0) 58 | 59 | if df is None: 60 | return read_df 61 | else: 62 | return pandas.concat([df, read_df], axis=1) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | nibabel 3 | SimpleITK 4 | Pillow 5 | numpy 6 | pandas 7 | scipy 8 | matplotlib 9 | opencv-python 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | description = "A compagnon library for deep learning on medical imaging" 6 | 7 | # Read in requirements 8 | requirements = open('requirements.txt').readlines() 9 | requirements = [r.strip() for r in requirements] 10 | 11 | setuptools.setup( 12 | name="torchmed", 13 | version="0.0.1a", 14 | author="Pierre-Antoine Ganaye", 15 | description=description, 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/trypag/pytorch-med", 19 | packages=setuptools.find_packages(), 20 | python_requires='>=3', 21 | install_requires=requirements, 22 | license='GNU GPLv3', 23 | classifiers=[ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 26 | "Operating System :: POSIX", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /test/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trypag/torchmed/87e8f8d398fc877c80b03c85df1f98f7e00ccc2a/test/evaluation/__init__.py -------------------------------------------------------------------------------- /test/evaluation/test_dice.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | # from torchmed.utils.evaluation import Dice 5 | # 6 | # 7 | # class Patch2d_pad(unittest.TestCase): 8 | # def test_measure(self): 9 | # measure = Dice().measure 10 | # arr = np.empty([10, 10]) 11 | # arr0 = np.copy(arr) 12 | # 13 | # arr.fill(10) 14 | # arr0.fill(0) 15 | # 16 | # self.assertEqual(measure(arr, arr), 1) 17 | # self.assertEqual(measure(arr, arr0), 0) 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /test/patterns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trypag/torchmed/87e8f8d398fc877c80b03c85df1f98f7e00ccc2a/test/patterns/__init__.py -------------------------------------------------------------------------------- /test/patterns/test_patch2d.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock 3 | from unittest.mock import patch 4 | 5 | # import numpy as np 6 | # from numpy.testing import assert_array_equal as array_equal 7 | # import torch 8 | # 9 | # from torchmed.sampler.patterns.patch_2d import Patch_2d 10 | # from torchmed.readers import BasicReader 11 | # 12 | # 13 | # class Patch2d_pad(unittest.TestCase): 14 | # def setUp(self): 15 | # self.image2d = np.arange(0, 25).reshape(1, 5, 5) 16 | # self.torch_image2d = torch.from_numpy(np.tile(self.image2d, (4, 1, 1))) 17 | # 18 | # @patch('torchmed.readers.simple_itk.BasicReader') 19 | # def get_patch(self, x, y, patch_size, mock_reader): 20 | # pattern = Patch_2d(patch_size, True, 'x', None) 21 | # mock_reader.to_tensor = MagicMock(return_value=self.torch_image2d) 22 | # mock_reader.to_numpy = MagicMock( 23 | # return_value=np.tile(self.image2d, (4, 1, 1))) 24 | # pattern.init_reader({'image': mock_reader}) 25 | # return pattern(0, x, y).numpy() 26 | # 27 | # def test_full_patch(self): 28 | # patch = self.get_patch(2, 2, 5) 29 | # array_equal(patch, self.image2d.reshape(5, 5)) 30 | # 31 | # def test_1x1_patch(self): 32 | # patch_center = self.get_patch(2, 2, 1) 33 | # patch_top_left = self.get_patch(0, 0, 1) 34 | # patch_top_right = self.get_patch(0, 4, 1) 35 | # patch_bottom_left = self.get_patch(4, 0, 1) 36 | # patch_bottom_right = self.get_patch(4, 4, 1) 37 | # 38 | # array_equal(patch_center, np.arange(12, 13).reshape(1, 1)) 39 | # array_equal(patch_top_left, np.arange(0, 1).reshape(1, 1)) 40 | # array_equal(patch_top_right, np.arange(4, 5).reshape(1, 1)) 41 | # array_equal(patch_bottom_left, np.arange(20, 21).reshape(1, 1)) 42 | # array_equal(patch_bottom_right, np.arange(24, 25).reshape(1, 1)) 43 | # 44 | # def test_3x3_patch(self): 45 | # truth_center = np.array([[6, 7, 8], 46 | # [11, 12, 13], 47 | # [16, 17, 18]]) 48 | # truth_top_left = np.array([[0, 1, 2], 49 | # [5, 6, 7], 50 | # [10, 11, 12]]) 51 | # truth_top_right = np.array([[2, 3, 4], 52 | # [7, 8, 9], 53 | # [12, 13, 14]]) 54 | # truth_bottom_left = np.array([[10, 11, 12], 55 | # [15, 16, 17], 56 | # [20, 21, 22]]) 57 | # truth_bottom_right = np.array([[12, 13, 14], 58 | # [17, 18, 19], 59 | # [22, 23, 24]]) 60 | # 61 | # patch_center = self.get_patch(2, 2, 3) 62 | # patch_top_left = self.get_patch(1, 1, 3) 63 | # patch_top_right = self.get_patch(1, 3, 3) 64 | # patch_bottom_left = self.get_patch(3, 1, 3) 65 | # patch_bottom_right = self.get_patch(3, 3, 3) 66 | # 67 | # array_equal(patch_center, truth_center) 68 | # array_equal(patch_top_left, truth_top_left) 69 | # array_equal(patch_top_right, truth_top_right) 70 | # array_equal(patch_bottom_left, truth_bottom_left) 71 | # array_equal(patch_bottom_right, truth_bottom_right) 72 | # 73 | # def test_3x3_patch_one_side_padding(self): 74 | # truth_left_middle = np.array([[0, 5, 6], 75 | # [0, 10, 11], 76 | # [0, 15, 16]]) 77 | # truth_right_middle = np.array([[8, 9, 0], 78 | # [13, 14, 0], 79 | # [18, 19, 0]]) 80 | # truth_top_middle = np.array([[0, 0, 0], 81 | # [1, 2, 3], 82 | # [6, 7, 8]]) 83 | # truth_bottom_middle = np.array([[16, 17, 18], 84 | # [21, 22, 23], 85 | # [0, 0, 0]]) 86 | # 87 | # patch_left_middle = self.get_patch(2, 0, 3) 88 | # patch_right_middle = self.get_patch(2, 4, 3) 89 | # patch_top_middle = self.get_patch(0, 2, 3) 90 | # patch_bottom_middle = self.get_patch(4, 2, 3) 91 | # 92 | # array_equal(patch_left_middle, truth_left_middle) 93 | # array_equal(patch_right_middle, truth_right_middle) 94 | # array_equal(patch_top_middle, truth_top_middle) 95 | # array_equal(patch_bottom_middle, truth_bottom_middle) 96 | # 97 | # def test_3x3_patch_two_side_padding(self): 98 | # truth_top_left = np.array([[0, 0, 0], 99 | # [0, 0, 1], 100 | # [0, 5, 6]]) 101 | # truth_top_right = np.array([[0, 0, 0], 102 | # [3, 4, 0], 103 | # [8, 9, 0]]) 104 | # truth_bottom_left = np.array([[0, 15, 16], 105 | # [0, 20, 21], 106 | # [0, 0, 0]]) 107 | # truth_bottom_right = np.array([[18, 19, 0], 108 | # [23, 24, 0], 109 | # [0, 0, 0]]) 110 | # 111 | # patch_top_left = self.get_patch(0, 0, 3) 112 | # patch_top_right = self.get_patch(0, 4, 3) 113 | # patch_bottom_left = self.get_patch(4, 0, 3) 114 | # patch_bottom_right = self.get_patch(4, 4, 3) 115 | # 116 | # array_equal(patch_top_left, truth_top_left) 117 | # array_equal(patch_top_right, truth_top_right) 118 | # array_equal(patch_bottom_left, truth_bottom_left) 119 | # array_equal(patch_bottom_right, truth_bottom_right) 120 | 121 | 122 | if __name__ == '__main__': 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /torchmed/__init__.py: -------------------------------------------------------------------------------- 1 | name = "torchmed" 2 | 3 | from .datasets import MedFile, MedFolder 4 | from .samplers.sampler import Sampler 5 | from .patterns.patch import Pattern 6 | from .readers.reader import Reader 7 | 8 | __all__ = ['MedFile', 'MedFolder', 'Sampler', 'Pattern', 'Reader'] 9 | -------------------------------------------------------------------------------- /torchmed/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .medfile import MedFile 2 | from .medfolder import MedFolder 3 | from .medcombiner import MedCombiner 4 | 5 | __all__ = ['MedFile', 'MedFolder', 'MedCombiner'] 6 | -------------------------------------------------------------------------------- /torchmed/datasets/medcombiner.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class MedCombiner(Dataset): 5 | """Combines a list of MedFolder/MedFile/Dataset into a single Dataset. 6 | 7 | Parameters 8 | ---------- 9 | dataset_list : Iterable[:class:`torch.utils.data.Dataset`] 10 | iterable of Datasets to assemble and iterate over. 11 | get_composer : (Iterable[:class:`torch.utils.data.Dataset`], int) -> Any 12 | function returning an object based on Datasets and an index. 13 | len_composer : Iterable[:class:`torch.utils.data.Dataset`] -> int 14 | function returning the length of the dataset based on Datasets. 15 | 16 | """ 17 | def __init__(self, dataset_list, get_composer, len_composer): 18 | self._datasets = dataset_list 19 | self._get = get_composer 20 | self._len = len_composer 21 | 22 | def __getitem__(self, index): 23 | """Gets at index, response is defined by get_composer. 24 | """ 25 | return self._get(self._datasets, index) 26 | 27 | def __len__(self): 28 | """Number of samples that can be extracted, defined by len_composer 29 | on the MedFolders. 30 | """ 31 | return self._len(self._datasets) 32 | -------------------------------------------------------------------------------- /torchmed/datasets/medfile.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class MedFile(Dataset): 5 | """Storage and utility class to process an image. Takes some input data, 6 | a way to sub-sample it and access to the samples. 7 | 8 | Parameters 9 | ---------- 10 | data_map : dict 11 | a dictionnary that maps a key referencing an image, to a value being 12 | one of the :mod:`torchmed.readers`. `image_ref` is a mandatory key 13 | referencing the main image to be read. 14 | sampler : :class:`torchmed.Sampler` 15 | a sampler defining how samples are extracted from the data (see 16 | :mod:`torchmed.samplers`). 17 | transform : Any -> Any 18 | function performing a transformation (ex: data augmentation) 19 | on the data returned by `sampler`. 20 | target_transform : Any -> Any 21 | function performing a transformation (ex: data augmentation) 22 | on the label (target) returned by `sampler`. 23 | paired_transform: Any -> Any 24 | function performing a transformation on the data tensor AND the 25 | label tensor. The transformation is applied after `transform` and 26 | `target_transform`. 27 | 28 | """ 29 | def __init__(self, data_map, sampler, transform=None, target_transform=None, 30 | paired_transform=None): 31 | self._sampler = sampler 32 | self._sampler.build(data_map) 33 | 34 | self._transform = transform 35 | self._target_transform = target_transform 36 | self._paired_transform = paired_transform 37 | 38 | def __getitem__(self, index): 39 | """Returns a sample at the corresponding index. 40 | 41 | Parameters 42 | ---------- 43 | index : int 44 | index of the sample to get. 45 | 46 | Returns 47 | ------- 48 | tuple(objects) 49 | tuple composed of : the spatial position of sampling, the sample, 50 | and the label if possible. 51 | 52 | """ 53 | # Samples the data pattern, target pattern and the position of sampling 54 | position, data, label = self._sampler[index] 55 | 56 | # to apply a transform on the image and/or target 57 | if self._transform is not None: 58 | data = self._transform(data) 59 | 60 | if label is not None: 61 | if self._target_transform is not None: 62 | label = self._target_transform(label) 63 | 64 | if self._paired_transform is not None: 65 | data, label = self._paired_transform(data, label) 66 | 67 | return position, data, label 68 | else: 69 | return position, data 70 | 71 | def __len__(self): 72 | """Number of samples in the MedFile. 73 | """ 74 | return len(self._sampler) 75 | -------------------------------------------------------------------------------- /torchmed/datasets/medfolder.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class MedFolder(Dataset): 5 | """Packs a list of MedFiles into one dataset. 6 | Makes it easier to iterate over a list of MedFiles. 7 | 8 | Parameters 9 | ---------- 10 | medfiles : Iterable[:class:`torchmed.datasets.MedFile`] 11 | list of MedFiles. 12 | transform : Any -> Any 13 | function performing a transformation (ex: data augmentation) 14 | on the data returned by a :class:`torchmed.Sampler`. This transform 15 | overwrites the default transform of the MedFiles. 16 | target_transform : Any -> Any 17 | function performing a transformation (ex: data augmentation) 18 | on the label (target) returned by a :class:`torchmed.Sampler`. This 19 | `target_transform` overwrites the default `target_transform` of the MedFiles. 20 | paired_transform: Any -> Any 21 | function performing a transformation on the data tensor AND the 22 | label tensor. The transformation is applied after `transform` and 23 | `target_transform`.This `paired_transform` overwrites the default 24 | `paired_transform` of the MedFiles. 25 | 26 | """ 27 | def __init__(self, medfiles, transform=None, target_transform=None, 28 | paired_transform=None): 29 | self._medfiles = medfiles 30 | self._map_file_index = [] 31 | self._dataset_size = 0 32 | 33 | for index in range(0, len(self._medfiles)): 34 | self._medfiles[index]._transform = transform 35 | self._medfiles[index]._target_transform = target_transform 36 | self._medfiles[index]._paired_transform = paired_transform 37 | # dataset's size is increased by the size of the medfile 38 | self._dataset_size += len(self._medfiles[index]) 39 | self._map_file_index.append((index, self._dataset_size)) 40 | 41 | def __getitem__(self, index): 42 | """Returns a sample at the corresponding index. 43 | 44 | Parameters 45 | ---------- 46 | index : int 47 | index of the sample to get. 48 | 49 | Returns 50 | ------- 51 | tuple(objects) 52 | tuple composed of : the spatial position of sampling, the sample, 53 | and the label if possible (see :func:`torchmed.datasets.MedFile.__getitem__`) 54 | 55 | """ 56 | start_id = 0 57 | patient_id = 0 58 | for file_id, limit in self._map_file_index: 59 | if limit > index: 60 | patient_id = file_id 61 | start_ind = start_id 62 | break 63 | else: 64 | start_id = limit 65 | 66 | return self._medfiles[patient_id][index - start_ind] 67 | 68 | def __len__(self): 69 | """Number of samples in the MedFolder (sum of all the samples in the MedFiles). 70 | """ 71 | return self._dataset_size 72 | -------------------------------------------------------------------------------- /torchmed/patterns/__init__.py: -------------------------------------------------------------------------------- 1 | from .patch import SquaredSlidingWindow 2 | 3 | __all__ = ['SquaredSlidingWindow'] 4 | -------------------------------------------------------------------------------- /torchmed/patterns/patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Pattern(object): 5 | """ 6 | Abstracts a way to extract some informations from an image. 7 | `can_apply` returns True is the sample can be extracted at the given 8 | position of the image, and False otherwise. A `__call__` to the pattern 9 | will return the information of interest. 10 | Allows to write patterns of sampling independently and combine them flawlessly. 11 | """ 12 | def prepare(self, image): 13 | raise NotImplementedError() 14 | 15 | def __call__(self, image, position): 16 | raise NotImplementedError() 17 | 18 | def can_apply(self, image, position): 19 | raise NotImplementedError() 20 | 21 | 22 | class SquaredSlidingWindow(Pattern): 23 | """Square patch with padding sliding over the image (stride=1). 24 | 25 | Starts extraction from top left pixel to the bottom right. :func:`can_apply` 26 | tests if the patch can be extracted at the corresponding position. 27 | If `use_padding` is true, the latter condition will always be valid. 28 | First call :func:`__init__`, then :func:`prepare` to apply the pattern 29 | on the data. 30 | 31 | Parameters 32 | ---------- 33 | patch_size : int or list 34 | size of the patch to extract. 35 | use_padding : bool 36 | whether to pad the image. 37 | pad_value : int 38 | padding fill value. 39 | 40 | """ 41 | def __init__(self, patch_size, use_padding=False, pad_value=0): 42 | if isinstance(patch_size, (list, tuple)): 43 | assert(len(patch_size) > 0) 44 | assert(all(isinstance(n, int) for n in patch_size)) 45 | else: 46 | assert(isinstance(patch_size, int)) 47 | assert(isinstance(use_padding, bool)) 48 | assert(isinstance(pad_value, int)) 49 | 50 | self._patch_size = patch_size 51 | self._use_padding = use_padding 52 | self._pad_value = pad_value 53 | 54 | def prepare(self, image): 55 | """Initialize the pattern by matching the image properties to the 56 | pattern rules. 57 | 58 | Parameters 59 | ---------- 60 | image : :class:`torch.Tensor` 61 | an N dimension array. 62 | 63 | """ 64 | if isinstance(self._patch_size, int): 65 | self._patch_size = [self._patch_size] * image.ndimension() 66 | else: 67 | self._patch_size = list(self._patch_size) 68 | 69 | if any(x <= 0 for x in self._patch_size): 70 | raise ValueError('The patch size must be at least 1 ' 71 | 'in any dimension.') 72 | elif len(self._patch_size) != image.ndimension(): 73 | raise ValueError('The patch dimensionality must be equal ' 74 | 'the image dimensionality.') 75 | 76 | self._pads = [(int(p / 2), int(p / 2 - (p - 1) % 2)) 77 | for p in self._patch_size] 78 | 79 | def can_apply(self, image, position): 80 | """Returns if the pattern can be applied at the given position. 81 | 82 | Parameters 83 | ---------- 84 | image : :class:`torch.Tensor` 85 | image tensor. 86 | position : list[int] 87 | each axis coordinate. 88 | 89 | Returns 90 | ------- 91 | bool 92 | True if the pattern can be applied at the given position, 93 | False otherwise. 94 | 95 | """ 96 | if self._use_padding: 97 | return True 98 | else: 99 | is_size_ok = True 100 | for dim_size, pad_size, pos in zip(list(image.size()), 101 | self._pads, position): 102 | is_size_ok = dim_size - pad_size[1] > pos >= pad_size[0] 103 | if not is_size_ok: 104 | is_size_ok = False 105 | break 106 | return is_size_ok 107 | 108 | def __call__(self, image, position): 109 | """Get specified image pattern at given spatial position. 110 | 111 | Parameters 112 | ---------- 113 | image : :class:`torch.Tensor` 114 | image tensor. 115 | position : list[int] 116 | list of spatial coordinates. 117 | 118 | Returns 119 | ------- 120 | :class:`torch.tensor` 121 | extracted pattern from the image. 122 | 123 | """ 124 | slices = () 125 | for dim, p in enumerate(position): 126 | slices += (slice(p - self._pads[dim][0], 127 | p + self._pads[dim][1] + 1),) 128 | 129 | # if no padding, narrow on both axis and return the view 130 | if not self._use_padding: 131 | return image[slices] 132 | else: 133 | # find the correct padding values 134 | pad_values = [] 135 | for dim, s in enumerate(slices): 136 | crop_start = s.start 137 | crop_end = s.stop 138 | if crop_start < 0: 139 | pad_values.append((dim, 0, abs(crop_start))) 140 | crop_start = 0 141 | if crop_end > image.size(dim): 142 | pad_values.append((dim, 1, crop_end - image.size(dim))) 143 | crop_end = image.size(dim) 144 | 145 | image = image.narrow(dim, crop_start, crop_end - crop_start) 146 | 147 | # pad each dimension 148 | for dim, side, pad in pad_values: 149 | constant_s = [] 150 | constant_e = [] 151 | shape = list(image.size()) 152 | shape[dim] = 1 153 | if side == 0: 154 | constant_s = [torch.zeros(shape, dtype=image.dtype) 155 | .fill_(self._pad_value)] * pad.item() 156 | else: 157 | constant_e = [torch.zeros(shape, dtype=image.dtype) 158 | .fill_(self._pad_value)] * pad.item() 159 | 160 | image = torch.cat(constant_s + [image] + constant_e, dim) 161 | 162 | return image 163 | -------------------------------------------------------------------------------- /torchmed/readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cv import CvReader 2 | from .nib import NibReader 3 | from .pil import PILReader 4 | from .simpleitk import SitkReader 5 | 6 | __all__ = ['CvReader', 'NibReader', 'PILReader', 'SitkReader'] 7 | -------------------------------------------------------------------------------- /torchmed/readers/cv.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | from .reader import Reader 6 | 7 | 8 | class CvReader(Reader): 9 | """Image reader based on OpenCV. 10 | 11 | On call to :func:`to_torch` or :func:`to_numpy`, loads an image with openCV's 12 | :func:`cv2.imread` and returns a :class:`torch.Tensor`. The image is 13 | loaded on the first call and **kept in memory during the life of the reader**. 14 | 15 | A casting type for numpy and torch can optionally be specified, via 16 | `numpy_type` and `torch_type`. The casting is performed only once, 17 | during image reading. 18 | 19 | Args: 20 | path (string): a path to the image. 21 | transform (:class:`torch.Tensor` -> :class:`torch.Tensor`, optional): 22 | function to apply on the :class:`torch.Tensor` after image reading. 23 | numpy_type (:class:`numpy.dtype`, optional): numpy type to cast the 24 | :class:`numpy.ndarray` returned by :func:`cv2.imread`. 25 | torch_type (:class:`torch.dtype`, optional): torch type to cast the 26 | :class:`numpy.ndarray`. 27 | shared_memory (bool, optional): whether to move the underlying tensor 28 | storage into `shared memory 29 | `_. 30 | 31 | .. note:: 32 | A modification in the tensor/array returned by :func:`to_torch` or 33 | :func:`to_numpy` will also affect the underlying storage. If you want 34 | a copy, with torch call :func:`.clone` and with numpy use :func:`.copy`. 35 | 36 | Examples: 37 | >>> x = torchmed.readers.CvReader("Screenshot.png") 38 | >>> x_arr = x.to_torch() 39 | >>> x_arr.size() 40 | torch.Size([928, 1421, 3]) 41 | >>> # update the underlying storage 42 | >>> im_array[0, 0, 0] = 666 43 | >>> # save the modified array back to a file 44 | >>> im.to_image_file('test.png') 45 | >>> # loads the newly saved image 46 | >>> im2 = torchmed.readers.CvReader('test.png') 47 | >>> im2_array = im2.to_torch() 48 | >>> im_array.size() == im2_array.size() 49 | True 50 | >>> im2_array[0, 0, 0] 51 | tensor(666.) 52 | 53 | """ 54 | def __init__(self, path, transform=None, numpy_type=None, 55 | torch_type=None, shared_memory=True): 56 | super().__init__(path, transform, numpy_type, torch_type, shared_memory) 57 | 58 | def _torch_init(self): 59 | cv_image = cv2.imread(self._path) 60 | super()._torch_init(cv_image) 61 | 62 | def to_image_file(self, path, params=None): 63 | """Saves the underlying tensor back to an image 64 | 65 | Parameters 66 | ---------- 67 | path : string 68 | a path to the output image. 69 | params : dict, optional 70 | a dictionnary of pairs (key, value) describing the parameters 71 | (see :func:`cv2.imwrite` function definition). 72 | 73 | Returns 74 | ------- 75 | True or Exception 76 | Returns an Exception in case of failure by :func:`sitk.WriteImage` 77 | and True in case of success. 78 | 79 | """ 80 | assert(isinstance(path, str)) 81 | assert(len(path) > 0) 82 | 83 | if params is not None: 84 | assert(isinstance(params, dict)) 85 | return cv2.imwrite(path, self.to_numpy(), **params) 86 | 87 | return cv2.imwrite(path, self.to_numpy()) 88 | -------------------------------------------------------------------------------- /torchmed/readers/nib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | import torch 4 | 5 | from .reader import Reader 6 | 7 | 8 | class NibReader(Reader): 9 | """Image reader based on NiBabel. 10 | 11 | On call to :func:`to_torch` or :func:`to_numpy`, loads an image with NiBabel's 12 | :func:`nibabel.load` and returns a :class:`torch.Tensor`. The image is 13 | loaded on the first call and **kept in memory during the life of the reader**. 14 | 15 | A casting type for numpy and torch can optionally be specified, via 16 | `numpy_type` and `torch_type`. The casting is performed only once, 17 | during image reading. 18 | 19 | Args: 20 | path (string): a path to the image. 21 | transform (:class:`torch.Tensor` -> :class:`torch.Tensor`, optional): 22 | function to apply on the :class:`torch.Tensor` after image reading. 23 | numpy_type (:class:`numpy.dtype`, optional): numpy type to cast the 24 | :class:`numpy.ndarray` returned by :func:`nibabel.get_data`. 25 | torch_type (:class:`torch.dtype`, optional): torch type to cast the 26 | :class:`numpy.ndarray`. 27 | shared_memory (bool, optional): whether to move the underlying tensor 28 | storage into `shared memory 29 | `_. 30 | 31 | .. note:: 32 | A modification in the tensor/array returned by :func:`to_torch` or 33 | :func:`to_numpy` will also affect the underlying storage. If you want 34 | a copy, with torch call :func:`.clone` and with numpy use :func:`.copy`. 35 | 36 | Examples: 37 | >>> im = torchmed.readers.NibReader("prepro_im_mni_bc.nii.gz") 38 | >>> im_array = im.to_torch() 39 | >>> im_array.size() 40 | torch.Size([182, 218, 182]) 41 | >>> # update the underlying storage 42 | >>> im_array[0, 0, 0] = 555 43 | >>> # save the modified array back to a file 44 | >>> im.to_image_file('test_img.nii.gz') 45 | >>> # loads the newly saved image 46 | >>> im2 = torchmed.readers.NibReader('test_img.nii.gz') 47 | >>> im2_array = im2.to_torch() 48 | >>> im_array.size() == im2_array.size() 49 | True 50 | >>> im2_array[0, 0, 0] 51 | tensor(555.) 52 | """ 53 | def __init__(self, path, transform=None, numpy_type=None, 54 | torch_type=None, shared_memory=True): 55 | super().__init__(path, transform, numpy_type, torch_type, shared_memory) 56 | 57 | def _torch_init(self): 58 | nib_image = nib.load(self._path) 59 | numpy_array = nib_image.get_data() 60 | super()._torch_init(numpy_array) 61 | 62 | def to_image_file(self, path): 63 | """Saves the underlying tensor back to an image 64 | 65 | Parameters 66 | ---------- 67 | path : string 68 | a path to the output image. 69 | 70 | Returns 71 | ------- 72 | None or Exception 73 | Returns an Exception in case of failure by :func:`nib.save` 74 | and None in case of success. 75 | 76 | """ 77 | assert(isinstance(path, str)) 78 | assert(len(path) > 0) 79 | 80 | # load the original image to get metadata 81 | nib_image = nib.load(self._path) 82 | 83 | # check that incoming data has the same resolution as the original one 84 | data = nib_image.get_data() 85 | assert(data.shape == self.to_numpy().shape) 86 | 87 | # copy the array 88 | data[...] = self.to_numpy() 89 | 90 | return nib.save(nib_image, path) 91 | -------------------------------------------------------------------------------- /torchmed/readers/pil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | from .reader import Reader 6 | 7 | 8 | class PILReader(Reader): 9 | """Image reader based on Pillow. 10 | 11 | On call to :func:`to_torch` or :func:`to_numpy`, loads an image with PIL's 12 | :func:`PIL.Image.open` and returns a :class:`torch.Tensor`. The image is 13 | loaded on the first call and **kept in memory during the life of the reader**. 14 | 15 | A casting type for numpy and torch can optionally be specified, via 16 | `numpy_type` and `torch_type`. The casting is performed only once, 17 | during image reading. 18 | 19 | Args: 20 | path (string): a path to the image. 21 | transform (:class:`torch.Tensor` -> :class:`torch.Tensor`, optional): 22 | function to apply on the :class:`torch.Tensor` after image reading. 23 | numpy_type (:class:`numpy.dtype`, optional): numpy type to cast the 24 | image returned by :func:`PIL.Image.open`. 25 | torch_type (:class:`torch.dtype`, optional): torch type to cast the 26 | :class:`numpy.ndarray`. 27 | shared_memory (bool, optional): whether to move the underlying tensor 28 | storage into `shared memory 29 | `_. 30 | 31 | .. note:: 32 | A modification in the tensor/array returned by :func:`to_torch` or 33 | :func:`to_numpy` will also affect the underlying storage. If you want 34 | a copy, with torch call :func:`.clone` and with numpy use :func:`.copy`. 35 | 36 | Examples: 37 | >>> x = torchmed.readers.PILReader("Screenshot.png") 38 | >>> x_arr = x.to_torch() 39 | >>> x_arr.size() 40 | torch.Size([928, 1421, 4]) 41 | >>> # update the underlying storage 42 | >>> im_array[0, 0, 0] = 666 43 | >>> # save the modified array back to a file 44 | >>> im.to_image_file('test.png') 45 | >>> # loads the newly saved image 46 | >>> im2 = torchmed.readers.PILReader('test.png') 47 | >>> im2_array = im2.to_torch() 48 | >>> im_array.size() == im2_array.size() 49 | True 50 | >>> im2_array[0, 0, 0] 51 | tensor(666.) 52 | 53 | """ 54 | def __init__(self, path, transform=None, numpy_type=None, 55 | torch_type=None, shared_memory=True): 56 | super().__init__(path, transform, numpy_type, torch_type, shared_memory) 57 | 58 | def _torch_init(self): 59 | pil_image = Image.open(self._path) 60 | super()._torch_init(pil_image) 61 | 62 | def to_image_file(self, path, cast_type=None): 63 | """Saves the underlying tensor back to an image 64 | 65 | Parameters 66 | ---------- 67 | path : string 68 | a path to the output image. 69 | cast_type : type, optional 70 | if you want to cast the pixel type before writing the image. 71 | 72 | Returns 73 | ------- 74 | None or Exception 75 | Returns an Exception in case of failure by :func:`.save` 76 | and None in case of success. 77 | 78 | """ 79 | assert(isinstance(path, str)) 80 | assert(len(path) > 0) 81 | 82 | if cast_type is not None: 83 | assert(isinstance(cast_type, type)) 84 | im = Image.fromarray(self.to_numpy().astype(cast_type)) 85 | else: 86 | im = Image.fromarray(self.to_numpy()) 87 | 88 | return im.save(path) 89 | -------------------------------------------------------------------------------- /torchmed/readers/reader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Reader(object): 6 | """Image reader parent class. 7 | 8 | On call to :func:`to_torch` or :func:`to_numpy`, loads an image and 9 | returns a :class:`torch.Tensor`. The image is loaded only once on the first 10 | call and **kept in memory during the life of the reader**. 11 | 12 | A casting type for numpy and torch can optionally be specified. 13 | 14 | Args: 15 | path (string): a path to the image. 16 | transform (:class:`torch.Tensor` -> :class:`torch.Tensor`, optional): 17 | a function to apply on the :class:`torch.Tensor` in :func:`_torch_init`. 18 | numpy_type (:class:`numpy.dtype`, optional): a numpy type to cast the 19 | :class:`numpy.ndarray` returned by reader. 20 | torch_type (:class:`torch.dtype`, optional): a torch type to cast the `numpy.ndarray`. 21 | shared_memory (bool, optional): whether to move the underlying tensor 22 | storage into `shared memory 23 | `_. 24 | """ 25 | def __init__(self, path, transform, numpy_type, torch_type, shared_memory): 26 | assert(isinstance(path, str) and len(path) > 0) 27 | self._path = path 28 | self._transform = transform 29 | 30 | self._torch_tensor = None 31 | self._numpy_type = numpy_type 32 | self._torch_type = torch_type 33 | self._shared_memory = shared_memory 34 | 35 | def _torch_init(self, numpy_array): 36 | # cast to numpy type if necessary 37 | if self._numpy_type is not None: 38 | numpy_array = np.array(numpy_array, dtype=self._numpy_type) 39 | else: 40 | numpy_array = np.array(numpy_array) 41 | 42 | # numpy array of type int16 should be converted to int32, 43 | # because torch does not handles np.int16 type 44 | if numpy_array.dtype == np.int16: 45 | self._torch_tensor = torch.from_numpy( 46 | numpy_array.astype(np.int32)) 47 | else: 48 | self._torch_tensor = torch.from_numpy(numpy_array) 49 | 50 | if self._torch_type is not None: 51 | self._torch_tensor = self._torch_tensor.type(self._torch_type) 52 | 53 | if self._transform is not None: 54 | self._torch_tensor = self._transform(self._torch_tensor) 55 | 56 | # Moves the underlying storage to shared memory. 57 | if self._shared_memory: 58 | self._torch_tensor.share_memory_() 59 | 60 | def to_numpy(self): 61 | """Returns a numpy array of the image. 62 | """ 63 | if self._torch_tensor is None: 64 | self._torch_init() 65 | 66 | return self._torch_tensor.numpy() 67 | 68 | def to_torch(self): 69 | """Returns a torch tensor of the image. 70 | 71 | """ 72 | if self._torch_tensor is None: 73 | self._torch_init() 74 | 75 | return self._torch_tensor 76 | -------------------------------------------------------------------------------- /torchmed/readers/simpleitk.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | import torch 4 | 5 | from .reader import Reader 6 | 7 | 8 | class SitkReader(Reader): 9 | """Image reader based on SimpleITK. 10 | 11 | On call to :func:`to_torch` or :func:`to_numpy`, loads an image with SimpleITK's 12 | :func:`sitk.ReadImage` and returns a :class:`torch.Tensor`. The image is 13 | loaded on the first call and **kept in memory during the life of the reader**. 14 | 15 | A casting type for numpy and torch can optionally be specified, via 16 | `numpy_type` and `torch_type`. The casting is performed only once, 17 | during image reading. 18 | 19 | Args: 20 | path (string): a path to the image. 21 | transform (:class:`torch.Tensor` -> :class:`torch.Tensor`, optional): 22 | function to apply on the :class:`torch.Tensor` after image reading. 23 | numpy_type (:class:`numpy.dtype`, optional): numpy type to cast the 24 | array returned by :func:`sitk.GetArrayFromImage`. 25 | torch_type (:class:`torch.dtype`, optional): torch type to cast the 26 | :class:`numpy.ndarray`. 27 | shared_memory (bool, optional): whether to move the underlying tensor 28 | storage into `shared memory 29 | `_. 30 | 31 | .. note:: 32 | A modification in the tensor/array returned by :func:`to_torch` or 33 | :func:`to_numpy` will also affect the underlying storage. If you want 34 | a copy, with torch call :func:`.clone` and with numpy use :func:`.copy`. 35 | 36 | Examples: 37 | >>> im = torchmed.readers.SitkReader("prepro_im_mni_bc.nii.gz") 38 | >>> im_array = im.to_torch() 39 | >>> im_array.size() 40 | torch.Size([182, 218, 182]) 41 | >>> # update the underlying storage 42 | >>> im_array[0, 0, 0] = 666 43 | >>> # save the modified array back to a file 44 | >>> im.to_image_file('test_img.nii.gz') 45 | >>> # loads the newly saved image 46 | >>> im2 = torchmed.readers.SitkReader('test_img.nii.gz') 47 | >>> im2_array = im2.to_torch() 48 | >>> im_array.size() == im2_array.size() 49 | True 50 | >>> im2_array[0, 0, 0] 51 | tensor(666.) 52 | """ 53 | def __init__(self, path, transform=None, numpy_type=None, 54 | torch_type=None, shared_memory=True): 55 | super().__init__(path, transform, numpy_type, torch_type, shared_memory) 56 | 57 | def _torch_init(self): 58 | itk_image = sitk.ReadImage(self._path) 59 | numpy_array = sitk.GetArrayFromImage(itk_image) 60 | super()._torch_init(numpy_array) 61 | 62 | def to_image_file(self, path, cast_type=None): 63 | """Saves the underlying tensor back to an image 64 | 65 | Parameters 66 | ---------- 67 | path : string 68 | a path to the output image. 69 | cast_type : int, optional 70 | if you want to cast the pixel type before writing the image. 71 | 72 | Returns 73 | ------- 74 | None or Exception 75 | Returns an Exception in case of failure by :func:`sitk.WriteImage` 76 | and None in case of success. 77 | 78 | """ 79 | assert(isinstance(path, str)) 80 | assert(len(path) > 0) 81 | 82 | if cast_type is not None: 83 | assert(isinstance(cast_type, int)) 84 | 85 | itk_image = sitk.ReadImage(self._path) 86 | image = sitk.GetImageFromArray(self.to_numpy()) 87 | image.CopyInformation(itk_image) 88 | 89 | if cast_type is None: 90 | itk_image = sitk.Cast(image, 91 | itk_image.GetPixelID()) 92 | else: 93 | itk_image = sitk.Cast(image, cast_type) 94 | 95 | return sitk.WriteImage(itk_image, path) 96 | -------------------------------------------------------------------------------- /torchmed/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .mask_sampler import MaskableSampler 2 | from .target_sampler import TargetSampler 3 | 4 | __all__ = ['MaskableSampler'] 5 | -------------------------------------------------------------------------------- /torchmed/samplers/mask_sampler.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import torch 3 | from .sampler import Sampler 4 | 5 | 6 | class MaskableSampler(Sampler): 7 | """Multi-processed sampler. 8 | 9 | Evaluates possible sampling positions with :func:`get_positions` and 10 | returns samples with :func:`__getitem__`. When :func:`Sampler.build` is called, 11 | a list of spatial coordinates is given to :func:`get_positions` for testing. 12 | 13 | An optional mask can be supplied to precise a sampling area. 14 | 15 | Parameters 16 | ---------- 17 | pattern_map : dict 18 | mapping of the filename (key) with the corresponding pattern (value). 19 | nb_workers : int 20 | number of process used during sampling evaluation. 21 | offset : int, list[int] 22 | offset from one coordinate to the next one, equivalent to stride in CNNs. 23 | If the offset is an int, the same offset is used for all dimensions. If 24 | it's a list of int each dimension will get its own offset. 25 | 26 | 27 | .. note:: 28 | the `pattern_map` keys should be prefixed by `image_` for inputs and 29 | `target_` for targets. The value associated to the key should be a tuple 30 | composed of: the input on which a pattern is applied and the pattern. 31 | 32 | """ 33 | def __init__(self, pattern_map, offset=1, nb_workers=1): 34 | super(MaskableSampler, self).__init__(pattern_map, offset, nb_workers) 35 | 36 | def __getitem__(self, index): 37 | # Extract the data pattern, target pattern and the position of extract 38 | position = self._coordinates[index] 39 | 40 | result = [] 41 | target = [] 42 | for desc, composition in self._pattern_map.items(): 43 | input_name, pattern = composition 44 | if desc.startswith('input'): 45 | if isinstance(input_name, list): 46 | sub_result = [] 47 | for n in input_name: 48 | sub_result.append(pattern(self._data[n].to_torch(), position)) 49 | result.append(sub_result) 50 | else: 51 | result.append(pattern(self._data[input_name].to_torch(), position)) 52 | elif desc.startswith('target'): 53 | target.append(pattern(self._data[input_name].to_torch(), position)) 54 | 55 | result = result[0] if len(result) == 1 else result 56 | target = target[0] if len(target) == 1 else target 57 | 58 | if len(target) == 0: 59 | # if no pattern is specified for the target 60 | # extract the pixel value if it's available, otherwise None 61 | if 'target' in self._data and len(target) == 0: 62 | target = self._data['target'].to_torch()[tuple(position)] 63 | elif 'target' not in self._data: 64 | target = None 65 | 66 | # return the position of extraction with the extracted data & target 67 | return self._coordinates[index], result, target 68 | 69 | def get_positions(self, positions): 70 | """Evaluates valid sampling coordinates. 71 | 72 | For each position, check if the patterns are applyable, if yes, the 73 | position is added to the dataset. Allocates a numpy array of 74 | maximum possible size. This array is returned with the index of the 75 | last element, so that the caller can extract only the relevant part of 76 | this array. 77 | 78 | Parameters 79 | ---------- 80 | positions : tuple(list[int], ..) 81 | tuple containing lists of int, each list for a dimension, each 82 | int for a coordinate. 83 | 84 | Returns 85 | ------- 86 | tuple(:class:`torch.ShortTensor`, int) 87 | tensor of coordinates and the number of valid elements in it. 88 | 89 | """ 90 | 91 | # if there is a mask use it 92 | if 'image_mask' in self._data.keys(): 93 | img_array = self._data['image_mask'].to_torch() 94 | use_mask = True 95 | elif 'image_ref' in self._data.keys(): 96 | img_array = self._data['image_ref'].to_torch() 97 | use_mask = False 98 | else: 99 | raise ValueError("data map must contain at least a reference image.") 100 | 101 | max_coord_nb = 1 102 | for n_coord in [len(l) for l in positions]: 103 | max_coord_nb *= n_coord 104 | coordinates = torch.ShortTensor(max_coord_nb, len(positions)) 105 | 106 | index = 0 107 | for position in product(*positions): 108 | if not use_mask or (use_mask and img_array[position] == 1): 109 | 110 | # for each pixel, see if the patterns are applyable 111 | # if so, store the position for future extraction 112 | can_extract = [] 113 | for desc, composition in self._pattern_map.items(): 114 | input_name, pattern = composition 115 | if desc.startswith('input'): 116 | can_extract.append(pattern.can_apply( 117 | self._data[input_name].to_torch(), position)) 118 | 119 | # if all of the patterns can be extracted 120 | if len(can_extract) > 0 and all(can_extract): 121 | coordinates[index] = torch.ShortTensor(position) 122 | index += 1 123 | 124 | return coordinates, index 125 | -------------------------------------------------------------------------------- /torchmed/samplers/sampler.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import multiprocessing as mp 3 | import torch 4 | import torchmed 5 | 6 | 7 | class Sampler(object): 8 | def __init__(self, pattern_map, offset=1, nb_workers=1): 9 | # checking on pattern_map 10 | assert(isinstance(pattern_map, dict)) 11 | assert(len(pattern_map) > 0) 12 | for k, v in pattern_map.items(): 13 | assert(isinstance(k, str)) 14 | assert(isinstance(v[0], str)) 15 | assert(isinstance(v[1], torchmed.Pattern)) 16 | 17 | # checking on nb_workers 18 | assert(isinstance(nb_workers, int)) 19 | assert(nb_workers > 0) 20 | 21 | # checking on offset 22 | if isinstance(offset, list): 23 | assert(all(isinstance(n, int) for n in offset)) 24 | else: 25 | assert(isinstance(offset, int)) 26 | assert(offset > 0) 27 | 28 | self._coordinates = None 29 | self._nb_workers = nb_workers 30 | self._offset = offset 31 | self._pattern_map = OrderedDict( 32 | sorted(pattern_map.items(), key=lambda t: t[0])) 33 | 34 | def __len__(self): 35 | """Number of sampling coordinates.""" 36 | return len(self._coordinates) 37 | 38 | def build(self, data): 39 | """Evaluates the valid sampling coordinates for all images, with respect 40 | to the pattern map and the data. 41 | 42 | Parameters 43 | ---------- 44 | data : dict 45 | data dictionnary. 46 | 47 | 48 | .. note:: 49 | the `data` dictionnary must contain at least one key named 50 | `image_ref` that we will be used a the main reference for image 51 | size. 52 | 53 | """ 54 | from random import randint 55 | 56 | assert('image_ref' in data.keys()) 57 | 58 | self._data = data 59 | for _, pattern_conf in self._pattern_map.items(): 60 | input_name, pattern = pattern_conf 61 | pattern.prepare(self._data[input_name].to_torch()) 62 | 63 | ref_image_size = self._data['image_ref'].to_torch().size() 64 | 65 | # set extraction offset 66 | if isinstance(self._offset, int): 67 | self._offset = [self._offset] * len(ref_image_size) 68 | if len(self._offset) != len(ref_image_size): 69 | raise ValueError('The offset dimensionality must equal ' 70 | 'the image dimensionality.') 71 | 72 | patch_positions = [list(range(0, ref_image_size[dim], offset)) 73 | for dim, offset in enumerate(self._offset)] 74 | 75 | # eval the optimal split size of dim 0, useful for large images 76 | first_axis_size = len(patch_positions[0]) 77 | ideal_split_size = int(first_axis_size // self._nb_workers) 78 | split_size = randint(ideal_split_size // 3, ideal_split_size // 2) 79 | 80 | # if split_size == 0 the default split size is 1 81 | split_size = split_size if split_size > 0 else 1 82 | 83 | splits = [] 84 | split_index = 0 85 | while split_index < first_axis_size: 86 | # if the next split is too big for the dataset, resize it 87 | if split_index + split_size > first_axis_size: 88 | split_size = first_axis_size - split_index 89 | 90 | splits.append((split_index, split_index + split_size)) 91 | split_index += split_size 92 | 93 | # if we have more than one worker we can use multiprocessing/queues 94 | if self._nb_workers > 1: 95 | # create queues 96 | task_queue = mp.JoinableQueue() 97 | done_queue = mp.Queue() 98 | 99 | for split in splits: 100 | # give the correct extraction positions to the worker 101 | patch_pos = (patch_positions[0][split[0]: split[1]], 102 | *patch_positions[1:]) 103 | task_queue.put(patch_pos) 104 | # add None objects so workers can stop when there is no more work 105 | [task_queue.put(None) for i in range(0, self._nb_workers - 1)] 106 | 107 | # start worker processes 108 | producers = [] 109 | end_events = [mp.Event() for i in range(0, self._nb_workers - 1)] 110 | for i in range(0, self._nb_workers - 1): 111 | process = mp.Process(target=self._sampler_worker, 112 | args=(task_queue, 113 | done_queue, 114 | end_events[i]) 115 | ) 116 | process.start() 117 | producers.append(process) 118 | 119 | """ 120 | Read results from workers and wait until the end of all. 121 | Each worker returns None when it has ended, so we need to count 122 | how many None we received before merging all the 123 | result arrays together. 124 | """ 125 | result_arrays = [] 126 | nb_ended_workers = 0 127 | while nb_ended_workers != self._nb_workers - 1: 128 | worker_result = done_queue.get() 129 | if worker_result is None: 130 | nb_ended_workers += 1 131 | else: 132 | result_arrays.append(worker_result) 133 | 134 | # concatenates all the results 135 | if len(result_arrays) == 0: 136 | self._coordinates = torch.ShortTensor(0, 0) 137 | else: 138 | self._coordinates = torch.cat(result_arrays, 0) 139 | 140 | # we can set free all the background processes 141 | [end_events[i].set() for i in range(0, self._nb_workers - 1)] 142 | 143 | # at this point all the processes are already ended, we can close 144 | # them by calling join on each one, and terminate them properly 145 | for process in producers: 146 | process.join() 147 | 148 | # Join and close queues 149 | done_queue.close() 150 | done_queue.join_thread() 151 | task_queue.close() 152 | task_queue.join_thread() 153 | 154 | # if one process: evaluate all the splits one after the other and concat 155 | else: 156 | for split in splits: 157 | patch_pos = (patch_positions[0][split[0]: split[1]], 158 | *patch_positions[1:]) 159 | coords, nb_elems = self.get_positions(patch_pos) 160 | if self._coordinates is None and nb_elems > 0: 161 | self._coordinates = coords[0:nb_elems, ] 162 | elif nb_elems > 0: 163 | self._coordinates = torch.cat( 164 | (self._coordinates, coords[0:nb_elems, ]), 0) 165 | 166 | def _sampler_worker(self, task_queue, done_queue, end_event): 167 | """ 168 | Each worker is given a task queue to read from, and a result queue to 169 | write results in. The worker is given the positions of extraction, 170 | then it returns a result array, picks up a new task and so on. Once the 171 | worker picks a None object it means there is no more work to do, 172 | thus the worker loop is interrupted. 173 | """ 174 | worker_result = None 175 | while True: 176 | task_args = task_queue.get() 177 | if task_args is None: 178 | task_queue.task_done() 179 | if worker_result is not None: 180 | done_queue.put(worker_result) 181 | done_queue.put(None) 182 | end_event.wait() 183 | break 184 | else: 185 | coords, nb_elems = self.get_positions(task_args) 186 | 187 | if worker_result is None and nb_elems > 0: 188 | worker_result = coords[0:nb_elems, ] 189 | elif nb_elems > 0: 190 | worker_result = torch.cat( 191 | (worker_result, coords[0:nb_elems, ]), 0) 192 | 193 | task_queue.task_done() 194 | -------------------------------------------------------------------------------- /torchmed/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['file', 'augmentation', 'logger_plotter', 'loss', 'metric', 'transforms'] 2 | -------------------------------------------------------------------------------- /torchmed/utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | from scipy.ndimage.interpolation import map_coordinates 5 | import torch 6 | 7 | 8 | def resize_2d(image, size, fx, fy, interp=cv2.INTER_CUBIC): 9 | """Resize an image to the desired size with a specific interpolator. 10 | 11 | image: (ndarray) 3d ndarray of size CxHxW. 12 | size: (tuple) output size of the image. 13 | interp: (int) OpenCV interpolation type. 14 | """ 15 | image = image.numpy() 16 | res_img = cv2.resize(image.swapaxes(0, 2), size, fx, fy, 17 | interpolation=interp).swapaxes(0, 2) 18 | return torch.from_numpy(res_img) 19 | 20 | 21 | def center_rotate_2d(image, angle, scale=1., interp=cv2.INTER_NEAREST, 22 | border_mode=cv2.BORDER_REPLICATE, border_value=0): 23 | """Apply a center rotation on a 2d image with a defined angle and scale. 24 | It supports 2d images with multiple channels. 25 | 26 | image: (ndarray) 3d ndarray of size CxHxW. 27 | angle: (float) rotation angle in degrees. 28 | scale: (float) scaling factor. 29 | interp: (int) OpenCV interpolation type. 30 | """ 31 | image = image.numpy() 32 | h, w = image.shape[-2:] 33 | 34 | # ask OpenCV for the rotation matrix 35 | rot_mat = cv2.getRotationMatrix2D((h // 2, w // 2), angle, scale) 36 | # apply the affine transform on the image 37 | res_img = cv2.warpAffine(image, rot_mat, (h, w), 38 | flags=interp, borderMode=border_mode, 39 | borderValue=border_value) 40 | return torch.from_numpy(res_img) 41 | 42 | 43 | def elastic_deformation_2d(image, alpha, sigma, order=0, mode='constant', 44 | constant=0, random_state=None): 45 | """Elastic deformation of images as described in [Simard2003]_ 46 | Simard, Steinkraus and Platt, "Best Practices for Convolutional Neural 47 | Networks applied to Visual Document Analysis" 48 | Based on https://www.kaggle.com/bguberfain/elastic-transform-for-data-augmentation 49 | 50 | image: (ndarray) 3d ndarray of size CxHxW. 51 | alpha: (number) Intensity of the deformation. 52 | sigma: (number) Sigma for smoothing the transformation. 53 | order: (int) coordinate remapping : order of the spline interpolation. 54 | mode: (str) coordinate remapping : interpolation type. 55 | constant: (int) constant value if mode is 'constant'. 56 | random_state: (RandomState) Numpy random state. 57 | """ 58 | image = image.numpy() 59 | if random_state is None: 60 | random_state = np.random.RandomState(None) 61 | 62 | shape = image.shape 63 | 64 | # random displacement field 65 | def_x = random_state.rand(*shape[-2:]) * 2 - 1 66 | def_y = random_state.rand(*shape[-2:]) * 2 - 1 67 | 68 | # smooth the displacement field of x,y axis 69 | dx = cv2.GaussianBlur(def_x, (0, 0), sigma) * alpha 70 | dy = cv2.GaussianBlur(def_y, (0, 0), sigma) * alpha 71 | 72 | # repeat the displacement field for each channel 73 | dx = np.repeat(dx[np.newaxis, :], shape[0], axis=0) 74 | dy = np.repeat(dy[np.newaxis, :], shape[0], axis=0) 75 | 76 | # grid of coordinates 77 | x, z, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), 78 | np.arange(shape[2])) 79 | 80 | indices = (z.reshape(-1, 1), np.reshape(x + dx, (-1, 1)), 81 | np.reshape(y + dy, (-1, 1))) 82 | 83 | def_img = map_coordinates(image, indices, order=order, 84 | mode=mode).reshape(shape) 85 | return torch.from_numpy(def_img) 86 | 87 | 88 | def max_crop_size_for_rotation(final_crop_size, max_rotation_angle): 89 | """ 90 | In order to find L, the size of the largest square containing a rotated 91 | square of size final_crop_size, we need to decompose L into x and y. 92 | L = x + y 93 | x = sin(max_rotation_angle) * final_crop_size 94 | y = sin(180-90-max_rotation_angle) * final_crop_size 95 | """ 96 | x = math.sin(math.radians(max_rotation_angle)) * final_crop_size 97 | y = math.sin(math.radians(180 - 90 - max_rotation_angle)) * final_crop_size 98 | 99 | return math.ceil(x + y) 100 | -------------------------------------------------------------------------------- /torchmed/utils/file.py: -------------------------------------------------------------------------------- 1 | from shutil import copyfile 2 | import os 3 | import numpy as np 4 | import cv2 5 | import matplotlib.pyplot 6 | 7 | 8 | def copy_file(source, destination=None, append='', prepend=''): 9 | """ 10 | Copy a file to a destination folder/file. 11 | If the destination is not specified, then the file is copied in the same 12 | directory as the source. 13 | """ 14 | if destination is None: 15 | destination = os.path.dirname(source) 16 | filename = os.path.basename(source) 17 | file_wo_ext = os.path.splitext(filename)[0] 18 | ext = os.path.splitext(filename)[1] 19 | 20 | # if the destination is a directory we can append prepend on the filename 21 | if os.path.isdir(destination): 22 | copyfile(source, os.path.join(destination, 23 | append + file_wo_ext + prepend + ext)) 24 | else: 25 | copyfile(source, destination) 26 | 27 | 28 | def export_to_file(tensor, file): 29 | """ 30 | Export a torch tensor to file 31 | """ 32 | from scipy import misc 33 | misc.imsave(file, tensor.numpy()) 34 | 35 | 36 | def write_image_segmentation(img, target, nb_classes, output_dir): 37 | def get_annotated_image(tensor, n_labels, colors): 38 | temp = tensor.numpy() 39 | r = temp.copy() 40 | g = temp.copy() 41 | b = temp.copy() 42 | 43 | for l in range(0, n_labels): 44 | r[temp == l] = colors[l, 0] 45 | g[temp == l] = colors[l, 1] 46 | b[temp == l] = colors[l, 2] 47 | 48 | # for unwanted labels 49 | r[temp == -1] = 255 50 | g[temp == -1] = 255 51 | b[temp == -1] = 255 52 | 53 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 54 | rgb[:, :, 0] = r 55 | rgb[:, :, 1] = g 56 | rgb[:, :, 2] = b 57 | 58 | return rgb 59 | 60 | def get_spaced_colors(n): 61 | max_value = 16581375 # 255**3 62 | interval = int(max_value / n) 63 | colors = [hex(I)[2:].zfill(6) for I in range(0, max_value, interval)] 64 | return [(int(i[:2], 16), int(i[2:4], 16), int(i[4:], 16)) for i in colors] 65 | 66 | if not os.path.exists(output_dir): 67 | os.makedirs(output_dir) 68 | 69 | if target is not None: 70 | seg_colors = np.array(get_spaced_colors(nb_classes)) 71 | for n in range(0, img.size(0)): 72 | brain = cv2.cvtColor(img[n].numpy(), cv2.COLOR_GRAY2BGR) 73 | res = cv2.normalize(brain, brain, 0, 255, cv2.NORM_MINMAX) 74 | if target is not None: 75 | seg_brain = get_annotated_image(target[n], nb_classes, seg_colors) 76 | res = np.concatenate([res, seg_brain], axis=1).astype(int) 77 | 78 | output_file = os.path.join(output_dir, str(n) + '.png') 79 | matplotlib.pyplot.imsave(output_file, res) 80 | -------------------------------------------------------------------------------- /torchmed/utils/logger_plotter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class LoggerPlotter(object): 9 | def __init__(self, output_path, log_names, figure_names): 10 | self.dir_fig = os.path.join(output_path, 'figures') 11 | self.figures = {name: os.path.join(self.dir_fig, name) 12 | for name in figure_names} 13 | 14 | self.dir_log = os.path.join(output_path, 'logs') 15 | os.makedirs(self.dir_fig) 16 | os.makedirs(self.dir_log) 17 | 18 | self.metrics = {} 19 | self.logs = {} 20 | for name in log_names: 21 | path = os.path.join(self.dir_log, name) 22 | self.logs.update({name: (path, open(path, "a"))}) 23 | 24 | def log(self, log_name, line): 25 | self.logs[log_name][1].write(line + '\n') 26 | self.logs[log_name][1].flush() 27 | 28 | def add_line(self, line_name, log_name, 29 | x_attribute, y_attribute, 30 | color, linewidth=1, alpha=1): 31 | df = pd.read_csv(self.logs[log_name][0], sep=";") 32 | plt.figure(1, figsize=(10, 7), dpi=100, edgecolor='b') 33 | 34 | line, = plt.plot(df[x_attribute], df[y_attribute], label=line_name) 35 | plt.setp(line, color=color, linewidth=linewidth, alpha=alpha) 36 | 37 | def plot(self, fig_name, x_label, y_label, max_x=None, max_y=None): 38 | if max_x is not None: 39 | plt.xlim(xmax=max_x) 40 | if max_y is not None: 41 | plt.ylim(top=max_y) 42 | 43 | plt.axis([0, max_x, 0, max_y]) 44 | plt.xlabel(x_label) 45 | plt.ylabel(y_label) 46 | 47 | plt.grid(alpha=0.6, linestyle='dotted') 48 | plt.legend() 49 | plt.savefig(self.figures[fig_name]) 50 | plt.clf() 51 | 52 | def write_val_metrics(self, it, file): 53 | line = '{:.3f}'.format(it) 54 | for key, metric in self.metrics.items(): 55 | line += (';' + metric.report_format).format(metric.val) 56 | self.log(file, line) 57 | 58 | def write_avg_metrics(self, it, file): 59 | line = '{:d}'.format(it) 60 | for key, metric in self.metrics.items(): 61 | line += (';' + metric.report_format).format(metric.avg) 62 | self.log(file, line) 63 | 64 | def add_metric(self, metric_dict): 65 | for m in metric_dict: 66 | self.metrics.update({m.id_name: m}) 67 | 68 | def print_metrics(self, epoch, iteration, total_iteration, phase='train'): 69 | phase_str = '@ Train ' if phase == 'train' else '# Test_ ' 70 | sep = ' |' if phase == 'train' else ' /' 71 | epoch_str = '[{0:^5}-{1:>5}/{2:<5}]'.format(epoch, iteration, total_iteration) 72 | metric_str = '' 73 | for key, metric in self.metrics.items(): 74 | tmp_str = (sep + ' {} = ' + metric.raw_format + ' (' + metric.avg_format + ')') 75 | metric_str += tmp_str.format(metric.display_name, metric.val, metric.avg) 76 | print(phase_str + epoch_str + metric_str) 77 | 78 | def clear_metrics(self): 79 | for key, metric in self.metrics.items(): 80 | metric.reset() 81 | 82 | 83 | class MetricLogger(object): 84 | """Compute and store the average and current value""" 85 | def __init__(self, id_name, display_name, raw_format=':.2e', avg_format=':.3e', 86 | report_format=':.5f'): 87 | self.reset() 88 | self.id_name = id_name 89 | self.display_name = display_name 90 | self.raw_format = raw_format 91 | self.avg_format = avg_format 92 | self.report_format = report_format 93 | 94 | def reset(self): 95 | self.val = 0 96 | self.avg = 0 97 | self.sum = 0 98 | self.count = 0 99 | 100 | def update(self, val, n=1): 101 | self.val = val 102 | self.sum += val * n 103 | self.count += n 104 | self.avg = self.sum / self.count 105 | -------------------------------------------------------------------------------- /torchmed/utils/loss.py: -------------------------------------------------------------------------------- 1 | def dice_loss(output, target, ignore_index=None): 2 | """ 3 | output : NxCxHxW Variable 4 | target : NxHxW LongTensor 5 | ignore_index : int index to ignore from loss 6 | """ 7 | encoded_target = output.detach() * 0 8 | if ignore_index is not None: 9 | mask = target == ignore_index 10 | target = target.clone() 11 | target[mask] = 0 12 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 13 | mask = mask.unsqueeze(1).expand_as(encoded_target) 14 | encoded_target[mask] = 0 15 | else: 16 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 17 | 18 | numerator = 2 * (output * encoded_target).sum(0).sum(1).sum(1) 19 | denominator = output + encoded_target 20 | 21 | if ignore_index is not None: 22 | denominator[mask] = 0 23 | denominator = denominator.sum(0).sum(1).sum(1) 24 | loss = 1 - (numerator / denominator) 25 | 26 | return loss.sum() / loss.size(0) 27 | -------------------------------------------------------------------------------- /torchmed/utils/multiproc.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | 4 | 5 | def parallelize_system_calls(nb_workers, commands): 6 | # create queues 7 | task_queue = mp.JoinableQueue() 8 | done_queue = mp.Queue() 9 | 10 | if isinstance(commands, (list, tuple)): 11 | for command in commands: 12 | task_queue.put(command) 13 | else: 14 | task_queue.put(commands) 15 | 16 | # add None objects so workers can stop when there is no more work 17 | [task_queue.put(None) for i in range(0, nb_workers - 1)] 18 | 19 | # start worker processes 20 | producers = [] 21 | end_events = [mp.Event() for i in range(0, nb_workers - 1)] 22 | for i in range(0, nb_workers - 1): 23 | process = mp.Process(target=sampler_worker, 24 | args=(task_queue, 25 | done_queue, 26 | end_events[i]) 27 | ) 28 | process.start() 29 | producers.append(process) 30 | 31 | nb_ended_workers = 0 32 | while nb_ended_workers != nb_workers - 1: 33 | worker_result = done_queue.get() 34 | if worker_result is None: 35 | nb_ended_workers += 1 36 | 37 | # we can set free all the background processes 38 | [end_events[i].set() for i in range(0, nb_workers - 1)] 39 | 40 | # at this point all the processes are already ended, we can close 41 | # them by calling join on each one, and terminate them properly 42 | for process in producers: 43 | process.join() 44 | 45 | # Join and close queues 46 | done_queue.close() 47 | done_queue.join_thread() 48 | task_queue.close() 49 | task_queue.join_thread() 50 | 51 | 52 | def sampler_worker(task_queue, done_queue, end_event): 53 | """ 54 | Each worker is given a task queue to read from, and a result queue to 55 | write results in. The worker is given the positions of extraction, 56 | then it returns a result array, picks up a new task and so on. Once the 57 | worker picks a None object it means there is no more work to do, 58 | thus the worker loop is broken. 59 | """ 60 | worker_result = None 61 | while True: 62 | task_args = task_queue.get() 63 | if task_args is None: 64 | task_queue.task_done() 65 | if worker_result is not None: 66 | done_queue.put(worker_result) 67 | done_queue.put(None) 68 | end_event.wait() 69 | break 70 | else: 71 | os.system(task_args) 72 | task_queue.task_done() 73 | -------------------------------------------------------------------------------- /torchmed/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | 3 | 4 | def N4BiasFieldCorrection(image, destination, nb_iteration=50): 5 | inputImage = sitk.ReadImage(image) 6 | 7 | inputImage = sitk.Cast(inputImage, sitk.sitkFloat32) 8 | corrector = sitk.N4BiasFieldCorrectionImageFilter() 9 | corrector.SetMaximumNumberOfIterations(nb_iteration) 10 | 11 | output = corrector.Execute(inputImage) 12 | sitk.WriteImage(output, destination) 13 | -------------------------------------------------------------------------------- /torchmed/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class CenterSquareCrops(object): 6 | """Returns a set of crops at different resolutions. 7 | Args: 8 | resolutions (Iterable[int]): list of resolutions to extract 9 | Example: 10 | >>> CenterSquareCrops([25, 51, 75]) 11 | """ 12 | def __init__(self, resolutions): 13 | self.resolutions = resolutions 14 | 15 | def __call__(self, tensor): 16 | crops = [] 17 | 18 | w, h = tensor.size()[-2:] 19 | for res in self.resolutions: 20 | x = int(round((w - res) / 2.)) 21 | y = int(round((h - res) / 2.)) 22 | crops.append(tensor[..., x:x + res, y:y + res]) 23 | 24 | return crops 25 | 26 | 27 | class Pad(object): 28 | def __init__(self, pad, mode='constant', fill_value=None): 29 | self.pad = pad 30 | self.mode = mode 31 | self.fill_value = fill_value 32 | 33 | def __call__(self, input): 34 | input = input.numpy() 35 | if self.mode == 'constant': 36 | x = np.pad(input, self.pad, self.mode, 37 | constant_values=self.fill_value) 38 | else: 39 | x = np.pad(input, self.pad, self.mode) 40 | 41 | return torch.from_numpy(x) 42 | 43 | 44 | class Unsqueeze(object): 45 | def __init__(self, dimension): 46 | self.dimension = dimension 47 | 48 | def __call__(self, input): 49 | if isinstance(input, list): 50 | for i, item in enumerate(input): 51 | input[i] = item.unsqueeze(self.dimension) 52 | return input 53 | 54 | else: 55 | return input.unsqueeze(self.dimension) 56 | 57 | 58 | class Squeeze(object): 59 | def __init__(self, dimension): 60 | self.dimension = dimension 61 | 62 | def __call__(self, input): 63 | if isinstance(input, list): 64 | for i, item in enumerate(input): 65 | input[i] = item.squeeze(self.dimension) 66 | return input 67 | 68 | else: 69 | return input.squeeze(self.dimension) 70 | 71 | 72 | class RemoveLastIndex(object): 73 | def __init__(self, dimension): 74 | self.dimension = dimension 75 | 76 | def __call__(self, input): 77 | if isinstance(input, list): 78 | for i, item in enumerate(input): 79 | input[i] = input[i].narrow(self.dimension, 80 | 0, 81 | input[i].size(self.dimension) - 1) 82 | return input 83 | 84 | else: 85 | return input.narrow(self.dimension, 86 | 0, 87 | input.size(self.dimension) - 1) 88 | 89 | 90 | class TransformNFirst(object): 91 | def __init__(self, transform, n): 92 | self.transform = transform 93 | self.n = n 94 | 95 | def __call__(self, tensor_list): 96 | ret_list = [] 97 | for i in range(0, len(tensor_list)): 98 | if i < self.n: 99 | ret_list.append(self.transform(tensor_list[i])) 100 | else: 101 | ret_list.append(tensor_list[i]) 102 | 103 | return ret_list 104 | --------------------------------------------------------------------------------