├── inferno
├── utils
│ ├── __init__.py
│ ├── math_utils.py
│ ├── exceptions.py
│ ├── test_utils.py
│ ├── model_utils.py
│ ├── io_utils.py
│ ├── partial_cls.py
│ └── torch_utils.py
├── version.py
├── inferno.py
├── extensions
│ ├── containers
│ │ ├── __init__.py
│ │ └── sequential.py
│ ├── metrics
│ │ ├── __init__.py
│ │ ├── cremi_score.py
│ │ └── base.py
│ ├── initializers
│ │ ├── __init__.py
│ │ ├── presets.py
│ │ └── base.py
│ ├── models
│ │ └── __init__.py
│ ├── optimizers
│ │ ├── __init__.py
│ │ ├── ranger.py
│ │ ├── annealed_adam.py
│ │ └── adam.py
│ ├── criteria
│ │ ├── __init__.py
│ │ ├── elementwise_measures.py
│ │ ├── core.py
│ │ └── regularized.py
│ ├── layers
│ │ ├── identity.py
│ │ ├── activations.py
│ │ ├── normalization.py
│ │ ├── __init__.py
│ │ ├── convolutional_blocks.py
│ │ ├── sampling.py
│ │ └── device.py
│ └── __init__.py
├── io
│ ├── __init__.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── data_utils.py
│ │ ├── base.py
│ │ └── concatenate.py
│ ├── transform
│ │ └── __init__.py
│ ├── volumetric
│ │ └── __init__.py
│ └── box
│ │ ├── __init__.py
│ │ └── binary_blobs.py
├── trainers
│ ├── __init__.py
│ └── callbacks
│ │ ├── __init__.py
│ │ ├── logging
│ │ ├── __init__.py
│ │ └── base.py
│ │ ├── tqdmstub.py
│ │ ├── gradients.py
│ │ ├── console.py
│ │ └── tqdm.py
└── __init__.py
├── docs
├── authors.rst
├── history.rst
├── readme.rst
├── contributing.rst
├── .gitignore
├── inferno-apidoc
│ ├── modules.rst
│ ├── inferno.io.rst
│ ├── inferno.extensions.rst
│ ├── inferno.trainers.rst
│ ├── inferno.rst
│ ├── inferno.extensions.containers.rst
│ ├── inferno.extensions.initializers.rst
│ ├── inferno.extensions.optimizers.rst
│ ├── inferno.trainers.callbacks.logging.rst
│ ├── inferno.io.volumetric.rst
│ ├── inferno.io.box.rst
│ ├── inferno.io.core.rst
│ ├── inferno.io.transform.rst
│ ├── inferno.extensions.criteria.rst
│ ├── inferno.extensions.metrics.rst
│ ├── inferno.trainers.callbacks.rst
│ ├── inferno.utils.rst
│ └── inferno.extensions.layers.rst
├── graphics
│ ├── tentative_logo.pdf
│ ├── plain_tentative_logo.svg
│ └── tentative_logo.svg
├── _templates
│ ├── layout.html
│ └── template_module.rst
├── zbibliography.rst
├── examples.rst
├── refs.bib
├── environment.yml
├── index.rst
└── installation.rst
├── requirements.txt
├── tests
├── __init__.py
├── test_io
│ ├── __init__.py
│ ├── test_box
│ │ ├── __init__.py
│ │ ├── test_cityscapes.py
│ │ └── test_camvid.py
│ ├── test_core
│ │ ├── __init__.py
│ │ ├── test_concatenate.py
│ │ └── test_zip.py
│ └── test_volumetric
│ │ ├── __init__.py
│ │ ├── test_volume_loader.py
│ │ └── test_lazy_volume_loader.py
├── test_utils
│ ├── __init__.py
│ ├── test_model_utils.py
│ ├── test_train_utils.py
│ └── test_partial_cls.py
├── test_extensions
│ ├── __init__.py
│ ├── test_models
│ │ ├── __init__.py
│ │ ├── test_unet.py
│ │ └── test_res_unet.py
│ ├── test_layers
│ │ ├── test_activations.py
│ │ ├── test_convolutional.py
│ │ ├── test_device.py
│ │ ├── test_reshape.py
│ │ └── deprecated
│ │ │ └── building_blocks.py
│ ├── test_criteria
│ │ ├── test_core.py
│ │ ├── test_elementwise_measures.py
│ │ └── test_set_similarity_measures.py
│ ├── test_metrics
│ │ └── categorical.py
│ └── test_containers
│ │ └── test_graph.py
└── test_training
│ ├── __init__.py
│ └── test_callbacks
│ ├── __init__.py
│ ├── test_logging
│ ├── __init__.py
│ ├── test_base.py
│ └── test_tensorboard.py
│ ├── test_scheduling.py
│ ├── test_base.py
│ └── test_essentials.py
├── examples
├── README.txt
├── trainer.py
└── regularized_mnist.py
├── readthedocs.yml
├── add2path.sh
├── MANIFEST.in
├── conda-recipe
├── build.sh
└── meta.yaml
├── requirements_dev.txt
├── .editorconfig
├── .github
└── ISSUE_TEMPLATE.md
├── LICENSE
├── HISTORY.rst
├── .gitignore
├── AUTHORS.rst
├── .travis.yml
├── setup.py
├── Makefile
└── CONTRIBUTING.rst
/inferno/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/authors.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../AUTHORS.rst
2 |
--------------------------------------------------------------------------------
/docs/history.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../HISTORY.rst
2 |
--------------------------------------------------------------------------------
/docs/readme.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../README.rst
2 |
--------------------------------------------------------------------------------
/inferno/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.4.0'
2 |
--------------------------------------------------------------------------------
/docs/contributing.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../CONTRIBUTING.rst
2 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | /inferno.rst
2 | /inferno.*.rst
3 | /modules.rst
4 |
--------------------------------------------------------------------------------
/inferno/inferno.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Main module."""
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dill
2 | pyyaml
3 | scipy>=0.13.0
4 | h5py
5 | numpy>=1.8
6 | scikit-image
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/inferno/extensions/containers/__init__.py:
--------------------------------------------------------------------------------
1 | from .graph import *
2 | from .sequential import *
3 |
--------------------------------------------------------------------------------
/inferno/extensions/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .categorical import *
2 | from .arand import *
3 |
--------------------------------------------------------------------------------
/inferno/extensions/initializers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from .presets import *
3 |
4 |
--------------------------------------------------------------------------------
/tests/test_io/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/tests/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/examples/README.txt:
--------------------------------------------------------------------------------
1 |
2 | .. _examples-index:
3 |
4 | Gallery of Examples
5 | ===================
6 |
7 |
--------------------------------------------------------------------------------
/tests/test_extensions/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/tests/test_io/test_box/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/tests/test_training/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/inferno/extensions/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet import UNet, UNetBase
2 | from .res_unet import ResBlockUNet
3 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | conda:
2 | file: docs/environment.yml
3 | python:
4 | version: 3.5
5 | pip_install: false
--------------------------------------------------------------------------------
/tests/test_io/test_core/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/modules.rst:
--------------------------------------------------------------------------------
1 | inferno
2 | =======
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | inferno
8 |
--------------------------------------------------------------------------------
/tests/test_io/test_volumetric/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/add2path.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Run this script from within the directory.
3 | export PYTHONPATH=${PYTHONPATH}:${PWD}
--------------------------------------------------------------------------------
/docs/graphics/tentative_logo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inferno-pytorch/inferno/HEAD/docs/graphics/tentative_logo.pdf
--------------------------------------------------------------------------------
/inferno/io/__init__.py:
--------------------------------------------------------------------------------
1 | from . import box
2 | from . import core
3 | from . import transform
4 | from . import volumetric
--------------------------------------------------------------------------------
/tests/test_extensions/test_models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/tests/test_training/test_callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/tests/test_training/test_callbacks/test_logging/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Unit test package for inferno."""
4 |
--------------------------------------------------------------------------------
/inferno/io/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import SyncableDataset
2 | from .zip import Zip, ZipReject
3 | from .concatenate import Concatenate
4 |
--------------------------------------------------------------------------------
/inferno/io/transform/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Transform, Compose
2 | from . import generic
3 | from . import image
4 | from . import volume
5 |
--------------------------------------------------------------------------------
/inferno/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from . import basic
2 | from . import callbacks
3 | from . basic import Trainer
4 | __all__ = ['basic','callbacks','Trainer']
--------------------------------------------------------------------------------
/inferno/extensions/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .adam import Adam
2 | from .annealed_adam import AnnealedAdam
3 | from .ranger import Ranger, RangerQH, RangerVA
4 |
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {# layout.html #}
2 | {# Import the theme's layout. #}
3 | {% extends "!layout.html" %}
4 |
5 | {% set css_files = css_files + ['_static/pygments.css'] %}
--------------------------------------------------------------------------------
/docs/zbibliography.rst:
--------------------------------------------------------------------------------
1 | .. _inferno_bibliography:
2 |
3 | Bibliography
4 | ============================
5 |
6 | The bibliography:
7 |
8 | .. bibliography:: refs.bib
9 | :style: alpha
--------------------------------------------------------------------------------
/inferno/io/volumetric/__init__.py:
--------------------------------------------------------------------------------
1 | from .volume import VolumeLoader, HDF5VolumeLoader, TIFVolumeLoader
from .lazy_volume_loader import LazyHDF5VolumeLoader, LazyZarrVolumeLoader, LazyN5VolumeLoader
--------------------------------------------------------------------------------
/docs/examples.rst:
--------------------------------------------------------------------------------
1 | .. _inferno_examples_gallery:
2 |
3 | Inferno Examples Gallery
4 | ============================
5 |
6 |
7 | .. toctree::
8 | :maxdepth: 5
9 |
10 | ../auto_examples/index
11 |
12 |
--------------------------------------------------------------------------------
/inferno/extensions/criteria/__init__.py:
--------------------------------------------------------------------------------
1 | from .set_similarity_measures import *
2 | from .elementwise_measures import *
3 | from .core import *
4 | from .regularized import *
5 |
6 | __all__ = ['set_similarity_measures', 'elementwise_measures','core','regularized']
--------------------------------------------------------------------------------
/inferno/extensions/layers/identity.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | __all__ = ['identity']
3 | _all = __all__
4 |
5 | class Identity(nn.Module):
6 | def __init__(self):
7 | super(Identity, self).__init__()
8 |
9 | def forward(self, x):
10 | return x
--------------------------------------------------------------------------------
/inferno/io/core/data_utils.py:
--------------------------------------------------------------------------------
1 |
2 | def implements_sync_primitives(dataset):
3 | return hasattr(dataset, 'sync_with') and callable(getattr(dataset, 'sync_with'))
4 |
5 |
6 | def defines_base_sequence(dataset):
7 | return hasattr(dataset, 'base_sequence') and dataset.base_sequence is not None
8 |
--------------------------------------------------------------------------------
/docs/refs.bib:
--------------------------------------------------------------------------------
1 |
2 | @inproceedings{alush_2013_simbad,
3 | title={Break and Conquer: Efficient Correlation Clustering for Image Segmentation},
4 | author={Alush, Amir and Goldberger, Jacob},
5 | booktitle={2nd International Workshop on Similarity-Based Pattern Analysis and Recognition},
6 | year={2013}
7 | }
8 |
--------------------------------------------------------------------------------
/inferno/extensions/optimizers/ranger.py:
--------------------------------------------------------------------------------
1 | # easy support for additional ranger optimizers from
2 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
3 | try:
4 | from ranger import Ranger, RangerVA, RangerQH
5 | except ImportError:
6 | Ranger = None
7 | RangerVA = None
8 | RangerQH = None
9 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include AUTHORS.rst
2 | include CONTRIBUTING.rst
3 | include HISTORY.rst
4 | include LICENSE
5 | include README.rst
6 |
7 | recursive-include tests *
8 | recursive-exclude * __pycache__
9 | recursive-exclude * *.py[co]
10 |
11 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
12 |
--------------------------------------------------------------------------------
/conda-recipe/build.sh:
--------------------------------------------------------------------------------
1 | PY_VER=$(python -c "import sys; print('{}.{}'.format(*sys.version_info[:2]))")
2 |
3 | # Install python modules
4 | mkdir -p ${PREFIX}/inferno
5 | cp -r inferno/* ${PREFIX}/inferno
6 | echo "${PREFIX}" > ${PREFIX}/lib/python${PY_VER}/site-packages/inferno.pth
7 | python -m compileall ${PREFIX}/inferno
8 |
--------------------------------------------------------------------------------
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | pip==8.1.2
2 | bumpversion==0.5.3
3 | wheel==0.29.0
4 | watchdog==0.8.3
5 | flake8==2.6.0
6 | tox==2.3.1
7 | coverage==4.1
8 | Sphinx==1.4.8
9 | cryptography==1.7
10 | PyYAML==5.1
11 | dill
12 | pyyaml
13 | scipy>=0.13.0
14 | h5py
15 | scikit-image
16 | sphinx-gallery
17 | sphinxcontrib-napoleon
18 | sphinxcontrib-inlinesyntaxhighlight
19 | sphinx_rtd_theme
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/inferno/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Top-level package for inferno."""
4 |
5 | from . import extensions
6 | from . import io
7 | from . import trainers
8 | from . import utils
9 | from .version import __version__
10 |
11 |
12 | __all__ = ['extensions', 'io', 'trainers', 'utils']
13 |
14 | __author__ = """Nasim Rahaman"""
15 | __email__ = 'nasim.rahaman@iwr.uni-heidelberg.de'
16 |
--------------------------------------------------------------------------------
/inferno/extensions/__init__.py:
--------------------------------------------------------------------------------
1 | from . import containers
2 | from . import criteria
3 | from . import initializers
4 | from . import layers
5 | from . import metrics
6 | from . import optimizers
7 | from . import models
8 | # Backward support
9 | from . import models as model
10 |
11 | __all__ = ['containers', 'criteria', 'initializers', 'layers', 'metrics', 'optimizers',
12 | 'models', 'model']
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | * inferno version:
2 | * Python version:
3 | * Operating System:
4 |
5 | ### Description
6 |
7 | Describe what you were trying to get done.
8 | Tell us what happened, what went wrong, and what you expected to happen.
9 |
10 | ### What I Did
11 |
12 | ```
13 | Paste the command(s) you ran and the output.
14 | If there was a crash, please include the traceback here.
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.io.rst:
--------------------------------------------------------------------------------
1 | inferno.io package
2 | ==================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | inferno.io.box
10 | inferno.io.core
11 | inferno.io.transform
12 | inferno.io.volumetric
13 |
14 | Module contents
15 | ---------------
16 |
17 | .. automodule:: inferno.io
18 | :members:
19 | :undoc-members:
20 | :show-inheritance:
21 |
--------------------------------------------------------------------------------
/inferno/io/box/__init__.py:
--------------------------------------------------------------------------------
1 | """Things that work out of the box. ;)"""
2 |
3 | from .camvid import CamVid, get_camvid_loaders
4 | from .cityscapes import Cityscapes, get_cityscapes_loaders
5 | from .cifar import get_cifar10_loaders, get_cifar100_loaders
6 |
7 |
8 | __all__ = [
9 | 'CamVid','get_camvid_loaders', 'Cityscapes', 'get_cityscapes_loaders',
10 | 'get_cifar10_loaders','get_cifar100_loaders'
11 | ]
--------------------------------------------------------------------------------
/tests/test_extensions/test_layers/test_activations.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import inferno.extensions.layers.activations as activations
4 |
5 |
6 | class ActivationTest(unittest.TestCase):
7 | def test_selu(self):
8 | x = torch.rand(100)
9 | y = activations.SELU()(x)
10 | self.assertEqual(list(x.size()), list(y.size()))
11 |
12 |
13 | if __name__ == '__main__':
14 | unittest.main()
15 |
--------------------------------------------------------------------------------
/docs/environment.yml:
--------------------------------------------------------------------------------
1 | name: inferno_docs
2 |
3 | channels:
4 | - soumith
5 | - anaconda
6 |
7 | dependencies:
8 | - python==3.5
9 | - pytorch>=0.1.12
10 | - torchvision
11 | - scikit-image
12 | - pip:
13 | - scipy>=0.13.0
14 | - h5py
15 | - scikit-image
16 | - pyyaml
17 | - dill
18 | - sphinx-gallery
19 | - sphinxcontrib-napoleon
20 | - sphinxcontrib-bibtex
21 | - sphinxcontrib-inlinesyntaxhighlight
22 |
--------------------------------------------------------------------------------
/inferno/extensions/metrics/cremi_score.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .voi import voi
3 | from .arand import adapted_rand
4 |
5 |
6 | # TODO build metrics object
7 |
8 |
9 | def cremi_metrics(seg, gt, no_seg_ignore=True):
10 | if no_seg_ignore:
11 | if 0 in seg:
12 | seg += 1
13 | vi_s, vi_m = voi(seg, gt)
14 | rand = 1. - adapted_rand(seg, gt)[0]
15 | cs = np.sqrt((vi_s + vi_m) * rand)
16 | return cs, vi_s, vi_m, rand
17 |
--------------------------------------------------------------------------------
/inferno/trainers/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['CallbackEngine', 'Callback', 'Console', 'essentials', 'scheduling', 'gradients']
2 |
3 | from .base import CallbackEngine, Callback
4 | from .console import Console
5 | from . import essentials
6 | from . import scheduling
7 | from . import gradients
8 |
9 | try:
10 | from .tqdm import TQDMProgressBar
11 | __all__.append('TQDMProgressBar')
12 | except ImportError:
13 | from .tqdmstub import TQDMProgressBar
14 |
--------------------------------------------------------------------------------
/inferno/trainers/callbacks/logging/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['get_logger']
2 | try:
3 | INFERNO_WITH_TENSORBOARD_LOGGER = True
4 | from .tensorboard import TensorboardLogger
5 | __all__.append('TensorboardLogger')
6 | except ImportError:
7 | INFERNO_WITH_TENSORBOARD_LOGGER = False
8 |
9 |
10 | def get_logger(name):
11 | if name in globals():
12 | return globals().get(name)
13 | else:
14 | raise NotImplementedError("Logger not found.")
15 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to inferno's documentation!
2 | ======================================
3 |
4 | Contents:
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 |
9 | readme
10 | installation
11 | usage
12 | examples
13 | contributing
14 | inferno-apidoc/modules
15 | authors
16 | history
17 | zbibliography
18 |
19 | .. automodule:: inferno
20 |
21 | Indices and tables
22 | ==================
23 |
24 | * :ref:`genindex`
25 | * :ref:`modindex`
26 | * :ref:`search`
27 |
--------------------------------------------------------------------------------
/inferno/trainers/callbacks/tqdmstub.py:
--------------------------------------------------------------------------------
1 | from .base import Callback
2 |
3 | class TQDMProgressBar(Callback):
4 | def __init__(self, *args, **kwargs):
5 | super(TQDMProgressBar, self).__init__(*args, **kwargs)
6 |
7 | def bind_trainer(self, *args, **kwargs):
8 | super(TQDMProgressBar, self).bind_trainer(*args, **kwargs)
9 | self.trainer.console.warning("tqdm is not installed. will fall back to normal stdout console.")
10 |
11 | def begin_of_fit(self, **_):
12 | pass
13 |
--------------------------------------------------------------------------------
/inferno/extensions/layers/activations.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | from ...utils.torch_utils import where
4 |
5 | __all__ = ['SELU']
6 | _all = __all__
7 |
8 | class SELU(nn.Module):
9 | def forward(self, input):
10 | return self.selu(input)
11 |
12 | @staticmethod
13 | def selu(x):
14 | alpha = 1.6732632423543772848170429916717
15 | scale = 1.0507009873554804934193349852946
16 | # noinspection PyTypeChecker
17 | return scale * where(x >= 0, x, alpha * F.elu(x))
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.extensions.rst:
--------------------------------------------------------------------------------
1 | inferno.extensions package
2 | ==========================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | inferno.extensions.containers
10 | inferno.extensions.criteria
11 | inferno.extensions.initializers
12 | inferno.extensions.layers
13 | inferno.extensions.metrics
14 | inferno.extensions.optimizers
15 |
16 | Module contents
17 | ---------------
18 |
19 | .. automodule:: inferno.extensions
20 | :members:
21 | :undoc-members:
22 | :show-inheritance:
23 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.trainers.rst:
--------------------------------------------------------------------------------
1 | inferno.trainers package
2 | ========================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | inferno.trainers.callbacks
10 |
11 | Submodules
12 | ----------
13 |
14 | inferno.trainers.basic module
15 | -----------------------------
16 |
17 | .. automodule:: inferno.trainers.basic
18 | :members:
19 | :undoc-members:
20 | :show-inheritance:
21 |
22 |
23 | Module contents
24 | ---------------
25 |
26 | .. automodule:: inferno.trainers
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/inferno/extensions/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BatchNormND(nn.Module):
5 | def __init__(self, dim, num_features,
6 | eps=1e-5, momentum=0.1,
7 | affine=True,track_running_stats=True):
8 | super(BatchNormND, self).__init__()
9 | assert dim in [1, 2, 3]
10 | self.bn = getattr(nn, 'BatchNorm{}d'.format(dim))(num_features=num_features,
11 | eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
12 |
13 | def forward(self, x):
14 | return self.bn(x)
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.rst:
--------------------------------------------------------------------------------
1 | inferno package
2 | ===============
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | inferno.extensions
10 | inferno.io
11 | inferno.trainers
12 | inferno.utils
13 |
14 | Submodules
15 | ----------
16 |
17 | inferno.inferno module
18 | ----------------------
19 |
20 | .. automodule:: inferno.inferno
21 | :members:
22 | :undoc-members:
23 | :show-inheritance:
24 |
25 |
26 | Module contents
27 | ---------------
28 |
29 | .. automodule:: inferno
30 | :members:
31 | :undoc-members:
32 | :show-inheritance:
33 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_criteria/test_core.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class TestCore(unittest.TestCase):
7 | def test_as_2d_criterion(self):
8 | from inferno.extensions.criteria.core import As2DCriterion
9 |
10 | prediction = torch.FloatTensor(2, 10, 100, 100).uniform_()
11 | prediction = nn.Softmax2d()(prediction)
12 | target = torch.LongTensor(2, 100, 100).fill_(0)
13 | criterion = As2DCriterion(nn.CrossEntropyLoss())
14 | criterion(prediction, target)
15 |
16 |
17 | if __name__ == '__main__':
18 | unittest.main()
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache Software License 2.0
3 |
4 | Copyright (c) 2017, Inferno Developers
5 |
6 | Licensed under the Apache License, Version 2.0 (the "License");
7 | you may not use this file except in compliance with the License.
8 | You may obtain a copy of the License at
9 |
10 | http://www.apache.org/licenses/LICENSE-2.0
11 |
12 | Unless required by applicable law or agreed to in writing, software
13 | distributed under the License is distributed on an "AS IS" BASIS,
14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | See the License for the specific language governing permissions and
16 | limitations under the License.
17 |
18 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_layers/test_convolutional.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | from inferno.utils.model_utils import ModelTester
4 |
5 |
6 | class TestConvolutional(unittest.TestCase):
7 | @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.")
8 | def test_bn_relu_depthwise_conv2d_pyinn(self):
9 | from inferno.extensions.layers.convolutional import BNReLUDepthwiseConv2D
10 | model = BNReLUDepthwiseConv2D(10, 'auto', 3)
11 | ModelTester((1, 10, 100, 100),
12 | (1, 10, 100, 100)).cuda()(model)
13 | self.assertTrue(model.depthwise)
14 | self.assertEqual(model.conv.groups, 10)
15 |
16 |
17 | if __name__ == '__main__':
18 | unittest.main()
19 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_criteria/test_elementwise_measures.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import inferno.extensions.criteria.elementwise_measures as em
3 | import torch
4 |
5 |
6 | class TestElementwiseMeasures(unittest.TestCase):
7 | def test_weighted_mse_loss(self):
8 | input = torch.zeros(10, 10)
9 | target = torch.ones(10, 10)
10 | loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target)
11 | self.assertAlmostEqual(loss.item(), 2., delta=1e-5)
12 | target = torch.zeros(10, 10)
13 | input = torch.ones(10, 10)
14 | loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target)
15 | self.assertAlmostEqual(loss.item(), 1., delta=1e-5)
16 |
17 |
18 | if __name__ == '__main__':
19 | unittest.main()
20 |
--------------------------------------------------------------------------------
/inferno/extensions/containers/sequential.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from ...utils import python_utils as pyu
3 |
4 |
5 | __all__ = ['Sequential1', 'Sequential2']
6 |
7 |
8 | class Sequential1(nn.Sequential):
9 | """Like torch.nn.Sequential, but with a few extra methods."""
10 | def __len__(self):
11 | return len(self._modules.values())
12 |
13 |
14 | class Sequential2(Sequential1):
15 | """Another sequential container.
16 | Identitcal to torch.nn.Sequential, except that modules may return multiple outputs and
17 | accept multiple inputs.
18 | """
19 | def forward(self, *input):
20 | for module in self._modules.values():
21 | input = pyu.to_iterable(module(*pyu.to_iterable(input)))
22 | return pyu.from_iterable(input)
23 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.extensions.containers.rst:
--------------------------------------------------------------------------------
1 | inferno.extensions.containers package
2 | =====================================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.extensions.containers.graph module
8 | ------------------------------------------
9 |
10 | .. automodule:: inferno.extensions.containers.graph
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.extensions.containers.sequential module
16 | -----------------------------------------------
17 |
18 | .. automodule:: inferno.extensions.containers.sequential
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 |
24 | Module contents
25 | ---------------
26 |
27 | .. automodule:: inferno.extensions.containers
28 | :members:
29 | :undoc-members:
30 | :show-inheritance:
31 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.extensions.initializers.rst:
--------------------------------------------------------------------------------
1 | inferno.extensions.initializers package
2 | =======================================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.extensions.initializers.base module
8 | -------------------------------------------
9 |
10 | .. automodule:: inferno.extensions.initializers.base
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.extensions.initializers.presets module
16 | ----------------------------------------------
17 |
18 | .. automodule:: inferno.extensions.initializers.presets
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 |
24 | Module contents
25 | ---------------
26 |
27 | .. automodule:: inferno.extensions.initializers
28 | :members:
29 | :undoc-members:
30 | :show-inheritance:
31 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.extensions.optimizers.rst:
--------------------------------------------------------------------------------
1 | inferno.extensions.optimizers package
2 | =====================================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.extensions.optimizers.adam module
8 | -----------------------------------------
9 |
10 | .. automodule:: inferno.extensions.optimizers.adam
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.extensions.optimizers.annealed\_adam module
16 | ---------------------------------------------------
17 |
18 | .. automodule:: inferno.extensions.optimizers.annealed_adam
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 |
24 | Module contents
25 | ---------------
26 |
27 | .. automodule:: inferno.extensions.optimizers
28 | :members:
29 | :undoc-members:
30 | :show-inheritance:
31 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.trainers.callbacks.logging.rst:
--------------------------------------------------------------------------------
1 | inferno.trainers.callbacks.logging package
2 | ==========================================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.trainers.callbacks.logging.base module
8 | ----------------------------------------------
9 |
10 | .. automodule:: inferno.trainers.callbacks.logging.base
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.trainers.callbacks.logging.tensorboard module
16 | -----------------------------------------------------
17 |
18 | .. automodule:: inferno.trainers.callbacks.logging.tensorboard
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 |
24 | Module contents
25 | ---------------
26 |
27 | .. automodule:: inferno.trainers.callbacks.logging
28 | :members:
29 | :undoc-members:
30 | :show-inheritance:
31 |
--------------------------------------------------------------------------------
/HISTORY.rst:
--------------------------------------------------------------------------------
1 | =======
2 | History
3 | =======
4 |
5 | 0.1.0 (2017-08-24)
6 | ------------------
7 |
8 | * First early release on PyPI
9 |
10 | 0.1.1 (2017-08-24)
11 | ------------------
12 |
13 | * Version Increment
14 |
15 | 0.1.2 (2017-08-24)
16 | ------------------
17 |
18 | * Version Increment
19 |
20 |
21 | 0.1.3 (2017-08-24)
22 | ------------------
23 |
24 | * Updated Documentation
25 |
26 | 0.1.4 (2017-08-24)
27 | ------------------
28 |
29 | * travis auto-deployment on pypi
30 |
31 |
32 | 0.1.5 (2017-08-24)
33 | ------------------
34 |
35 | * travis changes to run unittest
36 |
37 |
38 | 0.1.6 (2017-08-24)
39 | ------------------
40 |
41 | * travis missing packages for unittesting
42 | * fixed inconsistent version numbers
43 |
44 | 0.1.7 (2017-08-25)
45 | ------------------
46 |
47 | * setup.py critical bugix in install procedure
48 |
49 |
50 |
51 | CURRENT CHANGES
52 | -----------------
53 | * Flexible Unet
54 |
--------------------------------------------------------------------------------
/tests/test_utils/test_model_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import inferno.utils.model_utils as mu
3 | from inferno.utils.exceptions import ShapeError
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class ModelUtilTester(unittest.TestCase):
9 | def test_model_tester(self):
10 | model = mu.ModelTester((1, 10, 32, 32), (1, 20, 32, 32))(nn.Conv2d(10, 20, 3, padding=1))
11 | with self.assertRaises(ShapeError):
12 | mu.ModelTester((1, 10, 32, 32), (1, 30, 32, 32))(model)
13 |
14 | @unittest.skipUnless(torch.cuda.is_available(), "need cuda")
15 | def test_model_tester_cuda(self):
16 | tester = mu.ModelTester((1, 10, 32, 32), (1, 20, 32, 32)).cuda()
17 | model = tester(nn.Conv2d(10, 20, 3, padding=1).cuda())
18 | with self.assertRaises(ShapeError):
19 | mu.ModelTester((1, 10, 32, 32), (1, 30, 32, 32)).cuda()(model)
20 |
21 | if __name__ == '__main__':
22 | unittest.main()
23 |
--------------------------------------------------------------------------------
/inferno/extensions/layers/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = []
2 | from .activations import *
3 | from .convolutional import *
4 | from .device import *
5 | from .reshape import *
6 | from .convolutional_blocks import *
7 |
8 | #######################################################
9 | # the following is to make the sphinx example
10 | # gallery makes proper cross-references
11 | from .activations import _all as _activations_all
12 | from .convolutional import _all as _convolutional_all
13 | from .device import _all as _device_all
14 | from .reshape import _all as _reshape_all
15 | from .convolutional_blocks import _all as _convolutional_blocks_all
16 | from .identity import _all as _identity_all
17 |
18 | __all__.extend(_activations_all)
19 | __all__.extend(_convolutional_all)
20 | __all__.extend(_device_all)
21 | __all__.extend(_reshape_all)
22 | __all__.extend(_convolutional_blocks_all)
23 | __all__.extend(_identity_all)
24 |
25 | _all = __all__
26 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.io.volumetric.rst:
--------------------------------------------------------------------------------
1 | inferno.io.volumetric package
2 | =============================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.io.volumetric.lazy\_volume\_loader module
8 | -------------------------------------------------
9 |
10 | .. automodule:: inferno.io.volumetric.lazy_volume_loader
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.io.volumetric.volume module
16 | -----------------------------------
17 |
18 | .. automodule:: inferno.io.volumetric.volume
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | inferno.io.volumetric.volumetric\_utils module
24 | ----------------------------------------------
25 |
26 | .. automodule:: inferno.io.volumetric.volumetric_utils
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 |
32 | Module contents
33 | ---------------
34 |
35 | .. automodule:: inferno.io.volumetric
36 | :members:
37 | :undoc-members:
38 | :show-inheritance:
39 |
--------------------------------------------------------------------------------
/inferno/utils/math_utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def max_allowed_ds_steps(shape, factor):
4 | """How often can a shape be down-sampled by a given factor
5 | such that non of the divisions will give non-integers.
6 |
7 | Args:
8 | shape (listlike): tensor shape
9 | factor (integer): downsample factor
10 |
11 | Returns:
12 | int: maximum allowed downsample operations
13 | """
14 | def max_allowed_ds_steps_impl(size, factor):
15 |
16 | current_size = float(size)
17 | allowed_steps = 0
18 | while(True):
19 |
20 | new_size = current_size / float(factor)
21 | if(new_size >=1 and new_size.is_integer()):
22 |
23 | current_size = new_size
24 | allowed_steps += 1
25 | else:
26 | break
27 | return allowed_steps
28 |
29 | min_steps = float('inf')
30 |
31 | for s in shape:
32 | min_steps = int(min(min_steps, max_allowed_ds_steps_impl(s, factor)))
33 |
34 | return min_steps
35 |
--------------------------------------------------------------------------------
/inferno/utils/exceptions.py:
--------------------------------------------------------------------------------
1 | """Exceptions and Error Handling"""
2 |
3 |
4 | def assert_(condition, message='', exception_type=AssertionError):
5 | """Like assert, but with arbitrary exception types."""
6 | if not condition:
7 | raise exception_type(message)
8 |
9 |
10 | # ------ VALUE ERRORS ------
11 |
12 |
13 | class ShapeError(ValueError):
14 | pass
15 |
16 |
17 | class FrequencyValueError(ValueError):
18 | pass
19 |
20 |
21 | class DeviceError(ValueError):
22 | pass
23 |
24 |
25 | class NotSetError(ValueError):
26 | pass
27 |
28 |
29 | # ------ TYPE ERRORS ------
30 |
31 |
32 | class NotTorchModuleError(TypeError):
33 | pass
34 |
35 |
36 | class FrequencyTypeError(TypeError):
37 | pass
38 |
39 |
40 | class DTypeError(TypeError):
41 | pass
42 |
43 |
44 | # ------ LOOKUP ERRORS ------
45 |
46 |
47 | class ClassNotFoundError(LookupError):
48 | pass
49 |
50 |
51 | # ------ NOT-IMPLEMENTED ERRORS ------
52 |
53 |
54 | class NotUnwrappableError(NotImplementedError):
55 | pass
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 |
55 | # Sphinx documentation
56 | docs/_build/
57 |
58 | # PyBuilder
59 | target/
60 |
61 | # pyenv python configuration file
62 | .python-version
63 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.io.box.rst:
--------------------------------------------------------------------------------
1 | inferno.io.box package
2 | ======================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.io.box.binary\_blobs module
8 | -----------------------------------
9 |
10 | .. automodule:: inferno.io.box.binary_blobs
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.io.box.camvid module
16 | ----------------------------
17 |
18 | .. automodule:: inferno.io.box.camvid
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | inferno.io.box.cifar module
24 | ---------------------------
25 |
26 | .. automodule:: inferno.io.box.cifar
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | inferno.io.box.cityscapes module
32 | --------------------------------
33 |
34 | .. automodule:: inferno.io.box.cityscapes
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 |
40 | Module contents
41 | ---------------
42 |
43 | .. automodule:: inferno.io.box
44 | :members:
45 | :undoc-members:
46 | :show-inheritance:
47 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.io.core.rst:
--------------------------------------------------------------------------------
1 | inferno.io.core package
2 | =======================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.io.core.base module
8 | ---------------------------
9 |
10 | .. automodule:: inferno.io.core.base
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.io.core.concatenate module
16 | ----------------------------------
17 |
18 | .. automodule:: inferno.io.core.concatenate
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | inferno.io.core.data\_utils module
24 | ----------------------------------
25 |
26 | .. automodule:: inferno.io.core.data_utils
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | inferno.io.core.zip module
32 | --------------------------
33 |
34 | .. automodule:: inferno.io.core.zip
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 |
40 | Module contents
41 | ---------------
42 |
43 | .. automodule:: inferno.io.core
44 | :members:
45 | :undoc-members:
46 | :show-inheritance:
47 |
--------------------------------------------------------------------------------
/inferno/extensions/metrics/base.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class Metric(object):
4 |
5 | def forward(self, *args, **kwargs):
6 | raise NotImplementedError
7 |
8 | def __call__(self, prediction, target, **kwargs):
9 | # We might have listlike predictions (e.g. multi-scale)
10 | # If so, we evaluate the metric on the first prediction,
11 | # which should be at the original scale
12 | if isinstance(prediction, (list, tuple)):
13 | prediction = prediction[0]
14 | # same is true for the target
15 | if isinstance(target, (list, tuple)):
16 | target = target[0]
17 | # Make sure prediction and target live on the same device.
18 | # If they don't, move target to the right device.
19 | if not prediction.is_cuda:
20 | # Move to CPU
21 | target = target.cpu()
22 | else:
23 | # Find device to move to
24 | device_ordinal = prediction.get_device()
25 | target = target.cuda(device_ordinal)
26 | return self.forward(prediction, target, **kwargs)
27 |
--------------------------------------------------------------------------------
/conda-recipe/meta.yaml:
--------------------------------------------------------------------------------
1 | package:
2 | name: inferno
3 |
4 | {% set tagged_version = GIT_DESCRIBE_TAG|replace("v","")|replace("-", ".") %}
5 |
6 | # If we're using a non-tagged revision, append '.postN' to the version
7 | {% if GIT_DESCRIBE_NUMBER|int != 0 %}
8 | {% set tagged_version = tagged_version + '.post' + GIT_DESCRIBE_NUMBER %}
9 | {% endif %}
10 |
11 | version: {{tagged_version}}
12 |
13 | source:
14 | path: ..
15 |
16 | build:
17 | number: 1
18 | string: py_{{PKG_BUILDNUM}}_g{{GIT_FULL_HASH[:7]}}
19 |
20 | requirements:
21 | build:
22 | - python {{PY_VER}}*
23 | run:
24 | - python {{PY_VER}}*
25 | - pytorch
26 | - torchvision
27 | - pyyaml
28 | - scipy
29 | - scikit-image
30 | - scikit-learn
31 | - h5py
32 | - dill
33 | - networkx 1.11
34 | - tensorboardx
35 | - sphinx_rtd_theme
36 |
37 |
38 | test:
39 | imports:
40 | - inferno
41 |
42 | about:
43 | license: Apache License 2.0
44 | summary: A utility library around PyTorch
45 |
--------------------------------------------------------------------------------
/tests/test_io/test_core/test_concatenate.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 |
4 | class ConcatenateTest(unittest.TestCase):
5 | def test_concatenate(self):
6 | from inferno.io.core import Concatenate
7 | from torch.utils.data.dataset import Dataset
8 |
9 | with self.assertRaises(AssertionError):
10 | cated = Concatenate([1, 2, 3], [4, 5, 6, 7])
11 |
12 | class ListDataset(list, Dataset):
13 | pass
14 |
15 | dataset_1 = ListDataset([1, 2, 3, 4])
16 | dataset_2 = ListDataset([5, 6, 7])
17 | dataset_3 = ListDataset([8, 9, 10, 11, 12])
18 |
19 | cated = Concatenate(dataset_1, dataset_2, dataset_3)
20 | self.assertEqual(len(cated), 12)
21 |
22 | # Try to fetch
23 | self.assertEqual(cated[2], 3)
24 | self.assertEqual(cated[4], 5)
25 | self.assertEqual(cated[6], 7)
26 | self.assertEqual(cated[10], 11)
27 | self.assertEqual(cated[11], 12)
28 |
29 | with self.assertRaises(AssertionError):
30 | _ = cated[12]
31 |
32 | if __name__ == '__main__':
33 | unittest.main()
34 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.io.transform.rst:
--------------------------------------------------------------------------------
1 | inferno.io.transform package
2 | ============================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.io.transform.base module
8 | --------------------------------
9 |
10 | .. automodule:: inferno.io.transform.base
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.io.transform.generic module
16 | -----------------------------------
17 |
18 | .. automodule:: inferno.io.transform.generic
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | inferno.io.transform.image module
24 | ---------------------------------
25 |
26 | .. automodule:: inferno.io.transform.image
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | inferno.io.transform.volume module
32 | ----------------------------------
33 |
34 | .. automodule:: inferno.io.transform.volume
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 |
40 | Module contents
41 | ---------------
42 |
43 | .. automodule:: inferno.io.transform
44 | :members:
45 | :undoc-members:
46 | :show-inheritance:
47 |
--------------------------------------------------------------------------------
/tests/test_training/test_callbacks/test_logging/test_base.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from inferno.trainers.callbacks.logging.base import Logger
3 | from inferno.trainers.basic import Trainer
4 | from os.path import join, dirname
5 |
6 |
7 | class DummyLogger(Logger):
8 | def end_of_training_iteration(self, **_):
9 | pass
10 |
11 |
12 | class TestLogger(unittest.TestCase):
13 | ROOT = dirname(__file__)
14 |
15 | def test_serialization(self):
16 | trainer = Trainer()\
17 | .build_logger(logger=DummyLogger())\
18 | .save_to_directory(join(self.ROOT, 'saves'))
19 | trainer.save()
20 | # Unserialize
21 | trainer = Trainer().load(from_directory=join(self.ROOT, 'saves'))
22 | # Check if the loggers are consistent
23 | logger_from_trainer = trainer._logger
24 | logger_from_callback_engine = \
25 | next(iter(trainer.callbacks._callback_registry['end_of_training_iteration']))
26 | self.assertIs(logger_from_trainer, logger_from_callback_engine)
27 | self.assertIs(logger_from_callback_engine.trainer, trainer)
28 |
29 |
30 | if __name__ == '__main__':
31 | unittest.main()
--------------------------------------------------------------------------------
/docs/_templates/template_module.rst:
--------------------------------------------------------------------------------
1 | {{ fullname }}
2 | {{ underline }}
3 |
4 | .. automodule:: {{ fullname }}
5 |
6 | {% block functions %}
7 | {% if functions %}
8 |
9 | Functions
10 | ==================
11 |
12 | {% for item in functions %}
13 |
14 | .. autofunction:: {{ item }}
15 |
16 | .. include:: backreferences/{{fullname}}.{{item}}.examples
17 |
18 | .. raw:: html
19 |
20 |
21 |
22 | {%- endfor %}
23 | {% endif %}
24 | {% endblock %}
25 |
26 | {% block classes %}
27 | {% if classes %}
28 |
29 | Classes
30 | -------
31 |
32 | {% for item in classes %}
33 | .. autoclass:: {{ item }}
34 | :members:
35 |
36 | .. include:: backreferences/{{fullname}}.{{item}}.examples
37 |
38 | .. raw:: html
39 |
40 |
41 |
42 |
43 | {%- endfor %}
44 | {% endif %}
45 | {% endblock %}
46 |
47 | {% block exceptions %}
48 | {% if exceptions %}
49 |
50 | Exceptions
51 | ----------
52 |
53 | .. autosummary::
54 | {% for item in exceptions %}
55 | {{ item }}
56 | {%- endfor %}
57 | {% endif %}
58 | {% endblock %}
--------------------------------------------------------------------------------
/inferno/io/core/base.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataset import Dataset
2 |
3 |
4 | class SyncableDataset(Dataset):
5 | def __init__(self, base_sequence=None):
6 | self.base_sequence = base_sequence
7 |
8 | def sync_with(self, dataset):
9 | if hasattr(dataset, 'base_sequence'):
10 | self.base_sequence = dataset.base_sequence
11 | return self
12 |
13 | def __len__(self):
14 | if self.base_sequence is None:
15 | raise RuntimeError("Class {} does not specify a base sequence. Either specify "
16 | "one by assigning to self.base_sequence or override the "
17 | "__len__ method.".format(self.__class__.__name__))
18 | else:
19 | return len(self.base_sequence)
20 |
21 |
22 | class IndexSpec(object):
23 | """
24 | Class to wrap any extra index information a `Dataset` object might want to send back.
25 | This could be useful in (say) inference, where we would wish to (asynchronously) know
26 | more about the current input.
27 | """
28 | def __init__(self, index=None, base_sequence_at_index=None):
29 | self.index = index
30 | self.base_sequence_at_index = base_sequence_at_index
31 |
32 | def __int__(self):
33 | return int(self.index)
34 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.extensions.criteria.rst:
--------------------------------------------------------------------------------
1 | inferno.extensions.criteria package
2 | ===================================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.extensions.criteria.core module
8 | ---------------------------------------
9 |
10 | .. automodule:: inferno.extensions.criteria.core
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.extensions.criteria.elementwise\_measures module
16 | --------------------------------------------------------
17 |
18 | .. automodule:: inferno.extensions.criteria.elementwise_measures
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | inferno.extensions.criteria.regularized module
24 | ----------------------------------------------
25 |
26 | .. automodule:: inferno.extensions.criteria.regularized
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | inferno.extensions.criteria.set\_similarity\_measures module
32 | ------------------------------------------------------------
33 |
34 | .. automodule:: inferno.extensions.criteria.set_similarity_measures
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 |
40 | Module contents
41 | ---------------
42 |
43 | .. automodule:: inferno.extensions.criteria
44 | :members:
45 | :undoc-members:
46 | :show-inheritance:
47 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_layers/test_device.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from inferno.extensions.layers.device import DeviceTransfer, OnDevice
3 | import torch
4 |
5 |
6 | class TransferTest(unittest.TestCase):
7 | @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.")
8 | def test_device_transfer(self):
9 | if not torch.cuda.is_available():
10 | return
11 | # Build transfer model
12 | transfer = DeviceTransfer('cpu')
13 | x = torch.rand(10, 10).cuda()
14 | y = transfer(x)
15 | loss = y.mean()
16 | loss.backward()
17 | self.assertFalse(y.data.is_cuda)
18 | self.assertIsNotNone(x.grad)
19 | self.assertTrue(x.grad.data.is_cuda)
20 |
21 | @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.")
22 | def test_on_device(self):
23 | if not torch.cuda.is_available():
24 | return
25 | # Build variable on the GPU
26 | x = torch.rand(1, 10)
27 | # Build model over multiple devices
28 | multi_device_model = torch.nn.Sequential(OnDevice(torch.nn.Linear(10, 10), 'cuda'),
29 | OnDevice(torch.nn.Linear(10, 10), 'cpu'))
30 | y = multi_device_model(x)
31 | self.assertIsInstance(y.data, torch.FloatTensor)
32 |
33 |
34 | if __name__ == '__main__':
35 | unittest.main()
36 |
--------------------------------------------------------------------------------
/inferno/trainers/callbacks/logging/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from ..base import Callback
3 |
4 |
5 | class Logger(Callback):
6 | """
7 | A special callback for logging.
8 |
9 | Loggers are special because they're required to be serializable, whereas other
10 | callbacks have no such guarantees. In this regard, they jointly handled by
11 | trainers and the callback engine.
12 | """
13 | def __init__(self, log_directory=None):
14 | super(Logger, self).__init__()
15 | self._log_directory = None
16 | if log_directory is not None:
17 | self.set_log_directory(log_directory)
18 |
19 | @property
20 | def log_directory(self):
21 | if self._log_directory is not None:
22 | return self._log_directory
23 | elif self.trainer is not None and self.trainer._log_directory is not None:
24 | return self.trainer._log_directory
25 | else:
26 | raise RuntimeError("No log directory found.")
27 |
28 | @log_directory.setter
29 | def log_directory(self, value):
30 | self.set_log_directory(value)
31 |
32 | def set_log_directory(self, log_directory):
33 | assert isinstance(log_directory, str)
34 | if not os.path.isdir(log_directory):
35 | assert not os.path.exists(log_directory)
36 | os.makedirs(log_directory)
37 | self._log_directory = log_directory
38 | return self
39 |
--------------------------------------------------------------------------------
/tests/test_training/test_callbacks/test_scheduling.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from inferno.trainers.callbacks.scheduling import ManualLR
3 | from torch import nn
4 | from torch.optim import Adam
5 |
6 |
7 | class TestSchedulers(unittest.TestCase):
8 |
9 | def test_manual_lr(self):
10 | class DummyTrainer(object):
11 | def __init__(self):
12 | self.iteration_count = 0
13 | self.epoch_count = 0
14 | self.optimizer = Adam(nn.Linear(10, 10).parameters(), lr=1.)
15 |
16 | manual_lr = ManualLR([((100, 'iterations'), 0.5),
17 | ((200, 'iterations'), 0.5),
18 | ((200, 'iterations'), 0.1)])
19 | trainer = DummyTrainer()
20 | manual_lr._trainer = trainer
21 |
22 | manual_lr.end_of_training_iteration()
23 | self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 1.)
24 | trainer.iteration_count = 100
25 | manual_lr.end_of_training_iteration()
26 | self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.5)
27 | trainer.iteration_count = 200
28 | manual_lr.end_of_training_iteration()
29 | self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.025)
30 | trainer.iteration_count = 300
31 | self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.025)
32 |
33 | if __name__ == '__main__':
34 | unittest.main()
35 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.extensions.metrics.rst:
--------------------------------------------------------------------------------
1 | inferno.extensions.metrics package
2 | ==================================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.extensions.metrics.arand module
8 | ---------------------------------------
9 |
10 | .. automodule:: inferno.extensions.metrics.arand
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.extensions.metrics.base module
16 | --------------------------------------
17 |
18 | .. automodule:: inferno.extensions.metrics.base
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | inferno.extensions.metrics.categorical module
24 | ---------------------------------------------
25 |
26 | .. automodule:: inferno.extensions.metrics.categorical
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | inferno.extensions.metrics.cremi\_score module
32 | ----------------------------------------------
33 |
34 | .. automodule:: inferno.extensions.metrics.cremi_score
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | inferno.extensions.metrics.voi module
40 | -------------------------------------
41 |
42 | .. automodule:: inferno.extensions.metrics.voi
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 |
48 | Module contents
49 | ---------------
50 |
51 | .. automodule:: inferno.extensions.metrics
52 | :members:
53 | :undoc-members:
54 | :show-inheritance:
55 |
--------------------------------------------------------------------------------
/inferno/extensions/criteria/elementwise_measures.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from ...utils.exceptions import assert_
3 |
4 |
5 | class WeightedMSELoss(nn.Module):
6 | NEGATIVE_CLASS_WEIGHT = 1.
7 |
8 | def __init__(self, positive_class_weight=1., positive_class_value=1., size_average=True):
9 | super(WeightedMSELoss, self).__init__()
10 | assert_(positive_class_weight >= 0,
11 | "Positive class weight can't be less than zero, got {}."
12 | .format(positive_class_weight),
13 | ValueError)
14 | self.mse = nn.MSELoss(size_average=size_average)
15 | self.positive_class_weight = positive_class_weight
16 | self.positive_class_value = positive_class_value
17 |
18 | def forward(self, input, target):
19 | # Get a mask
20 | positive_class_mask = target.data.eq(self.positive_class_value).type_as(target.data)
21 | # Get differential weights (positive_weight - negative_weight,
22 | # i.e. subtract 1, assuming the negative weight is gauged at 1)
23 | weight_differential = (positive_class_mask
24 | .mul_(self.positive_class_weight - self.NEGATIVE_CLASS_WEIGHT))
25 | # Get final weight by adding weight differential to a tensor with negative weights
26 | weights = weight_differential.add_(self.NEGATIVE_CLASS_WEIGHT)
27 | # `weights` should be positive if NEGATIVE_CLASS_WEIGHT is not messed with.
28 | sqrt_weights = weights.sqrt_()
29 | return self.mse(input * sqrt_weights, target * sqrt_weights)
30 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.trainers.callbacks.rst:
--------------------------------------------------------------------------------
1 | inferno.trainers.callbacks package
2 | ==================================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | inferno.trainers.callbacks.logging
10 |
11 | Submodules
12 | ----------
13 |
14 | inferno.trainers.callbacks.base module
15 | --------------------------------------
16 |
17 | .. automodule:: inferno.trainers.callbacks.base
18 | :members:
19 | :undoc-members:
20 | :show-inheritance:
21 |
22 | inferno.trainers.callbacks.console module
23 | -----------------------------------------
24 |
25 | .. automodule:: inferno.trainers.callbacks.console
26 | :members:
27 | :undoc-members:
28 | :show-inheritance:
29 |
30 | inferno.trainers.callbacks.essentials module
31 | --------------------------------------------
32 |
33 | .. automodule:: inferno.trainers.callbacks.essentials
34 | :members:
35 | :undoc-members:
36 | :show-inheritance:
37 |
38 | inferno.trainers.callbacks.scheduling module
39 | --------------------------------------------
40 |
41 | .. automodule:: inferno.trainers.callbacks.scheduling
42 | :members:
43 | :undoc-members:
44 | :show-inheritance:
45 |
46 | inferno.trainers.callbacks.tqdm module
47 | --------------------------------------
48 |
49 | .. automodule:: inferno.trainers.callbacks.tqdm
50 | :members:
51 | :undoc-members:
52 | :show-inheritance:
53 |
54 | inferno.trainers.callbacks.tqdmstub module
55 | ------------------------------------------
56 |
57 | .. automodule:: inferno.trainers.callbacks.tqdmstub
58 | :members:
59 | :undoc-members:
60 | :show-inheritance:
61 |
62 |
63 | Module contents
64 | ---------------
65 |
66 | .. automodule:: inferno.trainers.callbacks
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.utils.rst:
--------------------------------------------------------------------------------
1 | inferno.utils package
2 | =====================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.utils.exceptions module
8 | -------------------------------
9 |
10 | .. automodule:: inferno.utils.exceptions
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.utils.io\_utils module
16 | ------------------------------
17 |
18 | .. automodule:: inferno.utils.io_utils
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | inferno.utils.math\_utils module
24 | --------------------------------
25 |
26 | .. automodule:: inferno.utils.math_utils
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | inferno.utils.model\_utils module
32 | ---------------------------------
33 |
34 | .. automodule:: inferno.utils.model_utils
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | inferno.utils.python\_utils module
40 | ----------------------------------
41 |
42 | .. automodule:: inferno.utils.python_utils
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | inferno.utils.test\_utils module
48 | --------------------------------
49 |
50 | .. automodule:: inferno.utils.test_utils
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | inferno.utils.torch\_utils module
56 | ---------------------------------
57 |
58 | .. automodule:: inferno.utils.torch_utils
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | inferno.utils.train\_utils module
64 | ---------------------------------
65 |
66 | .. automodule:: inferno.utils.train_utils
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 |
72 | Module contents
73 | ---------------
74 |
75 | .. automodule:: inferno.utils
76 | :members:
77 | :undoc-members:
78 | :show-inheritance:
79 |
--------------------------------------------------------------------------------
/docs/graphics/plain_tentative_logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests/test_io/test_volumetric/test_volume_loader.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | from shutil import rmtree
4 |
5 | import numpy as np
6 | import h5py
7 |
8 |
9 | class TestVolumeLoader(unittest.TestCase):
10 | shape = (100, 100, 100)
11 | def setUp(self):
12 | self.data = np.random.rand(*self.shape)
13 |
14 | def test_loader(self):
15 | from inferno.io.volumetric import VolumeLoader
16 | loader = VolumeLoader(self.data,
17 | window_size=(10, 10, 10),
18 | stride=(10, 10, 10), return_index_spec=True)
19 | for batch, idx in loader:
20 | slice_ = loader.base_sequence[int(idx)]
21 | expected = self.data[slice_]
22 | self.assertEqual(batch.shape, expected.shape)
23 | self.assertTrue(np.allclose(batch, expected))
24 |
25 |
26 | class TestHDF5VolumeLoader(unittest.TestCase):
27 | shape = (100, 100, 100)
28 | def setUp(self):
29 | try:
30 | os.mkdir('./tmp')
31 | except OSError:
32 | pass
33 | self.data = np.random.rand(*self.shape)
34 | with h5py.File('./tmp/data.h5') as f:
35 | f.create_dataset('data', data=self.data)
36 |
37 | def tearDown(self):
38 | try:
39 | rmtree('./tmp')
40 | except OSError:
41 | pass
42 |
43 | def test_hdf5_loader(self):
44 | from inferno.io.volumetric import HDF5VolumeLoader
45 | loader = HDF5VolumeLoader('./tmp/data.h5', 'data',
46 | window_size=(10, 10, 10),
47 | stride=(10, 10, 10), return_index_spec=True)
48 | for batch, idx in loader:
49 | slice_ = loader.base_sequence[int(idx)]
50 | expected = self.data[slice_]
51 | self.assertEqual(batch.shape, expected.shape)
52 | self.assertTrue(np.allclose(batch, expected))
53 |
54 |
55 |
56 | if __name__ == '__main__':
57 | unittest.main()
58 |
--------------------------------------------------------------------------------
/inferno/extensions/optimizers/annealed_adam.py:
--------------------------------------------------------------------------------
1 | from .adam import Adam
2 |
3 |
4 | class AnnealedAdam(Adam):
5 | """Implements Adam algorithm with learning rate annealing and optional L1 penalty.
6 |
7 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
8 |
9 | Arguments:
10 | params (iterable): iterable of parameters to optimize or dicts defining
11 | parameter groups
12 | lr (float, optional): learning rate (default: 1e-3)
13 | betas (Tuple[float, float], optional): coefficients used for computing
14 | running averages of gradient and its square (default: (0.9, 0.999))
15 | eps (float, optional): term added to the denominator to improve
16 | numerical stability (default: 1e-8)
17 | lambda_l1 (float, optional): L1 penalty (default: 0)
18 | weight_decay (float, optional): L2 penalty (weight decay) (default: 0)
19 | lr_decay(float, optional): decay learning rate by this factor after every step
20 | (default: 1.)
21 |
22 | .. _Adam\: A Method for Stochastic Optimization:
23 | https://arxiv.org/abs/1412.6980
24 | """
25 |
26 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
27 | lambda_l1=0, weight_decay=0, lr_decay=1.):
28 | defaults = dict(lr=lr, betas=betas, eps=eps,
29 | lambda_l1=lambda_l1, weight_decay=weight_decay,
30 | lr_decay=lr_decay)
31 | super(AnnealedAdam, self).__init__(params, **defaults)
32 |
33 | def step(self, closure=None):
34 | """Performs a single optimization step.
35 |
36 | Arguments:
37 | closure (callable, optional): A closure that reevaluates the model
38 | and returns the loss.
39 | """
40 | # Do an optimization step
41 | super(AnnealedAdam, self).step(closure=closure)
42 | # Update learning rate
43 | for group in self.param_groups:
44 | group['lr'] *= group['lr_decay']
45 |
--------------------------------------------------------------------------------
/AUTHORS.rst:
--------------------------------------------------------------------------------
1 | =======
2 | Credits
3 | =======
4 |
5 | Development Lead
6 | ----------------
7 |
8 | * `Nasim Rahaman `_ @ `Image Analysis and Learning Lab `_ , `Heidelberg Collaboratory for Image Processing `_ ,
9 |
10 |
11 | Contributors
12 | ------------
13 |
14 | In no particular order,
15 | * `Steffen Wolf `_ @
16 | `Image Analysis and Learning Lab `_ ,
17 | `Heidelberg Collaboratory for Image Processing `_ ,
18 | * `Maurice Weiler `_ @
19 | `Amsterdam Machine Learning Lab `_ ,
20 | `University of Amsterdam `_ ,
21 | * `Constantin Pape `_ @
22 | `Image Analysis and Learning Lab `_ ,
23 | `Heidelberg Collaboratory for Image Processing `_ ,
24 | * `Sven Peter `_ @
25 | `Image Analysis and Learning Lab `_ ,
26 | `Heidelberg Collaboratory for Image Processing `_ ,
27 | * `Manuel Haussmann `_ @
28 | `Image Analysis and Learning Lab `_ ,
29 | `Heidelberg Collaboratory for Image Processing `_ ,
30 | * `Thorsten Beier `_ @
31 | `Image Analysis and Learning Lab `_ ,
32 | `Heidelberg Collaboratory for Image Processing `_ ,
33 | * `Benjamin Striner `_ @
34 | `Machine Learning Department `_ ,
35 | `Carnegie Mellon University `_ ,
36 |
--------------------------------------------------------------------------------
/inferno/utils/test_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data.dataset import TensorDataset
3 | from torch.utils.data.dataloader import DataLoader
4 | import numpy as np
5 |
6 |
7 | def generate_random_data(num_samples, shape, num_classes, hardness=0.3, dtype=None):
8 | """Generate a random dataset with a given hardness and number of classes."""
9 | dataset_input = np.zeros((num_samples,) + shape, dtype=dtype)
10 | dataset_target = np.random.randint(num_classes, size=num_samples)
11 | for sample_num in range(num_samples):
12 | dataset_input[sample_num] = np.random.normal(loc=dataset_target[sample_num],
13 | scale=(1 - hardness),
14 | size=shape)
15 | return dataset_input, dataset_target
16 |
17 |
18 | def generate_random_dataset(num_samples, shape, num_classes, hardness=0.3, dtype=None):
19 | """Generate a random dataset with a given hardness and number of classes."""
20 | # Generate numpy arrays
21 | dataset_input, dataset_target = generate_random_data(num_samples, shape, num_classes,
22 | hardness=hardness, dtype=dtype)
23 | # Convert to tensor and build dataset
24 | dataset = TensorDataset(torch.from_numpy(dataset_input),
25 | torch.from_numpy(dataset_target))
26 | return dataset
27 |
28 |
29 | def generate_random_dataloader(num_samples, shape, num_classes, hardness=0.3, dtype=None,
30 | batch_size=1, shuffle=False, num_workers=0, pin_memory=False):
31 | """Generate a loader with a random dataset of given hardness and number of classes."""
32 | dataset = generate_random_dataset(num_samples, shape, num_classes, hardness=hardness,
33 | dtype=dtype)
34 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
35 | num_workers=num_workers, pin_memory=pin_memory)
36 | return dataloader
37 |
--------------------------------------------------------------------------------
/inferno/trainers/callbacks/gradients.py:
--------------------------------------------------------------------------------
1 | from ...utils.train_utils import Frequency
2 | from ...utils.exceptions import assert_, FrequencyValueError
3 | from .base import Callback
4 |
5 |
6 | class LogOutputGradients(Callback):
7 | """Logs the gradient of the network output"""
8 |
9 | def __init__(self, frequency):
10 | super(LogOutputGradients, self).__init__()
11 | self.log_every = frequency
12 | self.registered = False
13 | self.hook_handle = None
14 |
15 | @property
16 | def log_every(self):
17 | return self._log_every
18 |
19 | @log_every.setter
20 | def log_every(self, value):
21 | self._log_every = Frequency(value, 'iterations')
22 | assert_(self.log_every.is_consistent,
23 | "Log frequency is not consistent.",
24 | FrequencyValueError)
25 |
26 | def hook(self, module, grad_input, grad_output):
27 |
28 | #remove hook if trainer does not exits
29 | if self.trainer is None:
30 | self.hook_handle.remove()
31 | return
32 |
33 | if self.log_every.match(iteration_count=self.trainer.iteration_count,
34 | epoch_count=self.trainer.epoch_count,
35 | persistent=True, match_zero=True):
36 | self.trainer.update_state('output_gradient', grad_output[0].detach().float().clone().cpu())
37 |
38 | def add_hook(self):
39 | self.hook_handle = self.trainer.model.register_backward_hook(self.hook)
40 |
41 | def begin_of_fit(self, **kwargs):
42 | self._trainer.logger.observe_state("output_gradient",
43 | observe_while='training')
44 | self.add_hook()
45 |
46 | def begin_of_save(self, **_):
47 | # remove hook from model, because you can't pickle it.
48 | if self.hook_handle is not None:
49 | self.hook_handle.remove()
50 | self.hook_handle = None
51 |
52 | def end_of_save(self, **_):
53 | # add hook after model save
54 | self.add_hook()
55 |
56 |
--------------------------------------------------------------------------------
/tests/test_utils/test_train_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import inferno.utils.train_utils as tu
3 | import numpy as np
4 |
5 |
6 | class FrequencyTest(unittest.TestCase):
7 | def test_from_string(self):
8 | frequency = tu.Frequency.from_string('10 epochs')
9 | self.assertFalse(frequency.match(epoch_count=9))
10 | self.assertTrue(frequency.match(epoch_count=10))
11 | frequency = tu.Frequency.from_string('1 iteration')
12 | self.assertEqual(frequency.units, 'iterations')
13 | self.assertTrue(frequency.match(iteration_count=10))
14 | frequency = tu.Frequency.from_string('never')
15 | self.assertFalse(frequency.match(epoch_count=9))
16 | frequency = tu.Frequency.from_string('inf epochs')
17 | self.assertFalse(frequency.match(epoch_count=9))
18 |
19 | def test_from_tuple(self):
20 | frequency = tu.Frequency.build_from((np.inf, 'epoch'))
21 | self.assertFalse(frequency.match(epoch_count=9))
22 | self.assertFalse(frequency.match(epoch_count=10))
23 |
24 | def test_is_consistent(self):
25 | frequency = tu.Frequency.build_from('10 epochs')
26 | frequency._units = 'banana'
27 | self.assertFalse(frequency.is_consistent)
28 |
29 | def test_init(self):
30 | frequency = tu.Frequency()
31 | self.assertEqual(frequency.value, np.inf)
32 | self.assertEqual(frequency.units, frequency.UNIT_PRIORITY)
33 |
34 | def test_duration(self):
35 | duration = tu.Duration.build_from((3, 'iterations'))
36 | self.assertFalse(duration.match(iteration_count=2))
37 | self.assertFalse(duration.match(iteration_count=3))
38 | self.assertTrue(duration.match(iteration_count=3, when_equal_return=True))
39 | self.assertTrue(duration.match(iteration_count=4))
40 | self.assertEqual(duration.compare(iteration_count=1, epoch_count=3).get('iterations'),
41 | 2)
42 | with self.assertRaises(ValueError):
43 | duration.match(epoch_count=2)
44 |
45 |
46 | if __name__ == '__main__':
47 | unittest.main()
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | dist: xenial
4 |
5 | python:
6 | - 3.7
7 |
8 | env:
9 | - PYTORCH_CONDA="pytorch" TORCHVISION_CONDA="torchvision" TORCHVISION_CHANNEL=pytorch
10 |
11 | install:
12 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
13 | - bash miniconda.sh -b -p $HOME/miniconda
14 | - export PATH="$HOME/miniconda/bin:$PATH"
15 | - conda config --set always_yes yes --set changeps1 no
16 | - conda update -q conda
17 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION
18 | - source activate test-environment
19 | - conda install -c conda-forge networkx h5py scikit-image pyyaml dill tensorboardx
20 | - conda install -c pytorch $PYTORCH_CONDA
21 | - conda install -c $TORCHVISION_CHANNEL $TORCHVISION_CONDA
22 |
23 | deploy:
24 | provider: pypi
25 | user: nasimrahaman
26 | password:
27 | secure: !!binary |
28 | bWwzZitLcEpibHBaUWNhUVA4UUlGa2JsZDQxVkx3eFlkY1FiYWJqYkFvWm5pdDErRzlKRXZFM0hR
29 | ZE15V0tIWm5JQlJRSGlveXdYNjAzQVc1UFV3ZjNBOG0zc21vK3RaZjVSYnM5aE5ySE93ajBXc1N4
30 | akNHNGhOSnF6UnBDY2kwakxPeWhxaEwxQkR0empSaFdJbWVlOE81RDVPY2pSdGw1TDQ3QjhwVGor
31 | TVREdlpSYTVFd2xNNXdadTJYWFVXL3ZQY0VLZE9xckFoVk5PSHpkTTh5MGM1S1lHaS9nNThVK2JO
32 | OVp5RkFROVpuOEY3YmxPdzBQZnAvL202ZUkxamlKSmxhaE13UU4zV2tJRWRpNklVSTE0RUp1ck5s
33 | Q28xL2kzNER0dGVkZzI0eVhULzcxRFl5Y0pZQWMrcWtoa1VVVUo4NEZKV3JjUjNqTnF5bVI3Ykty
34 | cFJrR3JydjV0dUpGUnBhc2NIdEdKVUswMkdJWEJUc3JJWGg4bS9oRGtMaVJaMExBeitJQWR4b2tF
35 | MzB0OWppZ0x5VXFSMmxnVmNvZERzRWZMRnJEMTBHeTJVS2FueVhlYmpsck9qK3V5S1dtZm5UTXg4
36 | bGNzN09HWEZiUmo2K0ZuYTg5a00xN3poSXhzc3pSMnRGSVJwamV4a0gzZUpyZlpYY1daTFZ3QnV0
37 | clUwZW10VEsxeGFmOGFjNTd3Wll1R3JXNEZJT1h2bmxoeS9pV0FMVlE4YnVFZFFjQnJ5YWFiRjUy
38 | RkZvZk1SUnp3aDFhZ3Q3cUxVa0FIbXVuZ1NYQWZxMUlOTkVNYXRTcFVJUURJM3huWmNPeTNhSWFP
39 | YkVpSlFHY1lrWlhXZ1Z2cVdvcktPOW53a29Hem5BSm1HRVZHYU11dDYwaGg2SGU1MVJPTll3WHc9
40 | on:
41 | all_branches: false
42 | tags: true
43 |
44 | script:
45 | - source activate test-environment
46 | - python setup.py install
47 | - python -m unittest discover -s tests -v
48 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_criteria/test_set_similarity_measures.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 |
4 |
5 | class SetSimilarityTest(unittest.TestCase):
6 | def get_dummy_variables(self):
7 | x = torch.zeros(3, 2, 100, 100).uniform_()
8 | y = torch.zeros(3, 2, 100, 100).uniform_()
9 | return x, y
10 |
11 | def get_dummy_variables_with_channels_and_classes(self):
12 | # (batch_size, channels, classes, ...)
13 | x = torch.zeros(3, 2, 5, 100, 100).uniform_()
14 | y = torch.zeros(3, 2, 5, 100, 100).uniform_()
15 | return x, y
16 |
17 |
18 | class TestSorensenDice(SetSimilarityTest):
19 | # noinspection PyCallingNonCallable
20 | def test_channelwise(self):
21 | from inferno.extensions.criteria.set_similarity_measures import SorensenDiceLoss
22 | x, y = self.get_dummy_variables()
23 | channelwise = SorensenDiceLoss(channelwise=True)
24 | not_channelwise = SorensenDiceLoss(channelwise=False)
25 | # Compute expected channelwise loss
26 | expected_channelwise_loss = \
27 | not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \
28 | not_channelwise(x[:, 1, ...], y[:, 1, ...])
29 | # Compute channelwise
30 | channelwise_loss = channelwise(x, y)
31 | # Compare
32 | self.assertAlmostEqual(expected_channelwise_loss.item(), channelwise_loss.item())
33 |
34 |
35 | class TestGeneralizedSorensenDice(SetSimilarityTest):
36 | def test_channelwise(self):
37 | from inferno.extensions.criteria.set_similarity_measures import GeneralizedDiceLoss
38 | x, y = self.get_dummy_variables_with_channels_and_classes()
39 | channelwise = GeneralizedDiceLoss(channelwise=True)
40 | not_channelwise = GeneralizedDiceLoss(channelwise=False)
41 | # Compute channelwise loss and expected one:
42 | channelwise_loss = channelwise(x, y)
43 | expected_channelwise_loss = \
44 | not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \
45 | not_channelwise(x[:, 1, ...], y[:, 1, ...])
46 | # Compare
47 | self.assertAlmostEqual(expected_channelwise_loss.item(), channelwise_loss.item())
48 |
49 |
50 | if __name__ == '__main__':
51 | unittest.main()
52 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_metrics/categorical.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | from inferno.extensions.metrics import IOU
4 |
5 |
6 | class TestCategorical(unittest.TestCase):
7 | def test_iou_basic(self):
8 | # from one hot
9 | predicted_image = torch.zeros(*(2, 10, 10))
10 | predicted_image[:, 0:4, 0:4] = 1
11 | target_image = torch.zeros(*(2, 10, 10))
12 | target_image[:, 0:3, 0:3] = 1
13 | expected_iou = (3 * 3)/(4 * 4)
14 | iou = IOU()(predicted_image[None, ...], target_image[None, ...])
15 | self.assertAlmostEqual(iou, expected_iou, places=4)
16 |
17 | def test_iou_with_ignore_class(self):
18 | predicted_image = torch.zeros(*(2, 10, 10))
19 | predicted_image[0, 0:4, 0:4] = 1
20 | target_image = torch.zeros(*(2, 10, 10))
21 | target_image[:, 0:3, 0:3] = 1
22 | expected_iou = (3 * 3) / (4 * 4)
23 | iou = IOU(ignore_class=1)(predicted_image[None, ...], target_image[None, ...])
24 | self.assertAlmostEqual(iou, expected_iou, places=4)
25 |
26 | def test_multiclass_iou(self):
27 | predicted_image = torch.zeros(*(2, 10, 10))
28 | predicted_image[0, 0:4, 0:4] = 1
29 | target_image = torch.zeros(*(2, 10, 10))
30 | target_image[:, 0:3, 0:3] = 1
31 | iou_class_0 = (3 * 3) / (4 * 4)
32 | iou_class_1 = 0
33 | expected_mean_iou = 0.5 * (iou_class_0 + iou_class_1)
34 | iou = IOU()(predicted_image[None, ...], target_image[None, ...])
35 | self.assertAlmostEqual(iou, expected_mean_iou, places=4)
36 |
37 | def test_multiclass_iou_with_ignore_class(self):
38 | predicted_image = torch.zeros(*(3, 10, 10))
39 | predicted_image[0, 0:4, 0:4] = 1
40 | # Have the third plane be crap
41 | predicted_image[2, :, :] = 1
42 | target_image = torch.zeros(*(3, 10, 10))
43 | target_image[:, 0:3, 0:3] = 1
44 | iou_class_0 = (3 * 3) / (4 * 4)
45 | iou_class_1 = 0
46 | expected_mean_iou = 0.5 * (iou_class_0 + iou_class_1)
47 | iou = IOU(ignore_class=-1)(predicted_image[None, ...], target_image[None, ...])
48 | self.assertAlmostEqual(iou, expected_mean_iou, places=4)
49 |
50 | if __name__ == '__main__':
51 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_io/test_core/test_zip.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 |
4 | class ZipTest(unittest.TestCase):
5 | def test_zip_minimal(self):
6 | """Minimal test with python lists as iterators."""
7 | from inferno.io.core import Zip
8 | from torch.utils.data.dataset import Dataset
9 |
10 | with self.assertRaises(TypeError):
11 | zipped = Zip([1, 2, 3], [4, 5, 6, 7])
12 |
13 | # This is required because Zip checks if its inputs are actually torch datasets
14 | class ListDataset(list, Dataset):
15 | pass
16 |
17 | dataset_1 = ListDataset([1, 2, 3, 4])
18 | dataset_2 = ListDataset([5, 6, 7, 8, 9])
19 | zipped = Zip(dataset_1, dataset_2)
20 | self.assertEqual(len(zipped), 4)
21 |
22 | fetched = zipped[1]
23 | self.assertEqual(fetched, [2, 6])
24 |
25 | with self.assertRaises(IndexError):
26 | fetched = zipped[4]
27 |
28 | def test_zip_sync(self):
29 | """Test synchronization mechanics."""
30 | # TODO
31 |
32 | def test_zip_reject(self):
33 | from inferno.io.core import ZipReject
34 | from torch.utils.data.dataset import Dataset
35 |
36 | # This is required because Zip checks if its inputs are actually torch datasets
37 | class ListDataset(list, Dataset):
38 | pass
39 |
40 | def rejection_criterion(sample_1, sample_2):
41 | return sample_1 < sample_2
42 |
43 | dataset_1 = ListDataset([1, 2, 3, 4])
44 | dataset_2 = ListDataset([2, 1, 3, 4])
45 | dataset_3 = ListDataset([0, 1, 2, 3])
46 |
47 | zipped = ZipReject(dataset_1, dataset_2, dataset_3,
48 | rejection_criterion=rejection_criterion,
49 | random_jump_after_reject=False,
50 | rejection_dataset_indices=[0, 1])
51 | fetched = zipped[0]
52 | self.assertSequenceEqual(fetched, [2, 1, 1])
53 |
54 | zipped = ZipReject(dataset_1, dataset_2, dataset_3,
55 | rejection_criterion=rejection_criterion,
56 | rejection_dataset_indices=[1, 0])
57 | fetched = zipped[0]
58 | self.assertSequenceEqual(fetched, [1, 2, 0])
59 |
60 |
61 | if __name__ == '__main__':
62 | unittest.main()
63 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_models/test_unet.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch.cuda as cuda
3 | from inferno.utils.model_utils import ModelTester, MultiscaleModelTester
4 | from inferno.extensions.models import UNet
5 |
6 | class _MultiscaleUNet(UNet):
7 | def conv_op_factory(self, in_channels, out_channels, part, index):
8 | return super(_MultiscaleUNet, self).conv_op_factory(in_channels, out_channels, part, index)[0], True
9 |
10 | def forward(self, input):
11 | x = self._initial_conv(input)
12 | x = list(super(UNet, self).forward(x))
13 | x[-1] = self._output(x[-1])
14 | return tuple(x)
15 |
16 |
17 | class UNetTest(unittest.TestCase):
18 | def test_unet_2d(self):
19 | tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))
20 | if cuda.is_available():
21 | tester.cuda()
22 | tester(UNet(1, 1, dim=2, initial_features=32))
23 |
24 | def test_unet_3d(self):
25 | tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))
26 | if cuda.is_available():
27 | tester.cuda()
28 | # test default unet 3d
29 | tester(UNet(1, 1, dim=3, initial_features=8))
30 |
31 | def test_monochannel_unet_3d(self):
32 | nc = 2
33 | class _UNetMonochannel(_MultiscaleUNet):
34 | def _get_num_channels(self, depth):
35 | return nc
36 |
37 | shapes = [(1, nc, 16, 64, 64), (1, nc, 8, 32, 32), (1, nc, 4, 16, 16), (1, nc, 2, 8, 8), (1, nc, 1, 4, 4),
38 | (1, nc, 2, 8, 8), (1, nc, 4, 16, 16), (1, nc, 8, 32, 32), (1, 1, 16, 64, 64)]
39 | tester = MultiscaleModelTester((1, 1, 16, 64, 64), shapes)
40 | if cuda.is_available():
41 | tester.cuda()
42 | tester(_UNetMonochannel(1, 1, dim=3, initial_features=8))
43 |
44 | def test_inverse_pyramid_unet_2d(self):
45 | class _UNetInversePyramid(_MultiscaleUNet):
46 | def _get_num_channels(self, depth):
47 | return [13, 12, 11][depth - 1]
48 |
49 | shapes = [(1, 13, 16, 64), (1, 12, 8, 32), (1, 11, 4, 16), (1, 11, 2, 8),
50 | (1, 12, 4, 16), (1, 13, 8, 32), (1, 1, 16, 64)]
51 | tester = MultiscaleModelTester((1, 1, 16, 64), shapes)
52 | if cuda.is_available():
53 | tester.cuda()
54 | tester(_UNetInversePyramid(1, 1, dim=2, depth=3, initial_features=8))
55 |
56 |
57 | if __name__ == '__main__':
58 | unittest.main()
59 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """The setup script."""
5 |
6 | from setuptools import setup, find_packages
7 | import runpy
8 | __version__ = runpy.run_path('inferno/version.py')['__version__']
9 |
10 |
11 | with open('README.rst') as readme_file:
12 | readme = readme_file.read()
13 |
14 | with open('HISTORY.rst') as history_file:
15 | history = history_file.read()
16 |
17 | requirements = [
18 | # TODO: put package requirements here
19 | "pip>=8.1.2",
20 | "torch>=0.1.12",
21 | "dill",
22 | "pyyaml",
23 | "scipy>=0.13.0",
24 | "h5py",
25 | "numpy>=1.8",
26 | "scikit-image",
27 | "torchvision",
28 | "tqdm"
29 | ]
30 |
31 |
32 | setup_requirements = [
33 | 'pytest-runner'
34 | ]
35 |
36 | test_requirements = [
37 | 'pytest', 'unittest'
38 | ]
39 |
40 | dependency_links = [
41 | 'http://download.pytorch.org/whl/cu75/torch-0.2.0.post1-cp35-cp35m-manylinux1_x86_64.whl#egg=torch-0.2.0'
42 | ]
43 |
44 | setup(
45 | name='inferno-pytorch',
46 | version=__version__,
47 | description="Inferno is a little library providing utilities and convenience functions/classes around PyTorch.",
48 | long_description=readme + '\n\n' + history,
49 | author="Nasim Rahaman",
50 | author_email='nasim.rahaman@iwr.uni-heidelberg.de',
51 | url='https://github.com/inferno-pytorch/inferno',
52 | packages=find_packages(where='.', exclude=["*.tests", "*.tests.*",
53 | "tests.*", "tests",
54 | "__pycache__", "*.pyc"]),
55 | dependency_links=dependency_links,
56 | include_package_data=True,
57 | install_requires=requirements,
58 | license="Apache Software License 2.0",
59 | zip_safe=False,
60 | keywords='inferno pytorch torch deep learning cnn deep-pyromania',
61 | classifiers=[
62 | # How mature is this project? Common values are\
63 | # 2 - Pre-Alpha',
64 | # 3 - Alpha,
65 | # 4 - Beta,
66 | # 5 - Production/Stable
67 | 'Development Status :: 2 - Pre-Alpha',
68 | # Indicate who your project is intended for
69 | 'Intended Audience :: Science/Research',
70 | 'License :: OSI Approved :: Apache Software License',
71 | 'Natural Language :: English',
72 | 'Programming Language :: Python :: 3.5',
73 | 'Programming Language :: Python :: 3.6'
74 | ],
75 | test_suite='test',
76 | tests_require=test_requirements,
77 | setup_requires=setup_requirements,
78 | )
79 |
--------------------------------------------------------------------------------
/docs/inferno-apidoc/inferno.extensions.layers.rst:
--------------------------------------------------------------------------------
1 | inferno.extensions.layers package
2 | =================================
3 |
4 | Submodules
5 | ----------
6 |
7 | inferno.extensions.layers.activations module
8 | --------------------------------------------
9 |
10 | .. automodule:: inferno.extensions.layers.activations
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | inferno.extensions.layers.building\_blocks module
16 | -------------------------------------------------
17 |
18 | .. automodule:: inferno.extensions.layers.building_blocks
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | inferno.extensions.layers.convolutional module
24 | ----------------------------------------------
25 |
26 | .. automodule:: inferno.extensions.layers.convolutional
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | inferno.extensions.layers.device module
32 | ---------------------------------------
33 |
34 | .. automodule:: inferno.extensions.layers.device
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | inferno.extensions.layers.identity module
40 | -----------------------------------------
41 |
42 | .. automodule:: inferno.extensions.layers.identity
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | inferno.extensions.layers.prefab module
48 | ---------------------------------------
49 |
50 | .. automodule:: inferno.extensions.layers.prefab
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | inferno.extensions.layers.res\_unet module
56 | ------------------------------------------
57 |
58 | .. automodule:: inferno.extensions.layers.res_unet
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | inferno.extensions.layers.reshape module
64 | ----------------------------------------
65 |
66 | .. automodule:: inferno.extensions.layers.reshape
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | inferno.extensions.layers.sampling module
72 | -----------------------------------------
73 |
74 | .. automodule:: inferno.extensions.layers.sampling
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
79 | inferno.extensions.layers.unet\_base module
80 | -------------------------------------------
81 |
82 | .. automodule:: inferno.extensions.layers.unet_base
83 | :members:
84 | :undoc-members:
85 | :show-inheritance:
86 |
87 |
88 | Module contents
89 | ---------------
90 |
91 | .. automodule:: inferno.extensions.layers
92 | :members:
93 | :undoc-members:
94 | :show-inheritance:
95 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: clean clean-test clean-pyc clean-build docs help
2 | .DEFAULT_GOAL := help
3 | define BROWSER_PYSCRIPT
4 | import os, webbrowser, sys
5 | try:
6 | from urllib import pathname2url
7 | except:
8 | from urllib.request import pathname2url
9 |
10 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1])))
11 | endef
12 | export BROWSER_PYSCRIPT
13 |
14 | define PRINT_HELP_PYSCRIPT
15 | import re, sys
16 |
17 | for line in sys.stdin:
18 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)
19 | if match:
20 | target, help = match.groups()
21 | print("%-20s %s" % (target, help))
22 | endef
23 | export PRINT_HELP_PYSCRIPT
24 | BROWSER := python -c "$$BROWSER_PYSCRIPT"
25 |
26 | help:
27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
28 |
29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
30 |
31 |
32 | clean-build: ## remove build artifacts
33 | rm -fr build/
34 | rm -fr dist/
35 | rm -fr .eggs/
36 | find . -name '*.egg-info' -exec rm -fr {} +
37 | find . -name '*.egg' -exec rm -f {} +
38 |
39 | clean-pyc: ## remove Python file artifacts
40 | find . -name '*.pyc' -exec rm -f {} +
41 | find . -name '*.pyo' -exec rm -f {} +
42 | find . -name '*~' -exec rm -f {} +
43 | find . -name '__pycache__' -exec rm -fr {} +
44 |
45 | clean-test: ## remove test and coverage artifacts
46 | rm -fr .tox/
47 | rm -f .coverage
48 | rm -fr htmlcov/
49 |
50 | lint: ## check style with flake8
51 | flake8 inferno tests
52 |
53 | test: ## run tests quickly with the default Python
54 |
55 | python setup.py test
56 |
57 | test-all: ## run tests on every Python version with tox
58 | tox
59 |
60 | coverage: ## check code coverage quickly with the default Python
61 | coverage run --source inferno setup.py test
62 | coverage report -m
63 | coverage html
64 | $(BROWSER) htmlcov/index.html
65 |
66 | docs: ## generate Sphinx HTML documentation, including API docs
67 | rm -f docs/inferno.rst
68 | rm -f docs/modules.rst
69 | sphinx-apidoc -o docs/ inferno
70 | $(MAKE) -C docs clean
71 | $(MAKE) -C docs html
72 | $(BROWSER) docs/_build/html/index.html
73 |
74 | servedocs: docs ## compile the docs watching for changes
75 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D .
76 |
77 | release: clean ## package and upload a release
78 | python setup.py sdist upload
79 | python setup.py bdist_wheel upload
80 |
81 | dist: clean ## builds source and wheel package
82 | python setup.py sdist
83 | python setup.py bdist_wheel
84 | ls -l dist
85 |
86 | install: clean ## install the package to the active Python's site-packages
87 | python setup.py install
88 |
--------------------------------------------------------------------------------
/examples/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Trainer Example
3 | ================================
4 |
5 | This example should illustrate how to use the trainer class.
6 |
7 | """
8 |
9 | import torch.nn as nn
10 | from inferno.io.box.cifar import get_cifar10_loaders
11 | from inferno.trainers.basic import Trainer
12 | from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
13 | from inferno.extensions.layers import ConvELU2D
14 | from inferno.extensions.layers import Flatten
15 | from inferno.utils.python_utils import ensure_dir
16 |
17 | from inferno.extensions.layers import SELU
18 |
19 | ##################################################
20 | # change directories to your needs
21 | LOG_DIRECTORY = ensure_dir('log')
22 | SAVE_DIRECTORY = ensure_dir('save')
23 | DATASET_DIRECTORY = ensure_dir('dataset')
24 |
25 | ##################################################
26 | # shall models be downloaded
27 | DOWNLOAD_CIFAR = True
28 | USE_CUDA = True
29 |
30 | ##################################################
31 | # Build torch model
32 | model = nn.Sequential(
33 | ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
34 | nn.MaxPool2d(kernel_size=2, stride=2),
35 | ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
36 | nn.MaxPool2d(kernel_size=2, stride=2),
37 | ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
38 | nn.MaxPool2d(kernel_size=2, stride=2),
39 | Flatten(),
40 | nn.Linear(in_features=(256 * 4 * 4), out_features=10),
41 | nn.Softmax()
42 | )
43 |
44 | ##################################################
45 | # data loaders
46 | train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
47 | download=DOWNLOAD_CIFAR)
48 |
49 | ##################################################
50 | # Build trainer
51 | trainer = Trainer(model)
52 | trainer.build_criterion('CrossEntropyLoss')
53 | trainer.build_metric('CategoricalError')
54 | trainer.build_optimizer('Adam')
55 | trainer.validate_every((2, 'epochs'))
56 | trainer.save_every((5, 'epochs'))
57 | trainer.save_to_directory(SAVE_DIRECTORY)
58 | trainer.set_max_num_epochs(10)
59 | trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
60 | log_images_every='never'),
61 | log_directory=LOG_DIRECTORY)
62 |
63 | ##################################################
64 | # Bind loaders
65 | trainer.bind_loader('train', train_loader)
66 | trainer.bind_loader('validate', validate_loader)
67 |
68 | ##################################################
69 | # activate cuda
70 | if USE_CUDA:
71 | trainer.cuda()
72 |
73 | ##################################################
74 | # fit
75 | trainer.fit()
76 |
--------------------------------------------------------------------------------
/tests/test_training/test_callbacks/test_base.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | from inferno.trainers.callbacks.base import Callback, CallbackEngine
4 | from inferno.trainers.basic import Trainer
5 | from os.path import join, dirname, exists
6 | from os import makedirs
7 | from shutil import rmtree
8 |
9 |
10 | class DummyCallback(Callback):
11 | def end_of_training_iteration(self, **_):
12 | assert self.trainer is not None
13 |
14 |
15 | class WrongDummyCallback(Callback):
16 | def end_of_iteration(self):
17 | pass
18 |
19 |
20 | class CallbackMechTest(unittest.TestCase):
21 | ROOT_DIR = join(dirname(__file__), 'root')
22 |
23 | def setUp(self):
24 | makedirs(self.ROOT_DIR, exist_ok=True)
25 |
26 | def tearDown(self):
27 | if exists(self.ROOT_DIR):
28 | rmtree(self.ROOT_DIR)
29 |
30 | def test_serialization(self):
31 | # Build engine and trainer
32 | callback_engine = CallbackEngine().bind_trainer(Trainer())
33 | callback_engine.register_callback(DummyCallback())
34 | # Serialize
35 | torch.save(callback_engine, join(self.ROOT_DIR, 'callback_engine.pkl'))
36 | # Unserialize
37 | callback_engine = torch.load(join(self.ROOT_DIR, 'callback_engine.pkl'))
38 | # Make sure the trainer is detached
39 | self.assertIsNone(callback_engine._trainer)
40 | self.assertIsInstance(next(iter(callback_engine
41 | ._callback_registry
42 | .get('end_of_training_iteration'))),
43 | DummyCallback)
44 |
45 | def test_auto_registry(self):
46 | callback_engine = CallbackEngine().bind_trainer(Trainer())
47 | callback_engine.register_callback(DummyCallback())
48 | self.assertIsInstance(next(iter(callback_engine
49 | ._callback_registry
50 | .get('end_of_training_iteration'))),
51 | DummyCallback)
52 | with self.assertRaises(AssertionError):
53 | callback_engine.register_callback(WrongDummyCallback())
54 |
55 | def test_instance_registry(self):
56 | class Foo(Callback):
57 | pass
58 |
59 | class Bar(Callback):
60 | pass
61 |
62 | foo = Foo()
63 | bar = Bar()
64 | self.assertIs(foo.get_instances(), foo)
65 | self.assertIs(bar.get_instances(), bar)
66 | foo2 = Foo()
67 | self.assertSequenceEqual(foo2.get_instances(), [foo, foo2])
68 | self.assertIs(bar.get_instances(), bar)
69 |
70 | if __name__ == '__main__':
71 | unittest.main()
72 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_layers/test_reshape.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 |
4 |
5 | class TestReshape(unittest.TestCase):
6 | def _get_input_variable(self, *shape):
7 | return torch.rand(*shape)
8 |
9 | def test_as_matrix(self):
10 | from inferno.extensions.layers.reshape import AsMatrix
11 |
12 | input = self._get_input_variable(10, 20, 1, 1)
13 | as_matrix = AsMatrix()
14 | output = as_matrix(input)
15 | self.assertEqual(list(output.size()), [10, 20])
16 |
17 | def test_flatten(self):
18 | from inferno.extensions.layers.reshape import Flatten
19 |
20 | input = self._get_input_variable(10, 20, 2, 2)
21 | flatten = Flatten()
22 | output = flatten(input)
23 | self.assertEqual(list(output.size()), [10, 80])
24 |
25 | def test_as_2d(self):
26 | from inferno.extensions.layers.reshape import As2D
27 |
28 | as_2d = As2D()
29 |
30 | output_shape = as_2d(self._get_input_variable(10, 20, 3, 30, 30)).size()
31 | self.assertEqual(list(output_shape), [10, 60, 30, 30])
32 |
33 | output_shape = as_2d(self._get_input_variable(10, 20, 30, 30)).size()
34 | self.assertEqual(list(output_shape), [10, 20, 30, 30])
35 |
36 | output_shape = as_2d(self._get_input_variable(10, 20)).size()
37 | self.assertEqual(list(output_shape), [10, 20, 1, 1])
38 |
39 | def test_as_3d(self):
40 | from inferno.extensions.layers.reshape import As3D
41 | from inferno.utils.exceptions import ShapeError
42 |
43 | as_3d = As3D()
44 |
45 | output_shape = as_3d(self._get_input_variable(10, 20, 3, 30, 30)).size()
46 | self.assertEqual(list(output_shape), [10, 20, 3, 30, 30])
47 |
48 | output_shape = as_3d(self._get_input_variable(10, 20, 30, 30)).size()
49 | self.assertEqual(list(output_shape), [10, 20, 1, 30, 30])
50 |
51 | output_shape = as_3d(self._get_input_variable(10, 20)).size()
52 | self.assertEqual(list(output_shape), [10, 20, 1, 1, 1])
53 |
54 | as_3d.channel_as_z = True
55 | output_shape = as_3d(self._get_input_variable(10, 20, 30, 30)).size()
56 | self.assertEqual(list(output_shape), [10, 1, 20, 30, 30])
57 |
58 | as_3d.num_channels_or_num_z_slices = 2
59 | output_shape = as_3d(self._get_input_variable(10, 40, 30, 30)).size()
60 | self.assertEqual(list(output_shape), [10, 2, 20, 30, 30])
61 |
62 | with self.assertRaises(ShapeError):
63 | output_shape = as_3d(self._get_input_variable(10, 41, 30, 30)).size()
64 | self.assertEqual(list(output_shape), [10, 2, 20, 30, 30])
65 |
66 |
67 | if __name__ == '__main__':
68 | unittest.main()
69 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_layers/deprecated/building_blocks.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import inferno.extensions.layers.building_blocks as bb
4 |
5 |
6 | class ResBlockTest(unittest.TestCase):
7 |
8 | def test_2D_simple_(self):
9 |
10 | x = torch.rand(1, 3, 64, 15)
11 | model = bb.ResBlock(in_channels=3, out_channels=3, dim=2)
12 | xx = model(x)
13 | out_size = xx.size()
14 | self.assertEqual(list(out_size), [1,3, 64, 15])
15 |
16 | def test_3D_simple_(self):
17 |
18 | x = torch.rand(1,3,20, 64,15)
19 | model = bb.ResBlock(in_channels=3, out_channels=3, dim=3)
20 | xx = model(x)
21 | out_size = xx.size()
22 | self.assertEqual(list(out_size), [1,3, 20, 64, 15])
23 |
24 | def test_2D_simple_2(self):
25 |
26 | x = torch.rand(1,3,64,64)
27 | model = bb.ResBlock(in_channels=3, out_channels=6, dim=2)
28 | xx = model(x)
29 | out_size = xx.size()
30 | self.assertEqual(list(out_size), [1,6, 64, 64])
31 |
32 | def test_2D_simple_3(self):
33 |
34 | x = torch.rand(1,3,64,64)
35 | model = bb.ResBlock(in_channels=3, out_channels=6, dim=2, size=4)
36 | xx = model(x)
37 | out_size = xx.size()
38 | self.assertEqual(list(out_size), [1,6, 64, 64])
39 |
40 | def test_2D_simple_4(self):
41 |
42 | x = torch.rand(1,6,64,64)
43 | model = bb.ResBlock(in_channels=6, out_channels=6, dim=2, size=4,
44 | force_skip_op=True)
45 | xx = model(x)
46 | out_size = xx.size()
47 | self.assertEqual(list(out_size), [1,6, 64, 64])
48 |
49 | def test_2D_simple_5(self):
50 |
51 | x = torch.rand(1,6,64,64)
52 | model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=2, size=4,
53 | force_skip_op=True)
54 | xx = model(x)
55 | out_size = xx.size()
56 | self.assertEqual(list(out_size), [1,6, 64, 64])
57 |
58 | def test_2D_simple_6(self):
59 |
60 | x = torch.rand(1,6,64,64)
61 | model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=2, size=4,
62 | force_skip_op=True, activated=False)
63 | xx = model(x)
64 | out_size = xx.size()
65 | self.assertEqual(list(out_size), [1,6, 64, 64])
66 |
67 | def test_3D_simple_6(self):
68 |
69 | x = torch.rand(1,6,64,64, 20)
70 | model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=3, size=4,
71 | force_skip_op=True, activated=False)
72 | xx = model(x)
73 | out_size = xx.size()
74 | self.assertEqual(list(out_size), [1,6, 64, 64, 20])
75 |
76 |
77 | if __name__ == '__main__':
78 | unittest.main()
79 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | .. highlight:: shell
2 |
3 | ==================================
4 | Installation
5 | ==================================
6 |
7 | Install on Linux and OSX
8 | ------------------------
9 |
10 | Developers
11 | ~~~~~~~~~~~~~~~~~~~~~~
12 |
13 | First, make sure `you have Pytorch installed `_.
14 |
15 | Then, clone this repository with:
16 |
17 | .. code:: python
18 |
19 | $ git clone https://github.com/nasimrahaman/inferno.git
20 |
21 |
22 | Next, install the dependencies.
23 |
24 | .. code:: python
25 |
26 | $ cd inferno
27 | $ pip install -r requirements.txt
28 |
29 |
30 | If you use python from the shell:
31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32 |
33 | Finally, add *inferno* to your `PYTHONPATH` with:
34 |
35 | .. code:: python
36 |
37 | source add2path.sh
38 |
39 | If you use PyCharm:
40 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41 | Refer to this `QA `_ about setting up paths with Pycharm.
42 |
43 |
44 |
45 |
46 |
47 | ======================================================
48 | Installation via PyPi / pip / setup.py(Experimental)
49 | ======================================================
50 |
51 | You need to install pytorch via pip before installing
52 | inferno. Follow the `pytorch installation guide`_.
53 |
54 | Stable release
55 | --------------
56 |
57 | To install inferno, run this command in your terminal:
58 |
59 | .. code-block:: console
60 |
61 | $ pip install inferno-pytorch
62 |
63 | This is the preferred method to install inferno, as it will always install the most recent stable release.
64 |
65 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide
66 | you through the process.
67 |
68 | .. _pip: https://pip.pypa.io
69 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/
70 | .. _pytorch installation guide: http://pytorch.org/
71 |
72 | From sources
73 | ------------------------
74 | First, make sure `you have Pytorch installed `_.
75 | The sources for inferno can be downloaded from the `Github repo`_.
76 | You can either clone the public repository:
77 |
78 | .. code-block:: console
79 |
80 | $ git clone git://github.com/nasimrahaman/inferno
81 |
82 | Or download the `tarball`_:
83 |
84 | .. code-block:: console
85 |
86 | $ curl -OL https://github.com/nasimrahaman/inferno/tarball/master
87 |
88 | Once you have a copy of the source, you can install it with:
89 |
90 | .. code-block:: console
91 |
92 | $ python setup.py install
93 |
94 |
95 | .. _Github repo: https://github.com/nasimrahaman/inferno
96 | .. _tarball: https://github.com/nasimrahaman/inferno/tarball/master
97 |
--------------------------------------------------------------------------------
/inferno/io/core/concatenate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data.dataset import Dataset
3 | from ...utils import python_utils as pyu
4 |
5 |
6 | class Concatenate(Dataset):
7 | """
8 | Concatenates mutliple datasets to one. This class does not implement
9 | synchronization primitives.
10 | """
11 | def __init__(self, *datasets, transforms=None):
12 | assert all([isinstance(dataset, Dataset) for dataset in datasets])
13 | assert len(datasets) >= 1
14 | assert transforms is None or callable(transforms)
15 | self.datasets = datasets
16 | self.transforms = transforms
17 |
18 | def map_index(self, index):
19 | # Get a list of lengths of all datasets. Say the answer is [4, 3, 3],
20 | # and we're looking for index = 5.
21 | len_list = list(map(len, self.datasets))
22 | # Cumulate to a numpy array. The answer is [4, 7, 10]
23 | cumulative_len_list = np.cumsum(len_list)
24 | # When the index is subtracted, we get [-1, 2, 5]. We're looking for the (index
25 | # of the) first cumulated len which is larger than the index (in this case,
26 | # 7 (index 1)).
27 | offset_cumulative_len_list = cumulative_len_list - index
28 | dataset_index = np.argmax(offset_cumulative_len_list > 0)
29 | # With the dataset index, we figure out the index in dataset
30 | if dataset_index == 0:
31 | # First dataset - index corresponds to index_in_dataset
32 | index_in_dataset = index
33 | else:
34 | # Get cumulated length up to the current dataset
35 | len_up_to_dataset = cumulative_len_list[dataset_index - 1]
36 | # Compute index_in_dataset as that what's left
37 | index_in_dataset = index - len_up_to_dataset
38 | return dataset_index, index_in_dataset
39 |
40 | def __getitem__(self, index):
41 | assert index < len(self)
42 | dataset_index, index_in_dataset = self.map_index(index)
43 | fetched = self.datasets[dataset_index][index_in_dataset]
44 | if self.transforms is None:
45 | return fetched
46 | elif callable(self.transforms):
47 | return self.transforms(*pyu.to_iterable(fetched))
48 | else:
49 | raise NotImplementedError
50 |
51 | def __len__(self):
52 | return sum([len(dataset) for dataset in self.datasets])
53 |
54 | def __repr__(self):
55 | if len(self.datasets) < 3:
56 | return "Concatenate(" + \
57 | ", ".join([dataset.__repr__() for dataset in self.datasets[:-1]]) + ", " + \
58 | self.datasets[-1].__repr__() + \
59 | ")"
60 | else:
61 | return "Concatenate({}xDatasets)".format(len(self.datasets))
62 |
--------------------------------------------------------------------------------
/inferno/extensions/layers/convolutional_blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .convolutional import BNReLUConv2D, BNReLUDeconv2D, Conv2D, Deconv2D
3 | from ...utils import python_utils as pyu
4 | from ...utils.exceptions import assert_
5 |
6 | __all__ = ['ResidualBlock', 'PreActSimpleResidualBlock']
7 | _all = __all__
8 |
9 |
10 | class ResidualBlock(nn.Module):
11 | def __init__(self, layers, resample=None):
12 | super(ResidualBlock, self).__init__()
13 | assert pyu.is_listlike(layers)
14 | self.layers = nn.Sequential(*layers)
15 | self.resample = resample
16 |
17 | def forward(self, input):
18 | preaddition = self.layers(input)
19 | if self.resample is not None:
20 | skip = self.resample(input)
21 | else:
22 | skip = input
23 | output = preaddition + skip
24 | return output
25 |
26 |
27 | class PreActSimpleResidualBlock(ResidualBlock):
28 | def __init__(self, in_channels, num_hidden_channels, upsample=False, downsample=False):
29 | layers = []
30 | if downsample:
31 | assert_(not upsample, "Both downsample and upsample is set to true.", ValueError)
32 | layers.append(BNReLUConv2D(in_channels=in_channels,
33 | out_channels=num_hidden_channels,
34 | kernel_size=3,
35 | stride=2))
36 | resample = nn.Sequential(Conv2D(in_channels=in_channels,
37 | out_channels=in_channels,
38 | kernel_size=1, stride=2),
39 | nn.BatchNorm2d(in_channels))
40 | elif upsample:
41 | layers.append(BNReLUDeconv2D(in_channels=in_channels,
42 | out_channels=num_hidden_channels,
43 | kernel_size=2,
44 | stride=2))
45 | resample = nn.Sequential(Deconv2D(in_channels=in_channels,
46 | out_channels=in_channels,
47 | kernel_size=2, stride=2),
48 | nn.BatchNorm2d(in_channels))
49 | else:
50 | layers.append(BNReLUConv2D(in_channels=in_channels,
51 | out_channels=num_hidden_channels,
52 | kernel_size=3))
53 | resample = None
54 | layers.append(BNReLUConv2D(in_channels=num_hidden_channels,
55 | out_channels=in_channels,
56 | kernel_size=3))
57 | super(PreActSimpleResidualBlock, self).__init__(layers, resample)
58 |
59 |
60 | # TODO PreActBottleneckResidualBlock
61 |
--------------------------------------------------------------------------------
/inferno/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .exceptions import assert_, NotTorchModuleError, ShapeError
3 |
4 |
5 | def is_model_cuda(model):
6 | try:
7 | return next(model.parameters()).is_cuda
8 | except StopIteration:
9 | # Assuming that if a network has no parameters, it doesn't use CUDA
10 | return False
11 |
12 |
13 | class ModelTester(object):
14 | def __init__(self, input_shape, expected_output_shape):
15 | self._is_cuda = False
16 | self.input_shape = input_shape
17 | self.expected_output_shape = expected_output_shape
18 |
19 | def cuda(self):
20 | self._is_cuda = True
21 | return self
22 |
23 | def get_input(self):
24 | with torch.no_grad():
25 | if self._is_cuda:
26 | return torch.rand(*self.input_shape, requires_grad=False).cuda()
27 | else:
28 | return torch.rand(*self.input_shape, requires_grad=False)
29 |
30 | def __call__(self, model):
31 | # Make sure model is a model
32 | assert_(isinstance(model, torch.nn.Module),
33 | "Model is not a torch module.",
34 | NotTorchModuleError)
35 | # Transfer to cuda if required
36 | if not is_model_cuda(model) and self._is_cuda:
37 | model.cuda()
38 | input_ = self.get_input()
39 | output = model(input_)
40 | assert_(list(output.size()) == list(self.expected_output_shape),
41 | "Expected output shape {} for input shape {}, "
42 | "got output of shape {} instead.".format(list(self.expected_output_shape),
43 | list(self.input_shape),
44 | list(output.size())),
45 | ShapeError)
46 | return model
47 |
48 |
49 | class MultiscaleModelTester(ModelTester):
50 | def __call__(self, model):
51 | # Make sure model is a model
52 | assert_(isinstance(model, torch.nn.Module),
53 | "Model is not a torch module.",
54 | NotTorchModuleError)
55 | # Transfer to cuda if required
56 | if not is_model_cuda(model) and self._is_cuda:
57 | model.cuda()
58 | input_ = self.get_input()
59 | output = model(input_)
60 | assert_(isinstance(output, tuple), "Expect tuple output")
61 | for scale in range(len(output)):
62 | assert_(list(output[scale].size()) == list(self.expected_output_shape[scale]),
63 | "Expected output shape {} for input shape {}, "
64 | "got output of shape {} instead.".format(list(self.expected_output_shape[scale]),
65 | list(self.input_shape),
66 | list(output[scale].size())),
67 | ShapeError)
68 | return model
69 |
--------------------------------------------------------------------------------
/inferno/trainers/callbacks/console.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from .base import Callback
3 |
4 | class StdoutPrinter(object):
5 | def print(self, message):
6 | print("[+][{}] {}".format(str(datetime.now()), message))
7 |
8 |
9 | class Console(object):
10 | LEVEL_INFO = 1
11 | LEVEL_PROGRESS = 2
12 | LEVEL_WARNING = 3
13 | LEVEL_DEBUG = 4
14 |
15 | def __init__(self, printer=StdoutPrinter()):
16 | self._printer = printer
17 | self._enabled = {self.LEVEL_INFO, self.LEVEL_PROGRESS, self.LEVEL_WARNING}
18 |
19 | def set_console(self, console):
20 | self._printer = console
21 |
22 | def _print(self, message, level):
23 | if level not in self._enabled:
24 | return
25 |
26 | self._printer.print(message)
27 |
28 | def info(self, message):
29 | self._print("[INFO ] " + message, self.LEVEL_INFO)
30 |
31 | def print(self, message):
32 | self.info(message)
33 |
34 | def progress(self, message):
35 | self._print("[PROGRESS] " + message, self.LEVEL_PROGRESS)
36 |
37 | def warning(self, message):
38 | self._print("[WARNING ] " + message, self.LEVEL_WARNING)
39 |
40 | def debug(self, message):
41 | self._print("[DEBUG ] " + message, self.LEVEL_DEBUG)
42 |
43 | def _toggle(self, level, state):
44 | if state:
45 | self._enabled.add(level)
46 | else:
47 | if level in self._enabled:
48 | self._enabled.remove(level)
49 |
50 | def toggle_info(self, state):
51 | self._toggle(self.LEVEL_INFO, state)
52 |
53 | def toggle_progress(self, state):
54 | self._toggle(self.LEVEL_PROGRESS, state)
55 |
56 | def toggle_warning(self, state):
57 | self._toggle(self.LEVEL_WARNING, state)
58 |
59 |
60 |
61 | class ShowMinimalConsoleInfo(Callback):
62 | """
63 | Callback to show only minimum training info on console
64 | viz. current epoch number, current learning rate,
65 | training loss and training error if exists.
66 | """
67 | def __init__(self, *args, **kwargs):
68 | super(ShowMinimalConsoleInfo, self).__init__(*args, **kwargs)
69 |
70 | def begin_of_fit(self,**_):
71 | self.trainer.quiet()
72 |
73 | def end_of_epoch(self, **_):
74 | training_loss = self.trainer.get_state('training_loss')
75 | training_error = self.trainer.get_state('training_error')
76 | learning_rate = self.trainer.get_state('learning_rate')
77 |
78 | self.trainer.console.info("--------------------------------")
79 | self.trainer.console.info("Epoch "+str(self.trainer.epoch_count))
80 | if training_loss is not None:
81 | self.trainer.console.info("Train Loss "+str(training_loss.item()))
82 | if training_error is not None:
83 | self.trainer.console.info("Train Error "+str(training_error.item()))
84 | self.trainer.console.info("Current LR "+str(learning_rate))
--------------------------------------------------------------------------------
/inferno/utils/io_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import h5py as h5
3 | import numpy as np
4 | import yaml
5 | from skimage.io import imsave
6 |
7 |
8 | # Function to load in a dataset from a h5file
9 | def fromh5(path, datapath=None, dataslice=None, asnumpy=True, preptrain=None):
10 | """
11 | Opens a hdf5 file at path, loads in the dataset at datapath, and returns dataset
12 | as a numpy array.
13 | """
14 | # Check if path exists (thanks Lukas!)
15 | assert os.path.exists(path), "Path {} does not exist.".format(path)
16 | with h5.File(path, 'r') as f:
17 | # Init dataset
18 | h5dataset = f[datapath] if datapath is not None else f.values()[0]
19 | # Slice dataset
20 | h5dataset = h5dataset[dataslice] if dataslice is not None else h5dataset
21 | # Convert to numpy if required
22 | h5dataset = np.asarray(h5dataset) if asnumpy else h5dataset
23 | # Apply preptrain
24 | h5dataset = preptrain(h5dataset) if preptrain is not None else h5dataset
25 | return h5dataset
26 |
27 |
28 | # TODO we could also do **h5_kwargs instead
29 | def toh5(data, path, datapath='data', compression=None, chunks=None):
30 | """Write `data` to a HDF5 volume."""
31 | with h5.File(path) as f:
32 | f.create_dataset(datapath, data=data, compression=compression, chunks=chunks)
33 |
34 |
35 | def fromz5(path, datapath, dataslice=None, n_threads=8):
36 | # we import z5py only here because we don't want to assume that it's in the env
37 | import z5py
38 | assert os.path.exists(path), "Path {} does not exist.".format(path)
39 | with z5py.File(path) as f:
40 | ds = f[datapath]
41 | ds.n_threads = n_threads
42 | data = ds[:] if dataslice is None else ds[dataslice]
43 | return data
44 |
45 |
46 | # Yaml to dict reader
47 | def yaml2dict(path):
48 | if isinstance(path, dict):
49 | # Forgivable mistake that path is a dict already
50 | return path
51 | with open(path, 'r') as f:
52 | readict = yaml.load(f, Loader=yaml.FullLoader)
53 | return readict
54 |
55 |
56 | def print_tensor(tensor, prefix, directory):
57 | """Prints a image or volume tensor to file as images."""
58 | def _print_image(image, prefix, batch, channel, z=None):
59 | if z is None:
60 | file_name = "{}--B-{}--CH-{}.png".format(prefix, batch, channel)
61 | else:
62 | file_name = "{}--B-{}--CH-{}--Z-{}.png".format(prefix, batch, channel, z)
63 | full_file_name = os.path.join(directory, file_name)
64 | imsave(arr=image, fname=full_file_name)
65 |
66 | for batch in range(tensor.shape[0]):
67 | for channel in range(tensor.shape[1]):
68 | if tensor.ndim == 4:
69 | _print_image(tensor[batch, channel, ...], prefix, batch, channel)
70 | else:
71 | for plane in range(tensor.shape[2]):
72 | _print_image(tensor[batch, channel, plane, ...], prefix, batch, channel, plane)
73 |
--------------------------------------------------------------------------------
/inferno/extensions/initializers/presets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn.init as init
3 | from functools import partial
4 |
5 | from .base import Initialization, Initializer
6 |
7 |
8 | __all__ = ['Constant', 'NormalWeights',
9 | 'SELUWeightsZeroBias',
10 | 'ELUWeightsZeroBias',
11 | 'OrthogonalWeightsZeroBias',
12 | 'KaimingNormalWeightsZeroBias']
13 |
14 |
15 | class Constant(Initializer):
16 | """Initialize with a constant."""
17 | def __init__(self, constant):
18 | self.constant = constant
19 |
20 | def call_on_tensor(self, tensor):
21 | tensor.fill_(self.constant)
22 | return tensor
23 |
24 |
25 | class NormalWeights(Initializer):
26 | """
27 | Initialize weights with random numbers drawn from the normal distribution at
28 | `mean` and `stddev`.
29 | """
30 | def __init__(self, mean=0., stddev=1., sqrt_gain_over_fan_in=None):
31 | self.mean = mean
32 | self.stddev = stddev
33 | self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in
34 |
35 | def compute_fan_in(self, tensor):
36 | if tensor.dim() == 2:
37 | return tensor.size(1)
38 | else:
39 | return np.prod(list(tensor.size())[1:])
40 |
41 | def call_on_weight(self, tensor):
42 | # Compute stddev if required
43 | if self.sqrt_gain_over_fan_in is not None:
44 | stddev = self.stddev * \
45 | np.sqrt(self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor))
46 | else:
47 | stddev = self.stddev
48 | # Init
49 | tensor.normal_(self.mean, stddev)
50 |
51 |
52 | class OrthogonalWeightsZeroBias(Initialization):
53 | def __init__(self, orthogonal_gain=1.):
54 | # This prevents a deprecated warning in Pytorch 0.4+
55 | orthogonal = getattr(init, 'orthogonal_', init.orthogonal)
56 | super(OrthogonalWeightsZeroBias, self)\
57 | .__init__(weight_initializer=partial(orthogonal, gain=orthogonal_gain),
58 | bias_initializer=Constant(0.))
59 |
60 |
61 | class KaimingNormalWeightsZeroBias(Initialization):
62 | def __init__(self, relu_leakage=0):
63 | # This prevents a deprecated warning in Pytorch 0.4+
64 | kaiming_normal = getattr(init, 'kaiming_normal_', init.kaiming_normal)
65 | super(KaimingNormalWeightsZeroBias, self)\
66 | .__init__(weight_initializer=partial(kaiming_normal, a=relu_leakage),
67 | bias_initializer=Constant(0.))
68 |
69 |
70 | class SELUWeightsZeroBias(Initialization):
71 | def __init__(self):
72 | super(SELUWeightsZeroBias, self)\
73 | .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.),
74 | bias_initializer=Constant(0.))
75 |
76 |
77 | class ELUWeightsZeroBias(Initialization):
78 | def __init__(self):
79 | super(ELUWeightsZeroBias, self)\
80 | .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277),
81 | bias_initializer=Constant(0.))
82 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_models/test_res_unet.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torch.cuda as cuda
4 | from inferno.utils.model_utils import ModelTester
5 |
6 |
7 | class ResUNetTest(unittest.TestCase):
8 | def test_res_unet_2d(self):
9 | from inferno.extensions.models import ResBlockUNet
10 | tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))
11 | if cuda.is_available():
12 | tester.cuda()
13 | tester(ResBlockUNet(in_channels=1, out_channels=1, dim=2))
14 |
15 | def test_res_unet_3d(self):
16 | from inferno.extensions.models import ResBlockUNet
17 | tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))
18 | if cuda.is_available():
19 | tester.cuda()
20 | # test default unet 3d
21 | tester(ResBlockUNet(in_channels=1, out_channels=1, dim=3))
22 |
23 | def test_2d_side_out_bot_up(self):
24 | from inferno.extensions.models import ResBlockUNet
25 | depth = 3
26 | in_channels = 3
27 |
28 | x = torch.rand(1, in_channels, 64, 32)
29 | model = ResBlockUNet(in_channels=in_channels,
30 | out_channels=8, dim=2,
31 | side_out_parts=['bottom','up'],
32 | unet_kwargs=dict(depth=depth))
33 |
34 | out_list = model(x)
35 | self.assertEqual(len(out_list), depth + 1)
36 |
37 | self.assertEqual(list(out_list[0].size()), [1, 24, 8, 4])
38 | self.assertEqual(list(out_list[1].size()), [1, 12, 16, 8])
39 | self.assertEqual(list(out_list[2].size()), [1, 6, 32, 16])
40 | self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32])
41 |
42 | def test_2d_side_out_up(self):
43 | from inferno.extensions.models import ResBlockUNet
44 | depth = 3
45 | in_channels = 3
46 |
47 | x = torch.rand(1, in_channels, 64, 32)
48 | model = ResBlockUNet(in_channels=in_channels,
49 | out_channels=8, dim=2,
50 | side_out_parts=['up'],
51 | unet_kwargs=dict(depth=depth))
52 |
53 | out_list = model(x)
54 | self.assertEqual(len(out_list), depth)
55 |
56 | self.assertEqual(list(out_list[0].size()), [1,12, 16, 8])
57 | self.assertEqual(list(out_list[1].size()), [1, 6, 32, 16])
58 | self.assertEqual(list(out_list[2].size()), [1, 8, 64, 32])
59 |
60 | def test_2d_side_out_down(self):
61 | from inferno.extensions.models import ResBlockUNet
62 | depth = 3
63 | in_channels = 3
64 |
65 | x = torch.rand(1, in_channels, 64, 32)
66 | model = ResBlockUNet(in_channels=in_channels,
67 | out_channels=8, dim=2,
68 | side_out_parts=['down'],
69 | unet_kwargs=dict(depth=depth))
70 |
71 | out_list = model(x)
72 | self.assertEqual(len(out_list), depth + 1)
73 |
74 | self.assertEqual(list(out_list[0].size()), [1, 6, 64, 32])
75 | self.assertEqual(list(out_list[1].size()), [1, 12, 32, 16])
76 | self.assertEqual(list(out_list[2].size()), [1, 24, 16, 8])
77 |
78 | # the actual output
79 | self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32])
80 |
81 |
82 | if __name__ == '__main__':
83 | unittest.main()
84 |
--------------------------------------------------------------------------------
/inferno/extensions/criteria/core.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from functools import reduce
3 | from ...utils.exceptions import assert_, ShapeError, NotTorchModuleError
4 |
5 |
6 | __all__ = ['Criteria', 'As2DCriterion']
7 |
8 |
9 | class Criteria(nn.Module):
10 | """Aggregate multiple criteria to one."""
11 | def __init__(self, *criteria):
12 | super(Criteria, self).__init__()
13 | if len(criteria) == 1 and isinstance(criteria[0], (list, tuple)):
14 | criteria = list(criteria[0])
15 | else:
16 | criteria = list(criteria)
17 | # Validate criteria
18 | assert all([isinstance(criterion, nn.Module) for criterion in criteria]), \
19 | "Criterion must be a torch module."
20 | self.criteria = criteria
21 |
22 | def forward(self, prediction, target):
23 | assert isinstance(prediction, (list, tuple)), \
24 | "`prediction` must be a list or a tuple, got {} instead."\
25 | .format(type(prediction).__name__)
26 | assert isinstance(target, (list, tuple)), \
27 | "`prediction` must be a list or a tuple, got {} instead." \
28 | .format(type(target).__name__)
29 | assert len(prediction) == len(target), \
30 | "Number of predictions must equal the number of targets. " \
31 | "Got {} predictions but {} targets.".format(len(prediction), len(target))
32 | # Compute losses
33 | losses = [criterion(prediction, target)
34 | for _prediction, _target, criterion in zip(prediction, target, self.criteria)]
35 | # Aggegate losses
36 | loss = reduce(lambda x, y: x + y, losses)
37 | # Done
38 | return loss
39 |
40 |
41 | class As2DCriterion(nn.Module):
42 | """
43 | Makes a given criterion applicable on (N, C, H, W) prediction and (N, H, W) target tensors,
44 | if they're applicable to (N, C) prediction and (N,) target tensors .
45 | """
46 | def __init__(self, criterion):
47 | super(As2DCriterion, self).__init__()
48 | assert_(isinstance(criterion, nn.Module),
49 | "Criterion must be a module, got a {} instead."
50 | .format(type(criterion).__name__),
51 | NotTorchModuleError)
52 | self.criterion = criterion
53 |
54 | def forward(self, prediction, target):
55 | # Validate input
56 | assert_(prediction.dim() == 4, "`prediction` is expected to be a 4D tensor of shape "
57 | "(N, C, H, W), got a {}D "
58 | "tensor instead.".format(prediction.dim()),
59 | ShapeError)
60 | assert_(target.dim() == 3, "`target` is expected to be a 3D tensor of shape "
61 | "(N, H, W), got a {}D "
62 | "tensor instead.".format(target.dim()),
63 | ShapeError)
64 | # prediction is assumed to be NCHW, and target NHW.
65 | # this makes target (NHW,)
66 | target = target.contiguous().view(-1)
67 | # This makes prediction (N, H, W, C) --> (NHW, C)
68 | num_channels = prediction.size(1)
69 | prediction = prediction.permute(0, 2, 3, 1).contiguous().view(-1, num_channels)
70 | # Now, the criterion should be applicable as is
71 | loss = self.criterion(prediction, target)
72 | return loss
73 |
--------------------------------------------------------------------------------
/inferno/extensions/optimizers/adam.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim import Optimizer
3 |
4 |
5 | class Adam(Optimizer):
6 | """Implements Adam algorithm with the option of adding a L1 penalty.
7 |
8 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
9 |
10 | Arguments:
11 | params (iterable): iterable of parameters to optimize or dicts defining
12 | parameter groups
13 | lr (float, optional): learning rate (default: 1e-3)
14 | betas (Tuple[float, float], optional): coefficients used for computing
15 | running averages of gradient and its square (default: (0.9, 0.999))
16 | eps (float, optional): term added to the denominator to improve
17 | numerical stability (default: 1e-8)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 |
20 | .. _Adam\: A Method for Stochastic Optimization:
21 | https://arxiv.org/abs/1412.6980
22 | """
23 |
24 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
25 | lambda_l1=0, weight_decay=0, **kwargs):
26 | defaults = dict(lr=lr, betas=betas, eps=eps,
27 | lambda_l1=lambda_l1, weight_decay=weight_decay,
28 | **kwargs)
29 | super(Adam, self).__init__(params, defaults)
30 |
31 | def step(self, closure=None):
32 | """Performs a single optimization step.
33 |
34 | Arguments:
35 | closure (callable, optional): A closure that reevaluates the model
36 | and returns the loss.
37 | """
38 | loss = None
39 | if closure is not None:
40 | loss = closure()
41 |
42 | for group in self.param_groups:
43 | for p in group['params']:
44 | if p.grad is None:
45 | continue
46 | grad = p.grad.data
47 | state = self.state[p]
48 |
49 | # State initialization
50 | if len(state) == 0:
51 | state['step'] = 0
52 | # Exponential moving average of gradient values
53 | state['exp_avg'] = grad.new().resize_as_(grad).zero_()
54 | # Exponential moving average of squared gradient values
55 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
56 |
57 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
58 | beta1, beta2 = group['betas']
59 |
60 | state['step'] += 1
61 |
62 | if group['lambda_l1'] != 0:
63 | grad.add_(group['lambda_l1'], p.data.sign())
64 | if group['weight_decay'] != 0:
65 | grad.add_(group['weight_decay'], p.data)
66 |
67 | # Decay the first and second moment running average coefficient
68 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
69 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
70 |
71 | denom = exp_avg_sq.sqrt().add_(group['eps'])
72 |
73 | bias_correction1 = 1 - beta1 ** state['step']
74 | bias_correction2 = 1 - beta2 ** state['step']
75 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
76 |
77 | p.data.addcdiv_(-step_size, exp_avg, denom)
78 |
79 | return loss
80 |
--------------------------------------------------------------------------------
/inferno/extensions/layers/sampling.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | __all__ = ['AnisotropicUpsample', 'AnisotropicPool', 'Upsample', 'AnisotropicUpsample2D', 'AnisotropicPool2D']
4 |
5 |
6 | # torch is deprecating nn.Upsample in favor of nn.functional.interpolate
7 | # we wrap interpolate here to still use Upsample as class
8 | class Upsample(nn.Module):
9 | def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
10 | self.size = size
11 | self.scale_factor = scale_factor
12 | self.mode = mode
13 | self.align_corners = align_corners
14 | super(Upsample, self).__init__()
15 | # interpolate was only introduced in torch 0.4.1 for backward compatibility
16 | # we check if we have the attribute here and fall back to Upsample otherwise
17 | if hasattr(nn.functional, 'interpolate'):
18 | self.have_interpolate = True
19 | else:
20 | self.have_interpolate = False
21 | self.sampler = nn.Upsample(size=size, scale_factor=scale_factor,
22 | mode=mode, align_corners=align_corners)
23 |
24 | def forward(self, input):
25 | if self.have_interpolate:
26 | return nn.functional.interpolate(input, self.size, self.scale_factor,
27 | self.mode, self.align_corners)
28 | else:
29 | return self.sampler(input)
30 |
31 |
32 | class AnisotropicUpsample(nn.Module):
33 | def __init__(self, scale_factor):
34 | super(AnisotropicUpsample, self).__init__()
35 | self.upsampler = Upsample(scale_factor=scale_factor)
36 |
37 | def forward(self, input):
38 | # input is 3D of shape NCDHW
39 | N, C, D, H, W = input.size()
40 | # Fold C and D axes in one
41 | folded = input.view(N, C * D, H, W)
42 | # Upsample
43 | upsampled = self.upsampler(folded)
44 | # Unfold out the C and D axes
45 | unfolded = upsampled.view(N, C, D,
46 | self.upsampler.scale_factor * H,
47 | self.upsampler.scale_factor * W)
48 | # Done
49 | return unfolded
50 |
51 |
52 | class AnisotropicPool(nn.MaxPool3d):
53 | def __init__(self, downscale_factor):
54 | ds = downscale_factor
55 | super(AnisotropicPool, self).__init__(kernel_size=(1, ds + 1, ds + 1),
56 | stride=(1, ds, ds),
57 | padding=(0, 1, 1))
58 |
59 | class AnisotropicUpsample2D(nn.Module):
60 | def __init__(self, scale_factor):
61 | super(AnisotropicUpsample2D, self).__init__()
62 | self.upsampler = nn.Upsample(scale_factor=scale_factor)
63 |
64 | def forward(self, input):
65 | # input is 2D of shape NCDW (or NCDH, egal)
66 | N, C, D, W = input.size()
67 | # Fold C and D axes in one
68 | folded = input.view(N, C * D, W)
69 | # Upsample
70 | upsampled = self.upsampler(folded)
71 | # Unfold out the C and D axes
72 | unfolded = upsampled.view(N, C, D,
73 | self.upsampler.scale_factor * W)
74 | # Done
75 | return unfolded
76 |
77 |
78 | class AnisotropicPool2D(nn.MaxPool2d):
79 | def __init__(self, downscale_factor):
80 | ds = downscale_factor
81 | super(AnisotropicPool2D, self).__init__(kernel_size=(1, ds + 1),
82 | stride=(1, ds),
83 | padding=(0, 1))
84 |
85 |
--------------------------------------------------------------------------------
/tests/test_utils/test_partial_cls.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import inferno.utils.model_utils as mu
3 | from inferno.utils.partial_cls import register_partial_cls
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class TestCls(object):
9 | def __init__(self, a, b, c=1, d=2):
10 | self.a = a
11 | self.b = b
12 | self.c = c
13 | self.d = d
14 |
15 | class PartialClsTester(unittest.TestCase):
16 |
17 | def test_partial_cls(self):
18 | register_partial_cls(TestCls, 'TestA',
19 | fix=dict(a='a'),
20 | default=dict(b='b'),
21 | module=__name__
22 | )
23 | assert 'TestA' in globals()
24 |
25 | inst = TestA()
26 | assert inst.a == 'a'
27 | assert inst.b == 'b'
28 | assert inst.c == 1
29 | assert inst.d == 2
30 |
31 | inst = TestA('fu','bar','fubar')
32 | assert inst.a == 'a'
33 | assert inst.b == 'fu'
34 | assert inst.c == 'bar'
35 | assert inst.d == 'fubar'
36 |
37 | with self.assertRaises(TypeError):
38 | inst = TestA(a=2)
39 |
40 | def test_update_existing_default_cls(self):
41 | register_partial_cls(TestCls, 'TestA',
42 | fix=dict(a='a'),
43 | default=dict(d=3),
44 | module=__name__
45 | )
46 | assert 'TestA' in globals()
47 |
48 | inst = TestA(42)
49 | assert inst.a == 'a'
50 | assert inst.b == 42
51 | assert inst.c == 1
52 | assert inst.d == 3
53 |
54 | with self.assertRaises(TypeError):
55 | inst = TestA()
56 |
57 | def test_fix_nothing(self):
58 | register_partial_cls(TestCls, 'TestA',
59 | module=__name__
60 | )
61 | assert 'TestA' in globals()
62 |
63 | inst = TestA(1,2,3,4)
64 | assert inst.a == 1
65 | assert inst.b == 2
66 | assert inst.c == 3
67 | assert inst.d == 4
68 |
69 | with self.assertRaises(TypeError):
70 | inst = TestA()
71 |
72 | def test_fix_all(self):
73 | register_partial_cls(TestCls, 'TestA',
74 | module=__name__,
75 | fix=dict(a=4, b=3, c=2, d=1)
76 | )
77 | assert 'TestA' in globals()
78 |
79 | inst = TestA()
80 | assert inst.a == 4
81 | assert inst.b == 3
82 | assert inst.c == 2
83 | assert inst.d == 1
84 |
85 | with self.assertRaises(TypeError):
86 | inst = TestA('a')
87 |
88 | with self.assertRaises(TypeError):
89 | inst = TestA(a=1)
90 | with self.assertRaises(TypeError):
91 | inst = TestA(b=1)
92 | with self.assertRaises(TypeError):
93 | inst = TestA(c=1)
94 | with self.assertRaises(TypeError):
95 | inst = TestA(d=1)
96 |
97 |
98 | def test_default_all(self):
99 | register_partial_cls(TestCls, 'TestA',
100 | module=__name__,
101 | default=dict(a=4, b=3, c=2, d=1)
102 | )
103 | assert 'TestA' in globals()
104 |
105 | inst = TestA()
106 | assert inst.a == 4
107 | assert inst.b == 3
108 | assert inst.c == 2
109 | assert inst.d == 1
110 |
111 |
112 | inst = TestA(2)
113 | assert inst.a == 2
114 | assert inst.b == 3
115 | assert inst.c == 2
116 | assert inst.d == 1
117 |
118 | inst = TestA(2,3,4,5)
119 | assert inst.a == 2
120 | assert inst.b == 3
121 | assert inst.c == 4
122 | assert inst.d == 5
123 |
124 | with self.assertRaises(TypeError):
125 | inst = TestA(3,4,5,a=2)
126 |
127 | inst = TestA(3,4,5,d=2)
128 | assert inst.a == 3
129 | assert inst.b == 4
130 | assert inst.c == 5
131 | assert inst.d == 2
132 |
133 |
134 |
135 |
136 | if __name__ == '__main__':
137 | unittest.main()
138 |
--------------------------------------------------------------------------------
/inferno/trainers/callbacks/tqdm.py:
--------------------------------------------------------------------------------
1 | from .base import Callback
2 | from tqdm import tqdm
3 | from datetime import datetime
4 | from .console import Console
5 |
6 |
7 | class TQDMPrinter(object):
8 | def __init__(self, progress):
9 | self._progress = progress
10 |
11 | def print(self, message):
12 | if self._progress.outer_bar is not None:
13 | self._progress.outer_bar.clear()
14 | tqdm.write(message)
15 | if self._progress.outer_bar is not None:
16 | self._progress.outer_bar.refresh()
17 |
18 |
19 | class TQDMConsole(Console):
20 | def __init__(self):
21 | super(TQDMConsole, self).__init__(printer=TQDMPrinter(TQDMProgressBar()))
22 |
23 |
24 | class TQDMProgressBar(Callback):
25 | def __init__(self, *args, **kwargs):
26 | super(TQDMProgressBar, self).__init__(*args, **kwargs)
27 | self.epoch_bar = None
28 | self.outer_bar = None
29 | self.is_training = False
30 | self.is_validation = False
31 |
32 | def bind_trainer(self, *args, **kwargs):
33 | super(TQDMProgressBar, self).bind_trainer(*args, **kwargs)
34 | self.trainer.console.toggle_progress(False)
35 | self.trainer.console.set_console(TQDMPrinter(self))
36 |
37 | def _init_epoch_bar_train(self):
38 | n_batch = len(self.trainer._loader_iters['train'])
39 | self.epoch_bar = tqdm(total=n_batch, position=1, dynamic_ncols=True)
40 | self.epoch_bar.update(self.trainer._batch_count)
41 | self.epoch_bar.set_description("Training epoch %d" % self.trainer._epoch_count)
42 |
43 | def print(self, message, **_):
44 | if self.outer_bar is not None:
45 | self.outer_bar.clear()
46 | tqdm.write("[+][{}] {}".format(str(datetime.now()), message))
47 | if self.outer_bar is not None:
48 | self.outer_bar.refresh()
49 |
50 | def begin_of_fit(self, max_num_epochs, **_):
51 | if isinstance(max_num_epochs, int):
52 | self.outer_bar = tqdm(total=max_num_epochs, position=0, dynamic_ncols=True)
53 | else:
54 | self.outer_bar = tqdm(total=1000, position=0, dynamic_ncols=True)
55 | self.outer_bar.set_description("Epochs")
56 |
57 | def end_of_fit(self, **_):
58 | if self.outer_bar is not None:
59 | self.outer_bar.close()
60 | self.outer_bar = None
61 |
62 | def begin_of_epoch(self, **_):
63 | if self.epoch_bar is not None:
64 | self.epoch_bar.close()
65 |
66 | def end_of_epoch(self, **_):
67 | if self.outer_bar is not None:
68 | self.outer_bar.update(1)
69 |
70 | def begin_of_training_iteration(self, **_):
71 | if not self.epoch_bar and 'train' in self.trainer._loader_iters.keys():
72 | self._init_epoch_bar_train()
73 | return
74 |
75 | if self.epoch_bar:
76 | self.epoch_bar.update(1)
77 |
78 | def begin_of_validation_iteration(self, **_):
79 | if self.epoch_bar:
80 | self.epoch_bar.update(1)
81 |
82 | def begin_of_training_run(self, **_):
83 | self.is_training = True
84 |
85 | def end_of_training_run(self, **_):
86 | self.is_training = False
87 | if self.epoch_bar:
88 | self.epoch_bar.close()
89 | self.epoch_bar = None
90 |
91 | def begin_of_validation_run(self, num_iterations, num_iterations_in_generator, last_validated_at_epoch, **_):
92 | self.is_validation = True
93 | nmax = num_iterations
94 | if not nmax:
95 | nmax = num_iterations_in_generator
96 |
97 | self.epoch_bar = tqdm(total=nmax, position=1, dynamic_ncols=True)
98 | self.epoch_bar.set_description("Validating epoch %d" % (last_validated_at_epoch-1))
99 |
100 | def end_of_validation_run(self, **_):
101 | self.is_validation = False
102 | if self.epoch_bar:
103 | self.epoch_bar.close()
104 | self.epoch_bar = None
105 |
--------------------------------------------------------------------------------
/inferno/extensions/layers/device.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from ...utils.python_utils import from_iterable, to_iterable
3 | from ...utils.exceptions import assert_, DeviceError
4 |
5 | __all__ = ['DeviceTransfer', 'OnDevice']
6 | _all = __all__
7 |
8 |
9 | class DeviceTransfer(nn.Module):
10 | """Layer to transfer variables to a specified device."""
11 | def __init__(self, target_device, device_ordinal=None, asynchronous=False):
12 | """
13 | Parameters
14 | ----------
15 | target_device : {'cpu', 'cuda'}
16 | Device to transfer to.
17 | device_ordinal : int
18 | Device ordinal if target_device == 'cuda'.
19 | asynchronous : bool
20 | Whether to use asynchronous transfers.
21 | """
22 | super(DeviceTransfer, self).__init__()
23 | # Validate arguments
24 | assert_(target_device in ['cpu', 'cuda'],
25 | "Target device must either be 'cpu' or 'cuda'.",
26 | DeviceError)
27 | if target_device == 'cpu':
28 | assert_(device_ordinal is None,
29 | "'device_ordinal' must be None if target_device is 'cpu'.",
30 | DeviceError)
31 | self.target_device = target_device
32 | self.device_ordinal = device_ordinal
33 |
34 | def forward(self, *inputs):
35 | if self.target_device == 'cuda':
36 | transferred = tuple(input_.cuda(device=self.device_ordinal,
37 | non_blocking=self.asynchronous)
38 | for input_ in inputs)
39 | elif self.target_device == 'cpu':
40 | transferred = tuple(input_.cpu() for input_ in inputs)
41 | else:
42 | raise NotImplementedError
43 | return from_iterable(transferred)
44 |
45 |
46 | class OnDevice(nn.Module):
47 | """
48 | Moves a module to a device. The advantage of using this over `torch.nn.Module.cuda` is
49 | that the inputs are transferred to the same device as the module, enabling easy model
50 | parallelism.
51 | """
52 | def __init__(self, module, target_device, device_ordinal=None, asynchronous=False):
53 | """
54 | Parameters
55 | ----------
56 | module : torch.nn.Module
57 | Module to transfer to device.
58 | target_device : {'cuda', 'cpu'}
59 | The device to move `module` to. Must be either 'cuda' or 'cpu'.
60 | device_ordinal : int
61 | Ordinal of the GPU device if `target_device = 'cuda'`.
62 | asynchronous : bool
63 | Whether to use asynchronous transfers.
64 | """
65 | super(OnDevice, self).__init__()
66 | # Validate arguments
67 | assert_(target_device in ['cpu', 'cuda'],
68 | "Target device must either be 'cpu' or 'cuda'.",
69 | DeviceError)
70 | if target_device == 'cpu':
71 | assert_(device_ordinal is None,
72 | "'device_ordinal' must be None if target_device is 'cpu'.",
73 | DeviceError)
74 | self.target_device = target_device
75 | self.device_ordinal = device_ordinal
76 | self.asynchronous = asynchronous
77 | # This is a no-op if module is already in the right device
78 | self.device_transfer = DeviceTransfer(self.target_device,
79 | device_ordinal=self.device_ordinal,
80 | asynchronous=self.asynchronous)
81 |
82 | self.module = self.transfer_module(module)
83 |
84 | def transfer_module(self, module):
85 | if self.target_device == 'cuda':
86 | return module.cuda(device_id=self.device_ordinal)
87 | elif self.target_device == 'cpu':
88 | return module.cpu()
89 | else:
90 | raise NotImplementedError
91 |
92 | def forward(self, *inputs):
93 | # Transfer inputs (no-op if they're already on the right device)
94 | transferred = to_iterable(self.device_transfer(*inputs))
95 | output = self.module(*transferred)
96 | return output
97 |
--------------------------------------------------------------------------------
/tests/test_training/test_callbacks/test_essentials.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import shutil
3 | import h5py as h5
4 | from os.path import dirname, join
5 | from os import listdir
6 | from inferno.trainers.basic import Trainer
7 | from inferno.trainers.callbacks.essentials import DumpHDF5Every
8 | from inferno.utils.test_utils import generate_random_dataloader
9 | from inferno.extensions.layers import Conv2D, AsMatrix
10 | from torch.nn import Sequential, MaxPool2d, AdaptiveAvgPool2d, Linear, Softmax
11 |
12 |
13 | class TestEssentials(unittest.TestCase):
14 | WORKING_DIRECTORY = dirname(__file__)
15 |
16 | def setUp(self):
17 | # Build a simple ass model
18 | model = Sequential(Conv2D(3, 8, 3, activation='ReLU'),
19 | MaxPool2d(2, 2),
20 | Conv2D(8, 8, 3, activation='ReLU'),
21 | MaxPool2d(2, 2),
22 | Conv2D(8, 8, 3, activation='ReLU'),
23 | MaxPool2d(2, 2),
24 | Conv2D(8, 8, 3, activation='ReLU'),
25 | AdaptiveAvgPool2d((1, 1)),
26 | AsMatrix(),
27 | Linear(8, 10))
28 |
29 | train_dataloader = generate_random_dataloader(512, (3, 32, 32), 10, batch_size=16,
30 | dtype='float32')
31 | validate_dataloader = generate_random_dataloader(32, (3, 32, 32), 10, batch_size=16,
32 | dtype='float32')
33 | # Build trainer
34 | trainer = Trainer(model)\
35 | .bind_loader('train', train_dataloader)\
36 | .bind_loader('validate', validate_dataloader)\
37 | .save_to_directory(to_directory=join(self.WORKING_DIRECTORY, 'Weights'))\
38 | .build_criterion('CrossEntropyLoss').build_optimizer('RMSprop')
39 | self.trainer = trainer
40 |
41 | def test_dump_hdf5_every(self):
42 | # Configure callback
43 | dumper = DumpHDF5Every((1, 'epoch'),
44 | to_directory=join(self.WORKING_DIRECTORY, 'Weights'),
45 | dump_after_every_validation_run=True)
46 | self.trainer\
47 | .set_max_num_epochs(4)\
48 | .register_callback(dumper)\
49 | .validate_every((16, 'iterations'))
50 |
51 | self.trainer.fit()
52 | all_files = listdir(join(self.WORKING_DIRECTORY, 'Weights'))
53 | for epoch in range(5):
54 | self.assertIn('dump.training.epoch{}.iteration{}.h5'.format(epoch, epoch * 32),
55 | all_files)
56 | # We don't validate at last epoch
57 | if epoch != 4:
58 | self.assertIn('dump.validation.epoch{}.iteration{}.h5'
59 | .format(epoch, (epoch * 32) + 16),
60 | all_files)
61 | self.assertIn('dump.validation.epoch{}.iteration{}.h5'
62 | .format(epoch, (epoch * 32) + 32),
63 | all_files)
64 |
65 | # Check if the keys are right in a training dump
66 | sample_file_path = join(self.WORKING_DIRECTORY, 'Weights',
67 | 'dump.training.epoch0.iteration0.h5')
68 | with h5.File(sample_file_path, 'r') as sample_file:
69 | all_dataset_names = list(sample_file.keys())
70 | self.assertSequenceEqual(all_dataset_names,
71 | ['training_inputs_0', 'training_prediction', 'training_target'])
72 | # Check if the keys are right in a validation dump
73 | sample_file_path = join(self.WORKING_DIRECTORY, 'Weights',
74 | 'dump.validation.epoch0.iteration16.h5')
75 | with h5.File(sample_file_path, 'r') as sample_file:
76 | all_dataset_names = list(sample_file.keys())
77 | self.assertSequenceEqual(all_dataset_names,
78 | ['validation_inputs_0', 'validation_prediction',
79 | 'validation_target'])
80 |
81 | def tearDown(self):
82 | shutil.rmtree(join(self.WORKING_DIRECTORY, 'Weights'))
83 |
84 |
85 | if __name__ == '__main__':
86 | unittest.main()
87 |
--------------------------------------------------------------------------------
/CONTRIBUTING.rst:
--------------------------------------------------------------------------------
1 | .. highlight:: shell
2 |
3 | ============
4 | Contributing
5 | ============
6 |
7 | Contributions are welcome, and they are greatly appreciated! Every
8 | little bit helps, and credit will always be given.
9 |
10 | You can contribute in many ways:
11 |
12 | Types of Contributions
13 | ----------------------
14 |
15 | Report Bugs
16 | ~~~~~~~~~~~
17 |
18 | Report bugs at https://github.com/nasimrahaman/inferno/issues.
19 |
20 | If you are reporting a bug, please include:
21 |
22 | * Your operating system name and version.
23 | * Any details about your local setup that might be helpful in troubleshooting.
24 | * Detailed steps to reproduce the bug.
25 |
26 | Fix Bugs
27 | ~~~~~~~~
28 |
29 | Look through the GitHub issues for bugs. Anything tagged with "bug"
30 | and "help wanted" is open to whoever wants to implement it.
31 |
32 | Implement Features
33 | ~~~~~~~~~~~~~~~~~~
34 |
35 | Look through the GitHub issues for features. Anything tagged with "enhancement"
36 | and "help wanted" is open to whoever wants to implement it.
37 |
38 | Write Documentation
39 | ~~~~~~~~~~~~~~~~~~~
40 |
41 | inferno could always use more documentation, whether as part of the
42 | official inferno docs, in docstrings, or even on the web in blog posts,
43 | articles, and such.
44 |
45 | Submit Feedback
46 | ~~~~~~~~~~~~~~~
47 |
48 | The best way to send feedback is to file an issue at https://github.com/nasimrahaman/inferno/issues.
49 |
50 | If you are proposing a feature:
51 |
52 | * Explain in detail how it would work.
53 | * Keep the scope as narrow as possible, to make it easier to implement.
54 | * Remember that this is a volunteer-driven project, and that contributions
55 | are welcome :)
56 |
57 | Get Started!
58 | ------------
59 |
60 | Ready to contribute? Here's how to set up `inferno` for local development.
61 |
62 | 1. Fork the `inferno` repo on GitHub.
63 | 2. Clone your fork locally::
64 |
65 | $ git clone git@github.com:your_name_here/inferno.git
66 |
67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development::
68 |
69 | $ mkvirtualenv inferno
70 | $ cd inferno/
71 | $ python setup.py develop
72 |
73 | 4. Create a branch for local development::
74 |
75 | $ git checkout -b name-of-your-bugfix-or-feature
76 |
77 | Now you can make your changes locally.
78 |
79 | 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox::
80 |
81 | $ flake8 inferno tests
82 | $ python setup.py test or py.test
83 | $ tox
84 |
85 | To get flake8 and tox, just pip install them into your virtualenv.
86 |
87 | 6. Commit your changes and push your branch to GitHub::
88 |
89 | $ git add .
90 | $ git commit -m "Your detailed description of your changes."
91 | $ git push origin name-of-your-bugfix-or-feature
92 |
93 | 7. Submit a pull request through the GitHub website.
94 |
95 | Pull Request Guidelines
96 | -----------------------
97 |
98 | Before you submit a pull request, check that it meets these guidelines:
99 |
100 | 1. The pull request should include tests.
101 | 2. If the pull request adds functionality, the docs should be updated. Put
102 | your new functionality into a function with a docstring, and add the
103 | feature to the list in README.rst.
104 | 3. The pull request should work for Python 3.5 and 3.6. Check
105 | https://travis-ci.org/nasimrahaman/inferno/pull_requests
106 | and make sure that the tests pass for all supported Python versions.
107 |
108 | Tips
109 | ----
110 |
111 | To run a subset of tests::
112 |
113 | $ python -m unittest tests.test_inferno
114 |
115 |
116 |
117 | Sphinx Apidoc
118 | --------------
119 | before building the documentation
120 | one needs to generate the auto-generated
121 | sphinxs api documentation.
122 | These files need to be in the github repository.
123 |
124 | .. code:: bash
125 |
126 | cd docs
127 | sphinx-apidoc -o inferno-apidoc ../inferno
128 |
129 | .. warning::
130 |
131 | Do not make changes to `inferno/docs/inferno-apidoc` This folder is auto-generated
132 | by the above mentioned command.
133 |
134 | The following combines all the commands necessary to build the html documentation:
135 |
136 | .. code:: bash
137 |
138 | ./build_docs.sh
139 |
140 |
--------------------------------------------------------------------------------
/tests/test_training/test_callbacks/test_logging/test_tensorboard.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import os
4 | from shutil import rmtree
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from inferno.trainers.basic import Trainer
9 | from torch.utils.data.dataset import TensorDataset
10 | from torch.utils.data.dataloader import DataLoader
11 | from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
12 | from inferno.extensions.layers.reshape import AsMatrix
13 |
14 |
15 | class TestTensorboard(unittest.TestCase):
16 | ROOT_DIR = os.path.dirname(__file__)
17 | PRECISION = 'float'
18 | SAVE_DIRECTORY = os.path.join(ROOT_DIR, 'saves')
19 | LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs')
20 |
21 | @staticmethod
22 | def _make_test_model(input_channels):
23 | toy_net = nn.Sequential(nn.Conv2d(input_channels, 8, 3, 1, 1),
24 | nn.ELU(),
25 | nn.MaxPool2d(2),
26 | nn.Conv2d(8, 8, 3, 1, 1),
27 | nn.ELU(),
28 | nn.MaxPool2d(2),
29 | nn.Conv2d(8, 16, 3, 1, 1),
30 | nn.ELU(),
31 | nn.AdaptiveMaxPool2d((1, 1)),
32 | AsMatrix(),
33 | nn.Linear(16, 10))
34 | return toy_net
35 |
36 | def tearDown(self):
37 | for d in [self.SAVE_DIRECTORY, self.LOG_DIRECTORY]:
38 | try:
39 | rmtree(d)
40 | except OSError:
41 | pass
42 |
43 | def get_random_dataloaders(self, input_channels=3):
44 | # Convert build random tensor dataset
45 | data_shape = (1, input_channels, 64, 64)
46 | target_shape = (1)
47 | random_array = torch.from_numpy(np.random.rand(*data_shape)).float()
48 | target_array = torch.from_numpy(np.random.randint(0, 9, size=target_shape))
49 | train_dataset = TensorDataset(random_array, target_array)
50 | test_dataset = TensorDataset(random_array, target_array)
51 |
52 | # Build dataloaders from dataset
53 | train_loader = DataLoader(train_dataset, batch_size=1,
54 | shuffle=True, num_workers=0, pin_memory=False)
55 | test_loader = DataLoader(test_dataset, batch_size=1,
56 | shuffle=True, num_workers=0, pin_memory=False)
57 | return train_loader, test_loader
58 |
59 | def get_trainer(self, input_channels):
60 | # Build model
61 | net = self._make_test_model(input_channels)
62 | # Build trainer
63 | trainer = Trainer(net)\
64 | .build_logger(TensorboardLogger(send_image_at_batch_indices=0,
65 | send_image_at_channel_indices='all',
66 | log_images_every=(20, 'iterations')),
67 | log_directory=self.LOG_DIRECTORY)\
68 | .build_criterion('CrossEntropyLoss')\
69 | .build_metric('CategoricalError')\
70 | .build_optimizer('Adam')\
71 | .validate_every((1, 'epochs'))\
72 | .save_every((2, 'epochs'), to_directory=self.SAVE_DIRECTORY)\
73 | .save_at_best_validation_score()\
74 | .set_max_num_epochs(2)\
75 | .set_precision(self.PRECISION)
76 | # Bind loaders
77 | train_loader, test_loader = self.get_random_dataloaders(input_channels=input_channels)
78 | trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader)
79 | return trainer
80 |
81 | def test_tensorboard(self):
82 | trainer = self.get_trainer(3)
83 | trainer.fit()
84 |
85 | def test_tensorboard_grayscale(self):
86 | trainer = self.get_trainer(1)
87 | trainer.fit()
88 |
89 | def test_serialization(self):
90 | trainer = self.get_trainer(3)
91 | # Serialize
92 | trainer.save()
93 | # Unserialize
94 | trainer = Trainer().load(os.path.join(self.ROOT_DIR, 'saves'))
95 | train_loader, test_loader = self.get_random_dataloaders(input_channels=3)
96 | trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader)
97 | trainer.fit()
98 |
99 |
100 | if __name__ == '__main__':
101 | unittest.main()
102 |
--------------------------------------------------------------------------------
/docs/graphics/tentative_logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
90 |
--------------------------------------------------------------------------------
/tests/test_io/test_volumetric/test_lazy_volume_loader.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import numpy as np
4 |
5 | # try to load io libraries (h5py and z5py)
6 | try:
7 | import h5py
8 | WITH_H5PY = True
9 | except ImportError:
10 | WITH_H5PY = False
11 |
12 | # try:
13 | # import z5py
14 | # WITH_Z5PY = True
15 | # except ImportError:
16 | # WITH_Z5PY = False
17 |
18 |
19 | class TestLazyVolumeLoader(unittest.TestCase):
20 |
21 | def tearDown(self):
22 | try:
23 | os.remove('tmp.h5')
24 | except OSError:
25 | pass
26 |
27 | @unittest.skipUnless(WITH_H5PY, "Need h5py")
28 | def test_h5_loader(self):
29 | from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader
30 | shape = (100, 100)
31 |
32 | # test default data loader
33 | data = np.arange(np.product(shape)).reshape(shape)
34 | with h5py.File('tmp.h5') as f:
35 | f.create_dataset('data', data=data)
36 |
37 | loader = LazyHDF5VolumeLoader('tmp.h5', 'data',
38 | window_size=[10, 10], stride=[10, 10],
39 | return_index_spec=True)
40 | self.assertEqual(loader.shape, shape)
41 | for batch, index in loader:
42 | expected = data[index.base_sequence_at_index]
43 | self.assertEqual(batch.shape, expected.shape)
44 | self.assertTrue(np.allclose(batch, expected))
45 |
46 | @unittest.skipUnless(WITH_H5PY, "Need h5py")
47 | def test_h5_loader_data_slice(self):
48 | from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader
49 | shape = (100, 100, 100)
50 | data_slice = np.s_[:, 20:80, 10:30]
51 |
52 | # test default data loader
53 | data = np.arange(np.product(shape)).reshape(shape)
54 | with h5py.File('tmp.h5') as f:
55 | f.create_dataset('data', data=data)
56 | data = data[data_slice]
57 |
58 | loader = LazyHDF5VolumeLoader('tmp.h5', 'data',
59 | window_size=[10, 10, 10], stride=[10, 10, 10],
60 | return_index_spec=True, data_slice=data_slice)
61 | self.assertEqual(loader.shape, data.shape)
62 | for batch, index in loader:
63 | slice_ = index.base_sequence_at_index
64 | expected = data[slice_]
65 | self.assertEqual(batch.shape, expected.shape)
66 | self.assertTrue(np.allclose(batch, expected))
67 |
68 | @unittest.skipUnless(WITH_H5PY, "Need h5py")
69 | def test_h5_loader_pad(self):
70 | from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader
71 | shape = (100, 100, 100)
72 | pad = [[0, 10], [0, 0], [5, 15]]
73 |
74 | # test default data loader
75 | data = np.arange(np.product(shape)).reshape(shape)
76 | with h5py.File('tmp.h5') as f:
77 | f.create_dataset('data', data=data)
78 | data = np.pad(data, pad_width=pad, mode='constant')
79 |
80 | loader = LazyHDF5VolumeLoader('tmp.h5', 'data',
81 | window_size=[20, 20, 20], stride=[20, 20, 20],
82 | return_index_spec=True, padding=pad, padding_mode='constant')
83 | self.assertEqual(loader.shape, data.shape)
84 | for batch, index in loader:
85 | slice_ = index.base_sequence_at_index
86 | expected = data[slice_]
87 | self.assertEqual(batch.shape, expected.shape)
88 | self.assertTrue(np.allclose(batch, expected))
89 |
90 | @unittest.skipUnless(WITH_H5PY, "Need h5py")
91 | def test_h5_loader_data_slice_pad(self):
92 | from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader
93 | shape = (100, 100, 100)
94 | data_slice = np.s_[:, 20:80, 10:90]
95 | pad = [[0, 10], [5, 5], [5, 15]]
96 |
97 | # test default data loader
98 | data = np.arange(np.product(shape)).reshape(shape)
99 | with h5py.File('tmp.h5') as f:
100 | f.create_dataset('data', data=data)
101 | data = data[data_slice]
102 | data = np.pad(data, pad_width=pad, mode='constant')
103 |
104 | loader = LazyHDF5VolumeLoader('tmp.h5', 'data',
105 | window_size=[20, 20, 20], stride=[20, 20, 20],
106 | return_index_spec=True, padding=pad, padding_mode='constant',
107 | data_slice=data_slice)
108 | self.assertEqual(loader.shape, data.shape)
109 | for batch, index in loader:
110 | slice_ = index.base_sequence_at_index
111 | expected = data[slice_]
112 | self.assertEqual(batch.shape, expected.shape)
113 | self.assertTrue(np.allclose(batch, expected))
114 |
115 |
116 | if __name__ == '__main__':
117 | unittest.main()
118 |
--------------------------------------------------------------------------------
/inferno/utils/partial_cls.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import sys
3 | import types
4 | import inspect
5 |
6 |
7 | __all__ = [
8 | 'partial_cls',
9 | 'register_partial_cls'
10 | ]
11 |
12 |
13 | def partial_cls(base_cls, name, module, fix=None, default=None):
14 |
15 | # helper function
16 | def insert_if_not_present(dict_a, dict_b):
17 | for kw,val in dict_b.items():
18 | if kw not in dict_a:
19 | dict_a[kw] = val
20 | return dict_a
21 |
22 | # helper function
23 | def insert_call_if_present(dict_a, dict_b, callback):
24 | for kw,val in dict_b.items():
25 | if kw not in dict_a:
26 | dict_a[kw] = val
27 | else:
28 | callback(kw)
29 | return dict_a
30 |
31 | # helper class
32 | class PartialCls(object):
33 | def __init__(self, base_cls, name, module, fix=None, default=None):
34 |
35 | self.base_cls = base_cls
36 | self.name = name
37 | self.module = module
38 | self.fix = [fix, {}][fix is None]
39 | self.default = [default, {}][default is None]
40 |
41 | if self.fix.keys() & self.default.keys():
42 | raise TypeError('fix and default share keys')
43 |
44 | # remove binded kw
45 | self._allowed_kw = self._get_allowed_kw()
46 |
47 | def _get_allowed_kw(self):
48 |
49 |
50 | argspec = inspect.getfullargspec(base_cls.__init__)
51 | args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = argspec
52 |
53 | if varargs is not None:
54 | raise TypeError('partial_cls can only be used if __init__ has no varargs')
55 |
56 | if varkw is not None:
57 | raise TypeError('partial_cls can only be used if __init__ has no varkw')
58 |
59 | if kwonlyargs is not None and kwonlyargs != []:
60 | raise TypeError('partial_cls can only be used without kwonlyargs')
61 |
62 | if args is None or len(args) < 1:
63 | raise TypeError('seems like self is missing')
64 |
65 |
66 | return [kw for kw in args[1:] if kw not in self.fix]
67 |
68 |
69 | def _build_kw(self, args, kwargs):
70 | # handle *args
71 | if len(args) > len(self._allowed_kw):
72 | raise TypeError("to many arguments")
73 |
74 | all_args = {}
75 | for arg, akw in zip(args, self._allowed_kw):
76 | all_args[akw] = arg
77 |
78 | # handle **kwargs
79 | intersection = self.fix.keys() & kwargs.keys()
80 | if len(intersection) >= 1:
81 | kw = intersection.pop()
82 | raise TypeError("`{}.__init__` got unexpected keyword argument '{}'".format(name, kw))
83 |
84 | def raise_cb(kw):
85 | raise TypeError("{}.__init__ got multiple values for argument '{}'".format(name, kw))
86 | all_args = insert_call_if_present(all_args, kwargs, raise_cb)
87 |
88 | # handle fixed arguments
89 | def raise_cb(kw):
90 | raise TypeError()
91 | all_args = insert_call_if_present(all_args, self.fix, raise_cb)
92 |
93 | # handle defaults
94 | all_args = insert_if_not_present(all_args, self.default)
95 |
96 | # handle fixed
97 | all_args.update(self.fix)
98 |
99 | return all_args
100 |
101 | def build_cls(self):
102 |
103 | def new_init(self_of_new_cls, *args, **kwargs):
104 | combined_args = self._build_kw(args=args, kwargs=kwargs)
105 |
106 | #call base cls init
107 | super(self_of_new_cls.__class__, self_of_new_cls).__init__(**combined_args)
108 |
109 | return type(name, (self.base_cls,), {
110 | '__module__': self.module,
111 | '__init__' : new_init
112 | })
113 | return cls
114 |
115 |
116 | return PartialCls(base_cls=base_cls, name=name, module=module,
117 | fix=fix, default=default).build_cls()
118 |
119 |
120 | def register_partial_cls(base_cls, name, module, fix=None, default=None):
121 | module_dict = sys.modules[module].__dict__
122 | generatedClass = partial_cls(base_cls=base_cls,name=name, module=module,
123 | fix=fix, default=default)
124 | module_dict[generatedClass.__name__] = generatedClass
125 | del generatedClass
126 |
127 |
128 | if __name__ == "__main__":
129 |
130 | class Conv(object):
131 | def __init__(self, dim, activation, stride=1):
132 | print(f"dim {dim} act {activation} stride {stride}")
133 |
134 |
135 | Conv2D = partial_cls(Conv,'Conv2D',__name__, fix=dict(dim=2), default=dict(stride=2))
136 |
137 |
138 | #obj = Conv2D(activation='a')
139 | #obj = Conv2D('a',activation='a', stride=3)
140 | obj = Conv2D('fu','bar')
141 |
142 |
--------------------------------------------------------------------------------
/examples/regularized_mnist.py:
--------------------------------------------------------------------------------
1 | """
2 | Regularized MNIST Example
3 | ================================
4 |
5 | This example demonstrates adding and logging arbitrary regularization losses, in this case,
6 | L2 activity regularization and L1 weight regularization.
7 |
8 | - Add a `_losses` dictionary to any module containing loss names and values
9 | - Use a criterion from `inferno.extensions.criteria.regularized` that will collect and add those losses
10 | - Call `Trainer.observe_training_and_validation_states` to log the losses as well
11 | """
12 |
13 | import argparse
14 | import sys
15 |
16 | import torch
17 | import torch.nn as nn
18 | from torchvision import datasets, transforms
19 |
20 | from inferno.extensions.layers.reshape import Flatten
21 | from inferno.trainers.basic import Trainer
22 | from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
23 |
24 |
25 | class RegularizedLinear(nn.Linear):
26 | def __init__(self, *args, ar_weight=1e-3, l1_weight=1e-3, **kwargs):
27 | super(RegularizedLinear, self).__init__(*args, **kwargs)
28 | self.ar_weight = ar_weight
29 | self.l1_weight = l1_weight
30 | self._losses = {}
31 |
32 | def forward(self, input):
33 | output = super(RegularizedLinear, self).forward(input)
34 | self._losses['activity_regularization'] = (output * output).sum() * self.ar_weight
35 | self._losses['l1_weight_regularization'] = torch.abs(self.weight).sum() * self.l1_weight
36 | return output
37 |
38 |
39 | def model_fn():
40 | return nn.Sequential(
41 | Flatten(),
42 | RegularizedLinear(in_features=784, out_features=256),
43 | nn.LeakyReLU(),
44 | RegularizedLinear(in_features=256, out_features=128),
45 | nn.LeakyReLU(),
46 | RegularizedLinear(in_features=128, out_features=10)
47 | )
48 |
49 |
50 | def mnist_data_loaders(args):
51 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
52 | train_loader = torch.utils.data.DataLoader(
53 | datasets.MNIST('./data', train=True, download=True,
54 | transform=transforms.Compose([
55 | transforms.ToTensor(),
56 | transforms.Normalize((0.1307,), (0.3081,))
57 | ])),
58 | batch_size=args.batch_size, shuffle=True, **kwargs)
59 | test_loader = torch.utils.data.DataLoader(
60 | datasets.MNIST('./data', train=False, transform=transforms.Compose([
61 | transforms.ToTensor(),
62 | transforms.Normalize((0.1307,), (0.3081,))
63 | ])),
64 | batch_size=args.test_batch_size, shuffle=True, **kwargs)
65 | return train_loader, test_loader
66 |
67 |
68 | def train_model(args):
69 | model = model_fn()
70 | train_loader, validate_loader = mnist_data_loaders(args)
71 |
72 | # Build trainer
73 | trainer = Trainer(model) \
74 | .build_criterion('RegularizedCrossEntropyLoss') \
75 | .build_metric('CategoricalError') \
76 | .build_optimizer('Adam') \
77 | .validate_every((1, 'epochs')) \
78 | .save_every((1, 'epochs')) \
79 | .save_to_directory(args.save_directory) \
80 | .set_max_num_epochs(args.epochs) \
81 | .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
82 | log_images_every='never'),
83 | log_directory=args.save_directory)
84 |
85 | # Record regularization losses
86 | trainer.logger.observe_training_and_validation_states([
87 | 'main_loss',
88 | 'total_regularization_loss',
89 | 'activity_regularization',
90 | 'l1_weight_regularization'
91 | ])
92 |
93 | # Bind loaders
94 | trainer \
95 | .bind_loader('train', train_loader) \
96 | .bind_loader('validate', validate_loader)
97 |
98 | if args.cuda:
99 | trainer.cuda()
100 |
101 | # Go!
102 | trainer.fit()
103 |
104 |
105 | def main(argv):
106 | # Training settings
107 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
108 | parser.add_argument('--batch-size', type=int, default=64, metavar='N',
109 | help='input batch size for training (default: 64)')
110 | parser.add_argument('--save-directory', type=str, default='output/mnist/v1',
111 | help='output directory')
112 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
113 | help='input batch size for testing (default: 1000)')
114 | parser.add_argument('--epochs', type=int, default=20, metavar='N',
115 | help='number of epochs to train (default: 20)')
116 | parser.add_argument('--no-cuda', action='store_true', default=False,
117 | help='disables CUDA training')
118 | args = parser.parse_args(argv)
119 | args.cuda = not args.no_cuda and torch.cuda.is_available()
120 | train_model(args)
121 |
122 |
123 | if __name__ == '__main__':
124 | main(sys.argv[1:])
125 |
--------------------------------------------------------------------------------
/inferno/extensions/criteria/regularized.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from . import set_similarity_measures, core
7 |
8 | __all__ = [
9 | 'RegularizedLoss',
10 | 'RegularizedCrossEntropyLoss',
11 | 'RegularizedBCEWithLogitsLoss',
12 | 'RegularizedBCELoss',
13 | 'RegularizedMSELoss',
14 | 'RegularizedNLLLoss'
15 | ]
16 |
17 |
18 | def collect_losses(module):
19 | """Collect `_losses` dictionaries from module and children
20 |
21 | :param module: a Module to be searched for losses
22 | :return: dictionary of loss names to values
23 | """
24 | losses = {}
25 |
26 | def _collect(m):
27 | if hasattr(m, '_losses'):
28 | for k, v in m._losses.items():
29 | if k in losses:
30 | losses[k] = losses[k] + v
31 | else:
32 | losses[k] = v
33 |
34 | module.apply(_collect)
35 | return losses
36 |
37 |
38 | def build_criterion(criterion, *args, **kwargs):
39 | """Build a criterion
40 |
41 | :param criterion: criterion class, name of criterion class, or instance of criterion
42 | :param args: args for constructor
43 | :param kwargs: kwargs for constructor
44 | :return: instance of criterion
45 | """
46 | if isinstance(criterion, str):
47 | for module in [nn, core, set_similarity_measures]:
48 | criterion_class = getattr(module, criterion, None)
49 | if criterion_class is not None:
50 | break
51 | assert criterion_class is not None, "Criterion {} not found.".format(criterion)
52 | elif callable(criterion) and isinstance(criterion, type):
53 | criterion_class = criterion
54 | elif isinstance(criterion, torch.nn.Module):
55 | return criterion
56 | else:
57 | raise NotImplementedError
58 | return criterion_class(*args, **kwargs)
59 |
60 |
61 | class RegularizedLoss(nn.Module):
62 | """Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion.
63 | """
64 |
65 | def __init__(self, criterion, *args, **kwargs):
66 | super(RegularizedLoss, self).__init__()
67 | self.criterion = build_criterion(criterion, *args, **kwargs)
68 |
69 | def forward(self, *args, trainer=None, model=None, **kwargs):
70 | # calculate wrapped loss
71 | main_loss = self.criterion(*args, **kwargs)
72 |
73 | # If no trainer, we cannot record states
74 | if trainer is None:
75 | warnings.warn('No trainer parameter provided. Not logging regularization losses.')
76 | elif model is None:
77 | model = trainer.model
78 |
79 | # If no model or trainer, we cannot record states or collect losses
80 | if model is None:
81 | warnings.warn('No model or trainer parameter provided. Not calculating regularization losses.')
82 | regularization_losses = {}
83 | total_regularization_loss = None
84 | total_loss = main_loss
85 | else:
86 | regularization_losses = collect_losses(model)
87 | total_regularization_loss = sum(regularization_losses.values())
88 | total_loss = main_loss + total_regularization_loss
89 |
90 | # Record losses if trainer provided
91 | if trainer is not None:
92 | # prefix depending on mode
93 | if self.training:
94 | prefix = 'training'
95 | else:
96 | prefix = 'validation'
97 | # main loss
98 | updates = {'{}_main_loss'.format(prefix): main_loss}
99 | # total regulariztion loss
100 | if total_regularization_loss is not None:
101 | updates['{}_total_regularization_loss'.format(prefix)] = total_regularization_loss
102 | # detailed regularization losses
103 | for k, v in regularization_losses.items():
104 | updates['{}_{}'.format(prefix, k)] = v
105 | # record state
106 | trainer.update_state_from_dictionary(updates)
107 |
108 | return total_loss
109 |
110 |
111 | # Convenience wrappers for common losses
112 | class RegularizedCrossEntropyLoss(RegularizedLoss):
113 | def __init__(self, *args, **kwargs):
114 | super(RegularizedCrossEntropyLoss, self).__init__(nn.CrossEntropyLoss, *args, **kwargs)
115 |
116 |
117 | class RegularizedBCEWithLogitsLoss(RegularizedLoss):
118 | def __init__(self, *args, **kwargs):
119 | super(RegularizedBCEWithLogitsLoss, self).__init__(nn.BCEWithLogitsLoss, *args, **kwargs)
120 |
121 |
122 | class RegularizedBCELoss(RegularizedLoss):
123 | def __init__(self, *args, **kwargs):
124 | super(RegularizedBCELoss, self).__init__(nn.BCELoss, *args, **kwargs)
125 |
126 |
127 | class RegularizedMSELoss(RegularizedLoss):
128 | def __init__(self, *args, **kwargs):
129 | super(RegularizedMSELoss, self).__init__(nn.MSELoss, *args, **kwargs)
130 |
131 |
132 | class RegularizedNLLLoss(RegularizedLoss):
133 | def __init__(self, *args, **kwargs):
134 | super(RegularizedNLLLoss, self).__init__(nn.NLLLoss, *args, **kwargs)
135 |
--------------------------------------------------------------------------------
/inferno/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from .python_utils import delayed_keyboard_interrupt
5 | from .exceptions import assert_, ShapeError, NotUnwrappableError
6 |
7 |
8 | def unwrap(input_, to_cpu=True, as_numpy=False, extract_item=False):
9 | if isinstance(input_, (list, tuple)):
10 | return type(input_)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy)
11 | for _t in input_])
12 | elif torch.is_tensor(input_):
13 | tensor = input_
14 | elif isinstance(input_, np.ndarray):
15 | return input_
16 | elif isinstance(input_, (float, int)):
17 | return input_
18 | else:
19 | raise NotUnwrappableError("Cannot unwrap a '{}'."
20 | .format(type(input_).__name__))
21 | # Transfer to CPU if required
22 | if to_cpu:
23 | with delayed_keyboard_interrupt():
24 | tensor = tensor.cpu().detach()
25 | # Convert to numpy if required
26 | if as_numpy:
27 | return tensor.cpu().detach().numpy()
28 | elif extract_item:
29 | try:
30 | return tensor.item()
31 | except AttributeError:
32 | return tensor[0]
33 | else:
34 | return tensor
35 |
36 |
37 | def is_tensor(object_):
38 | missed_tensor_classes = (torch.HalfTensor,)
39 | return torch.is_tensor(object_) or isinstance(object_, missed_tensor_classes)
40 |
41 |
42 | def is_label_tensor(object_):
43 | return is_tensor(object_) and object_.type() in ['torch.LongTensor', 'torch.cuda.LongTensor']
44 |
45 |
46 | def is_image_tensor(object_):
47 | return is_tensor(object_) and object_.dim() == 4
48 |
49 |
50 | def is_volume_tensor(object_):
51 | return is_tensor(object_) and object_.dim() == 5
52 |
53 |
54 | def is_image_or_volume_tensor(object_):
55 | return is_image_tensor(object_) or is_volume_tensor(object_)
56 |
57 |
58 | def is_label_image_tensor(object_):
59 | return is_label_tensor(object_) and object_.dim() == 3
60 |
61 |
62 | def is_label_volume_tensor(object_):
63 | return is_label_tensor(object_) and object_.dim() == 4
64 |
65 |
66 | def is_label_image_or_volume_tensor(object_):
67 | return is_label_image_tensor(object_) or is_label_volume_tensor(object_)
68 |
69 |
70 | def is_matrix_tensor(object_):
71 | return is_tensor(object_) and object_.dim() == 2
72 |
73 |
74 | def is_scalar_tensor(object_):
75 | return is_tensor(object_) and object_.dim() <= 1 and object_.numel() == 1
76 |
77 |
78 | def is_vector_tensor(object_):
79 | return is_tensor(object_) and object_.dim() == 1 and object_.numel() > 1
80 |
81 |
82 | def assert_same_size(tensor_1, tensor_2):
83 | assert_(list(tensor_1.size()) == list(tensor_2.size()),
84 | "Tensor sizes {} and {} do not match.".format(tensor_1.size(), tensor_2.size()),
85 | ShapeError)
86 |
87 |
88 | def where(condition, if_true, if_false):
89 | """
90 | Torch equivalent of numpy.where.
91 |
92 | Parameters
93 | ----------
94 | condition : torch.ByteTensor or torch.cuda.ByteTensor
95 | Condition to check.
96 | if_true : torch.Tensor or torch.cuda.Tensor
97 | Output value if condition is true.
98 | if_false: torch.Tensor or torch.cuda.Tensor
99 | Output value if condition is false
100 |
101 | Returns
102 | -------
103 | torch.Tensor
104 |
105 | Raises
106 | ------
107 | AssertionError
108 | if if_true and if_false don't have the same datatype.
109 | """
110 | # noinspection PyArgumentList
111 | assert if_true.type() == if_false.type(), \
112 | "Type mismatch: {} and {}".format(if_true.data.type(), if_false.data.type())
113 | casted_condition = condition.type_as(if_true)
114 | output = casted_condition * if_true + (1 - casted_condition) * if_false
115 | return output
116 |
117 |
118 | def flatten_samples(input_):
119 | """
120 | Flattens a tensor or a variable such that the channel axis is first and the sample axis
121 | is second. The shapes are transformed as follows:
122 | (N, C, H, W) --> (C, N * H * W)
123 | (N, C, D, H, W) --> (C, N * D * H * W)
124 | (N, C) --> (C, N)
125 | The input must be atleast 2d.
126 | """
127 | assert_(input_.dim() >= 2,
128 | "Tensor or variable must be atleast 2D. Got one of dim {}."
129 | .format(input_.dim()),
130 | ShapeError)
131 | # Get number of channels
132 | num_channels = input_.size(1)
133 | # Permute the channel axis to first
134 | permute_axes = list(range(input_.dim()))
135 | permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0]
136 | # For input shape (say) NCHW, this should have the shape CNHW
137 | permuted = input_.permute(*permute_axes).contiguous()
138 | # Now flatten out all but the first axis and return
139 | flattened = permuted.view(num_channels, -1)
140 | return flattened
141 |
142 |
143 | def clip_gradients_(parameters, mode, norm_or_value):
144 | assert_(mode in ['norm', 'value'],
145 | f"Mode must be 'norm' or 'value', got '{mode}' instead.",
146 | ValueError)
147 | if mode == 'norm':
148 | torch.nn.utils.clip_grad_norm_(parameters, norm_or_value)
149 | elif mode == 'value':
150 | torch.nn.utils.clip_grad_value_(parameters, norm_or_value)
151 | else:
152 | raise NotImplementedError
153 |
--------------------------------------------------------------------------------
/inferno/extensions/initializers/base.py:
--------------------------------------------------------------------------------
1 | import torch.nn.init as init
2 |
3 |
4 | __all__ = ['Initializer',
5 | 'Initialization',
6 | 'WeightInitFunction',
7 | 'BiasInitFunction',
8 | 'TensorInitFunction']
9 |
10 |
11 | class Initializer(object):
12 | """
13 | Base class for all initializers.
14 | """
15 |
16 | # TODO Support LSTMs and GRUs
17 | VALID_LAYERS = {'Conv1d', 'Conv2d', 'Conv3d',
18 | 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
19 | 'Linear', 'Bilinear',
20 | 'Embedding'}
21 |
22 | def __call__(self, module):
23 | module_class_name = module.__class__.__name__
24 | if module_class_name in self.VALID_LAYERS:
25 | # Apply to weight and bias
26 | try:
27 | if hasattr(module, 'weight'):
28 | self.call_on_weight(module.weight.data)
29 | except NotImplementedError:
30 | # Don't cry if it's not implemented
31 | pass
32 |
33 | try:
34 | if hasattr(module, 'bias'):
35 | self.call_on_bias(module.bias.data)
36 | except NotImplementedError:
37 | pass
38 |
39 | return module
40 |
41 | def call_on_bias(self, tensor):
42 | return self.call_on_tensor(tensor)
43 |
44 | def call_on_weight(self, tensor):
45 | return self.call_on_tensor(tensor)
46 |
47 | def call_on_tensor(self, tensor):
48 | raise NotImplementedError
49 |
50 | @classmethod
51 | def initializes_weight(cls):
52 | return 'call_on_tensor' in cls.__dict__ or 'call_on_weight' in cls.__dict__
53 |
54 | @classmethod
55 | def initializes_bias(cls):
56 | return 'call_on_tensor' in cls.__dict__ or 'call_on_bias' in cls.__dict__
57 |
58 |
59 | class Initialization(Initializer):
60 | def __init__(self, weight_initializer=None, bias_initializer=None):
61 | if weight_initializer is None:
62 | self.weight_initializer = Initializer()
63 | else:
64 | if isinstance(weight_initializer, Initializer):
65 | assert weight_initializer.initializes_weight()
66 | self.weight_initializer = weight_initializer
67 | elif isinstance(weight_initializer, str):
68 | init_function = getattr(init, weight_initializer, None)
69 | assert init_function is not None
70 | self.weight_initializer = WeightInitFunction(init_function=init_function)
71 | else:
72 | # Provison for weight_initializer to be a function
73 | assert callable(weight_initializer)
74 | self.weight_initializer = WeightInitFunction(init_function=weight_initializer)
75 |
76 | if bias_initializer is None:
77 | self.bias_initializer = Initializer()
78 | else:
79 | if isinstance(bias_initializer, Initializer):
80 | assert bias_initializer.initializes_bias
81 | self.bias_initializer = bias_initializer
82 | elif isinstance(bias_initializer, str):
83 | init_function = getattr(init, bias_initializer, None)
84 | assert init_function is not None
85 | self.bias_initializer = BiasInitFunction(init_function=init_function)
86 | else:
87 | assert callable(bias_initializer)
88 | self.bias_initializer = BiasInitFunction(init_function=bias_initializer)
89 |
90 | def call_on_weight(self, tensor):
91 | return self.weight_initializer.call_on_weight(tensor)
92 |
93 | def call_on_bias(self, tensor):
94 | return self.bias_initializer.call_on_bias(tensor)
95 |
96 |
97 | class WeightInitFunction(Initializer):
98 | def __init__(self, init_function, *init_function_args, **init_function_kwargs):
99 | super(WeightInitFunction, self).__init__()
100 | assert callable(init_function)
101 | self.init_function = init_function
102 | self.init_function_args = init_function_args
103 | self.init_function_kwargs = init_function_kwargs
104 |
105 | def call_on_weight(self, tensor):
106 | return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)
107 |
108 |
109 | class BiasInitFunction(Initializer):
110 | def __init__(self, init_function, *init_function_args, **init_function_kwargs):
111 | super(BiasInitFunction, self).__init__()
112 | assert callable(init_function)
113 | self.init_function = init_function
114 | self.init_function_args = init_function_args
115 | self.init_function_kwargs = init_function_kwargs
116 |
117 | def call_on_bias(self, tensor):
118 | return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)
119 |
120 |
121 | class TensorInitFunction(Initializer):
122 | def __init__(self, init_function, *init_function_args, **init_function_kwargs):
123 | super(TensorInitFunction, self).__init__()
124 | assert callable(init_function)
125 | self.init_function = init_function
126 | self.init_function_args = init_function_args
127 | self.init_function_kwargs = init_function_kwargs
128 |
129 | def call_on_tensor(self, tensor):
130 | return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)
131 |
132 |
--------------------------------------------------------------------------------
/tests/test_extensions/test_containers/test_graph.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from functools import reduce
3 | import torch
4 |
5 |
6 | class TestGraph(unittest.TestCase):
7 | def setUp(self):
8 | import torch.nn as nn
9 | from inferno.utils.python_utils import from_iterable
10 |
11 | class DummyNamedModule(nn.Module):
12 | def __init__(self, name, history, num_inputs=1):
13 | super(DummyNamedModule, self).__init__()
14 | self.name = name
15 | self.history = history
16 | self.num_inputs = num_inputs
17 |
18 | def forward(self, *inputs):
19 | assert len(inputs) == self.num_inputs
20 | self.history.append(self.name)
21 | if self.num_inputs > 1:
22 | output = reduce(lambda x, y: x + y, inputs)
23 | else:
24 | output = from_iterable(inputs)
25 |
26 | return output
27 |
28 | self.DummyNamedModule = DummyNamedModule
29 |
30 | # @unittest.skip
31 | def test_graph_dummy_basic(self):
32 | import torch
33 | from inferno.extensions.containers.graph import Graph
34 |
35 | if not hasattr(self, 'DummyNamedModule'):
36 | self.setUp()
37 |
38 | DummyNamedModule = self.DummyNamedModule
39 |
40 | history = []
41 | # Build graph
42 | model = Graph()
43 | model.add_input_node('input_0')
44 | model.add_input_node('input_1')
45 | model.add_node('conv0_0', DummyNamedModule('conv0_0', history))
46 | model.add_node('conv0_1', DummyNamedModule('conv0_1', history))
47 | model.add_node('conv1', DummyNamedModule('conv1', history, 2))
48 | model.add_node('conv2', DummyNamedModule('conv2', history))
49 | model.add_output_node('output_0')
50 | model.add_edge('input_0', 'conv0_0')\
51 | .add_edge('input_1', 'conv0_1')\
52 | .add_edge('conv0_0', 'conv1')\
53 | .add_edge('conv0_1', 'conv1')\
54 | .add_edge('conv1', 'conv2')\
55 | .add_edge('conv2', 'output_0')
56 |
57 | input_0 = torch.rand(10, 10)
58 | input_1 = torch.rand(10, 10)
59 | model(input_0, input_1)
60 | self.assertTrue(history == ['conv0_0', 'conv0_1', 'conv1', 'conv2'] or
61 | history == ['conv0_1', 'conv0_0', 'conv1', 'conv2'])
62 |
63 | # @unittest.skip
64 | def test_graph_dummy_inception(self):
65 | import torch
66 | from inferno.extensions.containers.graph import Graph
67 |
68 | if not hasattr(self, 'DummyNamedModule'):
69 | self.setUp()
70 |
71 | DummyNamedModule = self.DummyNamedModule
72 |
73 | history = []
74 | # Build graph
75 | model = Graph()
76 | model.add_input_node('input_0')
77 | model.add_node('conv0', DummyNamedModule('conv0', history), 'input_0')
78 | model.add_node('conv1_0', DummyNamedModule('conv1_0', history), 'conv0')
79 | model.add_node('conv1_1', DummyNamedModule('conv1_1', history), 'conv0')
80 | model.add_node('conv2', DummyNamedModule('conv2', history, 2),
81 | ['conv1_0', 'conv1_1'])
82 | model.add_output_node('output_0', 'conv2')
83 | input_0 = torch.rand(10, 10)
84 | model(input_0)
85 | self.assertTrue(history == ['conv0', 'conv1_0', 'conv1_1', 'conv2'] or
86 | history == ['conv0', 'conv1_1', 'conv1_2', 'conv2'])
87 |
88 | # @unittest.skip
89 | def test_graph_basic(self):
90 | from inferno.extensions.containers.graph import Graph
91 | from inferno.extensions.layers.convolutional import ConvELU2D
92 | from inferno.utils.model_utils import ModelTester
93 | # Build graph
94 | model = Graph()
95 | model.add_input_node('input_0')
96 | model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0')
97 | model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0')
98 | model.add_output_node('output_0', previous='conv1')
99 | ModelTester((1, 1, 100, 100), (1, 1, 100, 100))(model)
100 |
101 | @unittest.skipUnless(torch.cuda.is_available(), "No cuda.")
102 | def test_graph_device_transfers(self):
103 | from inferno.extensions.containers.graph import Graph
104 | from inferno.extensions.layers.convolutional import ConvELU2D
105 | import torch
106 | # Build graph
107 | model = Graph()
108 | model.add_input_node('input_0')
109 | model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0')
110 | model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0')
111 | model.add_output_node('output_0', previous='conv1')
112 | # Transfer
113 | model.to_device('conv0', 'cpu').to_device('conv1', 'cuda', 0)
114 | x = torch.rand(1, 1, 100, 100)
115 | y = model(x)
116 | self.assertIsInstance(y.data, torch.cuda.FloatTensor)
117 |
118 | @unittest.skip("Needs machine with 4 GPUs")
119 | def test_multi_gpu(self):
120 | import torch
121 | import torch.nn as nn
122 | from torch.nn.parallel.data_parallel import data_parallel
123 | from inferno.extensions.containers.graph import Graph
124 |
125 | input_shape = [8, 1, 3, 128, 128]
126 | model = Graph() \
127 | .add_input_node('input') \
128 | .add_node('conv0', nn.Conv3d(1, 10, 3, padding=1), previous='input') \
129 | .add_node('conv1', nn.Conv3d(10, 1, 3, padding=1), previous='conv0') \
130 | .add_output_node('output', previous='conv1')
131 |
132 | model.cuda()
133 | input = torch.rand(*input_shape).cuda()
134 | data_parallel(model, input, device_ids=[0, 1, 2, 3])
135 |
136 |
137 | if __name__ == '__main__':
138 | unittest.main()
139 |
--------------------------------------------------------------------------------
/tests/test_io/test_box/test_cityscapes.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join, dirname, exists, isdir
3 | import unittest
4 | import numpy as np
5 | import time
6 |
7 | _CITYSCAPES_ROOT = None
8 |
9 |
10 | def _cityscapes_available():
11 | return _CITYSCAPES_ROOT is not None or os.environ.get('CITYSCAPES_ROOT') is not None
12 |
13 |
14 | class TestCityscapes(unittest.TestCase):
15 | CITYSCAPES_ROOT = _CITYSCAPES_ROOT
16 | PLOT_DIRECTORY = join(dirname(__file__), 'plots')
17 | INCLUDE_COARSE = False
18 |
19 | def get_cityscapes_root(self):
20 | if self.CITYSCAPES_ROOT is None:
21 | root = os.environ.get('CITYSCAPES_ROOT')
22 | assert root is not None, "Cityscapes Root not found."
23 | else:
24 | return self.CITYSCAPES_ROOT
25 |
26 | @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.")
27 | def test_cityscapes_dataset_without_transforms(self):
28 | from inferno.io.box.cityscapes import Cityscapes
29 | cityscapes = Cityscapes(self.get_cityscapes_root())
30 | image, label = cityscapes[0]
31 | image = np.asarray(image)
32 | label = np.asarray(label)
33 | self.assertSequenceEqual(image.shape, (1024, 2048, 3))
34 | self.assertSequenceEqual(label.shape, (1024, 2048))
35 | self.assertLessEqual(label.max(), 33)
36 |
37 | @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.")
38 | def test_cityscapes_dataset_without_transforms_unzipped(self):
39 | from inferno.io.box.cityscapes import Cityscapes
40 | cityscapes = Cityscapes(join(self.get_cityscapes_root(), 'extracted'),
41 | read_from_zip_archive=False)
42 | image, label = cityscapes[0]
43 | image = np.asarray(image)
44 | label = np.asarray(label)
45 | self.assertSequenceEqual(image.shape, (1024, 2048, 3))
46 | self.assertSequenceEqual(label.shape, (1024, 2048))
47 | self.assertLessEqual(label.max(), 33)
48 |
49 | @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.")
50 | def test_cityscapes_dataset_with_transforms(self):
51 | from inferno.io.box.cityscapes import get_cityscapes_loaders
52 | from inferno.utils.io_utils import print_tensor
53 |
54 | train_loader, validate_loader = get_cityscapes_loaders(self.get_cityscapes_root(),
55 | include_coarse_dataset=self.INCLUDE_COARSE)
56 | train_dataset = train_loader.dataset
57 | tic = time.time()
58 | image, label = train_dataset[0]
59 | toc = time.time()
60 | print("[+] Loaded sample in {} seconds.".format(toc - tic))
61 | # Make sure the shapes checkout
62 | self.assertSequenceEqual(image.size(), (3, 1024, 2048))
63 | self.assertSequenceEqual(label.size(), (1024, 2048))
64 | self.assertEqual(image.type(), 'torch.FloatTensor')
65 | self.assertEqual(label.type(), 'torch.LongTensor')
66 | # Print tensors to make sure they look legit
67 | if not exists(self.PLOT_DIRECTORY):
68 | os.mkdir(self.PLOT_DIRECTORY)
69 | else:
70 | assert isdir(self.PLOT_DIRECTORY)
71 | print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)
72 | for class_id in np.unique(label.numpy()):
73 | print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'),
74 | prefix='LAB-{}--'.format(class_id),
75 | directory=self.PLOT_DIRECTORY)
76 | print_tensor(label.numpy()[None, None, ...],
77 | prefix='LAB--',
78 | directory=self.PLOT_DIRECTORY)
79 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY))
80 |
81 | @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.")
82 | def test_cityscapes_dataset_with_transforms_unzipped(self):
83 | from inferno.io.box.cityscapes import get_cityscapes_loaders
84 | from inferno.utils.io_utils import print_tensor
85 |
86 | train_loader, validate_loader = get_cityscapes_loaders(join(self.get_cityscapes_root(),
87 | 'extracted'),
88 | include_coarse_dataset=self.INCLUDE_COARSE,
89 | read_from_zip_archive=False)
90 | train_dataset = train_loader.dataset
91 | tic = time.time()
92 | image, label = train_dataset[0]
93 | toc = time.time()
94 | print("[+] Loaded sample in {} seconds.".format(toc - tic))
95 | # Make sure the shapes checkout
96 | self.assertSequenceEqual(image.size(), (3, 1024, 2048))
97 | self.assertSequenceEqual(label.size(), (1024, 2048))
98 | self.assertEqual(image.type(), 'torch.FloatTensor')
99 | self.assertEqual(label.type(), 'torch.LongTensor')
100 | # Print tensors to make sure they look legit
101 | if not exists(self.PLOT_DIRECTORY):
102 | os.mkdir(self.PLOT_DIRECTORY)
103 | else:
104 | assert isdir(self.PLOT_DIRECTORY)
105 | print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)
106 | for class_id in np.unique(label.numpy()):
107 | print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'),
108 | prefix='LAB-{}--'.format(class_id),
109 | directory=self.PLOT_DIRECTORY)
110 | print_tensor(label.numpy()[None, None, ...],
111 | prefix='LAB--',
112 | directory=self.PLOT_DIRECTORY)
113 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY))
114 |
115 |
116 | if __name__ == '__main__':
117 | unittest.main()
118 |
--------------------------------------------------------------------------------
/tests/test_io/test_box/test_camvid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join, dirname, exists, isdir
3 | import unittest
4 | import numpy as np
5 |
6 |
7 | _CAMVID_ROOT = None
8 |
9 |
10 | def _camvid_available():
11 | return _CAMVID_ROOT is not None or os.environ.get('CAMVID_ROOT') is not None
12 |
13 |
14 | class TestCamvid(unittest.TestCase):
15 | CAMVID_ROOT = _CAMVID_ROOT
16 | PLOT_DIRECTORY = join(dirname(__file__), 'plots')
17 |
18 | def get_camvid_root(self):
19 | if self.CAMVID_ROOT is None:
20 | root = os.environ.get('CAMVID_ROOT')
21 | assert root is not None, "Camvid Root not found."
22 | else:
23 | return self.CAMVID_ROOT
24 |
25 | @unittest.skipUnless(_camvid_available(), "No root available.")
26 | def test_camvid_dataset_without_transforms(self):
27 | from inferno.io.box.camvid import CamVid
28 | camvid = CamVid(self.get_camvid_root())
29 | image, label = camvid[0]
30 | image = np.asarray(image)
31 | label = np.asarray(label)
32 | self.assertSequenceEqual(image.shape, (360, 480, 3))
33 | self.assertSequenceEqual(label.shape, (360, 480))
34 | self.assertLessEqual(label.max(), 11)
35 |
36 | @unittest.skipUnless(_camvid_available(), "No root available.")
37 | def _test_camvid_dataset_with_transforms(self):
38 | from inferno.io.box.camvid import CamVid
39 | from inferno.io.transform.base import Compose
40 | from inferno.io.transform.image import PILImage2NumPyArray, RandomSizedCrop, Scale
41 | from inferno.utils.io_utils import print_tensor
42 |
43 | camvid = CamVid(self.get_camvid_root(),
44 | image_transform=Compose(),
45 | label_transform=Compose(),
46 | joint_transform=Compose())
47 | camvid.image_transform.add(PILImage2NumPyArray())
48 | camvid.label_transform.add(PILImage2NumPyArray())
49 | image, label = camvid[0]
50 | self.assertSequenceEqual(image.shape, (3, 360, 480))
51 | self.assertSequenceEqual(label.shape, (360, 480))
52 | # Add crop trafo
53 | camvid.joint_transform.add(RandomSizedCrop(ratio_between=(0.7, 1.0),
54 | preserve_aspect_ratio=True))
55 | # We need 2 scale transforms, one with order 3 (image) and the other with order 0 (label)
56 | camvid.joint_transform.add(Scale(output_image_shape=(360, 480),
57 | interpolation_order=3, apply_to=[0]))
58 | camvid.joint_transform.add(Scale(output_image_shape=(360, 480),
59 | interpolation_order=0, apply_to=[1]))
60 | image, label = camvid[0]
61 | self.assertSequenceEqual(image.shape, (3, 360, 480))
62 | self.assertSequenceEqual(label.shape, (360, 480))
63 | self.assertLessEqual(len(np.unique(label)), 12)
64 | # Print tensors to make sure they look legit
65 | if not exists(self.PLOT_DIRECTORY):
66 | os.mkdir(self.PLOT_DIRECTORY)
67 | else:
68 | assert isdir(self.PLOT_DIRECTORY)
69 | print_tensor(image[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)
70 | print_tensor(label[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY)
71 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY))
72 |
73 | @unittest.skipUnless(_camvid_available(), "No root available.")
74 | def test_camvid_dataset_with_transforms(self):
75 | from inferno.io.box.camvid import get_camvid_loaders
76 | from inferno.utils.io_utils import print_tensor
77 |
78 | train_loader, validate_loader, test_loader = get_camvid_loaders(self.get_camvid_root())
79 | train_dataset = train_loader.dataset
80 | image, label = train_dataset[0]
81 | # Make sure the shapes checkout
82 | self.assertSequenceEqual(image.size(), (3, 360, 480))
83 | self.assertSequenceEqual(label.size(), (360, 480))
84 | self.assertEqual(image.type(), 'torch.FloatTensor')
85 | self.assertEqual(label.type(), 'torch.LongTensor')
86 | # Print tensors to make sure they look legit
87 | if not exists(self.PLOT_DIRECTORY):
88 | os.mkdir(self.PLOT_DIRECTORY)
89 | else:
90 | assert isdir(self.PLOT_DIRECTORY)
91 | print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)
92 | print_tensor(label.numpy()[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY)
93 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY))
94 |
95 | @unittest.skipUnless(_camvid_available(), "No root available.")
96 | def test_camvid_dataset_with_transforms_onehot(self):
97 | from inferno.io.box.camvid import get_camvid_loaders
98 | from inferno.utils.io_utils import print_tensor
99 |
100 | train_loader, validate_loader, test_loader = get_camvid_loaders(self.get_camvid_root(),
101 | labels_as_onehot=True)
102 | train_dataset = train_loader.dataset
103 | image, label = train_dataset[0]
104 | # Make sure the shapes checkout
105 | self.assertSequenceEqual(image.size(), (3, 360, 480))
106 | self.assertSequenceEqual(label.size(), (12, 360, 480))
107 | self.assertEqual(image.type(), 'torch.FloatTensor')
108 | self.assertEqual(label.type(), 'torch.FloatTensor')
109 | # Print tensors to make sure they look legit
110 | if not exists(self.PLOT_DIRECTORY):
111 | os.mkdir(self.PLOT_DIRECTORY)
112 | else:
113 | assert isdir(self.PLOT_DIRECTORY)
114 | print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)
115 | print_tensor(label.numpy()[None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY)
116 | print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY))
117 |
118 |
119 | if __name__ == '__main__':
120 | unittest.main()
121 |
--------------------------------------------------------------------------------
/inferno/io/box/binary_blobs.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import skimage.data
3 | import numpy
4 | from operator import mul
5 | from functools import reduce
6 |
7 | class BinaryBlobs(data.Dataset):
8 |
9 |
10 | def __init__(self, size=20, length=512, blob_size_fraction=0.1,
11 | n_dim=2, volume_fraction=0.5,split='train',
12 | uniform_noise_range=(-1.2, 1.2),
13 | gaussian_noise_sigma=1.2,
14 | noise_scale_factor=8,
15 | image_transform=None,
16 | label_transform=None,
17 | joint_transform=None):
18 | # how many images are in the dataset
19 | self.size = size
20 |
21 | # blob related members
22 | self.length = length
23 | self.blob_size_fraction = blob_size_fraction
24 | self.n_dim = n_dim
25 | self.volume_fraction = volume_fraction
26 |
27 | # which split {'train', 'test', 'validate'}
28 | self.split = split
29 |
30 | # noise related members
31 | self.uniform_noise_range = uniform_noise_range
32 | self.gaussian_noise_sigma = float(gaussian_noise_sigma)
33 | self.noise_scale_factor = noise_scale_factor
34 |
35 | # transforms
36 | self.image_transform = image_transform
37 | self.label_transform = label_transform
38 | self.joint_transform = joint_transform
39 |
40 | # internal
41 | split_to_seed = dict(train=0, test=1, validate=2)
42 | self.master_seed = split_to_seed[self.split]*self.size
43 |
44 | def __getitem__(self, index):
45 |
46 | # generate the labels
47 | label = skimage.data.binary_blobs(
48 | length=self.length,
49 | blob_size_fraction=self.blob_size_fraction,
50 | n_dim=self.n_dim,
51 | volume_fraction=self.volume_fraction,
52 | seed=self.master_seed + index)
53 |
54 | # make the raw image [-1,1]
55 | image = label.astype('float32')*2
56 | image -= 1
57 |
58 |
59 | # add uniform noise
60 | low, high = self.uniform_noise_range
61 | uniform_noise = numpy.random.uniform(low=low, high=high,
62 | size=image.size)
63 | image += uniform_noise.reshape(image.shape)
64 |
65 | # add gaussian noise
66 | gaussian_noise = numpy.random.normal(scale=self.gaussian_noise_sigma,
67 | size=image.size)
68 | image += gaussian_noise.reshape(image.shape)
69 |
70 |
71 | # generate noise at lower scales
72 | small_shape = [s//self.noise_scale_factor for s in label.shape]
73 | small_size = reduce(mul, small_shape, 1)
74 | small_noise_img = numpy.random.uniform(low=low, high=high,
75 | size=small_size)
76 | small_noise_img = small_noise_img.reshape(small_shape)
77 |
78 | gaussian_noise = numpy.random.normal(scale=self.gaussian_noise_sigma,
79 | size=small_size)
80 | small_noise_img += gaussian_noise.reshape(small_shape)
81 |
82 | noise_img = skimage.transform.resize(image = small_noise_img,
83 | output_shape=image.shape, mode='reflect')
84 |
85 |
86 | image += noise_img
87 |
88 | image -= image.mean()
89 | image /= image.std()
90 |
91 | label = label.astype('long')
92 | try:
93 | # Apply transforms
94 | if self.image_transform is not None:
95 | image = self.image_transform(image)
96 | if self.label_transform is not None:
97 | label = self.label_transform(label)
98 | if self.joint_transform is not None:
99 | image, label = self.joint_transform(image, label)
100 | except Exception:
101 | print("[!] An Exception occurred while applying the transforms at "
102 | "index {} of split '{}'.".format(index, self.split))
103 | raise
104 |
105 | image = image[None,...]
106 | return image, label
107 |
108 | def __len__(self):
109 | return self.size
110 |
111 |
112 | def get_binary_blob_loaders(train_batch_size=1, test_batch_size=1,
113 | num_workers=1,
114 | train_image_transform=None,
115 | train_label_transform=None,
116 | train_joint_transform=None,
117 | validate_image_transform=None,
118 | validate_label_transform=None,
119 | validate_joint_transform=None,
120 | test_image_transform=None,
121 | test_label_transform=None,
122 | test_joint_transform=None,
123 | **kwargs):
124 |
125 | trainset = BinaryBlobs(split='train', image_transform=train_image_transform,
126 | label_transform=train_label_transform, joint_transform=train_joint_transform, **kwargs)
127 | testset = BinaryBlobs(split='test', image_transform=test_image_transform,
128 | label_transform=test_label_transform, joint_transform=test_joint_transform, **kwargs)
129 | validset = BinaryBlobs(split='validate',image_transform=validate_image_transform,
130 | label_transform=validate_label_transform, joint_transform=validate_joint_transform, **kwargs)
131 |
132 |
133 | trainloader = data.DataLoader(trainset, batch_size=train_batch_size,
134 | num_workers=num_workers)
135 |
136 | testloader = data.DataLoader(testset, batch_size=test_batch_size,
137 | num_workers=num_workers)
138 |
139 | validloader = data.DataLoader(validset, batch_size=test_batch_size,
140 | num_workers=num_workers)
141 |
142 | return trainloader, testloader, validloader
143 |
144 | if __name__ == "__main__":
145 | ds = BinaryBlobs()
146 | ds[0]
--------------------------------------------------------------------------------