├── inferno ├── utils │ ├── __init__.py │ ├── math_utils.py │ ├── exceptions.py │ ├── test_utils.py │ ├── model_utils.py │ ├── io_utils.py │ ├── partial_cls.py │ └── torch_utils.py ├── version.py ├── inferno.py ├── extensions │ ├── containers │ │ ├── __init__.py │ │ └── sequential.py │ ├── metrics │ │ ├── __init__.py │ │ ├── cremi_score.py │ │ └── base.py │ ├── initializers │ │ ├── __init__.py │ │ ├── presets.py │ │ └── base.py │ ├── models │ │ └── __init__.py │ ├── optimizers │ │ ├── __init__.py │ │ ├── ranger.py │ │ ├── annealed_adam.py │ │ └── adam.py │ ├── criteria │ │ ├── __init__.py │ │ ├── elementwise_measures.py │ │ ├── core.py │ │ └── regularized.py │ ├── layers │ │ ├── identity.py │ │ ├── activations.py │ │ ├── normalization.py │ │ ├── __init__.py │ │ ├── convolutional_blocks.py │ │ ├── sampling.py │ │ └── device.py │ └── __init__.py ├── io │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── data_utils.py │ │ ├── base.py │ │ └── concatenate.py │ ├── transform │ │ └── __init__.py │ ├── volumetric │ │ └── __init__.py │ └── box │ │ ├── __init__.py │ │ └── binary_blobs.py ├── trainers │ ├── __init__.py │ └── callbacks │ │ ├── __init__.py │ │ ├── logging │ │ ├── __init__.py │ │ └── base.py │ │ ├── tqdmstub.py │ │ ├── gradients.py │ │ ├── console.py │ │ └── tqdm.py └── __init__.py ├── docs ├── authors.rst ├── history.rst ├── readme.rst ├── contributing.rst ├── .gitignore ├── inferno-apidoc │ ├── modules.rst │ ├── inferno.io.rst │ ├── inferno.extensions.rst │ ├── inferno.trainers.rst │ ├── inferno.rst │ ├── inferno.extensions.containers.rst │ ├── inferno.extensions.initializers.rst │ ├── inferno.extensions.optimizers.rst │ ├── inferno.trainers.callbacks.logging.rst │ ├── inferno.io.volumetric.rst │ ├── inferno.io.box.rst │ ├── inferno.io.core.rst │ ├── inferno.io.transform.rst │ ├── inferno.extensions.criteria.rst │ ├── inferno.extensions.metrics.rst │ ├── inferno.trainers.callbacks.rst │ ├── inferno.utils.rst │ └── inferno.extensions.layers.rst ├── graphics │ ├── tentative_logo.pdf │ ├── plain_tentative_logo.svg │ └── tentative_logo.svg ├── _templates │ ├── layout.html │ └── template_module.rst ├── zbibliography.rst ├── examples.rst ├── refs.bib ├── environment.yml ├── index.rst └── installation.rst ├── requirements.txt ├── tests ├── __init__.py ├── test_io │ ├── __init__.py │ ├── test_box │ │ ├── __init__.py │ │ ├── test_cityscapes.py │ │ └── test_camvid.py │ ├── test_core │ │ ├── __init__.py │ │ ├── test_concatenate.py │ │ └── test_zip.py │ └── test_volumetric │ │ ├── __init__.py │ │ ├── test_volume_loader.py │ │ └── test_lazy_volume_loader.py ├── test_utils │ ├── __init__.py │ ├── test_model_utils.py │ ├── test_train_utils.py │ └── test_partial_cls.py ├── test_extensions │ ├── __init__.py │ ├── test_models │ │ ├── __init__.py │ │ ├── test_unet.py │ │ └── test_res_unet.py │ ├── test_layers │ │ ├── test_activations.py │ │ ├── test_convolutional.py │ │ ├── test_device.py │ │ ├── test_reshape.py │ │ └── deprecated │ │ │ └── building_blocks.py │ ├── test_criteria │ │ ├── test_core.py │ │ ├── test_elementwise_measures.py │ │ └── test_set_similarity_measures.py │ ├── test_metrics │ │ └── categorical.py │ └── test_containers │ │ └── test_graph.py └── test_training │ ├── __init__.py │ └── test_callbacks │ ├── __init__.py │ ├── test_logging │ ├── __init__.py │ ├── test_base.py │ └── test_tensorboard.py │ ├── test_scheduling.py │ ├── test_base.py │ └── test_essentials.py ├── examples ├── README.txt ├── trainer.py └── regularized_mnist.py ├── readthedocs.yml ├── add2path.sh ├── MANIFEST.in ├── conda-recipe ├── build.sh └── meta.yaml ├── requirements_dev.txt ├── .editorconfig ├── .github └── ISSUE_TEMPLATE.md ├── LICENSE ├── HISTORY.rst ├── .gitignore ├── AUTHORS.rst ├── .travis.yml ├── setup.py ├── Makefile └── CONTRIBUTING.rst /inferno/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst 2 | -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../HISTORY.rst 2 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /inferno/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4.0' 2 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | /inferno.rst 2 | /inferno.*.rst 3 | /modules.rst 4 | -------------------------------------------------------------------------------- /inferno/inferno.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Main module.""" 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dill 2 | pyyaml 3 | scipy>=0.13.0 4 | h5py 5 | numpy>=1.8 6 | scikit-image -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /inferno/extensions/containers/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph import * 2 | from .sequential import * 3 | -------------------------------------------------------------------------------- /inferno/extensions/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .categorical import * 2 | from .arand import * 3 | -------------------------------------------------------------------------------- /inferno/extensions/initializers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .presets import * 3 | 4 | -------------------------------------------------------------------------------- /tests/test_io/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | 2 | .. _examples-index: 3 | 4 | Gallery of Examples 5 | =================== 6 | 7 | -------------------------------------------------------------------------------- /tests/test_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /tests/test_io/test_box/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /tests/test_training/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /inferno/extensions/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import UNet, UNetBase 2 | from .res_unet import ResBlockUNet 3 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | conda: 2 | file: docs/environment.yml 3 | python: 4 | version: 3.5 5 | pip_install: false -------------------------------------------------------------------------------- /tests/test_io/test_core/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/modules.rst: -------------------------------------------------------------------------------- 1 | inferno 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | inferno 8 | -------------------------------------------------------------------------------- /tests/test_io/test_volumetric/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /add2path.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Run this script from within the directory. 3 | export PYTHONPATH=${PYTHONPATH}:${PWD} -------------------------------------------------------------------------------- /docs/graphics/tentative_logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inferno-pytorch/inferno/HEAD/docs/graphics/tentative_logo.pdf -------------------------------------------------------------------------------- /inferno/io/__init__.py: -------------------------------------------------------------------------------- 1 | from . import box 2 | from . import core 3 | from . import transform 4 | from . import volumetric -------------------------------------------------------------------------------- /tests/test_extensions/test_models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /tests/test_training/test_callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /tests/test_training/test_callbacks/test_logging/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for inferno.""" 4 | -------------------------------------------------------------------------------- /inferno/io/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SyncableDataset 2 | from .zip import Zip, ZipReject 3 | from .concatenate import Concatenate 4 | -------------------------------------------------------------------------------- /inferno/io/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Transform, Compose 2 | from . import generic 3 | from . import image 4 | from . import volume 5 | -------------------------------------------------------------------------------- /inferno/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import basic 2 | from . import callbacks 3 | from . basic import Trainer 4 | __all__ = ['basic','callbacks','Trainer'] -------------------------------------------------------------------------------- /inferno/extensions/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam import Adam 2 | from .annealed_adam import AnnealedAdam 3 | from .ranger import Ranger, RangerQH, RangerVA 4 | -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {# layout.html #} 2 | {# Import the theme's layout. #} 3 | {% extends "!layout.html" %} 4 | 5 | {% set css_files = css_files + ['_static/pygments.css'] %} -------------------------------------------------------------------------------- /docs/zbibliography.rst: -------------------------------------------------------------------------------- 1 | .. _inferno_bibliography: 2 | 3 | Bibliography 4 | ============================ 5 | 6 | The bibliography: 7 | 8 | .. bibliography:: refs.bib 9 | :style: alpha -------------------------------------------------------------------------------- /inferno/io/volumetric/__init__.py: -------------------------------------------------------------------------------- 1 | from .volume import VolumeLoader, HDF5VolumeLoader, TIFVolumeLoader from .lazy_volume_loader import LazyHDF5VolumeLoader, LazyZarrVolumeLoader, LazyN5VolumeLoader -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | .. _inferno_examples_gallery: 2 | 3 | Inferno Examples Gallery 4 | ============================ 5 | 6 | 7 | .. toctree:: 8 | :maxdepth: 5 9 | 10 | ../auto_examples/index 11 | 12 | -------------------------------------------------------------------------------- /inferno/extensions/criteria/__init__.py: -------------------------------------------------------------------------------- 1 | from .set_similarity_measures import * 2 | from .elementwise_measures import * 3 | from .core import * 4 | from .regularized import * 5 | 6 | __all__ = ['set_similarity_measures', 'elementwise_measures','core','regularized'] -------------------------------------------------------------------------------- /inferno/extensions/layers/identity.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | __all__ = ['identity'] 3 | _all = __all__ 4 | 5 | class Identity(nn.Module): 6 | def __init__(self): 7 | super(Identity, self).__init__() 8 | 9 | def forward(self, x): 10 | return x -------------------------------------------------------------------------------- /inferno/io/core/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def implements_sync_primitives(dataset): 3 | return hasattr(dataset, 'sync_with') and callable(getattr(dataset, 'sync_with')) 4 | 5 | 6 | def defines_base_sequence(dataset): 7 | return hasattr(dataset, 'base_sequence') and dataset.base_sequence is not None 8 | -------------------------------------------------------------------------------- /docs/refs.bib: -------------------------------------------------------------------------------- 1 | 2 | @inproceedings{alush_2013_simbad, 3 | title={Break and Conquer: Efficient Correlation Clustering for Image Segmentation}, 4 | author={Alush, Amir and Goldberger, Jacob}, 5 | booktitle={2nd International Workshop on Similarity-Based Pattern Analysis and Recognition}, 6 | year={2013} 7 | } 8 | -------------------------------------------------------------------------------- /inferno/extensions/optimizers/ranger.py: -------------------------------------------------------------------------------- 1 | # easy support for additional ranger optimizers from 2 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 3 | try: 4 | from ranger import Ranger, RangerVA, RangerQH 5 | except ImportError: 6 | Ranger = None 7 | RangerVA = None 8 | RangerQH = None 9 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include HISTORY.rst 4 | include LICENSE 5 | include README.rst 6 | 7 | recursive-include tests * 8 | recursive-exclude * __pycache__ 9 | recursive-exclude * *.py[co] 10 | 11 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 12 | -------------------------------------------------------------------------------- /conda-recipe/build.sh: -------------------------------------------------------------------------------- 1 | PY_VER=$(python -c "import sys; print('{}.{}'.format(*sys.version_info[:2]))") 2 | 3 | # Install python modules 4 | mkdir -p ${PREFIX}/inferno 5 | cp -r inferno/* ${PREFIX}/inferno 6 | echo "${PREFIX}" > ${PREFIX}/lib/python${PY_VER}/site-packages/inferno.pth 7 | python -m compileall ${PREFIX}/inferno 8 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pip==8.1.2 2 | bumpversion==0.5.3 3 | wheel==0.29.0 4 | watchdog==0.8.3 5 | flake8==2.6.0 6 | tox==2.3.1 7 | coverage==4.1 8 | Sphinx==1.4.8 9 | cryptography==1.7 10 | PyYAML==5.1 11 | dill 12 | pyyaml 13 | scipy>=0.13.0 14 | h5py 15 | scikit-image 16 | sphinx-gallery 17 | sphinxcontrib-napoleon 18 | sphinxcontrib-inlinesyntaxhighlight 19 | sphinx_rtd_theme -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /inferno/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for inferno.""" 4 | 5 | from . import extensions 6 | from . import io 7 | from . import trainers 8 | from . import utils 9 | from .version import __version__ 10 | 11 | 12 | __all__ = ['extensions', 'io', 'trainers', 'utils'] 13 | 14 | __author__ = """Nasim Rahaman""" 15 | __email__ = 'nasim.rahaman@iwr.uni-heidelberg.de' 16 | -------------------------------------------------------------------------------- /inferno/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | from . import containers 2 | from . import criteria 3 | from . import initializers 4 | from . import layers 5 | from . import metrics 6 | from . import optimizers 7 | from . import models 8 | # Backward support 9 | from . import models as model 10 | 11 | __all__ = ['containers', 'criteria', 'initializers', 'layers', 'metrics', 'optimizers', 12 | 'models', 'model'] -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * inferno version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.io.rst: -------------------------------------------------------------------------------- 1 | inferno.io package 2 | ================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | inferno.io.box 10 | inferno.io.core 11 | inferno.io.transform 12 | inferno.io.volumetric 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: inferno.io 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /inferno/io/box/__init__.py: -------------------------------------------------------------------------------- 1 | """Things that work out of the box. ;)""" 2 | 3 | from .camvid import CamVid, get_camvid_loaders 4 | from .cityscapes import Cityscapes, get_cityscapes_loaders 5 | from .cifar import get_cifar10_loaders, get_cifar100_loaders 6 | 7 | 8 | __all__ = [ 9 | 'CamVid','get_camvid_loaders', 'Cityscapes', 'get_cityscapes_loaders', 10 | 'get_cifar10_loaders','get_cifar100_loaders' 11 | ] -------------------------------------------------------------------------------- /tests/test_extensions/test_layers/test_activations.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import inferno.extensions.layers.activations as activations 4 | 5 | 6 | class ActivationTest(unittest.TestCase): 7 | def test_selu(self): 8 | x = torch.rand(100) 9 | y = activations.SELU()(x) 10 | self.assertEqual(list(x.size()), list(y.size())) 11 | 12 | 13 | if __name__ == '__main__': 14 | unittest.main() 15 | -------------------------------------------------------------------------------- /docs/environment.yml: -------------------------------------------------------------------------------- 1 | name: inferno_docs 2 | 3 | channels: 4 | - soumith 5 | - anaconda 6 | 7 | dependencies: 8 | - python==3.5 9 | - pytorch>=0.1.12 10 | - torchvision 11 | - scikit-image 12 | - pip: 13 | - scipy>=0.13.0 14 | - h5py 15 | - scikit-image 16 | - pyyaml 17 | - dill 18 | - sphinx-gallery 19 | - sphinxcontrib-napoleon 20 | - sphinxcontrib-bibtex 21 | - sphinxcontrib-inlinesyntaxhighlight 22 | -------------------------------------------------------------------------------- /inferno/extensions/metrics/cremi_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .voi import voi 3 | from .arand import adapted_rand 4 | 5 | 6 | # TODO build metrics object 7 | 8 | 9 | def cremi_metrics(seg, gt, no_seg_ignore=True): 10 | if no_seg_ignore: 11 | if 0 in seg: 12 | seg += 1 13 | vi_s, vi_m = voi(seg, gt) 14 | rand = 1. - adapted_rand(seg, gt)[0] 15 | cs = np.sqrt((vi_s + vi_m) * rand) 16 | return cs, vi_s, vi_m, rand 17 | -------------------------------------------------------------------------------- /inferno/trainers/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['CallbackEngine', 'Callback', 'Console', 'essentials', 'scheduling', 'gradients'] 2 | 3 | from .base import CallbackEngine, Callback 4 | from .console import Console 5 | from . import essentials 6 | from . import scheduling 7 | from . import gradients 8 | 9 | try: 10 | from .tqdm import TQDMProgressBar 11 | __all__.append('TQDMProgressBar') 12 | except ImportError: 13 | from .tqdmstub import TQDMProgressBar 14 | -------------------------------------------------------------------------------- /inferno/trainers/callbacks/logging/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['get_logger'] 2 | try: 3 | INFERNO_WITH_TENSORBOARD_LOGGER = True 4 | from .tensorboard import TensorboardLogger 5 | __all__.append('TensorboardLogger') 6 | except ImportError: 7 | INFERNO_WITH_TENSORBOARD_LOGGER = False 8 | 9 | 10 | def get_logger(name): 11 | if name in globals(): 12 | return globals().get(name) 13 | else: 14 | raise NotImplementedError("Logger not found.") 15 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to inferno's documentation! 2 | ====================================== 3 | 4 | Contents: 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | readme 10 | installation 11 | usage 12 | examples 13 | contributing 14 | inferno-apidoc/modules 15 | authors 16 | history 17 | zbibliography 18 | 19 | .. automodule:: inferno 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | -------------------------------------------------------------------------------- /inferno/trainers/callbacks/tqdmstub.py: -------------------------------------------------------------------------------- 1 | from .base import Callback 2 | 3 | class TQDMProgressBar(Callback): 4 | def __init__(self, *args, **kwargs): 5 | super(TQDMProgressBar, self).__init__(*args, **kwargs) 6 | 7 | def bind_trainer(self, *args, **kwargs): 8 | super(TQDMProgressBar, self).bind_trainer(*args, **kwargs) 9 | self.trainer.console.warning("tqdm is not installed. will fall back to normal stdout console.") 10 | 11 | def begin_of_fit(self, **_): 12 | pass 13 | -------------------------------------------------------------------------------- /inferno/extensions/layers/activations.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | from ...utils.torch_utils import where 4 | 5 | __all__ = ['SELU'] 6 | _all = __all__ 7 | 8 | class SELU(nn.Module): 9 | def forward(self, input): 10 | return self.selu(input) 11 | 12 | @staticmethod 13 | def selu(x): 14 | alpha = 1.6732632423543772848170429916717 15 | scale = 1.0507009873554804934193349852946 16 | # noinspection PyTypeChecker 17 | return scale * where(x >= 0, x, alpha * F.elu(x)) -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.extensions.rst: -------------------------------------------------------------------------------- 1 | inferno.extensions package 2 | ========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | inferno.extensions.containers 10 | inferno.extensions.criteria 11 | inferno.extensions.initializers 12 | inferno.extensions.layers 13 | inferno.extensions.metrics 14 | inferno.extensions.optimizers 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: inferno.extensions 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.trainers.rst: -------------------------------------------------------------------------------- 1 | inferno.trainers package 2 | ======================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | inferno.trainers.callbacks 10 | 11 | Submodules 12 | ---------- 13 | 14 | inferno.trainers.basic module 15 | ----------------------------- 16 | 17 | .. automodule:: inferno.trainers.basic 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: inferno.trainers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /inferno/extensions/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BatchNormND(nn.Module): 5 | def __init__(self, dim, num_features, 6 | eps=1e-5, momentum=0.1, 7 | affine=True,track_running_stats=True): 8 | super(BatchNormND, self).__init__() 9 | assert dim in [1, 2, 3] 10 | self.bn = getattr(nn, 'BatchNorm{}d'.format(dim))(num_features=num_features, 11 | eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 12 | 13 | def forward(self, x): 14 | return self.bn(x) -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.rst: -------------------------------------------------------------------------------- 1 | inferno package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | inferno.extensions 10 | inferno.io 11 | inferno.trainers 12 | inferno.utils 13 | 14 | Submodules 15 | ---------- 16 | 17 | inferno.inferno module 18 | ---------------------- 19 | 20 | .. automodule:: inferno.inferno 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | 26 | Module contents 27 | --------------- 28 | 29 | .. automodule:: inferno 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | -------------------------------------------------------------------------------- /tests/test_extensions/test_criteria/test_core.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TestCore(unittest.TestCase): 7 | def test_as_2d_criterion(self): 8 | from inferno.extensions.criteria.core import As2DCriterion 9 | 10 | prediction = torch.FloatTensor(2, 10, 100, 100).uniform_() 11 | prediction = nn.Softmax2d()(prediction) 12 | target = torch.LongTensor(2, 100, 100).fill_(0) 13 | criterion = As2DCriterion(nn.CrossEntropyLoss()) 14 | criterion(prediction, target) 15 | 16 | 17 | if __name__ == '__main__': 18 | unittest.main() 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache Software License 2.0 3 | 4 | Copyright (c) 2017, Inferno Developers 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | -------------------------------------------------------------------------------- /tests/test_extensions/test_layers/test_convolutional.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from inferno.utils.model_utils import ModelTester 4 | 5 | 6 | class TestConvolutional(unittest.TestCase): 7 | @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.") 8 | def test_bn_relu_depthwise_conv2d_pyinn(self): 9 | from inferno.extensions.layers.convolutional import BNReLUDepthwiseConv2D 10 | model = BNReLUDepthwiseConv2D(10, 'auto', 3) 11 | ModelTester((1, 10, 100, 100), 12 | (1, 10, 100, 100)).cuda()(model) 13 | self.assertTrue(model.depthwise) 14 | self.assertEqual(model.conv.groups, 10) 15 | 16 | 17 | if __name__ == '__main__': 18 | unittest.main() 19 | -------------------------------------------------------------------------------- /tests/test_extensions/test_criteria/test_elementwise_measures.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import inferno.extensions.criteria.elementwise_measures as em 3 | import torch 4 | 5 | 6 | class TestElementwiseMeasures(unittest.TestCase): 7 | def test_weighted_mse_loss(self): 8 | input = torch.zeros(10, 10) 9 | target = torch.ones(10, 10) 10 | loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target) 11 | self.assertAlmostEqual(loss.item(), 2., delta=1e-5) 12 | target = torch.zeros(10, 10) 13 | input = torch.ones(10, 10) 14 | loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target) 15 | self.assertAlmostEqual(loss.item(), 1., delta=1e-5) 16 | 17 | 18 | if __name__ == '__main__': 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /inferno/extensions/containers/sequential.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ...utils import python_utils as pyu 3 | 4 | 5 | __all__ = ['Sequential1', 'Sequential2'] 6 | 7 | 8 | class Sequential1(nn.Sequential): 9 | """Like torch.nn.Sequential, but with a few extra methods.""" 10 | def __len__(self): 11 | return len(self._modules.values()) 12 | 13 | 14 | class Sequential2(Sequential1): 15 | """Another sequential container. 16 | Identitcal to torch.nn.Sequential, except that modules may return multiple outputs and 17 | accept multiple inputs. 18 | """ 19 | def forward(self, *input): 20 | for module in self._modules.values(): 21 | input = pyu.to_iterable(module(*pyu.to_iterable(input))) 22 | return pyu.from_iterable(input) 23 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.extensions.containers.rst: -------------------------------------------------------------------------------- 1 | inferno.extensions.containers package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.extensions.containers.graph module 8 | ------------------------------------------ 9 | 10 | .. automodule:: inferno.extensions.containers.graph 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.extensions.containers.sequential module 16 | ----------------------------------------------- 17 | 18 | .. automodule:: inferno.extensions.containers.sequential 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: inferno.extensions.containers 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.extensions.initializers.rst: -------------------------------------------------------------------------------- 1 | inferno.extensions.initializers package 2 | ======================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.extensions.initializers.base module 8 | ------------------------------------------- 9 | 10 | .. automodule:: inferno.extensions.initializers.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.extensions.initializers.presets module 16 | ---------------------------------------------- 17 | 18 | .. automodule:: inferno.extensions.initializers.presets 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: inferno.extensions.initializers 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.extensions.optimizers.rst: -------------------------------------------------------------------------------- 1 | inferno.extensions.optimizers package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.extensions.optimizers.adam module 8 | ----------------------------------------- 9 | 10 | .. automodule:: inferno.extensions.optimizers.adam 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.extensions.optimizers.annealed\_adam module 16 | --------------------------------------------------- 17 | 18 | .. automodule:: inferno.extensions.optimizers.annealed_adam 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: inferno.extensions.optimizers 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.trainers.callbacks.logging.rst: -------------------------------------------------------------------------------- 1 | inferno.trainers.callbacks.logging package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.trainers.callbacks.logging.base module 8 | ---------------------------------------------- 9 | 10 | .. automodule:: inferno.trainers.callbacks.logging.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.trainers.callbacks.logging.tensorboard module 16 | ----------------------------------------------------- 17 | 18 | .. automodule:: inferno.trainers.callbacks.logging.tensorboard 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: inferno.trainers.callbacks.logging 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.1.0 (2017-08-24) 6 | ------------------ 7 | 8 | * First early release on PyPI 9 | 10 | 0.1.1 (2017-08-24) 11 | ------------------ 12 | 13 | * Version Increment 14 | 15 | 0.1.2 (2017-08-24) 16 | ------------------ 17 | 18 | * Version Increment 19 | 20 | 21 | 0.1.3 (2017-08-24) 22 | ------------------ 23 | 24 | * Updated Documentation 25 | 26 | 0.1.4 (2017-08-24) 27 | ------------------ 28 | 29 | * travis auto-deployment on pypi 30 | 31 | 32 | 0.1.5 (2017-08-24) 33 | ------------------ 34 | 35 | * travis changes to run unittest 36 | 37 | 38 | 0.1.6 (2017-08-24) 39 | ------------------ 40 | 41 | * travis missing packages for unittesting 42 | * fixed inconsistent version numbers 43 | 44 | 0.1.7 (2017-08-25) 45 | ------------------ 46 | 47 | * setup.py critical bugix in install procedure 48 | 49 | 50 | 51 | CURRENT CHANGES 52 | ----------------- 53 | * Flexible Unet 54 | -------------------------------------------------------------------------------- /tests/test_utils/test_model_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import inferno.utils.model_utils as mu 3 | from inferno.utils.exceptions import ShapeError 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ModelUtilTester(unittest.TestCase): 9 | def test_model_tester(self): 10 | model = mu.ModelTester((1, 10, 32, 32), (1, 20, 32, 32))(nn.Conv2d(10, 20, 3, padding=1)) 11 | with self.assertRaises(ShapeError): 12 | mu.ModelTester((1, 10, 32, 32), (1, 30, 32, 32))(model) 13 | 14 | @unittest.skipUnless(torch.cuda.is_available(), "need cuda") 15 | def test_model_tester_cuda(self): 16 | tester = mu.ModelTester((1, 10, 32, 32), (1, 20, 32, 32)).cuda() 17 | model = tester(nn.Conv2d(10, 20, 3, padding=1).cuda()) 18 | with self.assertRaises(ShapeError): 19 | mu.ModelTester((1, 10, 32, 32), (1, 30, 32, 32)).cuda()(model) 20 | 21 | if __name__ == '__main__': 22 | unittest.main() 23 | -------------------------------------------------------------------------------- /inferno/extensions/layers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | from .activations import * 3 | from .convolutional import * 4 | from .device import * 5 | from .reshape import * 6 | from .convolutional_blocks import * 7 | 8 | ####################################################### 9 | # the following is to make the sphinx example 10 | # gallery makes proper cross-references 11 | from .activations import _all as _activations_all 12 | from .convolutional import _all as _convolutional_all 13 | from .device import _all as _device_all 14 | from .reshape import _all as _reshape_all 15 | from .convolutional_blocks import _all as _convolutional_blocks_all 16 | from .identity import _all as _identity_all 17 | 18 | __all__.extend(_activations_all) 19 | __all__.extend(_convolutional_all) 20 | __all__.extend(_device_all) 21 | __all__.extend(_reshape_all) 22 | __all__.extend(_convolutional_blocks_all) 23 | __all__.extend(_identity_all) 24 | 25 | _all = __all__ 26 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.io.volumetric.rst: -------------------------------------------------------------------------------- 1 | inferno.io.volumetric package 2 | ============================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.io.volumetric.lazy\_volume\_loader module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: inferno.io.volumetric.lazy_volume_loader 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.io.volumetric.volume module 16 | ----------------------------------- 17 | 18 | .. automodule:: inferno.io.volumetric.volume 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | inferno.io.volumetric.volumetric\_utils module 24 | ---------------------------------------------- 25 | 26 | .. automodule:: inferno.io.volumetric.volumetric_utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: inferno.io.volumetric 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /inferno/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def max_allowed_ds_steps(shape, factor): 4 | """How often can a shape be down-sampled by a given factor 5 | such that non of the divisions will give non-integers. 6 | 7 | Args: 8 | shape (listlike): tensor shape 9 | factor (integer): downsample factor 10 | 11 | Returns: 12 | int: maximum allowed downsample operations 13 | """ 14 | def max_allowed_ds_steps_impl(size, factor): 15 | 16 | current_size = float(size) 17 | allowed_steps = 0 18 | while(True): 19 | 20 | new_size = current_size / float(factor) 21 | if(new_size >=1 and new_size.is_integer()): 22 | 23 | current_size = new_size 24 | allowed_steps += 1 25 | else: 26 | break 27 | return allowed_steps 28 | 29 | min_steps = float('inf') 30 | 31 | for s in shape: 32 | min_steps = int(min(min_steps, max_allowed_ds_steps_impl(s, factor))) 33 | 34 | return min_steps 35 | -------------------------------------------------------------------------------- /inferno/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | """Exceptions and Error Handling""" 2 | 3 | 4 | def assert_(condition, message='', exception_type=AssertionError): 5 | """Like assert, but with arbitrary exception types.""" 6 | if not condition: 7 | raise exception_type(message) 8 | 9 | 10 | # ------ VALUE ERRORS ------ 11 | 12 | 13 | class ShapeError(ValueError): 14 | pass 15 | 16 | 17 | class FrequencyValueError(ValueError): 18 | pass 19 | 20 | 21 | class DeviceError(ValueError): 22 | pass 23 | 24 | 25 | class NotSetError(ValueError): 26 | pass 27 | 28 | 29 | # ------ TYPE ERRORS ------ 30 | 31 | 32 | class NotTorchModuleError(TypeError): 33 | pass 34 | 35 | 36 | class FrequencyTypeError(TypeError): 37 | pass 38 | 39 | 40 | class DTypeError(TypeError): 41 | pass 42 | 43 | 44 | # ------ LOOKUP ERRORS ------ 45 | 46 | 47 | class ClassNotFoundError(LookupError): 48 | pass 49 | 50 | 51 | # ------ NOT-IMPLEMENTED ERRORS ------ 52 | 53 | 54 | class NotUnwrappableError(NotImplementedError): 55 | pass -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | # pyenv python configuration file 62 | .python-version 63 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.io.box.rst: -------------------------------------------------------------------------------- 1 | inferno.io.box package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.io.box.binary\_blobs module 8 | ----------------------------------- 9 | 10 | .. automodule:: inferno.io.box.binary_blobs 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.io.box.camvid module 16 | ---------------------------- 17 | 18 | .. automodule:: inferno.io.box.camvid 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | inferno.io.box.cifar module 24 | --------------------------- 25 | 26 | .. automodule:: inferno.io.box.cifar 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | inferno.io.box.cityscapes module 32 | -------------------------------- 33 | 34 | .. automodule:: inferno.io.box.cityscapes 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: inferno.io.box 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.io.core.rst: -------------------------------------------------------------------------------- 1 | inferno.io.core package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.io.core.base module 8 | --------------------------- 9 | 10 | .. automodule:: inferno.io.core.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.io.core.concatenate module 16 | ---------------------------------- 17 | 18 | .. automodule:: inferno.io.core.concatenate 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | inferno.io.core.data\_utils module 24 | ---------------------------------- 25 | 26 | .. automodule:: inferno.io.core.data_utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | inferno.io.core.zip module 32 | -------------------------- 33 | 34 | .. automodule:: inferno.io.core.zip 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: inferno.io.core 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /inferno/extensions/metrics/base.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Metric(object): 4 | 5 | def forward(self, *args, **kwargs): 6 | raise NotImplementedError 7 | 8 | def __call__(self, prediction, target, **kwargs): 9 | # We might have listlike predictions (e.g. multi-scale) 10 | # If so, we evaluate the metric on the first prediction, 11 | # which should be at the original scale 12 | if isinstance(prediction, (list, tuple)): 13 | prediction = prediction[0] 14 | # same is true for the target 15 | if isinstance(target, (list, tuple)): 16 | target = target[0] 17 | # Make sure prediction and target live on the same device. 18 | # If they don't, move target to the right device. 19 | if not prediction.is_cuda: 20 | # Move to CPU 21 | target = target.cpu() 22 | else: 23 | # Find device to move to 24 | device_ordinal = prediction.get_device() 25 | target = target.cuda(device_ordinal) 26 | return self.forward(prediction, target, **kwargs) 27 | -------------------------------------------------------------------------------- /conda-recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: inferno 3 | 4 | {% set tagged_version = GIT_DESCRIBE_TAG|replace("v","")|replace("-", ".") %} 5 | 6 | # If we're using a non-tagged revision, append '.postN' to the version 7 | {% if GIT_DESCRIBE_NUMBER|int != 0 %} 8 | {% set tagged_version = tagged_version + '.post' + GIT_DESCRIBE_NUMBER %} 9 | {% endif %} 10 | 11 | version: {{tagged_version}} 12 | 13 | source: 14 | path: .. 15 | 16 | build: 17 | number: 1 18 | string: py_{{PKG_BUILDNUM}}_g{{GIT_FULL_HASH[:7]}} 19 | 20 | requirements: 21 | build: 22 | - python {{PY_VER}}* 23 | run: 24 | - python {{PY_VER}}* 25 | - pytorch 26 | - torchvision 27 | - pyyaml 28 | - scipy 29 | - scikit-image 30 | - scikit-learn 31 | - h5py 32 | - dill 33 | - networkx 1.11 34 | - tensorboardx 35 | - sphinx_rtd_theme 36 | 37 | 38 | test: 39 | imports: 40 | - inferno 41 | 42 | about: 43 | license: Apache License 2.0 44 | summary: A utility library around PyTorch 45 | -------------------------------------------------------------------------------- /tests/test_io/test_core/test_concatenate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class ConcatenateTest(unittest.TestCase): 5 | def test_concatenate(self): 6 | from inferno.io.core import Concatenate 7 | from torch.utils.data.dataset import Dataset 8 | 9 | with self.assertRaises(AssertionError): 10 | cated = Concatenate([1, 2, 3], [4, 5, 6, 7]) 11 | 12 | class ListDataset(list, Dataset): 13 | pass 14 | 15 | dataset_1 = ListDataset([1, 2, 3, 4]) 16 | dataset_2 = ListDataset([5, 6, 7]) 17 | dataset_3 = ListDataset([8, 9, 10, 11, 12]) 18 | 19 | cated = Concatenate(dataset_1, dataset_2, dataset_3) 20 | self.assertEqual(len(cated), 12) 21 | 22 | # Try to fetch 23 | self.assertEqual(cated[2], 3) 24 | self.assertEqual(cated[4], 5) 25 | self.assertEqual(cated[6], 7) 26 | self.assertEqual(cated[10], 11) 27 | self.assertEqual(cated[11], 12) 28 | 29 | with self.assertRaises(AssertionError): 30 | _ = cated[12] 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.io.transform.rst: -------------------------------------------------------------------------------- 1 | inferno.io.transform package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.io.transform.base module 8 | -------------------------------- 9 | 10 | .. automodule:: inferno.io.transform.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.io.transform.generic module 16 | ----------------------------------- 17 | 18 | .. automodule:: inferno.io.transform.generic 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | inferno.io.transform.image module 24 | --------------------------------- 25 | 26 | .. automodule:: inferno.io.transform.image 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | inferno.io.transform.volume module 32 | ---------------------------------- 33 | 34 | .. automodule:: inferno.io.transform.volume 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: inferno.io.transform 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /tests/test_training/test_callbacks/test_logging/test_base.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from inferno.trainers.callbacks.logging.base import Logger 3 | from inferno.trainers.basic import Trainer 4 | from os.path import join, dirname 5 | 6 | 7 | class DummyLogger(Logger): 8 | def end_of_training_iteration(self, **_): 9 | pass 10 | 11 | 12 | class TestLogger(unittest.TestCase): 13 | ROOT = dirname(__file__) 14 | 15 | def test_serialization(self): 16 | trainer = Trainer()\ 17 | .build_logger(logger=DummyLogger())\ 18 | .save_to_directory(join(self.ROOT, 'saves')) 19 | trainer.save() 20 | # Unserialize 21 | trainer = Trainer().load(from_directory=join(self.ROOT, 'saves')) 22 | # Check if the loggers are consistent 23 | logger_from_trainer = trainer._logger 24 | logger_from_callback_engine = \ 25 | next(iter(trainer.callbacks._callback_registry['end_of_training_iteration'])) 26 | self.assertIs(logger_from_trainer, logger_from_callback_engine) 27 | self.assertIs(logger_from_callback_engine.trainer, trainer) 28 | 29 | 30 | if __name__ == '__main__': 31 | unittest.main() -------------------------------------------------------------------------------- /docs/_templates/template_module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname }} 2 | {{ underline }} 3 | 4 | .. automodule:: {{ fullname }} 5 | 6 | {% block functions %} 7 | {% if functions %} 8 | 9 | Functions 10 | ================== 11 | 12 | {% for item in functions %} 13 | 14 | .. autofunction:: {{ item }} 15 | 16 | .. include:: backreferences/{{fullname}}.{{item}}.examples 17 | 18 | .. raw:: html 19 | 20 |
21 | 22 | {%- endfor %} 23 | {% endif %} 24 | {% endblock %} 25 | 26 | {% block classes %} 27 | {% if classes %} 28 | 29 | Classes 30 | ------- 31 | 32 | {% for item in classes %} 33 | .. autoclass:: {{ item }} 34 | :members: 35 | 36 | .. include:: backreferences/{{fullname}}.{{item}}.examples 37 | 38 | .. raw:: html 39 | 40 |
41 | 42 | 43 | {%- endfor %} 44 | {% endif %} 45 | {% endblock %} 46 | 47 | {% block exceptions %} 48 | {% if exceptions %} 49 | 50 | Exceptions 51 | ---------- 52 | 53 | .. autosummary:: 54 | {% for item in exceptions %} 55 | {{ item }} 56 | {%- endfor %} 57 | {% endif %} 58 | {% endblock %} -------------------------------------------------------------------------------- /inferno/io/core/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | 3 | 4 | class SyncableDataset(Dataset): 5 | def __init__(self, base_sequence=None): 6 | self.base_sequence = base_sequence 7 | 8 | def sync_with(self, dataset): 9 | if hasattr(dataset, 'base_sequence'): 10 | self.base_sequence = dataset.base_sequence 11 | return self 12 | 13 | def __len__(self): 14 | if self.base_sequence is None: 15 | raise RuntimeError("Class {} does not specify a base sequence. Either specify " 16 | "one by assigning to self.base_sequence or override the " 17 | "__len__ method.".format(self.__class__.__name__)) 18 | else: 19 | return len(self.base_sequence) 20 | 21 | 22 | class IndexSpec(object): 23 | """ 24 | Class to wrap any extra index information a `Dataset` object might want to send back. 25 | This could be useful in (say) inference, where we would wish to (asynchronously) know 26 | more about the current input. 27 | """ 28 | def __init__(self, index=None, base_sequence_at_index=None): 29 | self.index = index 30 | self.base_sequence_at_index = base_sequence_at_index 31 | 32 | def __int__(self): 33 | return int(self.index) 34 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.extensions.criteria.rst: -------------------------------------------------------------------------------- 1 | inferno.extensions.criteria package 2 | =================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.extensions.criteria.core module 8 | --------------------------------------- 9 | 10 | .. automodule:: inferno.extensions.criteria.core 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.extensions.criteria.elementwise\_measures module 16 | -------------------------------------------------------- 17 | 18 | .. automodule:: inferno.extensions.criteria.elementwise_measures 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | inferno.extensions.criteria.regularized module 24 | ---------------------------------------------- 25 | 26 | .. automodule:: inferno.extensions.criteria.regularized 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | inferno.extensions.criteria.set\_similarity\_measures module 32 | ------------------------------------------------------------ 33 | 34 | .. automodule:: inferno.extensions.criteria.set_similarity_measures 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: inferno.extensions.criteria 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /tests/test_extensions/test_layers/test_device.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from inferno.extensions.layers.device import DeviceTransfer, OnDevice 3 | import torch 4 | 5 | 6 | class TransferTest(unittest.TestCase): 7 | @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.") 8 | def test_device_transfer(self): 9 | if not torch.cuda.is_available(): 10 | return 11 | # Build transfer model 12 | transfer = DeviceTransfer('cpu') 13 | x = torch.rand(10, 10).cuda() 14 | y = transfer(x) 15 | loss = y.mean() 16 | loss.backward() 17 | self.assertFalse(y.data.is_cuda) 18 | self.assertIsNotNone(x.grad) 19 | self.assertTrue(x.grad.data.is_cuda) 20 | 21 | @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.") 22 | def test_on_device(self): 23 | if not torch.cuda.is_available(): 24 | return 25 | # Build variable on the GPU 26 | x = torch.rand(1, 10) 27 | # Build model over multiple devices 28 | multi_device_model = torch.nn.Sequential(OnDevice(torch.nn.Linear(10, 10), 'cuda'), 29 | OnDevice(torch.nn.Linear(10, 10), 'cpu')) 30 | y = multi_device_model(x) 31 | self.assertIsInstance(y.data, torch.FloatTensor) 32 | 33 | 34 | if __name__ == '__main__': 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /inferno/trainers/callbacks/logging/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ..base import Callback 3 | 4 | 5 | class Logger(Callback): 6 | """ 7 | A special callback for logging. 8 | 9 | Loggers are special because they're required to be serializable, whereas other 10 | callbacks have no such guarantees. In this regard, they jointly handled by 11 | trainers and the callback engine. 12 | """ 13 | def __init__(self, log_directory=None): 14 | super(Logger, self).__init__() 15 | self._log_directory = None 16 | if log_directory is not None: 17 | self.set_log_directory(log_directory) 18 | 19 | @property 20 | def log_directory(self): 21 | if self._log_directory is not None: 22 | return self._log_directory 23 | elif self.trainer is not None and self.trainer._log_directory is not None: 24 | return self.trainer._log_directory 25 | else: 26 | raise RuntimeError("No log directory found.") 27 | 28 | @log_directory.setter 29 | def log_directory(self, value): 30 | self.set_log_directory(value) 31 | 32 | def set_log_directory(self, log_directory): 33 | assert isinstance(log_directory, str) 34 | if not os.path.isdir(log_directory): 35 | assert not os.path.exists(log_directory) 36 | os.makedirs(log_directory) 37 | self._log_directory = log_directory 38 | return self 39 | -------------------------------------------------------------------------------- /tests/test_training/test_callbacks/test_scheduling.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from inferno.trainers.callbacks.scheduling import ManualLR 3 | from torch import nn 4 | from torch.optim import Adam 5 | 6 | 7 | class TestSchedulers(unittest.TestCase): 8 | 9 | def test_manual_lr(self): 10 | class DummyTrainer(object): 11 | def __init__(self): 12 | self.iteration_count = 0 13 | self.epoch_count = 0 14 | self.optimizer = Adam(nn.Linear(10, 10).parameters(), lr=1.) 15 | 16 | manual_lr = ManualLR([((100, 'iterations'), 0.5), 17 | ((200, 'iterations'), 0.5), 18 | ((200, 'iterations'), 0.1)]) 19 | trainer = DummyTrainer() 20 | manual_lr._trainer = trainer 21 | 22 | manual_lr.end_of_training_iteration() 23 | self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 1.) 24 | trainer.iteration_count = 100 25 | manual_lr.end_of_training_iteration() 26 | self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.5) 27 | trainer.iteration_count = 200 28 | manual_lr.end_of_training_iteration() 29 | self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.025) 30 | trainer.iteration_count = 300 31 | self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.025) 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.extensions.metrics.rst: -------------------------------------------------------------------------------- 1 | inferno.extensions.metrics package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.extensions.metrics.arand module 8 | --------------------------------------- 9 | 10 | .. automodule:: inferno.extensions.metrics.arand 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.extensions.metrics.base module 16 | -------------------------------------- 17 | 18 | .. automodule:: inferno.extensions.metrics.base 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | inferno.extensions.metrics.categorical module 24 | --------------------------------------------- 25 | 26 | .. automodule:: inferno.extensions.metrics.categorical 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | inferno.extensions.metrics.cremi\_score module 32 | ---------------------------------------------- 33 | 34 | .. automodule:: inferno.extensions.metrics.cremi_score 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | inferno.extensions.metrics.voi module 40 | ------------------------------------- 41 | 42 | .. automodule:: inferno.extensions.metrics.voi 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | 48 | Module contents 49 | --------------- 50 | 51 | .. automodule:: inferno.extensions.metrics 52 | :members: 53 | :undoc-members: 54 | :show-inheritance: 55 | -------------------------------------------------------------------------------- /inferno/extensions/criteria/elementwise_measures.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ...utils.exceptions import assert_ 3 | 4 | 5 | class WeightedMSELoss(nn.Module): 6 | NEGATIVE_CLASS_WEIGHT = 1. 7 | 8 | def __init__(self, positive_class_weight=1., positive_class_value=1., size_average=True): 9 | super(WeightedMSELoss, self).__init__() 10 | assert_(positive_class_weight >= 0, 11 | "Positive class weight can't be less than zero, got {}." 12 | .format(positive_class_weight), 13 | ValueError) 14 | self.mse = nn.MSELoss(size_average=size_average) 15 | self.positive_class_weight = positive_class_weight 16 | self.positive_class_value = positive_class_value 17 | 18 | def forward(self, input, target): 19 | # Get a mask 20 | positive_class_mask = target.data.eq(self.positive_class_value).type_as(target.data) 21 | # Get differential weights (positive_weight - negative_weight, 22 | # i.e. subtract 1, assuming the negative weight is gauged at 1) 23 | weight_differential = (positive_class_mask 24 | .mul_(self.positive_class_weight - self.NEGATIVE_CLASS_WEIGHT)) 25 | # Get final weight by adding weight differential to a tensor with negative weights 26 | weights = weight_differential.add_(self.NEGATIVE_CLASS_WEIGHT) 27 | # `weights` should be positive if NEGATIVE_CLASS_WEIGHT is not messed with. 28 | sqrt_weights = weights.sqrt_() 29 | return self.mse(input * sqrt_weights, target * sqrt_weights) 30 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.trainers.callbacks.rst: -------------------------------------------------------------------------------- 1 | inferno.trainers.callbacks package 2 | ================================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | inferno.trainers.callbacks.logging 10 | 11 | Submodules 12 | ---------- 13 | 14 | inferno.trainers.callbacks.base module 15 | -------------------------------------- 16 | 17 | .. automodule:: inferno.trainers.callbacks.base 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | inferno.trainers.callbacks.console module 23 | ----------------------------------------- 24 | 25 | .. automodule:: inferno.trainers.callbacks.console 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | inferno.trainers.callbacks.essentials module 31 | -------------------------------------------- 32 | 33 | .. automodule:: inferno.trainers.callbacks.essentials 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | 38 | inferno.trainers.callbacks.scheduling module 39 | -------------------------------------------- 40 | 41 | .. automodule:: inferno.trainers.callbacks.scheduling 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | inferno.trainers.callbacks.tqdm module 47 | -------------------------------------- 48 | 49 | .. automodule:: inferno.trainers.callbacks.tqdm 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | inferno.trainers.callbacks.tqdmstub module 55 | ------------------------------------------ 56 | 57 | .. automodule:: inferno.trainers.callbacks.tqdmstub 58 | :members: 59 | :undoc-members: 60 | :show-inheritance: 61 | 62 | 63 | Module contents 64 | --------------- 65 | 66 | .. automodule:: inferno.trainers.callbacks 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.utils.rst: -------------------------------------------------------------------------------- 1 | inferno.utils package 2 | ===================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.utils.exceptions module 8 | ------------------------------- 9 | 10 | .. automodule:: inferno.utils.exceptions 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.utils.io\_utils module 16 | ------------------------------ 17 | 18 | .. automodule:: inferno.utils.io_utils 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | inferno.utils.math\_utils module 24 | -------------------------------- 25 | 26 | .. automodule:: inferno.utils.math_utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | inferno.utils.model\_utils module 32 | --------------------------------- 33 | 34 | .. automodule:: inferno.utils.model_utils 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | inferno.utils.python\_utils module 40 | ---------------------------------- 41 | 42 | .. automodule:: inferno.utils.python_utils 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | inferno.utils.test\_utils module 48 | -------------------------------- 49 | 50 | .. automodule:: inferno.utils.test_utils 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | inferno.utils.torch\_utils module 56 | --------------------------------- 57 | 58 | .. automodule:: inferno.utils.torch_utils 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | inferno.utils.train\_utils module 64 | --------------------------------- 65 | 66 | .. automodule:: inferno.utils.train_utils 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | 72 | Module contents 73 | --------------- 74 | 75 | .. automodule:: inferno.utils 76 | :members: 77 | :undoc-members: 78 | :show-inheritance: 79 | -------------------------------------------------------------------------------- /docs/graphics/plain_tentative_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | image/svg+xmlINFERN 27 | -------------------------------------------------------------------------------- /tests/test_io/test_volumetric/test_volume_loader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from shutil import rmtree 4 | 5 | import numpy as np 6 | import h5py 7 | 8 | 9 | class TestVolumeLoader(unittest.TestCase): 10 | shape = (100, 100, 100) 11 | def setUp(self): 12 | self.data = np.random.rand(*self.shape) 13 | 14 | def test_loader(self): 15 | from inferno.io.volumetric import VolumeLoader 16 | loader = VolumeLoader(self.data, 17 | window_size=(10, 10, 10), 18 | stride=(10, 10, 10), return_index_spec=True) 19 | for batch, idx in loader: 20 | slice_ = loader.base_sequence[int(idx)] 21 | expected = self.data[slice_] 22 | self.assertEqual(batch.shape, expected.shape) 23 | self.assertTrue(np.allclose(batch, expected)) 24 | 25 | 26 | class TestHDF5VolumeLoader(unittest.TestCase): 27 | shape = (100, 100, 100) 28 | def setUp(self): 29 | try: 30 | os.mkdir('./tmp') 31 | except OSError: 32 | pass 33 | self.data = np.random.rand(*self.shape) 34 | with h5py.File('./tmp/data.h5') as f: 35 | f.create_dataset('data', data=self.data) 36 | 37 | def tearDown(self): 38 | try: 39 | rmtree('./tmp') 40 | except OSError: 41 | pass 42 | 43 | def test_hdf5_loader(self): 44 | from inferno.io.volumetric import HDF5VolumeLoader 45 | loader = HDF5VolumeLoader('./tmp/data.h5', 'data', 46 | window_size=(10, 10, 10), 47 | stride=(10, 10, 10), return_index_spec=True) 48 | for batch, idx in loader: 49 | slice_ = loader.base_sequence[int(idx)] 50 | expected = self.data[slice_] 51 | self.assertEqual(batch.shape, expected.shape) 52 | self.assertTrue(np.allclose(batch, expected)) 53 | 54 | 55 | 56 | if __name__ == '__main__': 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /inferno/extensions/optimizers/annealed_adam.py: -------------------------------------------------------------------------------- 1 | from .adam import Adam 2 | 3 | 4 | class AnnealedAdam(Adam): 5 | """Implements Adam algorithm with learning rate annealing and optional L1 penalty. 6 | 7 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 8 | 9 | Arguments: 10 | params (iterable): iterable of parameters to optimize or dicts defining 11 | parameter groups 12 | lr (float, optional): learning rate (default: 1e-3) 13 | betas (Tuple[float, float], optional): coefficients used for computing 14 | running averages of gradient and its square (default: (0.9, 0.999)) 15 | eps (float, optional): term added to the denominator to improve 16 | numerical stability (default: 1e-8) 17 | lambda_l1 (float, optional): L1 penalty (default: 0) 18 | weight_decay (float, optional): L2 penalty (weight decay) (default: 0) 19 | lr_decay(float, optional): decay learning rate by this factor after every step 20 | (default: 1.) 21 | 22 | .. _Adam\: A Method for Stochastic Optimization: 23 | https://arxiv.org/abs/1412.6980 24 | """ 25 | 26 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 27 | lambda_l1=0, weight_decay=0, lr_decay=1.): 28 | defaults = dict(lr=lr, betas=betas, eps=eps, 29 | lambda_l1=lambda_l1, weight_decay=weight_decay, 30 | lr_decay=lr_decay) 31 | super(AnnealedAdam, self).__init__(params, **defaults) 32 | 33 | def step(self, closure=None): 34 | """Performs a single optimization step. 35 | 36 | Arguments: 37 | closure (callable, optional): A closure that reevaluates the model 38 | and returns the loss. 39 | """ 40 | # Do an optimization step 41 | super(AnnealedAdam, self).step(closure=closure) 42 | # Update learning rate 43 | for group in self.param_groups: 44 | group['lr'] *= group['lr_decay'] 45 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * `Nasim Rahaman `_ @ `Image Analysis and Learning Lab `_ , `Heidelberg Collaboratory for Image Processing `_ , 9 | 10 | 11 | Contributors 12 | ------------ 13 | 14 | In no particular order, 15 | * `Steffen Wolf `_ @ 16 | `Image Analysis and Learning Lab `_ , 17 | `Heidelberg Collaboratory for Image Processing `_ , 18 | * `Maurice Weiler `_ @ 19 | `Amsterdam Machine Learning Lab `_ , 20 | `University of Amsterdam `_ , 21 | * `Constantin Pape `_ @ 22 | `Image Analysis and Learning Lab `_ , 23 | `Heidelberg Collaboratory for Image Processing `_ , 24 | * `Sven Peter `_ @ 25 | `Image Analysis and Learning Lab `_ , 26 | `Heidelberg Collaboratory for Image Processing `_ , 27 | * `Manuel Haussmann `_ @ 28 | `Image Analysis and Learning Lab `_ , 29 | `Heidelberg Collaboratory for Image Processing `_ , 30 | * `Thorsten Beier `_ @ 31 | `Image Analysis and Learning Lab `_ , 32 | `Heidelberg Collaboratory for Image Processing `_ , 33 | * `Benjamin Striner `_ @ 34 | `Machine Learning Department `_ , 35 | `Carnegie Mellon University `_ , 36 | -------------------------------------------------------------------------------- /inferno/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import TensorDataset 3 | from torch.utils.data.dataloader import DataLoader 4 | import numpy as np 5 | 6 | 7 | def generate_random_data(num_samples, shape, num_classes, hardness=0.3, dtype=None): 8 | """Generate a random dataset with a given hardness and number of classes.""" 9 | dataset_input = np.zeros((num_samples,) + shape, dtype=dtype) 10 | dataset_target = np.random.randint(num_classes, size=num_samples) 11 | for sample_num in range(num_samples): 12 | dataset_input[sample_num] = np.random.normal(loc=dataset_target[sample_num], 13 | scale=(1 - hardness), 14 | size=shape) 15 | return dataset_input, dataset_target 16 | 17 | 18 | def generate_random_dataset(num_samples, shape, num_classes, hardness=0.3, dtype=None): 19 | """Generate a random dataset with a given hardness and number of classes.""" 20 | # Generate numpy arrays 21 | dataset_input, dataset_target = generate_random_data(num_samples, shape, num_classes, 22 | hardness=hardness, dtype=dtype) 23 | # Convert to tensor and build dataset 24 | dataset = TensorDataset(torch.from_numpy(dataset_input), 25 | torch.from_numpy(dataset_target)) 26 | return dataset 27 | 28 | 29 | def generate_random_dataloader(num_samples, shape, num_classes, hardness=0.3, dtype=None, 30 | batch_size=1, shuffle=False, num_workers=0, pin_memory=False): 31 | """Generate a loader with a random dataset of given hardness and number of classes.""" 32 | dataset = generate_random_dataset(num_samples, shape, num_classes, hardness=hardness, 33 | dtype=dtype) 34 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 35 | num_workers=num_workers, pin_memory=pin_memory) 36 | return dataloader 37 | -------------------------------------------------------------------------------- /inferno/trainers/callbacks/gradients.py: -------------------------------------------------------------------------------- 1 | from ...utils.train_utils import Frequency 2 | from ...utils.exceptions import assert_, FrequencyValueError 3 | from .base import Callback 4 | 5 | 6 | class LogOutputGradients(Callback): 7 | """Logs the gradient of the network output""" 8 | 9 | def __init__(self, frequency): 10 | super(LogOutputGradients, self).__init__() 11 | self.log_every = frequency 12 | self.registered = False 13 | self.hook_handle = None 14 | 15 | @property 16 | def log_every(self): 17 | return self._log_every 18 | 19 | @log_every.setter 20 | def log_every(self, value): 21 | self._log_every = Frequency(value, 'iterations') 22 | assert_(self.log_every.is_consistent, 23 | "Log frequency is not consistent.", 24 | FrequencyValueError) 25 | 26 | def hook(self, module, grad_input, grad_output): 27 | 28 | #remove hook if trainer does not exits 29 | if self.trainer is None: 30 | self.hook_handle.remove() 31 | return 32 | 33 | if self.log_every.match(iteration_count=self.trainer.iteration_count, 34 | epoch_count=self.trainer.epoch_count, 35 | persistent=True, match_zero=True): 36 | self.trainer.update_state('output_gradient', grad_output[0].detach().float().clone().cpu()) 37 | 38 | def add_hook(self): 39 | self.hook_handle = self.trainer.model.register_backward_hook(self.hook) 40 | 41 | def begin_of_fit(self, **kwargs): 42 | self._trainer.logger.observe_state("output_gradient", 43 | observe_while='training') 44 | self.add_hook() 45 | 46 | def begin_of_save(self, **_): 47 | # remove hook from model, because you can't pickle it. 48 | if self.hook_handle is not None: 49 | self.hook_handle.remove() 50 | self.hook_handle = None 51 | 52 | def end_of_save(self, **_): 53 | # add hook after model save 54 | self.add_hook() 55 | 56 | -------------------------------------------------------------------------------- /tests/test_utils/test_train_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import inferno.utils.train_utils as tu 3 | import numpy as np 4 | 5 | 6 | class FrequencyTest(unittest.TestCase): 7 | def test_from_string(self): 8 | frequency = tu.Frequency.from_string('10 epochs') 9 | self.assertFalse(frequency.match(epoch_count=9)) 10 | self.assertTrue(frequency.match(epoch_count=10)) 11 | frequency = tu.Frequency.from_string('1 iteration') 12 | self.assertEqual(frequency.units, 'iterations') 13 | self.assertTrue(frequency.match(iteration_count=10)) 14 | frequency = tu.Frequency.from_string('never') 15 | self.assertFalse(frequency.match(epoch_count=9)) 16 | frequency = tu.Frequency.from_string('inf epochs') 17 | self.assertFalse(frequency.match(epoch_count=9)) 18 | 19 | def test_from_tuple(self): 20 | frequency = tu.Frequency.build_from((np.inf, 'epoch')) 21 | self.assertFalse(frequency.match(epoch_count=9)) 22 | self.assertFalse(frequency.match(epoch_count=10)) 23 | 24 | def test_is_consistent(self): 25 | frequency = tu.Frequency.build_from('10 epochs') 26 | frequency._units = 'banana' 27 | self.assertFalse(frequency.is_consistent) 28 | 29 | def test_init(self): 30 | frequency = tu.Frequency() 31 | self.assertEqual(frequency.value, np.inf) 32 | self.assertEqual(frequency.units, frequency.UNIT_PRIORITY) 33 | 34 | def test_duration(self): 35 | duration = tu.Duration.build_from((3, 'iterations')) 36 | self.assertFalse(duration.match(iteration_count=2)) 37 | self.assertFalse(duration.match(iteration_count=3)) 38 | self.assertTrue(duration.match(iteration_count=3, when_equal_return=True)) 39 | self.assertTrue(duration.match(iteration_count=4)) 40 | self.assertEqual(duration.compare(iteration_count=1, epoch_count=3).get('iterations'), 41 | 2) 42 | with self.assertRaises(ValueError): 43 | duration.match(epoch_count=2) 44 | 45 | 46 | if __name__ == '__main__': 47 | unittest.main() -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | dist: xenial 4 | 5 | python: 6 | - 3.7 7 | 8 | env: 9 | - PYTORCH_CONDA="pytorch" TORCHVISION_CONDA="torchvision" TORCHVISION_CHANNEL=pytorch 10 | 11 | install: 12 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 13 | - bash miniconda.sh -b -p $HOME/miniconda 14 | - export PATH="$HOME/miniconda/bin:$PATH" 15 | - conda config --set always_yes yes --set changeps1 no 16 | - conda update -q conda 17 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION 18 | - source activate test-environment 19 | - conda install -c conda-forge networkx h5py scikit-image pyyaml dill tensorboardx 20 | - conda install -c pytorch $PYTORCH_CONDA 21 | - conda install -c $TORCHVISION_CHANNEL $TORCHVISION_CONDA 22 | 23 | deploy: 24 | provider: pypi 25 | user: nasimrahaman 26 | password: 27 | secure: !!binary | 28 | bWwzZitLcEpibHBaUWNhUVA4UUlGa2JsZDQxVkx3eFlkY1FiYWJqYkFvWm5pdDErRzlKRXZFM0hR 29 | ZE15V0tIWm5JQlJRSGlveXdYNjAzQVc1UFV3ZjNBOG0zc21vK3RaZjVSYnM5aE5ySE93ajBXc1N4 30 | akNHNGhOSnF6UnBDY2kwakxPeWhxaEwxQkR0empSaFdJbWVlOE81RDVPY2pSdGw1TDQ3QjhwVGor 31 | TVREdlpSYTVFd2xNNXdadTJYWFVXL3ZQY0VLZE9xckFoVk5PSHpkTTh5MGM1S1lHaS9nNThVK2JO 32 | OVp5RkFROVpuOEY3YmxPdzBQZnAvL202ZUkxamlKSmxhaE13UU4zV2tJRWRpNklVSTE0RUp1ck5s 33 | Q28xL2kzNER0dGVkZzI0eVhULzcxRFl5Y0pZQWMrcWtoa1VVVUo4NEZKV3JjUjNqTnF5bVI3Ykty 34 | cFJrR3JydjV0dUpGUnBhc2NIdEdKVUswMkdJWEJUc3JJWGg4bS9oRGtMaVJaMExBeitJQWR4b2tF 35 | MzB0OWppZ0x5VXFSMmxnVmNvZERzRWZMRnJEMTBHeTJVS2FueVhlYmpsck9qK3V5S1dtZm5UTXg4 36 | bGNzN09HWEZiUmo2K0ZuYTg5a00xN3poSXhzc3pSMnRGSVJwamV4a0gzZUpyZlpYY1daTFZ3QnV0 37 | clUwZW10VEsxeGFmOGFjNTd3Wll1R3JXNEZJT1h2bmxoeS9pV0FMVlE4YnVFZFFjQnJ5YWFiRjUy 38 | RkZvZk1SUnp3aDFhZ3Q3cUxVa0FIbXVuZ1NYQWZxMUlOTkVNYXRTcFVJUURJM3huWmNPeTNhSWFP 39 | YkVpSlFHY1lrWlhXZ1Z2cVdvcktPOW53a29Hem5BSm1HRVZHYU11dDYwaGg2SGU1MVJPTll3WHc9 40 | on: 41 | all_branches: false 42 | tags: true 43 | 44 | script: 45 | - source activate test-environment 46 | - python setup.py install 47 | - python -m unittest discover -s tests -v 48 | -------------------------------------------------------------------------------- /tests/test_extensions/test_criteria/test_set_similarity_measures.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | 5 | class SetSimilarityTest(unittest.TestCase): 6 | def get_dummy_variables(self): 7 | x = torch.zeros(3, 2, 100, 100).uniform_() 8 | y = torch.zeros(3, 2, 100, 100).uniform_() 9 | return x, y 10 | 11 | def get_dummy_variables_with_channels_and_classes(self): 12 | # (batch_size, channels, classes, ...) 13 | x = torch.zeros(3, 2, 5, 100, 100).uniform_() 14 | y = torch.zeros(3, 2, 5, 100, 100).uniform_() 15 | return x, y 16 | 17 | 18 | class TestSorensenDice(SetSimilarityTest): 19 | # noinspection PyCallingNonCallable 20 | def test_channelwise(self): 21 | from inferno.extensions.criteria.set_similarity_measures import SorensenDiceLoss 22 | x, y = self.get_dummy_variables() 23 | channelwise = SorensenDiceLoss(channelwise=True) 24 | not_channelwise = SorensenDiceLoss(channelwise=False) 25 | # Compute expected channelwise loss 26 | expected_channelwise_loss = \ 27 | not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \ 28 | not_channelwise(x[:, 1, ...], y[:, 1, ...]) 29 | # Compute channelwise 30 | channelwise_loss = channelwise(x, y) 31 | # Compare 32 | self.assertAlmostEqual(expected_channelwise_loss.item(), channelwise_loss.item()) 33 | 34 | 35 | class TestGeneralizedSorensenDice(SetSimilarityTest): 36 | def test_channelwise(self): 37 | from inferno.extensions.criteria.set_similarity_measures import GeneralizedDiceLoss 38 | x, y = self.get_dummy_variables_with_channels_and_classes() 39 | channelwise = GeneralizedDiceLoss(channelwise=True) 40 | not_channelwise = GeneralizedDiceLoss(channelwise=False) 41 | # Compute channelwise loss and expected one: 42 | channelwise_loss = channelwise(x, y) 43 | expected_channelwise_loss = \ 44 | not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \ 45 | not_channelwise(x[:, 1, ...], y[:, 1, ...]) 46 | # Compare 47 | self.assertAlmostEqual(expected_channelwise_loss.item(), channelwise_loss.item()) 48 | 49 | 50 | if __name__ == '__main__': 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /tests/test_extensions/test_metrics/categorical.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from inferno.extensions.metrics import IOU 4 | 5 | 6 | class TestCategorical(unittest.TestCase): 7 | def test_iou_basic(self): 8 | # from one hot 9 | predicted_image = torch.zeros(*(2, 10, 10)) 10 | predicted_image[:, 0:4, 0:4] = 1 11 | target_image = torch.zeros(*(2, 10, 10)) 12 | target_image[:, 0:3, 0:3] = 1 13 | expected_iou = (3 * 3)/(4 * 4) 14 | iou = IOU()(predicted_image[None, ...], target_image[None, ...]) 15 | self.assertAlmostEqual(iou, expected_iou, places=4) 16 | 17 | def test_iou_with_ignore_class(self): 18 | predicted_image = torch.zeros(*(2, 10, 10)) 19 | predicted_image[0, 0:4, 0:4] = 1 20 | target_image = torch.zeros(*(2, 10, 10)) 21 | target_image[:, 0:3, 0:3] = 1 22 | expected_iou = (3 * 3) / (4 * 4) 23 | iou = IOU(ignore_class=1)(predicted_image[None, ...], target_image[None, ...]) 24 | self.assertAlmostEqual(iou, expected_iou, places=4) 25 | 26 | def test_multiclass_iou(self): 27 | predicted_image = torch.zeros(*(2, 10, 10)) 28 | predicted_image[0, 0:4, 0:4] = 1 29 | target_image = torch.zeros(*(2, 10, 10)) 30 | target_image[:, 0:3, 0:3] = 1 31 | iou_class_0 = (3 * 3) / (4 * 4) 32 | iou_class_1 = 0 33 | expected_mean_iou = 0.5 * (iou_class_0 + iou_class_1) 34 | iou = IOU()(predicted_image[None, ...], target_image[None, ...]) 35 | self.assertAlmostEqual(iou, expected_mean_iou, places=4) 36 | 37 | def test_multiclass_iou_with_ignore_class(self): 38 | predicted_image = torch.zeros(*(3, 10, 10)) 39 | predicted_image[0, 0:4, 0:4] = 1 40 | # Have the third plane be crap 41 | predicted_image[2, :, :] = 1 42 | target_image = torch.zeros(*(3, 10, 10)) 43 | target_image[:, 0:3, 0:3] = 1 44 | iou_class_0 = (3 * 3) / (4 * 4) 45 | iou_class_1 = 0 46 | expected_mean_iou = 0.5 * (iou_class_0 + iou_class_1) 47 | iou = IOU(ignore_class=-1)(predicted_image[None, ...], target_image[None, ...]) 48 | self.assertAlmostEqual(iou, expected_mean_iou, places=4) 49 | 50 | if __name__ == '__main__': 51 | unittest.main() -------------------------------------------------------------------------------- /tests/test_io/test_core/test_zip.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class ZipTest(unittest.TestCase): 5 | def test_zip_minimal(self): 6 | """Minimal test with python lists as iterators.""" 7 | from inferno.io.core import Zip 8 | from torch.utils.data.dataset import Dataset 9 | 10 | with self.assertRaises(TypeError): 11 | zipped = Zip([1, 2, 3], [4, 5, 6, 7]) 12 | 13 | # This is required because Zip checks if its inputs are actually torch datasets 14 | class ListDataset(list, Dataset): 15 | pass 16 | 17 | dataset_1 = ListDataset([1, 2, 3, 4]) 18 | dataset_2 = ListDataset([5, 6, 7, 8, 9]) 19 | zipped = Zip(dataset_1, dataset_2) 20 | self.assertEqual(len(zipped), 4) 21 | 22 | fetched = zipped[1] 23 | self.assertEqual(fetched, [2, 6]) 24 | 25 | with self.assertRaises(IndexError): 26 | fetched = zipped[4] 27 | 28 | def test_zip_sync(self): 29 | """Test synchronization mechanics.""" 30 | # TODO 31 | 32 | def test_zip_reject(self): 33 | from inferno.io.core import ZipReject 34 | from torch.utils.data.dataset import Dataset 35 | 36 | # This is required because Zip checks if its inputs are actually torch datasets 37 | class ListDataset(list, Dataset): 38 | pass 39 | 40 | def rejection_criterion(sample_1, sample_2): 41 | return sample_1 < sample_2 42 | 43 | dataset_1 = ListDataset([1, 2, 3, 4]) 44 | dataset_2 = ListDataset([2, 1, 3, 4]) 45 | dataset_3 = ListDataset([0, 1, 2, 3]) 46 | 47 | zipped = ZipReject(dataset_1, dataset_2, dataset_3, 48 | rejection_criterion=rejection_criterion, 49 | random_jump_after_reject=False, 50 | rejection_dataset_indices=[0, 1]) 51 | fetched = zipped[0] 52 | self.assertSequenceEqual(fetched, [2, 1, 1]) 53 | 54 | zipped = ZipReject(dataset_1, dataset_2, dataset_3, 55 | rejection_criterion=rejection_criterion, 56 | rejection_dataset_indices=[1, 0]) 57 | fetched = zipped[0] 58 | self.assertSequenceEqual(fetched, [1, 2, 0]) 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tests/test_extensions/test_models/test_unet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch.cuda as cuda 3 | from inferno.utils.model_utils import ModelTester, MultiscaleModelTester 4 | from inferno.extensions.models import UNet 5 | 6 | class _MultiscaleUNet(UNet): 7 | def conv_op_factory(self, in_channels, out_channels, part, index): 8 | return super(_MultiscaleUNet, self).conv_op_factory(in_channels, out_channels, part, index)[0], True 9 | 10 | def forward(self, input): 11 | x = self._initial_conv(input) 12 | x = list(super(UNet, self).forward(x)) 13 | x[-1] = self._output(x[-1]) 14 | return tuple(x) 15 | 16 | 17 | class UNetTest(unittest.TestCase): 18 | def test_unet_2d(self): 19 | tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256)) 20 | if cuda.is_available(): 21 | tester.cuda() 22 | tester(UNet(1, 1, dim=2, initial_features=32)) 23 | 24 | def test_unet_3d(self): 25 | tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64)) 26 | if cuda.is_available(): 27 | tester.cuda() 28 | # test default unet 3d 29 | tester(UNet(1, 1, dim=3, initial_features=8)) 30 | 31 | def test_monochannel_unet_3d(self): 32 | nc = 2 33 | class _UNetMonochannel(_MultiscaleUNet): 34 | def _get_num_channels(self, depth): 35 | return nc 36 | 37 | shapes = [(1, nc, 16, 64, 64), (1, nc, 8, 32, 32), (1, nc, 4, 16, 16), (1, nc, 2, 8, 8), (1, nc, 1, 4, 4), 38 | (1, nc, 2, 8, 8), (1, nc, 4, 16, 16), (1, nc, 8, 32, 32), (1, 1, 16, 64, 64)] 39 | tester = MultiscaleModelTester((1, 1, 16, 64, 64), shapes) 40 | if cuda.is_available(): 41 | tester.cuda() 42 | tester(_UNetMonochannel(1, 1, dim=3, initial_features=8)) 43 | 44 | def test_inverse_pyramid_unet_2d(self): 45 | class _UNetInversePyramid(_MultiscaleUNet): 46 | def _get_num_channels(self, depth): 47 | return [13, 12, 11][depth - 1] 48 | 49 | shapes = [(1, 13, 16, 64), (1, 12, 8, 32), (1, 11, 4, 16), (1, 11, 2, 8), 50 | (1, 12, 4, 16), (1, 13, 8, 32), (1, 1, 16, 64)] 51 | tester = MultiscaleModelTester((1, 1, 16, 64), shapes) 52 | if cuda.is_available(): 53 | tester.cuda() 54 | tester(_UNetInversePyramid(1, 1, dim=2, depth=3, initial_features=8)) 55 | 56 | 57 | if __name__ == '__main__': 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """The setup script.""" 5 | 6 | from setuptools import setup, find_packages 7 | import runpy 8 | __version__ = runpy.run_path('inferno/version.py')['__version__'] 9 | 10 | 11 | with open('README.rst') as readme_file: 12 | readme = readme_file.read() 13 | 14 | with open('HISTORY.rst') as history_file: 15 | history = history_file.read() 16 | 17 | requirements = [ 18 | # TODO: put package requirements here 19 | "pip>=8.1.2", 20 | "torch>=0.1.12", 21 | "dill", 22 | "pyyaml", 23 | "scipy>=0.13.0", 24 | "h5py", 25 | "numpy>=1.8", 26 | "scikit-image", 27 | "torchvision", 28 | "tqdm" 29 | ] 30 | 31 | 32 | setup_requirements = [ 33 | 'pytest-runner' 34 | ] 35 | 36 | test_requirements = [ 37 | 'pytest', 'unittest' 38 | ] 39 | 40 | dependency_links = [ 41 | 'http://download.pytorch.org/whl/cu75/torch-0.2.0.post1-cp35-cp35m-manylinux1_x86_64.whl#egg=torch-0.2.0' 42 | ] 43 | 44 | setup( 45 | name='inferno-pytorch', 46 | version=__version__, 47 | description="Inferno is a little library providing utilities and convenience functions/classes around PyTorch.", 48 | long_description=readme + '\n\n' + history, 49 | author="Nasim Rahaman", 50 | author_email='nasim.rahaman@iwr.uni-heidelberg.de', 51 | url='https://github.com/inferno-pytorch/inferno', 52 | packages=find_packages(where='.', exclude=["*.tests", "*.tests.*", 53 | "tests.*", "tests", 54 | "__pycache__", "*.pyc"]), 55 | dependency_links=dependency_links, 56 | include_package_data=True, 57 | install_requires=requirements, 58 | license="Apache Software License 2.0", 59 | zip_safe=False, 60 | keywords='inferno pytorch torch deep learning cnn deep-pyromania', 61 | classifiers=[ 62 | # How mature is this project? Common values are\ 63 | # 2 - Pre-Alpha', 64 | # 3 - Alpha, 65 | # 4 - Beta, 66 | # 5 - Production/Stable 67 | 'Development Status :: 2 - Pre-Alpha', 68 | # Indicate who your project is intended for 69 | 'Intended Audience :: Science/Research', 70 | 'License :: OSI Approved :: Apache Software License', 71 | 'Natural Language :: English', 72 | 'Programming Language :: Python :: 3.5', 73 | 'Programming Language :: Python :: 3.6' 74 | ], 75 | test_suite='test', 76 | tests_require=test_requirements, 77 | setup_requires=setup_requirements, 78 | ) 79 | -------------------------------------------------------------------------------- /docs/inferno-apidoc/inferno.extensions.layers.rst: -------------------------------------------------------------------------------- 1 | inferno.extensions.layers package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | inferno.extensions.layers.activations module 8 | -------------------------------------------- 9 | 10 | .. automodule:: inferno.extensions.layers.activations 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | inferno.extensions.layers.building\_blocks module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: inferno.extensions.layers.building_blocks 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | inferno.extensions.layers.convolutional module 24 | ---------------------------------------------- 25 | 26 | .. automodule:: inferno.extensions.layers.convolutional 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | inferno.extensions.layers.device module 32 | --------------------------------------- 33 | 34 | .. automodule:: inferno.extensions.layers.device 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | inferno.extensions.layers.identity module 40 | ----------------------------------------- 41 | 42 | .. automodule:: inferno.extensions.layers.identity 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | inferno.extensions.layers.prefab module 48 | --------------------------------------- 49 | 50 | .. automodule:: inferno.extensions.layers.prefab 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | inferno.extensions.layers.res\_unet module 56 | ------------------------------------------ 57 | 58 | .. automodule:: inferno.extensions.layers.res_unet 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | inferno.extensions.layers.reshape module 64 | ---------------------------------------- 65 | 66 | .. automodule:: inferno.extensions.layers.reshape 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | inferno.extensions.layers.sampling module 72 | ----------------------------------------- 73 | 74 | .. automodule:: inferno.extensions.layers.sampling 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | inferno.extensions.layers.unet\_base module 80 | ------------------------------------------- 81 | 82 | .. automodule:: inferno.extensions.layers.unet_base 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | 88 | Module contents 89 | --------------- 90 | 91 | .. automodule:: inferno.extensions.layers 92 | :members: 93 | :undoc-members: 94 | :show-inheritance: 95 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | define BROWSER_PYSCRIPT 4 | import os, webbrowser, sys 5 | try: 6 | from urllib import pathname2url 7 | except: 8 | from urllib.request import pathname2url 9 | 10 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 11 | endef 12 | export BROWSER_PYSCRIPT 13 | 14 | define PRINT_HELP_PYSCRIPT 15 | import re, sys 16 | 17 | for line in sys.stdin: 18 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 19 | if match: 20 | target, help = match.groups() 21 | print("%-20s %s" % (target, help)) 22 | endef 23 | export PRINT_HELP_PYSCRIPT 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | 32 | clean-build: ## remove build artifacts 33 | rm -fr build/ 34 | rm -fr dist/ 35 | rm -fr .eggs/ 36 | find . -name '*.egg-info' -exec rm -fr {} + 37 | find . -name '*.egg' -exec rm -f {} + 38 | 39 | clean-pyc: ## remove Python file artifacts 40 | find . -name '*.pyc' -exec rm -f {} + 41 | find . -name '*.pyo' -exec rm -f {} + 42 | find . -name '*~' -exec rm -f {} + 43 | find . -name '__pycache__' -exec rm -fr {} + 44 | 45 | clean-test: ## remove test and coverage artifacts 46 | rm -fr .tox/ 47 | rm -f .coverage 48 | rm -fr htmlcov/ 49 | 50 | lint: ## check style with flake8 51 | flake8 inferno tests 52 | 53 | test: ## run tests quickly with the default Python 54 | 55 | python setup.py test 56 | 57 | test-all: ## run tests on every Python version with tox 58 | tox 59 | 60 | coverage: ## check code coverage quickly with the default Python 61 | coverage run --source inferno setup.py test 62 | coverage report -m 63 | coverage html 64 | $(BROWSER) htmlcov/index.html 65 | 66 | docs: ## generate Sphinx HTML documentation, including API docs 67 | rm -f docs/inferno.rst 68 | rm -f docs/modules.rst 69 | sphinx-apidoc -o docs/ inferno 70 | $(MAKE) -C docs clean 71 | $(MAKE) -C docs html 72 | $(BROWSER) docs/_build/html/index.html 73 | 74 | servedocs: docs ## compile the docs watching for changes 75 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 76 | 77 | release: clean ## package and upload a release 78 | python setup.py sdist upload 79 | python setup.py bdist_wheel upload 80 | 81 | dist: clean ## builds source and wheel package 82 | python setup.py sdist 83 | python setup.py bdist_wheel 84 | ls -l dist 85 | 86 | install: clean ## install the package to the active Python's site-packages 87 | python setup.py install 88 | -------------------------------------------------------------------------------- /examples/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainer Example 3 | ================================ 4 | 5 | This example should illustrate how to use the trainer class. 6 | 7 | """ 8 | 9 | import torch.nn as nn 10 | from inferno.io.box.cifar import get_cifar10_loaders 11 | from inferno.trainers.basic import Trainer 12 | from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger 13 | from inferno.extensions.layers import ConvELU2D 14 | from inferno.extensions.layers import Flatten 15 | from inferno.utils.python_utils import ensure_dir 16 | 17 | from inferno.extensions.layers import SELU 18 | 19 | ################################################## 20 | # change directories to your needs 21 | LOG_DIRECTORY = ensure_dir('log') 22 | SAVE_DIRECTORY = ensure_dir('save') 23 | DATASET_DIRECTORY = ensure_dir('dataset') 24 | 25 | ################################################## 26 | # shall models be downloaded 27 | DOWNLOAD_CIFAR = True 28 | USE_CUDA = True 29 | 30 | ################################################## 31 | # Build torch model 32 | model = nn.Sequential( 33 | ConvELU2D(in_channels=3, out_channels=256, kernel_size=3), 34 | nn.MaxPool2d(kernel_size=2, stride=2), 35 | ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), 36 | nn.MaxPool2d(kernel_size=2, stride=2), 37 | ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), 38 | nn.MaxPool2d(kernel_size=2, stride=2), 39 | Flatten(), 40 | nn.Linear(in_features=(256 * 4 * 4), out_features=10), 41 | nn.Softmax() 42 | ) 43 | 44 | ################################################## 45 | # data loaders 46 | train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY, 47 | download=DOWNLOAD_CIFAR) 48 | 49 | ################################################## 50 | # Build trainer 51 | trainer = Trainer(model) 52 | trainer.build_criterion('CrossEntropyLoss') 53 | trainer.build_metric('CategoricalError') 54 | trainer.build_optimizer('Adam') 55 | trainer.validate_every((2, 'epochs')) 56 | trainer.save_every((5, 'epochs')) 57 | trainer.save_to_directory(SAVE_DIRECTORY) 58 | trainer.set_max_num_epochs(10) 59 | trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), 60 | log_images_every='never'), 61 | log_directory=LOG_DIRECTORY) 62 | 63 | ################################################## 64 | # Bind loaders 65 | trainer.bind_loader('train', train_loader) 66 | trainer.bind_loader('validate', validate_loader) 67 | 68 | ################################################## 69 | # activate cuda 70 | if USE_CUDA: 71 | trainer.cuda() 72 | 73 | ################################################## 74 | # fit 75 | trainer.fit() 76 | -------------------------------------------------------------------------------- /tests/test_training/test_callbacks/test_base.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from inferno.trainers.callbacks.base import Callback, CallbackEngine 4 | from inferno.trainers.basic import Trainer 5 | from os.path import join, dirname, exists 6 | from os import makedirs 7 | from shutil import rmtree 8 | 9 | 10 | class DummyCallback(Callback): 11 | def end_of_training_iteration(self, **_): 12 | assert self.trainer is not None 13 | 14 | 15 | class WrongDummyCallback(Callback): 16 | def end_of_iteration(self): 17 | pass 18 | 19 | 20 | class CallbackMechTest(unittest.TestCase): 21 | ROOT_DIR = join(dirname(__file__), 'root') 22 | 23 | def setUp(self): 24 | makedirs(self.ROOT_DIR, exist_ok=True) 25 | 26 | def tearDown(self): 27 | if exists(self.ROOT_DIR): 28 | rmtree(self.ROOT_DIR) 29 | 30 | def test_serialization(self): 31 | # Build engine and trainer 32 | callback_engine = CallbackEngine().bind_trainer(Trainer()) 33 | callback_engine.register_callback(DummyCallback()) 34 | # Serialize 35 | torch.save(callback_engine, join(self.ROOT_DIR, 'callback_engine.pkl')) 36 | # Unserialize 37 | callback_engine = torch.load(join(self.ROOT_DIR, 'callback_engine.pkl')) 38 | # Make sure the trainer is detached 39 | self.assertIsNone(callback_engine._trainer) 40 | self.assertIsInstance(next(iter(callback_engine 41 | ._callback_registry 42 | .get('end_of_training_iteration'))), 43 | DummyCallback) 44 | 45 | def test_auto_registry(self): 46 | callback_engine = CallbackEngine().bind_trainer(Trainer()) 47 | callback_engine.register_callback(DummyCallback()) 48 | self.assertIsInstance(next(iter(callback_engine 49 | ._callback_registry 50 | .get('end_of_training_iteration'))), 51 | DummyCallback) 52 | with self.assertRaises(AssertionError): 53 | callback_engine.register_callback(WrongDummyCallback()) 54 | 55 | def test_instance_registry(self): 56 | class Foo(Callback): 57 | pass 58 | 59 | class Bar(Callback): 60 | pass 61 | 62 | foo = Foo() 63 | bar = Bar() 64 | self.assertIs(foo.get_instances(), foo) 65 | self.assertIs(bar.get_instances(), bar) 66 | foo2 = Foo() 67 | self.assertSequenceEqual(foo2.get_instances(), [foo, foo2]) 68 | self.assertIs(bar.get_instances(), bar) 69 | 70 | if __name__ == '__main__': 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /tests/test_extensions/test_layers/test_reshape.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | 5 | class TestReshape(unittest.TestCase): 6 | def _get_input_variable(self, *shape): 7 | return torch.rand(*shape) 8 | 9 | def test_as_matrix(self): 10 | from inferno.extensions.layers.reshape import AsMatrix 11 | 12 | input = self._get_input_variable(10, 20, 1, 1) 13 | as_matrix = AsMatrix() 14 | output = as_matrix(input) 15 | self.assertEqual(list(output.size()), [10, 20]) 16 | 17 | def test_flatten(self): 18 | from inferno.extensions.layers.reshape import Flatten 19 | 20 | input = self._get_input_variable(10, 20, 2, 2) 21 | flatten = Flatten() 22 | output = flatten(input) 23 | self.assertEqual(list(output.size()), [10, 80]) 24 | 25 | def test_as_2d(self): 26 | from inferno.extensions.layers.reshape import As2D 27 | 28 | as_2d = As2D() 29 | 30 | output_shape = as_2d(self._get_input_variable(10, 20, 3, 30, 30)).size() 31 | self.assertEqual(list(output_shape), [10, 60, 30, 30]) 32 | 33 | output_shape = as_2d(self._get_input_variable(10, 20, 30, 30)).size() 34 | self.assertEqual(list(output_shape), [10, 20, 30, 30]) 35 | 36 | output_shape = as_2d(self._get_input_variable(10, 20)).size() 37 | self.assertEqual(list(output_shape), [10, 20, 1, 1]) 38 | 39 | def test_as_3d(self): 40 | from inferno.extensions.layers.reshape import As3D 41 | from inferno.utils.exceptions import ShapeError 42 | 43 | as_3d = As3D() 44 | 45 | output_shape = as_3d(self._get_input_variable(10, 20, 3, 30, 30)).size() 46 | self.assertEqual(list(output_shape), [10, 20, 3, 30, 30]) 47 | 48 | output_shape = as_3d(self._get_input_variable(10, 20, 30, 30)).size() 49 | self.assertEqual(list(output_shape), [10, 20, 1, 30, 30]) 50 | 51 | output_shape = as_3d(self._get_input_variable(10, 20)).size() 52 | self.assertEqual(list(output_shape), [10, 20, 1, 1, 1]) 53 | 54 | as_3d.channel_as_z = True 55 | output_shape = as_3d(self._get_input_variable(10, 20, 30, 30)).size() 56 | self.assertEqual(list(output_shape), [10, 1, 20, 30, 30]) 57 | 58 | as_3d.num_channels_or_num_z_slices = 2 59 | output_shape = as_3d(self._get_input_variable(10, 40, 30, 30)).size() 60 | self.assertEqual(list(output_shape), [10, 2, 20, 30, 30]) 61 | 62 | with self.assertRaises(ShapeError): 63 | output_shape = as_3d(self._get_input_variable(10, 41, 30, 30)).size() 64 | self.assertEqual(list(output_shape), [10, 2, 20, 30, 30]) 65 | 66 | 67 | if __name__ == '__main__': 68 | unittest.main() 69 | -------------------------------------------------------------------------------- /tests/test_extensions/test_layers/deprecated/building_blocks.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import inferno.extensions.layers.building_blocks as bb 4 | 5 | 6 | class ResBlockTest(unittest.TestCase): 7 | 8 | def test_2D_simple_(self): 9 | 10 | x = torch.rand(1, 3, 64, 15) 11 | model = bb.ResBlock(in_channels=3, out_channels=3, dim=2) 12 | xx = model(x) 13 | out_size = xx.size() 14 | self.assertEqual(list(out_size), [1,3, 64, 15]) 15 | 16 | def test_3D_simple_(self): 17 | 18 | x = torch.rand(1,3,20, 64,15) 19 | model = bb.ResBlock(in_channels=3, out_channels=3, dim=3) 20 | xx = model(x) 21 | out_size = xx.size() 22 | self.assertEqual(list(out_size), [1,3, 20, 64, 15]) 23 | 24 | def test_2D_simple_2(self): 25 | 26 | x = torch.rand(1,3,64,64) 27 | model = bb.ResBlock(in_channels=3, out_channels=6, dim=2) 28 | xx = model(x) 29 | out_size = xx.size() 30 | self.assertEqual(list(out_size), [1,6, 64, 64]) 31 | 32 | def test_2D_simple_3(self): 33 | 34 | x = torch.rand(1,3,64,64) 35 | model = bb.ResBlock(in_channels=3, out_channels=6, dim=2, size=4) 36 | xx = model(x) 37 | out_size = xx.size() 38 | self.assertEqual(list(out_size), [1,6, 64, 64]) 39 | 40 | def test_2D_simple_4(self): 41 | 42 | x = torch.rand(1,6,64,64) 43 | model = bb.ResBlock(in_channels=6, out_channels=6, dim=2, size=4, 44 | force_skip_op=True) 45 | xx = model(x) 46 | out_size = xx.size() 47 | self.assertEqual(list(out_size), [1,6, 64, 64]) 48 | 49 | def test_2D_simple_5(self): 50 | 51 | x = torch.rand(1,6,64,64) 52 | model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=2, size=4, 53 | force_skip_op=True) 54 | xx = model(x) 55 | out_size = xx.size() 56 | self.assertEqual(list(out_size), [1,6, 64, 64]) 57 | 58 | def test_2D_simple_6(self): 59 | 60 | x = torch.rand(1,6,64,64) 61 | model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=2, size=4, 62 | force_skip_op=True, activated=False) 63 | xx = model(x) 64 | out_size = xx.size() 65 | self.assertEqual(list(out_size), [1,6, 64, 64]) 66 | 67 | def test_3D_simple_6(self): 68 | 69 | x = torch.rand(1,6,64,64, 20) 70 | model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=3, size=4, 71 | force_skip_op=True, activated=False) 72 | xx = model(x) 73 | out_size = xx.size() 74 | self.assertEqual(list(out_size), [1,6, 64, 64, 20]) 75 | 76 | 77 | if __name__ == '__main__': 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ================================== 4 | Installation 5 | ================================== 6 | 7 | Install on Linux and OSX 8 | ------------------------ 9 | 10 | Developers 11 | ~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | First, make sure `you have Pytorch installed `_. 14 | 15 | Then, clone this repository with: 16 | 17 | .. code:: python 18 | 19 | $ git clone https://github.com/nasimrahaman/inferno.git 20 | 21 | 22 | Next, install the dependencies. 23 | 24 | .. code:: python 25 | 26 | $ cd inferno 27 | $ pip install -r requirements.txt 28 | 29 | 30 | If you use python from the shell: 31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 32 | 33 | Finally, add *inferno* to your `PYTHONPATH` with: 34 | 35 | .. code:: python 36 | 37 | source add2path.sh 38 | 39 | If you use PyCharm: 40 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 41 | Refer to this `QA `_ about setting up paths with Pycharm. 42 | 43 | 44 | 45 | 46 | 47 | ====================================================== 48 | Installation via PyPi / pip / setup.py(Experimental) 49 | ====================================================== 50 | 51 | You need to install pytorch via pip before installing 52 | inferno. Follow the `pytorch installation guide`_. 53 | 54 | Stable release 55 | -------------- 56 | 57 | To install inferno, run this command in your terminal: 58 | 59 | .. code-block:: console 60 | 61 | $ pip install inferno-pytorch 62 | 63 | This is the preferred method to install inferno, as it will always install the most recent stable release. 64 | 65 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide 66 | you through the process. 67 | 68 | .. _pip: https://pip.pypa.io 69 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ 70 | .. _pytorch installation guide: http://pytorch.org/ 71 | 72 | From sources 73 | ------------------------ 74 | First, make sure `you have Pytorch installed `_. 75 | The sources for inferno can be downloaded from the `Github repo`_. 76 | You can either clone the public repository: 77 | 78 | .. code-block:: console 79 | 80 | $ git clone git://github.com/nasimrahaman/inferno 81 | 82 | Or download the `tarball`_: 83 | 84 | .. code-block:: console 85 | 86 | $ curl -OL https://github.com/nasimrahaman/inferno/tarball/master 87 | 88 | Once you have a copy of the source, you can install it with: 89 | 90 | .. code-block:: console 91 | 92 | $ python setup.py install 93 | 94 | 95 | .. _Github repo: https://github.com/nasimrahaman/inferno 96 | .. _tarball: https://github.com/nasimrahaman/inferno/tarball/master 97 | -------------------------------------------------------------------------------- /inferno/io/core/concatenate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.dataset import Dataset 3 | from ...utils import python_utils as pyu 4 | 5 | 6 | class Concatenate(Dataset): 7 | """ 8 | Concatenates mutliple datasets to one. This class does not implement 9 | synchronization primitives. 10 | """ 11 | def __init__(self, *datasets, transforms=None): 12 | assert all([isinstance(dataset, Dataset) for dataset in datasets]) 13 | assert len(datasets) >= 1 14 | assert transforms is None or callable(transforms) 15 | self.datasets = datasets 16 | self.transforms = transforms 17 | 18 | def map_index(self, index): 19 | # Get a list of lengths of all datasets. Say the answer is [4, 3, 3], 20 | # and we're looking for index = 5. 21 | len_list = list(map(len, self.datasets)) 22 | # Cumulate to a numpy array. The answer is [4, 7, 10] 23 | cumulative_len_list = np.cumsum(len_list) 24 | # When the index is subtracted, we get [-1, 2, 5]. We're looking for the (index 25 | # of the) first cumulated len which is larger than the index (in this case, 26 | # 7 (index 1)). 27 | offset_cumulative_len_list = cumulative_len_list - index 28 | dataset_index = np.argmax(offset_cumulative_len_list > 0) 29 | # With the dataset index, we figure out the index in dataset 30 | if dataset_index == 0: 31 | # First dataset - index corresponds to index_in_dataset 32 | index_in_dataset = index 33 | else: 34 | # Get cumulated length up to the current dataset 35 | len_up_to_dataset = cumulative_len_list[dataset_index - 1] 36 | # Compute index_in_dataset as that what's left 37 | index_in_dataset = index - len_up_to_dataset 38 | return dataset_index, index_in_dataset 39 | 40 | def __getitem__(self, index): 41 | assert index < len(self) 42 | dataset_index, index_in_dataset = self.map_index(index) 43 | fetched = self.datasets[dataset_index][index_in_dataset] 44 | if self.transforms is None: 45 | return fetched 46 | elif callable(self.transforms): 47 | return self.transforms(*pyu.to_iterable(fetched)) 48 | else: 49 | raise NotImplementedError 50 | 51 | def __len__(self): 52 | return sum([len(dataset) for dataset in self.datasets]) 53 | 54 | def __repr__(self): 55 | if len(self.datasets) < 3: 56 | return "Concatenate(" + \ 57 | ", ".join([dataset.__repr__() for dataset in self.datasets[:-1]]) + ", " + \ 58 | self.datasets[-1].__repr__() + \ 59 | ")" 60 | else: 61 | return "Concatenate({}xDatasets)".format(len(self.datasets)) 62 | -------------------------------------------------------------------------------- /inferno/extensions/layers/convolutional_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .convolutional import BNReLUConv2D, BNReLUDeconv2D, Conv2D, Deconv2D 3 | from ...utils import python_utils as pyu 4 | from ...utils.exceptions import assert_ 5 | 6 | __all__ = ['ResidualBlock', 'PreActSimpleResidualBlock'] 7 | _all = __all__ 8 | 9 | 10 | class ResidualBlock(nn.Module): 11 | def __init__(self, layers, resample=None): 12 | super(ResidualBlock, self).__init__() 13 | assert pyu.is_listlike(layers) 14 | self.layers = nn.Sequential(*layers) 15 | self.resample = resample 16 | 17 | def forward(self, input): 18 | preaddition = self.layers(input) 19 | if self.resample is not None: 20 | skip = self.resample(input) 21 | else: 22 | skip = input 23 | output = preaddition + skip 24 | return output 25 | 26 | 27 | class PreActSimpleResidualBlock(ResidualBlock): 28 | def __init__(self, in_channels, num_hidden_channels, upsample=False, downsample=False): 29 | layers = [] 30 | if downsample: 31 | assert_(not upsample, "Both downsample and upsample is set to true.", ValueError) 32 | layers.append(BNReLUConv2D(in_channels=in_channels, 33 | out_channels=num_hidden_channels, 34 | kernel_size=3, 35 | stride=2)) 36 | resample = nn.Sequential(Conv2D(in_channels=in_channels, 37 | out_channels=in_channels, 38 | kernel_size=1, stride=2), 39 | nn.BatchNorm2d(in_channels)) 40 | elif upsample: 41 | layers.append(BNReLUDeconv2D(in_channels=in_channels, 42 | out_channels=num_hidden_channels, 43 | kernel_size=2, 44 | stride=2)) 45 | resample = nn.Sequential(Deconv2D(in_channels=in_channels, 46 | out_channels=in_channels, 47 | kernel_size=2, stride=2), 48 | nn.BatchNorm2d(in_channels)) 49 | else: 50 | layers.append(BNReLUConv2D(in_channels=in_channels, 51 | out_channels=num_hidden_channels, 52 | kernel_size=3)) 53 | resample = None 54 | layers.append(BNReLUConv2D(in_channels=num_hidden_channels, 55 | out_channels=in_channels, 56 | kernel_size=3)) 57 | super(PreActSimpleResidualBlock, self).__init__(layers, resample) 58 | 59 | 60 | # TODO PreActBottleneckResidualBlock 61 | -------------------------------------------------------------------------------- /inferno/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .exceptions import assert_, NotTorchModuleError, ShapeError 3 | 4 | 5 | def is_model_cuda(model): 6 | try: 7 | return next(model.parameters()).is_cuda 8 | except StopIteration: 9 | # Assuming that if a network has no parameters, it doesn't use CUDA 10 | return False 11 | 12 | 13 | class ModelTester(object): 14 | def __init__(self, input_shape, expected_output_shape): 15 | self._is_cuda = False 16 | self.input_shape = input_shape 17 | self.expected_output_shape = expected_output_shape 18 | 19 | def cuda(self): 20 | self._is_cuda = True 21 | return self 22 | 23 | def get_input(self): 24 | with torch.no_grad(): 25 | if self._is_cuda: 26 | return torch.rand(*self.input_shape, requires_grad=False).cuda() 27 | else: 28 | return torch.rand(*self.input_shape, requires_grad=False) 29 | 30 | def __call__(self, model): 31 | # Make sure model is a model 32 | assert_(isinstance(model, torch.nn.Module), 33 | "Model is not a torch module.", 34 | NotTorchModuleError) 35 | # Transfer to cuda if required 36 | if not is_model_cuda(model) and self._is_cuda: 37 | model.cuda() 38 | input_ = self.get_input() 39 | output = model(input_) 40 | assert_(list(output.size()) == list(self.expected_output_shape), 41 | "Expected output shape {} for input shape {}, " 42 | "got output of shape {} instead.".format(list(self.expected_output_shape), 43 | list(self.input_shape), 44 | list(output.size())), 45 | ShapeError) 46 | return model 47 | 48 | 49 | class MultiscaleModelTester(ModelTester): 50 | def __call__(self, model): 51 | # Make sure model is a model 52 | assert_(isinstance(model, torch.nn.Module), 53 | "Model is not a torch module.", 54 | NotTorchModuleError) 55 | # Transfer to cuda if required 56 | if not is_model_cuda(model) and self._is_cuda: 57 | model.cuda() 58 | input_ = self.get_input() 59 | output = model(input_) 60 | assert_(isinstance(output, tuple), "Expect tuple output") 61 | for scale in range(len(output)): 62 | assert_(list(output[scale].size()) == list(self.expected_output_shape[scale]), 63 | "Expected output shape {} for input shape {}, " 64 | "got output of shape {} instead.".format(list(self.expected_output_shape[scale]), 65 | list(self.input_shape), 66 | list(output[scale].size())), 67 | ShapeError) 68 | return model 69 | -------------------------------------------------------------------------------- /inferno/trainers/callbacks/console.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from .base import Callback 3 | 4 | class StdoutPrinter(object): 5 | def print(self, message): 6 | print("[+][{}] {}".format(str(datetime.now()), message)) 7 | 8 | 9 | class Console(object): 10 | LEVEL_INFO = 1 11 | LEVEL_PROGRESS = 2 12 | LEVEL_WARNING = 3 13 | LEVEL_DEBUG = 4 14 | 15 | def __init__(self, printer=StdoutPrinter()): 16 | self._printer = printer 17 | self._enabled = {self.LEVEL_INFO, self.LEVEL_PROGRESS, self.LEVEL_WARNING} 18 | 19 | def set_console(self, console): 20 | self._printer = console 21 | 22 | def _print(self, message, level): 23 | if level not in self._enabled: 24 | return 25 | 26 | self._printer.print(message) 27 | 28 | def info(self, message): 29 | self._print("[INFO ] " + message, self.LEVEL_INFO) 30 | 31 | def print(self, message): 32 | self.info(message) 33 | 34 | def progress(self, message): 35 | self._print("[PROGRESS] " + message, self.LEVEL_PROGRESS) 36 | 37 | def warning(self, message): 38 | self._print("[WARNING ] " + message, self.LEVEL_WARNING) 39 | 40 | def debug(self, message): 41 | self._print("[DEBUG ] " + message, self.LEVEL_DEBUG) 42 | 43 | def _toggle(self, level, state): 44 | if state: 45 | self._enabled.add(level) 46 | else: 47 | if level in self._enabled: 48 | self._enabled.remove(level) 49 | 50 | def toggle_info(self, state): 51 | self._toggle(self.LEVEL_INFO, state) 52 | 53 | def toggle_progress(self, state): 54 | self._toggle(self.LEVEL_PROGRESS, state) 55 | 56 | def toggle_warning(self, state): 57 | self._toggle(self.LEVEL_WARNING, state) 58 | 59 | 60 | 61 | class ShowMinimalConsoleInfo(Callback): 62 | """ 63 | Callback to show only minimum training info on console 64 | viz. current epoch number, current learning rate, 65 | training loss and training error if exists. 66 | """ 67 | def __init__(self, *args, **kwargs): 68 | super(ShowMinimalConsoleInfo, self).__init__(*args, **kwargs) 69 | 70 | def begin_of_fit(self,**_): 71 | self.trainer.quiet() 72 | 73 | def end_of_epoch(self, **_): 74 | training_loss = self.trainer.get_state('training_loss') 75 | training_error = self.trainer.get_state('training_error') 76 | learning_rate = self.trainer.get_state('learning_rate') 77 | 78 | self.trainer.console.info("--------------------------------") 79 | self.trainer.console.info("Epoch "+str(self.trainer.epoch_count)) 80 | if training_loss is not None: 81 | self.trainer.console.info("Train Loss "+str(training_loss.item())) 82 | if training_error is not None: 83 | self.trainer.console.info("Train Error "+str(training_error.item())) 84 | self.trainer.console.info("Current LR "+str(learning_rate)) -------------------------------------------------------------------------------- /inferno/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py as h5 3 | import numpy as np 4 | import yaml 5 | from skimage.io import imsave 6 | 7 | 8 | # Function to load in a dataset from a h5file 9 | def fromh5(path, datapath=None, dataslice=None, asnumpy=True, preptrain=None): 10 | """ 11 | Opens a hdf5 file at path, loads in the dataset at datapath, and returns dataset 12 | as a numpy array. 13 | """ 14 | # Check if path exists (thanks Lukas!) 15 | assert os.path.exists(path), "Path {} does not exist.".format(path) 16 | with h5.File(path, 'r') as f: 17 | # Init dataset 18 | h5dataset = f[datapath] if datapath is not None else f.values()[0] 19 | # Slice dataset 20 | h5dataset = h5dataset[dataslice] if dataslice is not None else h5dataset 21 | # Convert to numpy if required 22 | h5dataset = np.asarray(h5dataset) if asnumpy else h5dataset 23 | # Apply preptrain 24 | h5dataset = preptrain(h5dataset) if preptrain is not None else h5dataset 25 | return h5dataset 26 | 27 | 28 | # TODO we could also do **h5_kwargs instead 29 | def toh5(data, path, datapath='data', compression=None, chunks=None): 30 | """Write `data` to a HDF5 volume.""" 31 | with h5.File(path) as f: 32 | f.create_dataset(datapath, data=data, compression=compression, chunks=chunks) 33 | 34 | 35 | def fromz5(path, datapath, dataslice=None, n_threads=8): 36 | # we import z5py only here because we don't want to assume that it's in the env 37 | import z5py 38 | assert os.path.exists(path), "Path {} does not exist.".format(path) 39 | with z5py.File(path) as f: 40 | ds = f[datapath] 41 | ds.n_threads = n_threads 42 | data = ds[:] if dataslice is None else ds[dataslice] 43 | return data 44 | 45 | 46 | # Yaml to dict reader 47 | def yaml2dict(path): 48 | if isinstance(path, dict): 49 | # Forgivable mistake that path is a dict already 50 | return path 51 | with open(path, 'r') as f: 52 | readict = yaml.load(f, Loader=yaml.FullLoader) 53 | return readict 54 | 55 | 56 | def print_tensor(tensor, prefix, directory): 57 | """Prints a image or volume tensor to file as images.""" 58 | def _print_image(image, prefix, batch, channel, z=None): 59 | if z is None: 60 | file_name = "{}--B-{}--CH-{}.png".format(prefix, batch, channel) 61 | else: 62 | file_name = "{}--B-{}--CH-{}--Z-{}.png".format(prefix, batch, channel, z) 63 | full_file_name = os.path.join(directory, file_name) 64 | imsave(arr=image, fname=full_file_name) 65 | 66 | for batch in range(tensor.shape[0]): 67 | for channel in range(tensor.shape[1]): 68 | if tensor.ndim == 4: 69 | _print_image(tensor[batch, channel, ...], prefix, batch, channel) 70 | else: 71 | for plane in range(tensor.shape[2]): 72 | _print_image(tensor[batch, channel, plane, ...], prefix, batch, channel, plane) 73 | -------------------------------------------------------------------------------- /inferno/extensions/initializers/presets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.init as init 3 | from functools import partial 4 | 5 | from .base import Initialization, Initializer 6 | 7 | 8 | __all__ = ['Constant', 'NormalWeights', 9 | 'SELUWeightsZeroBias', 10 | 'ELUWeightsZeroBias', 11 | 'OrthogonalWeightsZeroBias', 12 | 'KaimingNormalWeightsZeroBias'] 13 | 14 | 15 | class Constant(Initializer): 16 | """Initialize with a constant.""" 17 | def __init__(self, constant): 18 | self.constant = constant 19 | 20 | def call_on_tensor(self, tensor): 21 | tensor.fill_(self.constant) 22 | return tensor 23 | 24 | 25 | class NormalWeights(Initializer): 26 | """ 27 | Initialize weights with random numbers drawn from the normal distribution at 28 | `mean` and `stddev`. 29 | """ 30 | def __init__(self, mean=0., stddev=1., sqrt_gain_over_fan_in=None): 31 | self.mean = mean 32 | self.stddev = stddev 33 | self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in 34 | 35 | def compute_fan_in(self, tensor): 36 | if tensor.dim() == 2: 37 | return tensor.size(1) 38 | else: 39 | return np.prod(list(tensor.size())[1:]) 40 | 41 | def call_on_weight(self, tensor): 42 | # Compute stddev if required 43 | if self.sqrt_gain_over_fan_in is not None: 44 | stddev = self.stddev * \ 45 | np.sqrt(self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor)) 46 | else: 47 | stddev = self.stddev 48 | # Init 49 | tensor.normal_(self.mean, stddev) 50 | 51 | 52 | class OrthogonalWeightsZeroBias(Initialization): 53 | def __init__(self, orthogonal_gain=1.): 54 | # This prevents a deprecated warning in Pytorch 0.4+ 55 | orthogonal = getattr(init, 'orthogonal_', init.orthogonal) 56 | super(OrthogonalWeightsZeroBias, self)\ 57 | .__init__(weight_initializer=partial(orthogonal, gain=orthogonal_gain), 58 | bias_initializer=Constant(0.)) 59 | 60 | 61 | class KaimingNormalWeightsZeroBias(Initialization): 62 | def __init__(self, relu_leakage=0): 63 | # This prevents a deprecated warning in Pytorch 0.4+ 64 | kaiming_normal = getattr(init, 'kaiming_normal_', init.kaiming_normal) 65 | super(KaimingNormalWeightsZeroBias, self)\ 66 | .__init__(weight_initializer=partial(kaiming_normal, a=relu_leakage), 67 | bias_initializer=Constant(0.)) 68 | 69 | 70 | class SELUWeightsZeroBias(Initialization): 71 | def __init__(self): 72 | super(SELUWeightsZeroBias, self)\ 73 | .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.), 74 | bias_initializer=Constant(0.)) 75 | 76 | 77 | class ELUWeightsZeroBias(Initialization): 78 | def __init__(self): 79 | super(ELUWeightsZeroBias, self)\ 80 | .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277), 81 | bias_initializer=Constant(0.)) 82 | -------------------------------------------------------------------------------- /tests/test_extensions/test_models/test_res_unet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch.cuda as cuda 4 | from inferno.utils.model_utils import ModelTester 5 | 6 | 7 | class ResUNetTest(unittest.TestCase): 8 | def test_res_unet_2d(self): 9 | from inferno.extensions.models import ResBlockUNet 10 | tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256)) 11 | if cuda.is_available(): 12 | tester.cuda() 13 | tester(ResBlockUNet(in_channels=1, out_channels=1, dim=2)) 14 | 15 | def test_res_unet_3d(self): 16 | from inferno.extensions.models import ResBlockUNet 17 | tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64)) 18 | if cuda.is_available(): 19 | tester.cuda() 20 | # test default unet 3d 21 | tester(ResBlockUNet(in_channels=1, out_channels=1, dim=3)) 22 | 23 | def test_2d_side_out_bot_up(self): 24 | from inferno.extensions.models import ResBlockUNet 25 | depth = 3 26 | in_channels = 3 27 | 28 | x = torch.rand(1, in_channels, 64, 32) 29 | model = ResBlockUNet(in_channels=in_channels, 30 | out_channels=8, dim=2, 31 | side_out_parts=['bottom','up'], 32 | unet_kwargs=dict(depth=depth)) 33 | 34 | out_list = model(x) 35 | self.assertEqual(len(out_list), depth + 1) 36 | 37 | self.assertEqual(list(out_list[0].size()), [1, 24, 8, 4]) 38 | self.assertEqual(list(out_list[1].size()), [1, 12, 16, 8]) 39 | self.assertEqual(list(out_list[2].size()), [1, 6, 32, 16]) 40 | self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32]) 41 | 42 | def test_2d_side_out_up(self): 43 | from inferno.extensions.models import ResBlockUNet 44 | depth = 3 45 | in_channels = 3 46 | 47 | x = torch.rand(1, in_channels, 64, 32) 48 | model = ResBlockUNet(in_channels=in_channels, 49 | out_channels=8, dim=2, 50 | side_out_parts=['up'], 51 | unet_kwargs=dict(depth=depth)) 52 | 53 | out_list = model(x) 54 | self.assertEqual(len(out_list), depth) 55 | 56 | self.assertEqual(list(out_list[0].size()), [1,12, 16, 8]) 57 | self.assertEqual(list(out_list[1].size()), [1, 6, 32, 16]) 58 | self.assertEqual(list(out_list[2].size()), [1, 8, 64, 32]) 59 | 60 | def test_2d_side_out_down(self): 61 | from inferno.extensions.models import ResBlockUNet 62 | depth = 3 63 | in_channels = 3 64 | 65 | x = torch.rand(1, in_channels, 64, 32) 66 | model = ResBlockUNet(in_channels=in_channels, 67 | out_channels=8, dim=2, 68 | side_out_parts=['down'], 69 | unet_kwargs=dict(depth=depth)) 70 | 71 | out_list = model(x) 72 | self.assertEqual(len(out_list), depth + 1) 73 | 74 | self.assertEqual(list(out_list[0].size()), [1, 6, 64, 32]) 75 | self.assertEqual(list(out_list[1].size()), [1, 12, 32, 16]) 76 | self.assertEqual(list(out_list[2].size()), [1, 24, 16, 8]) 77 | 78 | # the actual output 79 | self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32]) 80 | 81 | 82 | if __name__ == '__main__': 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /inferno/extensions/criteria/core.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from functools import reduce 3 | from ...utils.exceptions import assert_, ShapeError, NotTorchModuleError 4 | 5 | 6 | __all__ = ['Criteria', 'As2DCriterion'] 7 | 8 | 9 | class Criteria(nn.Module): 10 | """Aggregate multiple criteria to one.""" 11 | def __init__(self, *criteria): 12 | super(Criteria, self).__init__() 13 | if len(criteria) == 1 and isinstance(criteria[0], (list, tuple)): 14 | criteria = list(criteria[0]) 15 | else: 16 | criteria = list(criteria) 17 | # Validate criteria 18 | assert all([isinstance(criterion, nn.Module) for criterion in criteria]), \ 19 | "Criterion must be a torch module." 20 | self.criteria = criteria 21 | 22 | def forward(self, prediction, target): 23 | assert isinstance(prediction, (list, tuple)), \ 24 | "`prediction` must be a list or a tuple, got {} instead."\ 25 | .format(type(prediction).__name__) 26 | assert isinstance(target, (list, tuple)), \ 27 | "`prediction` must be a list or a tuple, got {} instead." \ 28 | .format(type(target).__name__) 29 | assert len(prediction) == len(target), \ 30 | "Number of predictions must equal the number of targets. " \ 31 | "Got {} predictions but {} targets.".format(len(prediction), len(target)) 32 | # Compute losses 33 | losses = [criterion(prediction, target) 34 | for _prediction, _target, criterion in zip(prediction, target, self.criteria)] 35 | # Aggegate losses 36 | loss = reduce(lambda x, y: x + y, losses) 37 | # Done 38 | return loss 39 | 40 | 41 | class As2DCriterion(nn.Module): 42 | """ 43 | Makes a given criterion applicable on (N, C, H, W) prediction and (N, H, W) target tensors, 44 | if they're applicable to (N, C) prediction and (N,) target tensors . 45 | """ 46 | def __init__(self, criterion): 47 | super(As2DCriterion, self).__init__() 48 | assert_(isinstance(criterion, nn.Module), 49 | "Criterion must be a module, got a {} instead." 50 | .format(type(criterion).__name__), 51 | NotTorchModuleError) 52 | self.criterion = criterion 53 | 54 | def forward(self, prediction, target): 55 | # Validate input 56 | assert_(prediction.dim() == 4, "`prediction` is expected to be a 4D tensor of shape " 57 | "(N, C, H, W), got a {}D " 58 | "tensor instead.".format(prediction.dim()), 59 | ShapeError) 60 | assert_(target.dim() == 3, "`target` is expected to be a 3D tensor of shape " 61 | "(N, H, W), got a {}D " 62 | "tensor instead.".format(target.dim()), 63 | ShapeError) 64 | # prediction is assumed to be NCHW, and target NHW. 65 | # this makes target (NHW,) 66 | target = target.contiguous().view(-1) 67 | # This makes prediction (N, H, W, C) --> (NHW, C) 68 | num_channels = prediction.size(1) 69 | prediction = prediction.permute(0, 2, 3, 1).contiguous().view(-1, num_channels) 70 | # Now, the criterion should be applicable as is 71 | loss = self.criterion(prediction, target) 72 | return loss 73 | -------------------------------------------------------------------------------- /inferno/extensions/optimizers/adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Adam(Optimizer): 6 | """Implements Adam algorithm with the option of adding a L1 penalty. 7 | 8 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 1e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square (default: (0.9, 0.999)) 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | 20 | .. _Adam\: A Method for Stochastic Optimization: 21 | https://arxiv.org/abs/1412.6980 22 | """ 23 | 24 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 25 | lambda_l1=0, weight_decay=0, **kwargs): 26 | defaults = dict(lr=lr, betas=betas, eps=eps, 27 | lambda_l1=lambda_l1, weight_decay=weight_decay, 28 | **kwargs) 29 | super(Adam, self).__init__(params, defaults) 30 | 31 | def step(self, closure=None): 32 | """Performs a single optimization step. 33 | 34 | Arguments: 35 | closure (callable, optional): A closure that reevaluates the model 36 | and returns the loss. 37 | """ 38 | loss = None 39 | if closure is not None: 40 | loss = closure() 41 | 42 | for group in self.param_groups: 43 | for p in group['params']: 44 | if p.grad is None: 45 | continue 46 | grad = p.grad.data 47 | state = self.state[p] 48 | 49 | # State initialization 50 | if len(state) == 0: 51 | state['step'] = 0 52 | # Exponential moving average of gradient values 53 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 54 | # Exponential moving average of squared gradient values 55 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 56 | 57 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 58 | beta1, beta2 = group['betas'] 59 | 60 | state['step'] += 1 61 | 62 | if group['lambda_l1'] != 0: 63 | grad.add_(group['lambda_l1'], p.data.sign()) 64 | if group['weight_decay'] != 0: 65 | grad.add_(group['weight_decay'], p.data) 66 | 67 | # Decay the first and second moment running average coefficient 68 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 69 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 70 | 71 | denom = exp_avg_sq.sqrt().add_(group['eps']) 72 | 73 | bias_correction1 = 1 - beta1 ** state['step'] 74 | bias_correction2 = 1 - beta2 ** state['step'] 75 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 76 | 77 | p.data.addcdiv_(-step_size, exp_avg, denom) 78 | 79 | return loss 80 | -------------------------------------------------------------------------------- /inferno/extensions/layers/sampling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ['AnisotropicUpsample', 'AnisotropicPool', 'Upsample', 'AnisotropicUpsample2D', 'AnisotropicPool2D'] 4 | 5 | 6 | # torch is deprecating nn.Upsample in favor of nn.functional.interpolate 7 | # we wrap interpolate here to still use Upsample as class 8 | class Upsample(nn.Module): 9 | def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): 10 | self.size = size 11 | self.scale_factor = scale_factor 12 | self.mode = mode 13 | self.align_corners = align_corners 14 | super(Upsample, self).__init__() 15 | # interpolate was only introduced in torch 0.4.1 for backward compatibility 16 | # we check if we have the attribute here and fall back to Upsample otherwise 17 | if hasattr(nn.functional, 'interpolate'): 18 | self.have_interpolate = True 19 | else: 20 | self.have_interpolate = False 21 | self.sampler = nn.Upsample(size=size, scale_factor=scale_factor, 22 | mode=mode, align_corners=align_corners) 23 | 24 | def forward(self, input): 25 | if self.have_interpolate: 26 | return nn.functional.interpolate(input, self.size, self.scale_factor, 27 | self.mode, self.align_corners) 28 | else: 29 | return self.sampler(input) 30 | 31 | 32 | class AnisotropicUpsample(nn.Module): 33 | def __init__(self, scale_factor): 34 | super(AnisotropicUpsample, self).__init__() 35 | self.upsampler = Upsample(scale_factor=scale_factor) 36 | 37 | def forward(self, input): 38 | # input is 3D of shape NCDHW 39 | N, C, D, H, W = input.size() 40 | # Fold C and D axes in one 41 | folded = input.view(N, C * D, H, W) 42 | # Upsample 43 | upsampled = self.upsampler(folded) 44 | # Unfold out the C and D axes 45 | unfolded = upsampled.view(N, C, D, 46 | self.upsampler.scale_factor * H, 47 | self.upsampler.scale_factor * W) 48 | # Done 49 | return unfolded 50 | 51 | 52 | class AnisotropicPool(nn.MaxPool3d): 53 | def __init__(self, downscale_factor): 54 | ds = downscale_factor 55 | super(AnisotropicPool, self).__init__(kernel_size=(1, ds + 1, ds + 1), 56 | stride=(1, ds, ds), 57 | padding=(0, 1, 1)) 58 | 59 | class AnisotropicUpsample2D(nn.Module): 60 | def __init__(self, scale_factor): 61 | super(AnisotropicUpsample2D, self).__init__() 62 | self.upsampler = nn.Upsample(scale_factor=scale_factor) 63 | 64 | def forward(self, input): 65 | # input is 2D of shape NCDW (or NCDH, egal) 66 | N, C, D, W = input.size() 67 | # Fold C and D axes in one 68 | folded = input.view(N, C * D, W) 69 | # Upsample 70 | upsampled = self.upsampler(folded) 71 | # Unfold out the C and D axes 72 | unfolded = upsampled.view(N, C, D, 73 | self.upsampler.scale_factor * W) 74 | # Done 75 | return unfolded 76 | 77 | 78 | class AnisotropicPool2D(nn.MaxPool2d): 79 | def __init__(self, downscale_factor): 80 | ds = downscale_factor 81 | super(AnisotropicPool2D, self).__init__(kernel_size=(1, ds + 1), 82 | stride=(1, ds), 83 | padding=(0, 1)) 84 | 85 | -------------------------------------------------------------------------------- /tests/test_utils/test_partial_cls.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import inferno.utils.model_utils as mu 3 | from inferno.utils.partial_cls import register_partial_cls 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class TestCls(object): 9 | def __init__(self, a, b, c=1, d=2): 10 | self.a = a 11 | self.b = b 12 | self.c = c 13 | self.d = d 14 | 15 | class PartialClsTester(unittest.TestCase): 16 | 17 | def test_partial_cls(self): 18 | register_partial_cls(TestCls, 'TestA', 19 | fix=dict(a='a'), 20 | default=dict(b='b'), 21 | module=__name__ 22 | ) 23 | assert 'TestA' in globals() 24 | 25 | inst = TestA() 26 | assert inst.a == 'a' 27 | assert inst.b == 'b' 28 | assert inst.c == 1 29 | assert inst.d == 2 30 | 31 | inst = TestA('fu','bar','fubar') 32 | assert inst.a == 'a' 33 | assert inst.b == 'fu' 34 | assert inst.c == 'bar' 35 | assert inst.d == 'fubar' 36 | 37 | with self.assertRaises(TypeError): 38 | inst = TestA(a=2) 39 | 40 | def test_update_existing_default_cls(self): 41 | register_partial_cls(TestCls, 'TestA', 42 | fix=dict(a='a'), 43 | default=dict(d=3), 44 | module=__name__ 45 | ) 46 | assert 'TestA' in globals() 47 | 48 | inst = TestA(42) 49 | assert inst.a == 'a' 50 | assert inst.b == 42 51 | assert inst.c == 1 52 | assert inst.d == 3 53 | 54 | with self.assertRaises(TypeError): 55 | inst = TestA() 56 | 57 | def test_fix_nothing(self): 58 | register_partial_cls(TestCls, 'TestA', 59 | module=__name__ 60 | ) 61 | assert 'TestA' in globals() 62 | 63 | inst = TestA(1,2,3,4) 64 | assert inst.a == 1 65 | assert inst.b == 2 66 | assert inst.c == 3 67 | assert inst.d == 4 68 | 69 | with self.assertRaises(TypeError): 70 | inst = TestA() 71 | 72 | def test_fix_all(self): 73 | register_partial_cls(TestCls, 'TestA', 74 | module=__name__, 75 | fix=dict(a=4, b=3, c=2, d=1) 76 | ) 77 | assert 'TestA' in globals() 78 | 79 | inst = TestA() 80 | assert inst.a == 4 81 | assert inst.b == 3 82 | assert inst.c == 2 83 | assert inst.d == 1 84 | 85 | with self.assertRaises(TypeError): 86 | inst = TestA('a') 87 | 88 | with self.assertRaises(TypeError): 89 | inst = TestA(a=1) 90 | with self.assertRaises(TypeError): 91 | inst = TestA(b=1) 92 | with self.assertRaises(TypeError): 93 | inst = TestA(c=1) 94 | with self.assertRaises(TypeError): 95 | inst = TestA(d=1) 96 | 97 | 98 | def test_default_all(self): 99 | register_partial_cls(TestCls, 'TestA', 100 | module=__name__, 101 | default=dict(a=4, b=3, c=2, d=1) 102 | ) 103 | assert 'TestA' in globals() 104 | 105 | inst = TestA() 106 | assert inst.a == 4 107 | assert inst.b == 3 108 | assert inst.c == 2 109 | assert inst.d == 1 110 | 111 | 112 | inst = TestA(2) 113 | assert inst.a == 2 114 | assert inst.b == 3 115 | assert inst.c == 2 116 | assert inst.d == 1 117 | 118 | inst = TestA(2,3,4,5) 119 | assert inst.a == 2 120 | assert inst.b == 3 121 | assert inst.c == 4 122 | assert inst.d == 5 123 | 124 | with self.assertRaises(TypeError): 125 | inst = TestA(3,4,5,a=2) 126 | 127 | inst = TestA(3,4,5,d=2) 128 | assert inst.a == 3 129 | assert inst.b == 4 130 | assert inst.c == 5 131 | assert inst.d == 2 132 | 133 | 134 | 135 | 136 | if __name__ == '__main__': 137 | unittest.main() 138 | -------------------------------------------------------------------------------- /inferno/trainers/callbacks/tqdm.py: -------------------------------------------------------------------------------- 1 | from .base import Callback 2 | from tqdm import tqdm 3 | from datetime import datetime 4 | from .console import Console 5 | 6 | 7 | class TQDMPrinter(object): 8 | def __init__(self, progress): 9 | self._progress = progress 10 | 11 | def print(self, message): 12 | if self._progress.outer_bar is not None: 13 | self._progress.outer_bar.clear() 14 | tqdm.write(message) 15 | if self._progress.outer_bar is not None: 16 | self._progress.outer_bar.refresh() 17 | 18 | 19 | class TQDMConsole(Console): 20 | def __init__(self): 21 | super(TQDMConsole, self).__init__(printer=TQDMPrinter(TQDMProgressBar())) 22 | 23 | 24 | class TQDMProgressBar(Callback): 25 | def __init__(self, *args, **kwargs): 26 | super(TQDMProgressBar, self).__init__(*args, **kwargs) 27 | self.epoch_bar = None 28 | self.outer_bar = None 29 | self.is_training = False 30 | self.is_validation = False 31 | 32 | def bind_trainer(self, *args, **kwargs): 33 | super(TQDMProgressBar, self).bind_trainer(*args, **kwargs) 34 | self.trainer.console.toggle_progress(False) 35 | self.trainer.console.set_console(TQDMPrinter(self)) 36 | 37 | def _init_epoch_bar_train(self): 38 | n_batch = len(self.trainer._loader_iters['train']) 39 | self.epoch_bar = tqdm(total=n_batch, position=1, dynamic_ncols=True) 40 | self.epoch_bar.update(self.trainer._batch_count) 41 | self.epoch_bar.set_description("Training epoch %d" % self.trainer._epoch_count) 42 | 43 | def print(self, message, **_): 44 | if self.outer_bar is not None: 45 | self.outer_bar.clear() 46 | tqdm.write("[+][{}] {}".format(str(datetime.now()), message)) 47 | if self.outer_bar is not None: 48 | self.outer_bar.refresh() 49 | 50 | def begin_of_fit(self, max_num_epochs, **_): 51 | if isinstance(max_num_epochs, int): 52 | self.outer_bar = tqdm(total=max_num_epochs, position=0, dynamic_ncols=True) 53 | else: 54 | self.outer_bar = tqdm(total=1000, position=0, dynamic_ncols=True) 55 | self.outer_bar.set_description("Epochs") 56 | 57 | def end_of_fit(self, **_): 58 | if self.outer_bar is not None: 59 | self.outer_bar.close() 60 | self.outer_bar = None 61 | 62 | def begin_of_epoch(self, **_): 63 | if self.epoch_bar is not None: 64 | self.epoch_bar.close() 65 | 66 | def end_of_epoch(self, **_): 67 | if self.outer_bar is not None: 68 | self.outer_bar.update(1) 69 | 70 | def begin_of_training_iteration(self, **_): 71 | if not self.epoch_bar and 'train' in self.trainer._loader_iters.keys(): 72 | self._init_epoch_bar_train() 73 | return 74 | 75 | if self.epoch_bar: 76 | self.epoch_bar.update(1) 77 | 78 | def begin_of_validation_iteration(self, **_): 79 | if self.epoch_bar: 80 | self.epoch_bar.update(1) 81 | 82 | def begin_of_training_run(self, **_): 83 | self.is_training = True 84 | 85 | def end_of_training_run(self, **_): 86 | self.is_training = False 87 | if self.epoch_bar: 88 | self.epoch_bar.close() 89 | self.epoch_bar = None 90 | 91 | def begin_of_validation_run(self, num_iterations, num_iterations_in_generator, last_validated_at_epoch, **_): 92 | self.is_validation = True 93 | nmax = num_iterations 94 | if not nmax: 95 | nmax = num_iterations_in_generator 96 | 97 | self.epoch_bar = tqdm(total=nmax, position=1, dynamic_ncols=True) 98 | self.epoch_bar.set_description("Validating epoch %d" % (last_validated_at_epoch-1)) 99 | 100 | def end_of_validation_run(self, **_): 101 | self.is_validation = False 102 | if self.epoch_bar: 103 | self.epoch_bar.close() 104 | self.epoch_bar = None 105 | -------------------------------------------------------------------------------- /inferno/extensions/layers/device.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ...utils.python_utils import from_iterable, to_iterable 3 | from ...utils.exceptions import assert_, DeviceError 4 | 5 | __all__ = ['DeviceTransfer', 'OnDevice'] 6 | _all = __all__ 7 | 8 | 9 | class DeviceTransfer(nn.Module): 10 | """Layer to transfer variables to a specified device.""" 11 | def __init__(self, target_device, device_ordinal=None, asynchronous=False): 12 | """ 13 | Parameters 14 | ---------- 15 | target_device : {'cpu', 'cuda'} 16 | Device to transfer to. 17 | device_ordinal : int 18 | Device ordinal if target_device == 'cuda'. 19 | asynchronous : bool 20 | Whether to use asynchronous transfers. 21 | """ 22 | super(DeviceTransfer, self).__init__() 23 | # Validate arguments 24 | assert_(target_device in ['cpu', 'cuda'], 25 | "Target device must either be 'cpu' or 'cuda'.", 26 | DeviceError) 27 | if target_device == 'cpu': 28 | assert_(device_ordinal is None, 29 | "'device_ordinal' must be None if target_device is 'cpu'.", 30 | DeviceError) 31 | self.target_device = target_device 32 | self.device_ordinal = device_ordinal 33 | 34 | def forward(self, *inputs): 35 | if self.target_device == 'cuda': 36 | transferred = tuple(input_.cuda(device=self.device_ordinal, 37 | non_blocking=self.asynchronous) 38 | for input_ in inputs) 39 | elif self.target_device == 'cpu': 40 | transferred = tuple(input_.cpu() for input_ in inputs) 41 | else: 42 | raise NotImplementedError 43 | return from_iterable(transferred) 44 | 45 | 46 | class OnDevice(nn.Module): 47 | """ 48 | Moves a module to a device. The advantage of using this over `torch.nn.Module.cuda` is 49 | that the inputs are transferred to the same device as the module, enabling easy model 50 | parallelism. 51 | """ 52 | def __init__(self, module, target_device, device_ordinal=None, asynchronous=False): 53 | """ 54 | Parameters 55 | ---------- 56 | module : torch.nn.Module 57 | Module to transfer to device. 58 | target_device : {'cuda', 'cpu'} 59 | The device to move `module` to. Must be either 'cuda' or 'cpu'. 60 | device_ordinal : int 61 | Ordinal of the GPU device if `target_device = 'cuda'`. 62 | asynchronous : bool 63 | Whether to use asynchronous transfers. 64 | """ 65 | super(OnDevice, self).__init__() 66 | # Validate arguments 67 | assert_(target_device in ['cpu', 'cuda'], 68 | "Target device must either be 'cpu' or 'cuda'.", 69 | DeviceError) 70 | if target_device == 'cpu': 71 | assert_(device_ordinal is None, 72 | "'device_ordinal' must be None if target_device is 'cpu'.", 73 | DeviceError) 74 | self.target_device = target_device 75 | self.device_ordinal = device_ordinal 76 | self.asynchronous = asynchronous 77 | # This is a no-op if module is already in the right device 78 | self.device_transfer = DeviceTransfer(self.target_device, 79 | device_ordinal=self.device_ordinal, 80 | asynchronous=self.asynchronous) 81 | 82 | self.module = self.transfer_module(module) 83 | 84 | def transfer_module(self, module): 85 | if self.target_device == 'cuda': 86 | return module.cuda(device_id=self.device_ordinal) 87 | elif self.target_device == 'cpu': 88 | return module.cpu() 89 | else: 90 | raise NotImplementedError 91 | 92 | def forward(self, *inputs): 93 | # Transfer inputs (no-op if they're already on the right device) 94 | transferred = to_iterable(self.device_transfer(*inputs)) 95 | output = self.module(*transferred) 96 | return output 97 | -------------------------------------------------------------------------------- /tests/test_training/test_callbacks/test_essentials.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import shutil 3 | import h5py as h5 4 | from os.path import dirname, join 5 | from os import listdir 6 | from inferno.trainers.basic import Trainer 7 | from inferno.trainers.callbacks.essentials import DumpHDF5Every 8 | from inferno.utils.test_utils import generate_random_dataloader 9 | from inferno.extensions.layers import Conv2D, AsMatrix 10 | from torch.nn import Sequential, MaxPool2d, AdaptiveAvgPool2d, Linear, Softmax 11 | 12 | 13 | class TestEssentials(unittest.TestCase): 14 | WORKING_DIRECTORY = dirname(__file__) 15 | 16 | def setUp(self): 17 | # Build a simple ass model 18 | model = Sequential(Conv2D(3, 8, 3, activation='ReLU'), 19 | MaxPool2d(2, 2), 20 | Conv2D(8, 8, 3, activation='ReLU'), 21 | MaxPool2d(2, 2), 22 | Conv2D(8, 8, 3, activation='ReLU'), 23 | MaxPool2d(2, 2), 24 | Conv2D(8, 8, 3, activation='ReLU'), 25 | AdaptiveAvgPool2d((1, 1)), 26 | AsMatrix(), 27 | Linear(8, 10)) 28 | 29 | train_dataloader = generate_random_dataloader(512, (3, 32, 32), 10, batch_size=16, 30 | dtype='float32') 31 | validate_dataloader = generate_random_dataloader(32, (3, 32, 32), 10, batch_size=16, 32 | dtype='float32') 33 | # Build trainer 34 | trainer = Trainer(model)\ 35 | .bind_loader('train', train_dataloader)\ 36 | .bind_loader('validate', validate_dataloader)\ 37 | .save_to_directory(to_directory=join(self.WORKING_DIRECTORY, 'Weights'))\ 38 | .build_criterion('CrossEntropyLoss').build_optimizer('RMSprop') 39 | self.trainer = trainer 40 | 41 | def test_dump_hdf5_every(self): 42 | # Configure callback 43 | dumper = DumpHDF5Every((1, 'epoch'), 44 | to_directory=join(self.WORKING_DIRECTORY, 'Weights'), 45 | dump_after_every_validation_run=True) 46 | self.trainer\ 47 | .set_max_num_epochs(4)\ 48 | .register_callback(dumper)\ 49 | .validate_every((16, 'iterations')) 50 | 51 | self.trainer.fit() 52 | all_files = listdir(join(self.WORKING_DIRECTORY, 'Weights')) 53 | for epoch in range(5): 54 | self.assertIn('dump.training.epoch{}.iteration{}.h5'.format(epoch, epoch * 32), 55 | all_files) 56 | # We don't validate at last epoch 57 | if epoch != 4: 58 | self.assertIn('dump.validation.epoch{}.iteration{}.h5' 59 | .format(epoch, (epoch * 32) + 16), 60 | all_files) 61 | self.assertIn('dump.validation.epoch{}.iteration{}.h5' 62 | .format(epoch, (epoch * 32) + 32), 63 | all_files) 64 | 65 | # Check if the keys are right in a training dump 66 | sample_file_path = join(self.WORKING_DIRECTORY, 'Weights', 67 | 'dump.training.epoch0.iteration0.h5') 68 | with h5.File(sample_file_path, 'r') as sample_file: 69 | all_dataset_names = list(sample_file.keys()) 70 | self.assertSequenceEqual(all_dataset_names, 71 | ['training_inputs_0', 'training_prediction', 'training_target']) 72 | # Check if the keys are right in a validation dump 73 | sample_file_path = join(self.WORKING_DIRECTORY, 'Weights', 74 | 'dump.validation.epoch0.iteration16.h5') 75 | with h5.File(sample_file_path, 'r') as sample_file: 76 | all_dataset_names = list(sample_file.keys()) 77 | self.assertSequenceEqual(all_dataset_names, 78 | ['validation_inputs_0', 'validation_prediction', 79 | 'validation_target']) 80 | 81 | def tearDown(self): 82 | shutil.rmtree(join(self.WORKING_DIRECTORY, 'Weights')) 83 | 84 | 85 | if __name__ == '__main__': 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every 8 | little bit helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | Types of Contributions 13 | ---------------------- 14 | 15 | Report Bugs 16 | ~~~~~~~~~~~ 17 | 18 | Report bugs at https://github.com/nasimrahaman/inferno/issues. 19 | 20 | If you are reporting a bug, please include: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | Fix Bugs 27 | ~~~~~~~~ 28 | 29 | Look through the GitHub issues for bugs. Anything tagged with "bug" 30 | and "help wanted" is open to whoever wants to implement it. 31 | 32 | Implement Features 33 | ~~~~~~~~~~~~~~~~~~ 34 | 35 | Look through the GitHub issues for features. Anything tagged with "enhancement" 36 | and "help wanted" is open to whoever wants to implement it. 37 | 38 | Write Documentation 39 | ~~~~~~~~~~~~~~~~~~~ 40 | 41 | inferno could always use more documentation, whether as part of the 42 | official inferno docs, in docstrings, or even on the web in blog posts, 43 | articles, and such. 44 | 45 | Submit Feedback 46 | ~~~~~~~~~~~~~~~ 47 | 48 | The best way to send feedback is to file an issue at https://github.com/nasimrahaman/inferno/issues. 49 | 50 | If you are proposing a feature: 51 | 52 | * Explain in detail how it would work. 53 | * Keep the scope as narrow as possible, to make it easier to implement. 54 | * Remember that this is a volunteer-driven project, and that contributions 55 | are welcome :) 56 | 57 | Get Started! 58 | ------------ 59 | 60 | Ready to contribute? Here's how to set up `inferno` for local development. 61 | 62 | 1. Fork the `inferno` repo on GitHub. 63 | 2. Clone your fork locally:: 64 | 65 | $ git clone git@github.com:your_name_here/inferno.git 66 | 67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 68 | 69 | $ mkvirtualenv inferno 70 | $ cd inferno/ 71 | $ python setup.py develop 72 | 73 | 4. Create a branch for local development:: 74 | 75 | $ git checkout -b name-of-your-bugfix-or-feature 76 | 77 | Now you can make your changes locally. 78 | 79 | 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:: 80 | 81 | $ flake8 inferno tests 82 | $ python setup.py test or py.test 83 | $ tox 84 | 85 | To get flake8 and tox, just pip install them into your virtualenv. 86 | 87 | 6. Commit your changes and push your branch to GitHub:: 88 | 89 | $ git add . 90 | $ git commit -m "Your detailed description of your changes." 91 | $ git push origin name-of-your-bugfix-or-feature 92 | 93 | 7. Submit a pull request through the GitHub website. 94 | 95 | Pull Request Guidelines 96 | ----------------------- 97 | 98 | Before you submit a pull request, check that it meets these guidelines: 99 | 100 | 1. The pull request should include tests. 101 | 2. If the pull request adds functionality, the docs should be updated. Put 102 | your new functionality into a function with a docstring, and add the 103 | feature to the list in README.rst. 104 | 3. The pull request should work for Python 3.5 and 3.6. Check 105 | https://travis-ci.org/nasimrahaman/inferno/pull_requests 106 | and make sure that the tests pass for all supported Python versions. 107 | 108 | Tips 109 | ---- 110 | 111 | To run a subset of tests:: 112 | 113 | $ python -m unittest tests.test_inferno 114 | 115 | 116 | 117 | Sphinx Apidoc 118 | -------------- 119 | before building the documentation 120 | one needs to generate the auto-generated 121 | sphinxs api documentation. 122 | These files need to be in the github repository. 123 | 124 | .. code:: bash 125 | 126 | cd docs 127 | sphinx-apidoc -o inferno-apidoc ../inferno 128 | 129 | .. warning:: 130 | 131 | Do not make changes to `inferno/docs/inferno-apidoc` This folder is auto-generated 132 | by the above mentioned command. 133 | 134 | The following combines all the commands necessary to build the html documentation: 135 | 136 | .. code:: bash 137 | 138 | ./build_docs.sh 139 | 140 | -------------------------------------------------------------------------------- /tests/test_training/test_callbacks/test_logging/test_tensorboard.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import os 4 | from shutil import rmtree 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from inferno.trainers.basic import Trainer 9 | from torch.utils.data.dataset import TensorDataset 10 | from torch.utils.data.dataloader import DataLoader 11 | from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger 12 | from inferno.extensions.layers.reshape import AsMatrix 13 | 14 | 15 | class TestTensorboard(unittest.TestCase): 16 | ROOT_DIR = os.path.dirname(__file__) 17 | PRECISION = 'float' 18 | SAVE_DIRECTORY = os.path.join(ROOT_DIR, 'saves') 19 | LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs') 20 | 21 | @staticmethod 22 | def _make_test_model(input_channels): 23 | toy_net = nn.Sequential(nn.Conv2d(input_channels, 8, 3, 1, 1), 24 | nn.ELU(), 25 | nn.MaxPool2d(2), 26 | nn.Conv2d(8, 8, 3, 1, 1), 27 | nn.ELU(), 28 | nn.MaxPool2d(2), 29 | nn.Conv2d(8, 16, 3, 1, 1), 30 | nn.ELU(), 31 | nn.AdaptiveMaxPool2d((1, 1)), 32 | AsMatrix(), 33 | nn.Linear(16, 10)) 34 | return toy_net 35 | 36 | def tearDown(self): 37 | for d in [self.SAVE_DIRECTORY, self.LOG_DIRECTORY]: 38 | try: 39 | rmtree(d) 40 | except OSError: 41 | pass 42 | 43 | def get_random_dataloaders(self, input_channels=3): 44 | # Convert build random tensor dataset 45 | data_shape = (1, input_channels, 64, 64) 46 | target_shape = (1) 47 | random_array = torch.from_numpy(np.random.rand(*data_shape)).float() 48 | target_array = torch.from_numpy(np.random.randint(0, 9, size=target_shape)) 49 | train_dataset = TensorDataset(random_array, target_array) 50 | test_dataset = TensorDataset(random_array, target_array) 51 | 52 | # Build dataloaders from dataset 53 | train_loader = DataLoader(train_dataset, batch_size=1, 54 | shuffle=True, num_workers=0, pin_memory=False) 55 | test_loader = DataLoader(test_dataset, batch_size=1, 56 | shuffle=True, num_workers=0, pin_memory=False) 57 | return train_loader, test_loader 58 | 59 | def get_trainer(self, input_channels): 60 | # Build model 61 | net = self._make_test_model(input_channels) 62 | # Build trainer 63 | trainer = Trainer(net)\ 64 | .build_logger(TensorboardLogger(send_image_at_batch_indices=0, 65 | send_image_at_channel_indices='all', 66 | log_images_every=(20, 'iterations')), 67 | log_directory=self.LOG_DIRECTORY)\ 68 | .build_criterion('CrossEntropyLoss')\ 69 | .build_metric('CategoricalError')\ 70 | .build_optimizer('Adam')\ 71 | .validate_every((1, 'epochs'))\ 72 | .save_every((2, 'epochs'), to_directory=self.SAVE_DIRECTORY)\ 73 | .save_at_best_validation_score()\ 74 | .set_max_num_epochs(2)\ 75 | .set_precision(self.PRECISION) 76 | # Bind loaders 77 | train_loader, test_loader = self.get_random_dataloaders(input_channels=input_channels) 78 | trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader) 79 | return trainer 80 | 81 | def test_tensorboard(self): 82 | trainer = self.get_trainer(3) 83 | trainer.fit() 84 | 85 | def test_tensorboard_grayscale(self): 86 | trainer = self.get_trainer(1) 87 | trainer.fit() 88 | 89 | def test_serialization(self): 90 | trainer = self.get_trainer(3) 91 | # Serialize 92 | trainer.save() 93 | # Unserialize 94 | trainer = Trainer().load(os.path.join(self.ROOT_DIR, 'saves')) 95 | train_loader, test_loader = self.get_random_dataloaders(input_channels=3) 96 | trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader) 97 | trainer.fit() 98 | 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /docs/graphics/tentative_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 18 | 20 | 38 | 40 | 41 | 43 | image/svg+xml 44 | 46 | 47 | 48 | 49 | 50 | 54 | INFERN 65 | 76 | 81 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /tests/test_io/test_volumetric/test_lazy_volume_loader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import numpy as np 4 | 5 | # try to load io libraries (h5py and z5py) 6 | try: 7 | import h5py 8 | WITH_H5PY = True 9 | except ImportError: 10 | WITH_H5PY = False 11 | 12 | # try: 13 | # import z5py 14 | # WITH_Z5PY = True 15 | # except ImportError: 16 | # WITH_Z5PY = False 17 | 18 | 19 | class TestLazyVolumeLoader(unittest.TestCase): 20 | 21 | def tearDown(self): 22 | try: 23 | os.remove('tmp.h5') 24 | except OSError: 25 | pass 26 | 27 | @unittest.skipUnless(WITH_H5PY, "Need h5py") 28 | def test_h5_loader(self): 29 | from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader 30 | shape = (100, 100) 31 | 32 | # test default data loader 33 | data = np.arange(np.product(shape)).reshape(shape) 34 | with h5py.File('tmp.h5') as f: 35 | f.create_dataset('data', data=data) 36 | 37 | loader = LazyHDF5VolumeLoader('tmp.h5', 'data', 38 | window_size=[10, 10], stride=[10, 10], 39 | return_index_spec=True) 40 | self.assertEqual(loader.shape, shape) 41 | for batch, index in loader: 42 | expected = data[index.base_sequence_at_index] 43 | self.assertEqual(batch.shape, expected.shape) 44 | self.assertTrue(np.allclose(batch, expected)) 45 | 46 | @unittest.skipUnless(WITH_H5PY, "Need h5py") 47 | def test_h5_loader_data_slice(self): 48 | from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader 49 | shape = (100, 100, 100) 50 | data_slice = np.s_[:, 20:80, 10:30] 51 | 52 | # test default data loader 53 | data = np.arange(np.product(shape)).reshape(shape) 54 | with h5py.File('tmp.h5') as f: 55 | f.create_dataset('data', data=data) 56 | data = data[data_slice] 57 | 58 | loader = LazyHDF5VolumeLoader('tmp.h5', 'data', 59 | window_size=[10, 10, 10], stride=[10, 10, 10], 60 | return_index_spec=True, data_slice=data_slice) 61 | self.assertEqual(loader.shape, data.shape) 62 | for batch, index in loader: 63 | slice_ = index.base_sequence_at_index 64 | expected = data[slice_] 65 | self.assertEqual(batch.shape, expected.shape) 66 | self.assertTrue(np.allclose(batch, expected)) 67 | 68 | @unittest.skipUnless(WITH_H5PY, "Need h5py") 69 | def test_h5_loader_pad(self): 70 | from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader 71 | shape = (100, 100, 100) 72 | pad = [[0, 10], [0, 0], [5, 15]] 73 | 74 | # test default data loader 75 | data = np.arange(np.product(shape)).reshape(shape) 76 | with h5py.File('tmp.h5') as f: 77 | f.create_dataset('data', data=data) 78 | data = np.pad(data, pad_width=pad, mode='constant') 79 | 80 | loader = LazyHDF5VolumeLoader('tmp.h5', 'data', 81 | window_size=[20, 20, 20], stride=[20, 20, 20], 82 | return_index_spec=True, padding=pad, padding_mode='constant') 83 | self.assertEqual(loader.shape, data.shape) 84 | for batch, index in loader: 85 | slice_ = index.base_sequence_at_index 86 | expected = data[slice_] 87 | self.assertEqual(batch.shape, expected.shape) 88 | self.assertTrue(np.allclose(batch, expected)) 89 | 90 | @unittest.skipUnless(WITH_H5PY, "Need h5py") 91 | def test_h5_loader_data_slice_pad(self): 92 | from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader 93 | shape = (100, 100, 100) 94 | data_slice = np.s_[:, 20:80, 10:90] 95 | pad = [[0, 10], [5, 5], [5, 15]] 96 | 97 | # test default data loader 98 | data = np.arange(np.product(shape)).reshape(shape) 99 | with h5py.File('tmp.h5') as f: 100 | f.create_dataset('data', data=data) 101 | data = data[data_slice] 102 | data = np.pad(data, pad_width=pad, mode='constant') 103 | 104 | loader = LazyHDF5VolumeLoader('tmp.h5', 'data', 105 | window_size=[20, 20, 20], stride=[20, 20, 20], 106 | return_index_spec=True, padding=pad, padding_mode='constant', 107 | data_slice=data_slice) 108 | self.assertEqual(loader.shape, data.shape) 109 | for batch, index in loader: 110 | slice_ = index.base_sequence_at_index 111 | expected = data[slice_] 112 | self.assertEqual(batch.shape, expected.shape) 113 | self.assertTrue(np.allclose(batch, expected)) 114 | 115 | 116 | if __name__ == '__main__': 117 | unittest.main() 118 | -------------------------------------------------------------------------------- /inferno/utils/partial_cls.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import sys 3 | import types 4 | import inspect 5 | 6 | 7 | __all__ = [ 8 | 'partial_cls', 9 | 'register_partial_cls' 10 | ] 11 | 12 | 13 | def partial_cls(base_cls, name, module, fix=None, default=None): 14 | 15 | # helper function 16 | def insert_if_not_present(dict_a, dict_b): 17 | for kw,val in dict_b.items(): 18 | if kw not in dict_a: 19 | dict_a[kw] = val 20 | return dict_a 21 | 22 | # helper function 23 | def insert_call_if_present(dict_a, dict_b, callback): 24 | for kw,val in dict_b.items(): 25 | if kw not in dict_a: 26 | dict_a[kw] = val 27 | else: 28 | callback(kw) 29 | return dict_a 30 | 31 | # helper class 32 | class PartialCls(object): 33 | def __init__(self, base_cls, name, module, fix=None, default=None): 34 | 35 | self.base_cls = base_cls 36 | self.name = name 37 | self.module = module 38 | self.fix = [fix, {}][fix is None] 39 | self.default = [default, {}][default is None] 40 | 41 | if self.fix.keys() & self.default.keys(): 42 | raise TypeError('fix and default share keys') 43 | 44 | # remove binded kw 45 | self._allowed_kw = self._get_allowed_kw() 46 | 47 | def _get_allowed_kw(self): 48 | 49 | 50 | argspec = inspect.getfullargspec(base_cls.__init__) 51 | args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = argspec 52 | 53 | if varargs is not None: 54 | raise TypeError('partial_cls can only be used if __init__ has no varargs') 55 | 56 | if varkw is not None: 57 | raise TypeError('partial_cls can only be used if __init__ has no varkw') 58 | 59 | if kwonlyargs is not None and kwonlyargs != []: 60 | raise TypeError('partial_cls can only be used without kwonlyargs') 61 | 62 | if args is None or len(args) < 1: 63 | raise TypeError('seems like self is missing') 64 | 65 | 66 | return [kw for kw in args[1:] if kw not in self.fix] 67 | 68 | 69 | def _build_kw(self, args, kwargs): 70 | # handle *args 71 | if len(args) > len(self._allowed_kw): 72 | raise TypeError("to many arguments") 73 | 74 | all_args = {} 75 | for arg, akw in zip(args, self._allowed_kw): 76 | all_args[akw] = arg 77 | 78 | # handle **kwargs 79 | intersection = self.fix.keys() & kwargs.keys() 80 | if len(intersection) >= 1: 81 | kw = intersection.pop() 82 | raise TypeError("`{}.__init__` got unexpected keyword argument '{}'".format(name, kw)) 83 | 84 | def raise_cb(kw): 85 | raise TypeError("{}.__init__ got multiple values for argument '{}'".format(name, kw)) 86 | all_args = insert_call_if_present(all_args, kwargs, raise_cb) 87 | 88 | # handle fixed arguments 89 | def raise_cb(kw): 90 | raise TypeError() 91 | all_args = insert_call_if_present(all_args, self.fix, raise_cb) 92 | 93 | # handle defaults 94 | all_args = insert_if_not_present(all_args, self.default) 95 | 96 | # handle fixed 97 | all_args.update(self.fix) 98 | 99 | return all_args 100 | 101 | def build_cls(self): 102 | 103 | def new_init(self_of_new_cls, *args, **kwargs): 104 | combined_args = self._build_kw(args=args, kwargs=kwargs) 105 | 106 | #call base cls init 107 | super(self_of_new_cls.__class__, self_of_new_cls).__init__(**combined_args) 108 | 109 | return type(name, (self.base_cls,), { 110 | '__module__': self.module, 111 | '__init__' : new_init 112 | }) 113 | return cls 114 | 115 | 116 | return PartialCls(base_cls=base_cls, name=name, module=module, 117 | fix=fix, default=default).build_cls() 118 | 119 | 120 | def register_partial_cls(base_cls, name, module, fix=None, default=None): 121 | module_dict = sys.modules[module].__dict__ 122 | generatedClass = partial_cls(base_cls=base_cls,name=name, module=module, 123 | fix=fix, default=default) 124 | module_dict[generatedClass.__name__] = generatedClass 125 | del generatedClass 126 | 127 | 128 | if __name__ == "__main__": 129 | 130 | class Conv(object): 131 | def __init__(self, dim, activation, stride=1): 132 | print(f"dim {dim} act {activation} stride {stride}") 133 | 134 | 135 | Conv2D = partial_cls(Conv,'Conv2D',__name__, fix=dict(dim=2), default=dict(stride=2)) 136 | 137 | 138 | #obj = Conv2D(activation='a') 139 | #obj = Conv2D('a',activation='a', stride=3) 140 | obj = Conv2D('fu','bar') 141 | 142 | -------------------------------------------------------------------------------- /examples/regularized_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regularized MNIST Example 3 | ================================ 4 | 5 | This example demonstrates adding and logging arbitrary regularization losses, in this case, 6 | L2 activity regularization and L1 weight regularization. 7 | 8 | - Add a `_losses` dictionary to any module containing loss names and values 9 | - Use a criterion from `inferno.extensions.criteria.regularized` that will collect and add those losses 10 | - Call `Trainer.observe_training_and_validation_states` to log the losses as well 11 | """ 12 | 13 | import argparse 14 | import sys 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torchvision import datasets, transforms 19 | 20 | from inferno.extensions.layers.reshape import Flatten 21 | from inferno.trainers.basic import Trainer 22 | from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger 23 | 24 | 25 | class RegularizedLinear(nn.Linear): 26 | def __init__(self, *args, ar_weight=1e-3, l1_weight=1e-3, **kwargs): 27 | super(RegularizedLinear, self).__init__(*args, **kwargs) 28 | self.ar_weight = ar_weight 29 | self.l1_weight = l1_weight 30 | self._losses = {} 31 | 32 | def forward(self, input): 33 | output = super(RegularizedLinear, self).forward(input) 34 | self._losses['activity_regularization'] = (output * output).sum() * self.ar_weight 35 | self._losses['l1_weight_regularization'] = torch.abs(self.weight).sum() * self.l1_weight 36 | return output 37 | 38 | 39 | def model_fn(): 40 | return nn.Sequential( 41 | Flatten(), 42 | RegularizedLinear(in_features=784, out_features=256), 43 | nn.LeakyReLU(), 44 | RegularizedLinear(in_features=256, out_features=128), 45 | nn.LeakyReLU(), 46 | RegularizedLinear(in_features=128, out_features=10) 47 | ) 48 | 49 | 50 | def mnist_data_loaders(args): 51 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 52 | train_loader = torch.utils.data.DataLoader( 53 | datasets.MNIST('./data', train=True, download=True, 54 | transform=transforms.Compose([ 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.1307,), (0.3081,)) 57 | ])), 58 | batch_size=args.batch_size, shuffle=True, **kwargs) 59 | test_loader = torch.utils.data.DataLoader( 60 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.1307,), (0.3081,)) 63 | ])), 64 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 65 | return train_loader, test_loader 66 | 67 | 68 | def train_model(args): 69 | model = model_fn() 70 | train_loader, validate_loader = mnist_data_loaders(args) 71 | 72 | # Build trainer 73 | trainer = Trainer(model) \ 74 | .build_criterion('RegularizedCrossEntropyLoss') \ 75 | .build_metric('CategoricalError') \ 76 | .build_optimizer('Adam') \ 77 | .validate_every((1, 'epochs')) \ 78 | .save_every((1, 'epochs')) \ 79 | .save_to_directory(args.save_directory) \ 80 | .set_max_num_epochs(args.epochs) \ 81 | .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), 82 | log_images_every='never'), 83 | log_directory=args.save_directory) 84 | 85 | # Record regularization losses 86 | trainer.logger.observe_training_and_validation_states([ 87 | 'main_loss', 88 | 'total_regularization_loss', 89 | 'activity_regularization', 90 | 'l1_weight_regularization' 91 | ]) 92 | 93 | # Bind loaders 94 | trainer \ 95 | .bind_loader('train', train_loader) \ 96 | .bind_loader('validate', validate_loader) 97 | 98 | if args.cuda: 99 | trainer.cuda() 100 | 101 | # Go! 102 | trainer.fit() 103 | 104 | 105 | def main(argv): 106 | # Training settings 107 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 108 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 109 | help='input batch size for training (default: 64)') 110 | parser.add_argument('--save-directory', type=str, default='output/mnist/v1', 111 | help='output directory') 112 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 113 | help='input batch size for testing (default: 1000)') 114 | parser.add_argument('--epochs', type=int, default=20, metavar='N', 115 | help='number of epochs to train (default: 20)') 116 | parser.add_argument('--no-cuda', action='store_true', default=False, 117 | help='disables CUDA training') 118 | args = parser.parse_args(argv) 119 | args.cuda = not args.no_cuda and torch.cuda.is_available() 120 | train_model(args) 121 | 122 | 123 | if __name__ == '__main__': 124 | main(sys.argv[1:]) 125 | -------------------------------------------------------------------------------- /inferno/extensions/criteria/regularized.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from . import set_similarity_measures, core 7 | 8 | __all__ = [ 9 | 'RegularizedLoss', 10 | 'RegularizedCrossEntropyLoss', 11 | 'RegularizedBCEWithLogitsLoss', 12 | 'RegularizedBCELoss', 13 | 'RegularizedMSELoss', 14 | 'RegularizedNLLLoss' 15 | ] 16 | 17 | 18 | def collect_losses(module): 19 | """Collect `_losses` dictionaries from module and children 20 | 21 | :param module: a Module to be searched for losses 22 | :return: dictionary of loss names to values 23 | """ 24 | losses = {} 25 | 26 | def _collect(m): 27 | if hasattr(m, '_losses'): 28 | for k, v in m._losses.items(): 29 | if k in losses: 30 | losses[k] = losses[k] + v 31 | else: 32 | losses[k] = v 33 | 34 | module.apply(_collect) 35 | return losses 36 | 37 | 38 | def build_criterion(criterion, *args, **kwargs): 39 | """Build a criterion 40 | 41 | :param criterion: criterion class, name of criterion class, or instance of criterion 42 | :param args: args for constructor 43 | :param kwargs: kwargs for constructor 44 | :return: instance of criterion 45 | """ 46 | if isinstance(criterion, str): 47 | for module in [nn, core, set_similarity_measures]: 48 | criterion_class = getattr(module, criterion, None) 49 | if criterion_class is not None: 50 | break 51 | assert criterion_class is not None, "Criterion {} not found.".format(criterion) 52 | elif callable(criterion) and isinstance(criterion, type): 53 | criterion_class = criterion 54 | elif isinstance(criterion, torch.nn.Module): 55 | return criterion 56 | else: 57 | raise NotImplementedError 58 | return criterion_class(*args, **kwargs) 59 | 60 | 61 | class RegularizedLoss(nn.Module): 62 | """Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion. 63 | """ 64 | 65 | def __init__(self, criterion, *args, **kwargs): 66 | super(RegularizedLoss, self).__init__() 67 | self.criterion = build_criterion(criterion, *args, **kwargs) 68 | 69 | def forward(self, *args, trainer=None, model=None, **kwargs): 70 | # calculate wrapped loss 71 | main_loss = self.criterion(*args, **kwargs) 72 | 73 | # If no trainer, we cannot record states 74 | if trainer is None: 75 | warnings.warn('No trainer parameter provided. Not logging regularization losses.') 76 | elif model is None: 77 | model = trainer.model 78 | 79 | # If no model or trainer, we cannot record states or collect losses 80 | if model is None: 81 | warnings.warn('No model or trainer parameter provided. Not calculating regularization losses.') 82 | regularization_losses = {} 83 | total_regularization_loss = None 84 | total_loss = main_loss 85 | else: 86 | regularization_losses = collect_losses(model) 87 | total_regularization_loss = sum(regularization_losses.values()) 88 | total_loss = main_loss + total_regularization_loss 89 | 90 | # Record losses if trainer provided 91 | if trainer is not None: 92 | # prefix depending on mode 93 | if self.training: 94 | prefix = 'training' 95 | else: 96 | prefix = 'validation' 97 | # main loss 98 | updates = {'{}_main_loss'.format(prefix): main_loss} 99 | # total regulariztion loss 100 | if total_regularization_loss is not None: 101 | updates['{}_total_regularization_loss'.format(prefix)] = total_regularization_loss 102 | # detailed regularization losses 103 | for k, v in regularization_losses.items(): 104 | updates['{}_{}'.format(prefix, k)] = v 105 | # record state 106 | trainer.update_state_from_dictionary(updates) 107 | 108 | return total_loss 109 | 110 | 111 | # Convenience wrappers for common losses 112 | class RegularizedCrossEntropyLoss(RegularizedLoss): 113 | def __init__(self, *args, **kwargs): 114 | super(RegularizedCrossEntropyLoss, self).__init__(nn.CrossEntropyLoss, *args, **kwargs) 115 | 116 | 117 | class RegularizedBCEWithLogitsLoss(RegularizedLoss): 118 | def __init__(self, *args, **kwargs): 119 | super(RegularizedBCEWithLogitsLoss, self).__init__(nn.BCEWithLogitsLoss, *args, **kwargs) 120 | 121 | 122 | class RegularizedBCELoss(RegularizedLoss): 123 | def __init__(self, *args, **kwargs): 124 | super(RegularizedBCELoss, self).__init__(nn.BCELoss, *args, **kwargs) 125 | 126 | 127 | class RegularizedMSELoss(RegularizedLoss): 128 | def __init__(self, *args, **kwargs): 129 | super(RegularizedMSELoss, self).__init__(nn.MSELoss, *args, **kwargs) 130 | 131 | 132 | class RegularizedNLLLoss(RegularizedLoss): 133 | def __init__(self, *args, **kwargs): 134 | super(RegularizedNLLLoss, self).__init__(nn.NLLLoss, *args, **kwargs) 135 | -------------------------------------------------------------------------------- /inferno/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .python_utils import delayed_keyboard_interrupt 5 | from .exceptions import assert_, ShapeError, NotUnwrappableError 6 | 7 | 8 | def unwrap(input_, to_cpu=True, as_numpy=False, extract_item=False): 9 | if isinstance(input_, (list, tuple)): 10 | return type(input_)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy) 11 | for _t in input_]) 12 | elif torch.is_tensor(input_): 13 | tensor = input_ 14 | elif isinstance(input_, np.ndarray): 15 | return input_ 16 | elif isinstance(input_, (float, int)): 17 | return input_ 18 | else: 19 | raise NotUnwrappableError("Cannot unwrap a '{}'." 20 | .format(type(input_).__name__)) 21 | # Transfer to CPU if required 22 | if to_cpu: 23 | with delayed_keyboard_interrupt(): 24 | tensor = tensor.cpu().detach() 25 | # Convert to numpy if required 26 | if as_numpy: 27 | return tensor.cpu().detach().numpy() 28 | elif extract_item: 29 | try: 30 | return tensor.item() 31 | except AttributeError: 32 | return tensor[0] 33 | else: 34 | return tensor 35 | 36 | 37 | def is_tensor(object_): 38 | missed_tensor_classes = (torch.HalfTensor,) 39 | return torch.is_tensor(object_) or isinstance(object_, missed_tensor_classes) 40 | 41 | 42 | def is_label_tensor(object_): 43 | return is_tensor(object_) and object_.type() in ['torch.LongTensor', 'torch.cuda.LongTensor'] 44 | 45 | 46 | def is_image_tensor(object_): 47 | return is_tensor(object_) and object_.dim() == 4 48 | 49 | 50 | def is_volume_tensor(object_): 51 | return is_tensor(object_) and object_.dim() == 5 52 | 53 | 54 | def is_image_or_volume_tensor(object_): 55 | return is_image_tensor(object_) or is_volume_tensor(object_) 56 | 57 | 58 | def is_label_image_tensor(object_): 59 | return is_label_tensor(object_) and object_.dim() == 3 60 | 61 | 62 | def is_label_volume_tensor(object_): 63 | return is_label_tensor(object_) and object_.dim() == 4 64 | 65 | 66 | def is_label_image_or_volume_tensor(object_): 67 | return is_label_image_tensor(object_) or is_label_volume_tensor(object_) 68 | 69 | 70 | def is_matrix_tensor(object_): 71 | return is_tensor(object_) and object_.dim() == 2 72 | 73 | 74 | def is_scalar_tensor(object_): 75 | return is_tensor(object_) and object_.dim() <= 1 and object_.numel() == 1 76 | 77 | 78 | def is_vector_tensor(object_): 79 | return is_tensor(object_) and object_.dim() == 1 and object_.numel() > 1 80 | 81 | 82 | def assert_same_size(tensor_1, tensor_2): 83 | assert_(list(tensor_1.size()) == list(tensor_2.size()), 84 | "Tensor sizes {} and {} do not match.".format(tensor_1.size(), tensor_2.size()), 85 | ShapeError) 86 | 87 | 88 | def where(condition, if_true, if_false): 89 | """ 90 | Torch equivalent of numpy.where. 91 | 92 | Parameters 93 | ---------- 94 | condition : torch.ByteTensor or torch.cuda.ByteTensor 95 | Condition to check. 96 | if_true : torch.Tensor or torch.cuda.Tensor 97 | Output value if condition is true. 98 | if_false: torch.Tensor or torch.cuda.Tensor 99 | Output value if condition is false 100 | 101 | Returns 102 | ------- 103 | torch.Tensor 104 | 105 | Raises 106 | ------ 107 | AssertionError 108 | if if_true and if_false don't have the same datatype. 109 | """ 110 | # noinspection PyArgumentList 111 | assert if_true.type() == if_false.type(), \ 112 | "Type mismatch: {} and {}".format(if_true.data.type(), if_false.data.type()) 113 | casted_condition = condition.type_as(if_true) 114 | output = casted_condition * if_true + (1 - casted_condition) * if_false 115 | return output 116 | 117 | 118 | def flatten_samples(input_): 119 | """ 120 | Flattens a tensor or a variable such that the channel axis is first and the sample axis 121 | is second. The shapes are transformed as follows: 122 | (N, C, H, W) --> (C, N * H * W) 123 | (N, C, D, H, W) --> (C, N * D * H * W) 124 | (N, C) --> (C, N) 125 | The input must be atleast 2d. 126 | """ 127 | assert_(input_.dim() >= 2, 128 | "Tensor or variable must be atleast 2D. Got one of dim {}." 129 | .format(input_.dim()), 130 | ShapeError) 131 | # Get number of channels 132 | num_channels = input_.size(1) 133 | # Permute the channel axis to first 134 | permute_axes = list(range(input_.dim())) 135 | permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] 136 | # For input shape (say) NCHW, this should have the shape CNHW 137 | permuted = input_.permute(*permute_axes).contiguous() 138 | # Now flatten out all but the first axis and return 139 | flattened = permuted.view(num_channels, -1) 140 | return flattened 141 | 142 | 143 | def clip_gradients_(parameters, mode, norm_or_value): 144 | assert_(mode in ['norm', 'value'], 145 | f"Mode must be 'norm' or 'value', got '{mode}' instead.", 146 | ValueError) 147 | if mode == 'norm': 148 | torch.nn.utils.clip_grad_norm_(parameters, norm_or_value) 149 | elif mode == 'value': 150 | torch.nn.utils.clip_grad_value_(parameters, norm_or_value) 151 | else: 152 | raise NotImplementedError 153 | -------------------------------------------------------------------------------- /inferno/extensions/initializers/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn.init as init 2 | 3 | 4 | __all__ = ['Initializer', 5 | 'Initialization', 6 | 'WeightInitFunction', 7 | 'BiasInitFunction', 8 | 'TensorInitFunction'] 9 | 10 | 11 | class Initializer(object): 12 | """ 13 | Base class for all initializers. 14 | """ 15 | 16 | # TODO Support LSTMs and GRUs 17 | VALID_LAYERS = {'Conv1d', 'Conv2d', 'Conv3d', 18 | 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 19 | 'Linear', 'Bilinear', 20 | 'Embedding'} 21 | 22 | def __call__(self, module): 23 | module_class_name = module.__class__.__name__ 24 | if module_class_name in self.VALID_LAYERS: 25 | # Apply to weight and bias 26 | try: 27 | if hasattr(module, 'weight'): 28 | self.call_on_weight(module.weight.data) 29 | except NotImplementedError: 30 | # Don't cry if it's not implemented 31 | pass 32 | 33 | try: 34 | if hasattr(module, 'bias'): 35 | self.call_on_bias(module.bias.data) 36 | except NotImplementedError: 37 | pass 38 | 39 | return module 40 | 41 | def call_on_bias(self, tensor): 42 | return self.call_on_tensor(tensor) 43 | 44 | def call_on_weight(self, tensor): 45 | return self.call_on_tensor(tensor) 46 | 47 | def call_on_tensor(self, tensor): 48 | raise NotImplementedError 49 | 50 | @classmethod 51 | def initializes_weight(cls): 52 | return 'call_on_tensor' in cls.__dict__ or 'call_on_weight' in cls.__dict__ 53 | 54 | @classmethod 55 | def initializes_bias(cls): 56 | return 'call_on_tensor' in cls.__dict__ or 'call_on_bias' in cls.__dict__ 57 | 58 | 59 | class Initialization(Initializer): 60 | def __init__(self, weight_initializer=None, bias_initializer=None): 61 | if weight_initializer is None: 62 | self.weight_initializer = Initializer() 63 | else: 64 | if isinstance(weight_initializer, Initializer): 65 | assert weight_initializer.initializes_weight() 66 | self.weight_initializer = weight_initializer 67 | elif isinstance(weight_initializer, str): 68 | init_function = getattr(init, weight_initializer, None) 69 | assert init_function is not None 70 | self.weight_initializer = WeightInitFunction(init_function=init_function) 71 | else: 72 | # Provison for weight_initializer to be a function 73 | assert callable(weight_initializer) 74 | self.weight_initializer = WeightInitFunction(init_function=weight_initializer) 75 | 76 | if bias_initializer is None: 77 | self.bias_initializer = Initializer() 78 | else: 79 | if isinstance(bias_initializer, Initializer): 80 | assert bias_initializer.initializes_bias 81 | self.bias_initializer = bias_initializer 82 | elif isinstance(bias_initializer, str): 83 | init_function = getattr(init, bias_initializer, None) 84 | assert init_function is not None 85 | self.bias_initializer = BiasInitFunction(init_function=init_function) 86 | else: 87 | assert callable(bias_initializer) 88 | self.bias_initializer = BiasInitFunction(init_function=bias_initializer) 89 | 90 | def call_on_weight(self, tensor): 91 | return self.weight_initializer.call_on_weight(tensor) 92 | 93 | def call_on_bias(self, tensor): 94 | return self.bias_initializer.call_on_bias(tensor) 95 | 96 | 97 | class WeightInitFunction(Initializer): 98 | def __init__(self, init_function, *init_function_args, **init_function_kwargs): 99 | super(WeightInitFunction, self).__init__() 100 | assert callable(init_function) 101 | self.init_function = init_function 102 | self.init_function_args = init_function_args 103 | self.init_function_kwargs = init_function_kwargs 104 | 105 | def call_on_weight(self, tensor): 106 | return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs) 107 | 108 | 109 | class BiasInitFunction(Initializer): 110 | def __init__(self, init_function, *init_function_args, **init_function_kwargs): 111 | super(BiasInitFunction, self).__init__() 112 | assert callable(init_function) 113 | self.init_function = init_function 114 | self.init_function_args = init_function_args 115 | self.init_function_kwargs = init_function_kwargs 116 | 117 | def call_on_bias(self, tensor): 118 | return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs) 119 | 120 | 121 | class TensorInitFunction(Initializer): 122 | def __init__(self, init_function, *init_function_args, **init_function_kwargs): 123 | super(TensorInitFunction, self).__init__() 124 | assert callable(init_function) 125 | self.init_function = init_function 126 | self.init_function_args = init_function_args 127 | self.init_function_kwargs = init_function_kwargs 128 | 129 | def call_on_tensor(self, tensor): 130 | return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs) 131 | 132 | -------------------------------------------------------------------------------- /tests/test_extensions/test_containers/test_graph.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from functools import reduce 3 | import torch 4 | 5 | 6 | class TestGraph(unittest.TestCase): 7 | def setUp(self): 8 | import torch.nn as nn 9 | from inferno.utils.python_utils import from_iterable 10 | 11 | class DummyNamedModule(nn.Module): 12 | def __init__(self, name, history, num_inputs=1): 13 | super(DummyNamedModule, self).__init__() 14 | self.name = name 15 | self.history = history 16 | self.num_inputs = num_inputs 17 | 18 | def forward(self, *inputs): 19 | assert len(inputs) == self.num_inputs 20 | self.history.append(self.name) 21 | if self.num_inputs > 1: 22 | output = reduce(lambda x, y: x + y, inputs) 23 | else: 24 | output = from_iterable(inputs) 25 | 26 | return output 27 | 28 | self.DummyNamedModule = DummyNamedModule 29 | 30 | # @unittest.skip 31 | def test_graph_dummy_basic(self): 32 | import torch 33 | from inferno.extensions.containers.graph import Graph 34 | 35 | if not hasattr(self, 'DummyNamedModule'): 36 | self.setUp() 37 | 38 | DummyNamedModule = self.DummyNamedModule 39 | 40 | history = [] 41 | # Build graph 42 | model = Graph() 43 | model.add_input_node('input_0') 44 | model.add_input_node('input_1') 45 | model.add_node('conv0_0', DummyNamedModule('conv0_0', history)) 46 | model.add_node('conv0_1', DummyNamedModule('conv0_1', history)) 47 | model.add_node('conv1', DummyNamedModule('conv1', history, 2)) 48 | model.add_node('conv2', DummyNamedModule('conv2', history)) 49 | model.add_output_node('output_0') 50 | model.add_edge('input_0', 'conv0_0')\ 51 | .add_edge('input_1', 'conv0_1')\ 52 | .add_edge('conv0_0', 'conv1')\ 53 | .add_edge('conv0_1', 'conv1')\ 54 | .add_edge('conv1', 'conv2')\ 55 | .add_edge('conv2', 'output_0') 56 | 57 | input_0 = torch.rand(10, 10) 58 | input_1 = torch.rand(10, 10) 59 | model(input_0, input_1) 60 | self.assertTrue(history == ['conv0_0', 'conv0_1', 'conv1', 'conv2'] or 61 | history == ['conv0_1', 'conv0_0', 'conv1', 'conv2']) 62 | 63 | # @unittest.skip 64 | def test_graph_dummy_inception(self): 65 | import torch 66 | from inferno.extensions.containers.graph import Graph 67 | 68 | if not hasattr(self, 'DummyNamedModule'): 69 | self.setUp() 70 | 71 | DummyNamedModule = self.DummyNamedModule 72 | 73 | history = [] 74 | # Build graph 75 | model = Graph() 76 | model.add_input_node('input_0') 77 | model.add_node('conv0', DummyNamedModule('conv0', history), 'input_0') 78 | model.add_node('conv1_0', DummyNamedModule('conv1_0', history), 'conv0') 79 | model.add_node('conv1_1', DummyNamedModule('conv1_1', history), 'conv0') 80 | model.add_node('conv2', DummyNamedModule('conv2', history, 2), 81 | ['conv1_0', 'conv1_1']) 82 | model.add_output_node('output_0', 'conv2') 83 | input_0 = torch.rand(10, 10) 84 | model(input_0) 85 | self.assertTrue(history == ['conv0', 'conv1_0', 'conv1_1', 'conv2'] or 86 | history == ['conv0', 'conv1_1', 'conv1_2', 'conv2']) 87 | 88 | # @unittest.skip 89 | def test_graph_basic(self): 90 | from inferno.extensions.containers.graph import Graph 91 | from inferno.extensions.layers.convolutional import ConvELU2D 92 | from inferno.utils.model_utils import ModelTester 93 | # Build graph 94 | model = Graph() 95 | model.add_input_node('input_0') 96 | model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0') 97 | model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0') 98 | model.add_output_node('output_0', previous='conv1') 99 | ModelTester((1, 1, 100, 100), (1, 1, 100, 100))(model) 100 | 101 | @unittest.skipUnless(torch.cuda.is_available(), "No cuda.") 102 | def test_graph_device_transfers(self): 103 | from inferno.extensions.containers.graph import Graph 104 | from inferno.extensions.layers.convolutional import ConvELU2D 105 | import torch 106 | # Build graph 107 | model = Graph() 108 | model.add_input_node('input_0') 109 | model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0') 110 | model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0') 111 | model.add_output_node('output_0', previous='conv1') 112 | # Transfer 113 | model.to_device('conv0', 'cpu').to_device('conv1', 'cuda', 0) 114 | x = torch.rand(1, 1, 100, 100) 115 | y = model(x) 116 | self.assertIsInstance(y.data, torch.cuda.FloatTensor) 117 | 118 | @unittest.skip("Needs machine with 4 GPUs") 119 | def test_multi_gpu(self): 120 | import torch 121 | import torch.nn as nn 122 | from torch.nn.parallel.data_parallel import data_parallel 123 | from inferno.extensions.containers.graph import Graph 124 | 125 | input_shape = [8, 1, 3, 128, 128] 126 | model = Graph() \ 127 | .add_input_node('input') \ 128 | .add_node('conv0', nn.Conv3d(1, 10, 3, padding=1), previous='input') \ 129 | .add_node('conv1', nn.Conv3d(10, 1, 3, padding=1), previous='conv0') \ 130 | .add_output_node('output', previous='conv1') 131 | 132 | model.cuda() 133 | input = torch.rand(*input_shape).cuda() 134 | data_parallel(model, input, device_ids=[0, 1, 2, 3]) 135 | 136 | 137 | if __name__ == '__main__': 138 | unittest.main() 139 | -------------------------------------------------------------------------------- /tests/test_io/test_box/test_cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, dirname, exists, isdir 3 | import unittest 4 | import numpy as np 5 | import time 6 | 7 | _CITYSCAPES_ROOT = None 8 | 9 | 10 | def _cityscapes_available(): 11 | return _CITYSCAPES_ROOT is not None or os.environ.get('CITYSCAPES_ROOT') is not None 12 | 13 | 14 | class TestCityscapes(unittest.TestCase): 15 | CITYSCAPES_ROOT = _CITYSCAPES_ROOT 16 | PLOT_DIRECTORY = join(dirname(__file__), 'plots') 17 | INCLUDE_COARSE = False 18 | 19 | def get_cityscapes_root(self): 20 | if self.CITYSCAPES_ROOT is None: 21 | root = os.environ.get('CITYSCAPES_ROOT') 22 | assert root is not None, "Cityscapes Root not found." 23 | else: 24 | return self.CITYSCAPES_ROOT 25 | 26 | @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.") 27 | def test_cityscapes_dataset_without_transforms(self): 28 | from inferno.io.box.cityscapes import Cityscapes 29 | cityscapes = Cityscapes(self.get_cityscapes_root()) 30 | image, label = cityscapes[0] 31 | image = np.asarray(image) 32 | label = np.asarray(label) 33 | self.assertSequenceEqual(image.shape, (1024, 2048, 3)) 34 | self.assertSequenceEqual(label.shape, (1024, 2048)) 35 | self.assertLessEqual(label.max(), 33) 36 | 37 | @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.") 38 | def test_cityscapes_dataset_without_transforms_unzipped(self): 39 | from inferno.io.box.cityscapes import Cityscapes 40 | cityscapes = Cityscapes(join(self.get_cityscapes_root(), 'extracted'), 41 | read_from_zip_archive=False) 42 | image, label = cityscapes[0] 43 | image = np.asarray(image) 44 | label = np.asarray(label) 45 | self.assertSequenceEqual(image.shape, (1024, 2048, 3)) 46 | self.assertSequenceEqual(label.shape, (1024, 2048)) 47 | self.assertLessEqual(label.max(), 33) 48 | 49 | @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.") 50 | def test_cityscapes_dataset_with_transforms(self): 51 | from inferno.io.box.cityscapes import get_cityscapes_loaders 52 | from inferno.utils.io_utils import print_tensor 53 | 54 | train_loader, validate_loader = get_cityscapes_loaders(self.get_cityscapes_root(), 55 | include_coarse_dataset=self.INCLUDE_COARSE) 56 | train_dataset = train_loader.dataset 57 | tic = time.time() 58 | image, label = train_dataset[0] 59 | toc = time.time() 60 | print("[+] Loaded sample in {} seconds.".format(toc - tic)) 61 | # Make sure the shapes checkout 62 | self.assertSequenceEqual(image.size(), (3, 1024, 2048)) 63 | self.assertSequenceEqual(label.size(), (1024, 2048)) 64 | self.assertEqual(image.type(), 'torch.FloatTensor') 65 | self.assertEqual(label.type(), 'torch.LongTensor') 66 | # Print tensors to make sure they look legit 67 | if not exists(self.PLOT_DIRECTORY): 68 | os.mkdir(self.PLOT_DIRECTORY) 69 | else: 70 | assert isdir(self.PLOT_DIRECTORY) 71 | print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) 72 | for class_id in np.unique(label.numpy()): 73 | print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'), 74 | prefix='LAB-{}--'.format(class_id), 75 | directory=self.PLOT_DIRECTORY) 76 | print_tensor(label.numpy()[None, None, ...], 77 | prefix='LAB--', 78 | directory=self.PLOT_DIRECTORY) 79 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) 80 | 81 | @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.") 82 | def test_cityscapes_dataset_with_transforms_unzipped(self): 83 | from inferno.io.box.cityscapes import get_cityscapes_loaders 84 | from inferno.utils.io_utils import print_tensor 85 | 86 | train_loader, validate_loader = get_cityscapes_loaders(join(self.get_cityscapes_root(), 87 | 'extracted'), 88 | include_coarse_dataset=self.INCLUDE_COARSE, 89 | read_from_zip_archive=False) 90 | train_dataset = train_loader.dataset 91 | tic = time.time() 92 | image, label = train_dataset[0] 93 | toc = time.time() 94 | print("[+] Loaded sample in {} seconds.".format(toc - tic)) 95 | # Make sure the shapes checkout 96 | self.assertSequenceEqual(image.size(), (3, 1024, 2048)) 97 | self.assertSequenceEqual(label.size(), (1024, 2048)) 98 | self.assertEqual(image.type(), 'torch.FloatTensor') 99 | self.assertEqual(label.type(), 'torch.LongTensor') 100 | # Print tensors to make sure they look legit 101 | if not exists(self.PLOT_DIRECTORY): 102 | os.mkdir(self.PLOT_DIRECTORY) 103 | else: 104 | assert isdir(self.PLOT_DIRECTORY) 105 | print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) 106 | for class_id in np.unique(label.numpy()): 107 | print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'), 108 | prefix='LAB-{}--'.format(class_id), 109 | directory=self.PLOT_DIRECTORY) 110 | print_tensor(label.numpy()[None, None, ...], 111 | prefix='LAB--', 112 | directory=self.PLOT_DIRECTORY) 113 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) 114 | 115 | 116 | if __name__ == '__main__': 117 | unittest.main() 118 | -------------------------------------------------------------------------------- /tests/test_io/test_box/test_camvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, dirname, exists, isdir 3 | import unittest 4 | import numpy as np 5 | 6 | 7 | _CAMVID_ROOT = None 8 | 9 | 10 | def _camvid_available(): 11 | return _CAMVID_ROOT is not None or os.environ.get('CAMVID_ROOT') is not None 12 | 13 | 14 | class TestCamvid(unittest.TestCase): 15 | CAMVID_ROOT = _CAMVID_ROOT 16 | PLOT_DIRECTORY = join(dirname(__file__), 'plots') 17 | 18 | def get_camvid_root(self): 19 | if self.CAMVID_ROOT is None: 20 | root = os.environ.get('CAMVID_ROOT') 21 | assert root is not None, "Camvid Root not found." 22 | else: 23 | return self.CAMVID_ROOT 24 | 25 | @unittest.skipUnless(_camvid_available(), "No root available.") 26 | def test_camvid_dataset_without_transforms(self): 27 | from inferno.io.box.camvid import CamVid 28 | camvid = CamVid(self.get_camvid_root()) 29 | image, label = camvid[0] 30 | image = np.asarray(image) 31 | label = np.asarray(label) 32 | self.assertSequenceEqual(image.shape, (360, 480, 3)) 33 | self.assertSequenceEqual(label.shape, (360, 480)) 34 | self.assertLessEqual(label.max(), 11) 35 | 36 | @unittest.skipUnless(_camvid_available(), "No root available.") 37 | def _test_camvid_dataset_with_transforms(self): 38 | from inferno.io.box.camvid import CamVid 39 | from inferno.io.transform.base import Compose 40 | from inferno.io.transform.image import PILImage2NumPyArray, RandomSizedCrop, Scale 41 | from inferno.utils.io_utils import print_tensor 42 | 43 | camvid = CamVid(self.get_camvid_root(), 44 | image_transform=Compose(), 45 | label_transform=Compose(), 46 | joint_transform=Compose()) 47 | camvid.image_transform.add(PILImage2NumPyArray()) 48 | camvid.label_transform.add(PILImage2NumPyArray()) 49 | image, label = camvid[0] 50 | self.assertSequenceEqual(image.shape, (3, 360, 480)) 51 | self.assertSequenceEqual(label.shape, (360, 480)) 52 | # Add crop trafo 53 | camvid.joint_transform.add(RandomSizedCrop(ratio_between=(0.7, 1.0), 54 | preserve_aspect_ratio=True)) 55 | # We need 2 scale transforms, one with order 3 (image) and the other with order 0 (label) 56 | camvid.joint_transform.add(Scale(output_image_shape=(360, 480), 57 | interpolation_order=3, apply_to=[0])) 58 | camvid.joint_transform.add(Scale(output_image_shape=(360, 480), 59 | interpolation_order=0, apply_to=[1])) 60 | image, label = camvid[0] 61 | self.assertSequenceEqual(image.shape, (3, 360, 480)) 62 | self.assertSequenceEqual(label.shape, (360, 480)) 63 | self.assertLessEqual(len(np.unique(label)), 12) 64 | # Print tensors to make sure they look legit 65 | if not exists(self.PLOT_DIRECTORY): 66 | os.mkdir(self.PLOT_DIRECTORY) 67 | else: 68 | assert isdir(self.PLOT_DIRECTORY) 69 | print_tensor(image[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) 70 | print_tensor(label[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY) 71 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) 72 | 73 | @unittest.skipUnless(_camvid_available(), "No root available.") 74 | def test_camvid_dataset_with_transforms(self): 75 | from inferno.io.box.camvid import get_camvid_loaders 76 | from inferno.utils.io_utils import print_tensor 77 | 78 | train_loader, validate_loader, test_loader = get_camvid_loaders(self.get_camvid_root()) 79 | train_dataset = train_loader.dataset 80 | image, label = train_dataset[0] 81 | # Make sure the shapes checkout 82 | self.assertSequenceEqual(image.size(), (3, 360, 480)) 83 | self.assertSequenceEqual(label.size(), (360, 480)) 84 | self.assertEqual(image.type(), 'torch.FloatTensor') 85 | self.assertEqual(label.type(), 'torch.LongTensor') 86 | # Print tensors to make sure they look legit 87 | if not exists(self.PLOT_DIRECTORY): 88 | os.mkdir(self.PLOT_DIRECTORY) 89 | else: 90 | assert isdir(self.PLOT_DIRECTORY) 91 | print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) 92 | print_tensor(label.numpy()[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY) 93 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) 94 | 95 | @unittest.skipUnless(_camvid_available(), "No root available.") 96 | def test_camvid_dataset_with_transforms_onehot(self): 97 | from inferno.io.box.camvid import get_camvid_loaders 98 | from inferno.utils.io_utils import print_tensor 99 | 100 | train_loader, validate_loader, test_loader = get_camvid_loaders(self.get_camvid_root(), 101 | labels_as_onehot=True) 102 | train_dataset = train_loader.dataset 103 | image, label = train_dataset[0] 104 | # Make sure the shapes checkout 105 | self.assertSequenceEqual(image.size(), (3, 360, 480)) 106 | self.assertSequenceEqual(label.size(), (12, 360, 480)) 107 | self.assertEqual(image.type(), 'torch.FloatTensor') 108 | self.assertEqual(label.type(), 'torch.FloatTensor') 109 | # Print tensors to make sure they look legit 110 | if not exists(self.PLOT_DIRECTORY): 111 | os.mkdir(self.PLOT_DIRECTORY) 112 | else: 113 | assert isdir(self.PLOT_DIRECTORY) 114 | print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) 115 | print_tensor(label.numpy()[None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY) 116 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) 117 | 118 | 119 | if __name__ == '__main__': 120 | unittest.main() 121 | -------------------------------------------------------------------------------- /inferno/io/box/binary_blobs.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import skimage.data 3 | import numpy 4 | from operator import mul 5 | from functools import reduce 6 | 7 | class BinaryBlobs(data.Dataset): 8 | 9 | 10 | def __init__(self, size=20, length=512, blob_size_fraction=0.1, 11 | n_dim=2, volume_fraction=0.5,split='train', 12 | uniform_noise_range=(-1.2, 1.2), 13 | gaussian_noise_sigma=1.2, 14 | noise_scale_factor=8, 15 | image_transform=None, 16 | label_transform=None, 17 | joint_transform=None): 18 | # how many images are in the dataset 19 | self.size = size 20 | 21 | # blob related members 22 | self.length = length 23 | self.blob_size_fraction = blob_size_fraction 24 | self.n_dim = n_dim 25 | self.volume_fraction = volume_fraction 26 | 27 | # which split {'train', 'test', 'validate'} 28 | self.split = split 29 | 30 | # noise related members 31 | self.uniform_noise_range = uniform_noise_range 32 | self.gaussian_noise_sigma = float(gaussian_noise_sigma) 33 | self.noise_scale_factor = noise_scale_factor 34 | 35 | # transforms 36 | self.image_transform = image_transform 37 | self.label_transform = label_transform 38 | self.joint_transform = joint_transform 39 | 40 | # internal 41 | split_to_seed = dict(train=0, test=1, validate=2) 42 | self.master_seed = split_to_seed[self.split]*self.size 43 | 44 | def __getitem__(self, index): 45 | 46 | # generate the labels 47 | label = skimage.data.binary_blobs( 48 | length=self.length, 49 | blob_size_fraction=self.blob_size_fraction, 50 | n_dim=self.n_dim, 51 | volume_fraction=self.volume_fraction, 52 | seed=self.master_seed + index) 53 | 54 | # make the raw image [-1,1] 55 | image = label.astype('float32')*2 56 | image -= 1 57 | 58 | 59 | # add uniform noise 60 | low, high = self.uniform_noise_range 61 | uniform_noise = numpy.random.uniform(low=low, high=high, 62 | size=image.size) 63 | image += uniform_noise.reshape(image.shape) 64 | 65 | # add gaussian noise 66 | gaussian_noise = numpy.random.normal(scale=self.gaussian_noise_sigma, 67 | size=image.size) 68 | image += gaussian_noise.reshape(image.shape) 69 | 70 | 71 | # generate noise at lower scales 72 | small_shape = [s//self.noise_scale_factor for s in label.shape] 73 | small_size = reduce(mul, small_shape, 1) 74 | small_noise_img = numpy.random.uniform(low=low, high=high, 75 | size=small_size) 76 | small_noise_img = small_noise_img.reshape(small_shape) 77 | 78 | gaussian_noise = numpy.random.normal(scale=self.gaussian_noise_sigma, 79 | size=small_size) 80 | small_noise_img += gaussian_noise.reshape(small_shape) 81 | 82 | noise_img = skimage.transform.resize(image = small_noise_img, 83 | output_shape=image.shape, mode='reflect') 84 | 85 | 86 | image += noise_img 87 | 88 | image -= image.mean() 89 | image /= image.std() 90 | 91 | label = label.astype('long') 92 | try: 93 | # Apply transforms 94 | if self.image_transform is not None: 95 | image = self.image_transform(image) 96 | if self.label_transform is not None: 97 | label = self.label_transform(label) 98 | if self.joint_transform is not None: 99 | image, label = self.joint_transform(image, label) 100 | except Exception: 101 | print("[!] An Exception occurred while applying the transforms at " 102 | "index {} of split '{}'.".format(index, self.split)) 103 | raise 104 | 105 | image = image[None,...] 106 | return image, label 107 | 108 | def __len__(self): 109 | return self.size 110 | 111 | 112 | def get_binary_blob_loaders(train_batch_size=1, test_batch_size=1, 113 | num_workers=1, 114 | train_image_transform=None, 115 | train_label_transform=None, 116 | train_joint_transform=None, 117 | validate_image_transform=None, 118 | validate_label_transform=None, 119 | validate_joint_transform=None, 120 | test_image_transform=None, 121 | test_label_transform=None, 122 | test_joint_transform=None, 123 | **kwargs): 124 | 125 | trainset = BinaryBlobs(split='train', image_transform=train_image_transform, 126 | label_transform=train_label_transform, joint_transform=train_joint_transform, **kwargs) 127 | testset = BinaryBlobs(split='test', image_transform=test_image_transform, 128 | label_transform=test_label_transform, joint_transform=test_joint_transform, **kwargs) 129 | validset = BinaryBlobs(split='validate',image_transform=validate_image_transform, 130 | label_transform=validate_label_transform, joint_transform=validate_joint_transform, **kwargs) 131 | 132 | 133 | trainloader = data.DataLoader(trainset, batch_size=train_batch_size, 134 | num_workers=num_workers) 135 | 136 | testloader = data.DataLoader(testset, batch_size=test_batch_size, 137 | num_workers=num_workers) 138 | 139 | validloader = data.DataLoader(validset, batch_size=test_batch_size, 140 | num_workers=num_workers) 141 | 142 | return trainloader, testloader, validloader 143 | 144 | if __name__ == "__main__": 145 | ds = BinaryBlobs() 146 | ds[0] --------------------------------------------------------------------------------