├── .coveragerc ├── .gitignore ├── .scrutinizer.yml ├── .travis-data.sh ├── .travis.yml ├── LICENSE ├── README.rst ├── docs ├── api │ ├── converters.rst │ ├── data_streams.rst │ ├── dataset.rst │ ├── downloaders.rst │ ├── index.rst │ ├── iteration_schemes.rst │ ├── transformers.rst │ └── utils.rst ├── built_in_datasets.rst ├── caching.rst ├── conf.py ├── extending_fuel.rst ├── h5py_dataset.rst ├── index.rst ├── new_dataset.rst ├── overview.rst ├── server.rst └── setup.rst ├── doctests └── __init__.py ├── fuel ├── __init__.py ├── bin │ ├── __init__.py │ ├── fuel_convert.py │ ├── fuel_download.py │ └── fuel_info.py ├── config_parser.py ├── converters │ ├── __init__.py │ ├── adult.py │ ├── base.py │ ├── binarized_mnist.py │ ├── caltech101_silhouettes.py │ ├── celeba.py │ ├── cifar10.py │ ├── cifar100.py │ ├── dogs_vs_cats.py │ ├── ilsvrc2010.py │ ├── ilsvrc2012.py │ ├── iris.py │ ├── mnist.py │ ├── svhn.py │ └── youtube_audio.py ├── datasets │ ├── __init__.py │ ├── adult.py │ ├── base.py │ ├── billion.py │ ├── binarized_mnist.py │ ├── caltech101_silhouettes.py │ ├── celeba.py │ ├── cifar10.py │ ├── cifar100.py │ ├── dogs_vs_cats.py │ ├── hdf5.py │ ├── imagenet.py │ ├── iris.py │ ├── mnist.py │ ├── svhn.py │ ├── text.py │ ├── toy.py │ └── youtube_audio.py ├── downloaders │ ├── __init__.py │ ├── adult.py │ ├── base.py │ ├── binarized_mnist.py │ ├── caltech101_silhouettes.py │ ├── celeba.py │ ├── cifar10.py │ ├── cifar100.py │ ├── dogs_vs_cats.py │ ├── ilsvrc2010.py │ ├── ilsvrc2012.py │ ├── iris.py │ ├── mnist.py │ ├── svhn.py │ └── youtube_audio.py ├── exceptions.py ├── iterator.py ├── schemes.py ├── server.py ├── streams.py ├── transformers │ ├── __init__.py │ ├── _image.c │ ├── _image.pyx │ ├── defaults.py │ ├── image.py │ └── sequences.py ├── utils │ ├── __init__.py │ ├── cache.py │ ├── disk.py │ ├── formats.py │ ├── lock.py │ └── parallel.py └── version.py ├── req-rtd.txt ├── req-travis-conda.txt ├── req-travis-pip.txt ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── converters ├── __init__.py ├── test_convert_ilsvrc2010.py └── test_convert_ilsvrc2012.py ├── test_adult.py ├── test_billion.py ├── test_binarized_mnist.py ├── test_caltech101_silhouettes.py ├── test_celeba.py ├── test_cifar10.py ├── test_cifar100.py ├── test_config_parser.py ├── test_converters.py ├── test_datasets.py ├── test_dogs_vs_cats.py ├── test_downloaders.py ├── test_hdf5.py ├── test_iris.py ├── test_mnist.py ├── test_schemes.py ├── test_sequences.py ├── test_serialization.py ├── test_server.py ├── test_streams.py ├── test_svhn.py ├── test_toy.py ├── test_utils.py └── transformers ├── __init__.py ├── test_image.py └── test_transformers.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | omit = 3 | fuel/bin/* 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | 25 | # PyInstaller 26 | # Usually these files are written by a python script from a template 27 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 28 | *.manifest 29 | *.spec 30 | 31 | # Installer logs 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | 35 | # Unit test / coverage reports 36 | htmlcov/ 37 | .tox/ 38 | .coverage 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | 43 | # Translations 44 | *.mo 45 | *.pot 46 | 47 | # Django stuff: 48 | *.log 49 | 50 | # Sphinx documentation 51 | docs/_build/ 52 | 53 | # PyBuilder 54 | target/ 55 | 56 | # Editors 57 | *~ 58 | *.sw[op] 59 | -------------------------------------------------------------------------------- /.scrutinizer.yml: -------------------------------------------------------------------------------- 1 | build: 2 | dependencies: 3 | override: 4 | - pip install -q flake8 5 | - pip install -q git+git://github.com/bartvm/pep257.git@numpy 6 | tests: 7 | override: 8 | - flake8 fuel doctests tests 9 | - pep257 fuel --numpy --ignore=D100,D101,D102,D103 10 | - pep257 doctests tests --numpy --ignore=D100,D101,D102,D103 --match='.*\.py' 11 | checks: 12 | python: 13 | code_rating: true 14 | duplicate_code: true 15 | format_bad_indentation: 16 | indentation: '4 spaces' 17 | format_mixed_indentation: true 18 | format_line_too_long: 19 | max_length: '79' 20 | imports_relative_import: true 21 | imports_wildcard_import: true 22 | format_bad_whitespace: true 23 | format_multiple_statements: true 24 | basic_invalid_name: 25 | functions: '[a-z_][a-z0-9_]{0,30}$' 26 | variables: '(([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$' 27 | whitelisted_names: '_,floatX,logger,config' 28 | constants: '(([A-Z_][A-Z0-9_]*)|(__.*__))$' 29 | attributes: '(([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$' 30 | arguments: '(([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$' 31 | class_attributes: '([A-Za-z_][A-Za-z0-9_]{0,30}|(__.*__))$' 32 | inline_vars: '[A-Za-z_][A-Za-z0-9_]*$' 33 | classes: '[A-Z_][a-zA-Z0-9]+$' 34 | modules: '(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$' 35 | methods: '[a-z_][a-z0-9_]{0,30}$' 36 | classes_no_self_argument: true 37 | classes_bad_mcs_method_argument: true 38 | classes_bad_classmethod_argument: true 39 | variables_unused_variable: true 40 | variables_unused_import: true 41 | variables_used_before_assignment: true 42 | variables_undefined_variable: true 43 | variables_undefined_loop_variable: true 44 | variables_redefined_outer_name: true 45 | variables_redefined_builtin: true 46 | variables_redefine_in_handler: true 47 | variables_no_name_in_module: true 48 | variables_global_variable_undefined: true 49 | variables_global_variable_not_assigned: true 50 | variables_global_statement: true 51 | typecheck_unexpected_keyword_arg: true 52 | variables_global_at_module_level: true 53 | variables_unused_wildcard_import: true 54 | variables_unused_argument: true 55 | variables_unpacking_non_sequence: true 56 | variables_undefined_all_variable: true 57 | variables_unbalanced_tuple_unpacking: true 58 | variables_invalid_all_object: true 59 | typecheck_too_many_function_args: true 60 | typecheck_redundant_keyword_arg: true 61 | typecheck_not_callable: true 62 | typecheck_no_member: true 63 | typecheck_missing_kwoa: true 64 | typecheck_maybe_no_member: true 65 | typecheck_duplicate_keyword_arg: true 66 | typecheck_assignment_from_none: true 67 | typecheck_assignment_from_no_return: true 68 | string_unused_format_string_key: true 69 | string_truncated_format_string: true 70 | string_too_many_format_args: true 71 | string_too_few_format_args: true 72 | string_mixed_format_string: true 73 | string_missing_format_string_key: true 74 | string_format_needs_mapping: true 75 | string_constant_anomalous_unicode_escape_in_string: true 76 | string_constant_anomalous_backslash_in_string: true 77 | string_bad_str_strip_call: true 78 | string_bad_format_string_key: true 79 | string_bad_format_character: true 80 | open_mode_bad_open_mode: true 81 | newstyle_bad_super_call: true 82 | logging_unsupported_format: true 83 | logging_too_many_args: true 84 | logging_too_few_args: true 85 | logging_not_lazy: true 86 | logging_format_truncated: true 87 | imports_reimported: true 88 | imports_import_self: true 89 | imports_deprecated_module: true 90 | imports_cyclic_import: true 91 | format_unnecessary_semicolon: true 92 | format_trailing_whitespace: true 93 | format_superfluous_parens: true 94 | format_old_ne_operator: true 95 | format_missing_final_newline: true 96 | format_lowercase_l_suffix: true 97 | format_backtick: true 98 | exceptions_raising_string: true 99 | exceptions_raising_non_exception: true 100 | exceptions_raising_bad_type: true 101 | exceptions_pointless_except: true 102 | exceptions_notimplemented_raised: true 103 | exceptions_catching_non_exception: true 104 | exceptions_broad_except: true 105 | exceptions_binary_op_exception: true 106 | exceptions_bare_except: true 107 | exceptions_bad_except_order: true 108 | design_interface_not_implemented: true 109 | design_abstract_class_not_used: true 110 | design_abstract_class_little_used: true 111 | classes_valid_slots: true 112 | classes_super_init_not_called: true 113 | classes_signature_differs: true 114 | classes_protected_access: true 115 | classes_non_parent_init_called: true 116 | classes_non_iterator_returned: true 117 | classes_no_method_argument: true 118 | classes_no_method_argument: true 119 | classes_no_init: true 120 | classes_missing_interface_method: true 121 | classes_method_hidden: true 122 | classes_interface_is_not_class: true 123 | classes_bad_staticmethod_argument: true 124 | classes_bad_mcs_classmethod_argument: true 125 | classes_bad_context_manager: true 126 | classes_arguments_differ: true 127 | classes_access_member_before_definition: true 128 | basic_yield_outside_function: true 129 | basic_useless_else_on_loop: true 130 | basic_unreachable: true 131 | basic_unnecessary_pass: true 132 | basic_unnecessary_lambda: true 133 | basic_return_outside_function: true 134 | basic_return_in_init: true 135 | basic_return_arg_in_generator: true 136 | basic_pointless_string_statement: true 137 | basic_pointless_statement: true 138 | basic_old_raise_syntax: true 139 | basic_not_in_loop: true 140 | basic_nonexistent_operator: true 141 | basic_missing_reversed_argument: true 142 | basic_missing_module_attribute: true 143 | basic_lost_exception: true 144 | basic_init_is_generator: true 145 | basic_function_redefined: true 146 | basic_expression_not_assigned: true 147 | basic_exec_used: true 148 | basic_eval_used: true 149 | basic_empty_docstring: true 150 | basic_duplicate_key: true 151 | basic_duplicate_argument_name: true 152 | basic_dangerous_default_value: true 153 | basic_bad_reversed_sequence: true 154 | basic_assert_on_tuple: true 155 | basic_abstract_class_instantiated: true 156 | filter: 157 | paths: 158 | - fuel/* 159 | -------------------------------------------------------------------------------- /.travis-data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Download and convert given datasets if not present already 3 | # Usage: .travis-download.sh [dataset ...] 4 | set -ev 5 | 6 | function download { 7 | if [ ! -f $FUEL_DATA_PATH/$1.hdf5 ]; then 8 | fuel-download $@ 9 | fuel-convert $@ 10 | fuel-download $@ --clear 11 | fi 12 | } 13 | 14 | cd $FUEL_DATA_PATH 15 | 16 | for dataset in "$@"; do 17 | if [ "$dataset" == "ilsvrc2010" ]; then 18 | wget "http://www.image-net.org/challenges/LSVRC/2010/download/ILSVRC2010_devkit-1.0.tar.gz" 19 | wget "http://www.image-net.org/challenges/LSVRC/2010/ILSVRC2010_test_ground_truth.txt" 20 | else 21 | download $dataset 22 | fi 23 | done 24 | 25 | 26 | 27 | cd - 28 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | cache: 3 | directories: 4 | - $TRAVIS_BUILD_DIR/data 5 | branches: 6 | only: 7 | - master 8 | - stable 9 | language: python 10 | python: 11 | - "2.7" 12 | - "3.5" 13 | env: 14 | - TESTS=fuel 15 | - TESTS=blocks 16 | before_install: 17 | # Setup Python environment with BLAS libraries 18 | - | 19 | if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then 20 | # "./.travis-data.sh: line 8: fuel-download: command not found" with minocnda 4.4.10 21 | wget -q http://repo.continuum.io/miniconda/Miniconda2-4.3.31-Linux-x86_64.sh -O miniconda.sh 22 | else 23 | wget -q http://repo.continuum.io/miniconda/Miniconda3-4.3.31-Linux-x86_64.sh -O miniconda.sh 24 | fi 25 | - chmod +x miniconda.sh 26 | - ./miniconda.sh -b -p $HOME/miniconda 27 | - export PATH=$HOME/miniconda/bin:$PATH 28 | - python --version 29 | - pip --version 30 | - conda --version 31 | - export FUEL_DATA_PATH=$TRAVIS_BUILD_DIR/data 32 | - export FUEL_LOCAL_DATA_PATH=$TRAVIS_BUILD_DIR/data_local 33 | - export FUEL_FLOATX=float64 34 | install: 35 | # Install all Python dependencies 36 | - | 37 | if [[ $TESTS == 'blocks' ]]; then 38 | curl -O https://raw.githubusercontent.com/mila-udem/blocks/$TRAVIS_BRANCH/req-travis-conda.txt 39 | conda install -q --yes python=$TRAVIS_PYTHON_VERSION --file req-travis-conda.txt 40 | pip install -r https://raw.githubusercontent.com/mila-udem/blocks/$TRAVIS_BRANCH/req-travis-pip.txt 41 | pip install -e git+git://github.com/mila-udem/blocks.git@$TRAVIS_BRANCH#egg=blocks --src=$HOME -r https://raw.githubusercontent.com/mila-udem/blocks/$TRAVIS_BRANCH/requirements.txt 42 | fi 43 | - | 44 | if [[ $TESTS == 'fuel' ]]; then 45 | conda install -q --yes python=$TRAVIS_PYTHON_VERSION --file req-travis-conda.txt 46 | pip install -r req-travis-pip.txt 47 | pip install . -r requirements.txt # Installs the fuel-download command needed by .travis-data.sh 48 | python setup.py build_ext --inplace 49 | fi 50 | script: 51 | - ./.travis-data.sh adult mnist binarized_mnist "caltech101_silhouettes 16" cifar10 cifar100 iris ilsvrc2010 52 | - function fail { export FAILED=1; } 53 | - | 54 | if [[ $TESTS == 'blocks' ]]; then 55 | export MKL_THREADING_LAYER=GNU # For Theano 56 | bokeh-server &> /dev/null & 57 | export PYTHONPATH=$HOME/blocks 58 | nose2 tests --start-dir $HOME/blocks || fail 59 | nose2 doctests --start-dir $HOME/blocks || fail 60 | return $FAILED 61 | fi 62 | - | 63 | if [[ $TESTS == 'fuel' ]]; then 64 | # Running nose2 within coverage makes imports count towards coverage 65 | coverage run -p --source=fuel -m nose2.__main__ -v tests || fail 66 | coverage run -p --source=fuel -m nose2.__main__ -v doctests || fail 67 | return $FAILED 68 | fi 69 | after_script: 70 | - | 71 | if [[ $TESTS == 'fuel' ]]; then 72 | coverage combine 73 | coveralls 74 | fi 75 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Bart van Merriënboer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://img.shields.io/coveralls/mila-udem/fuel.svg 2 | :target: https://coveralls.io/r/mila-udem/fuel 3 | 4 | .. image:: https://travis-ci.org/mila-udem/fuel.svg?branch=master 5 | :target: https://travis-ci.org/mila-udem/fuel 6 | 7 | .. image:: https://readthedocs.org/projects/fuel/badge/?version=latest 8 | :target: https://fuel.readthedocs.org/ 9 | 10 | .. image:: https://img.shields.io/scrutinizer/g/mila-udem/fuel.svg 11 | :target: https://scrutinizer-ci.com/g/mila-udem/fuel/ 12 | 13 | .. image:: https://requires.io/github/mila-udem/fuel/requirements.svg?branch=master 14 | :target: https://requires.io/github/mila-udem/fuel/requirements/?branch=master 15 | 16 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg 17 | :target: https://github.com/mila-udem/fuel/blob/master/LICENSE 18 | 19 | Fuel 20 | ==== 21 | 22 | Fuel provides your machine learning models with the data they need to learn. 23 | 24 | * Interfaces to common datasets such as MNIST, CIFAR-10 (image datasets), Google's One Billion Words (text), and many more 25 | * The ability to iterate over your data in a variety of ways, such as in minibatches with shuffled/sequential examples 26 | * A pipeline of preprocessors that allow you to edit your data on-the-fly, for example by adding noise, extracting n-grams from sentences, extracting patches from images, etc. 27 | * Ensure that the entire pipeline is serializable with pickle; this is a requirement for being able to checkpoint and resume long-running experiments. For this, we rely heavily on the picklable_itertools_ library. 28 | 29 | Fuel is developed primarily for use by Blocks_, a Theano toolkit that helps you train neural networks. 30 | 31 | If you have questions, don't hesitate to write to the `mailing list`_. 32 | 33 | Citing Fuel 34 | If you use Blocks or Fuel in your work, we'd really appreciate it if you could cite the following paper: 35 | 36 | Bart van Merriënboer, Dzmitry Bahdanau, Vincent Dumoulin, Dmitriy Serdyuk, David Warde-Farley, Jan Chorowski, and Yoshua Bengio, "`Blocks and Fuel: Frameworks for deep learning`_," *arXiv preprint arXiv:1506.00619 [cs.LG]*, 2015. 37 | 38 | Documentation 39 | Please see the documentation_ for more information. 40 | 41 | 42 | .. _picklable_itertools: http://github.com/dwf/picklable_itertools 43 | .. _Blocks: http://github.com/mila-udem/blocks 44 | .. _mailing list: https://groups.google.com/d/forum/fuel-users 45 | .. _documentation: http://fuel.readthedocs.org/en/latest/ 46 | .. _Blocks and Fuel\: Frameworks for deep learning: http://arxiv.org/abs/1506.00619 47 | -------------------------------------------------------------------------------- /docs/api/converters.rst: -------------------------------------------------------------------------------- 1 | Converters 2 | ========== 3 | 4 | Base classes 5 | ------------ 6 | 7 | .. automodule:: fuel.converters.base 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Adult 13 | ----- 14 | 15 | .. automodule:: fuel.converters.adult 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | CalTech 101 Silhouettes 21 | ----------------------- 22 | 23 | .. automodule:: fuel.converters.caltech101_silhouettes 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Binarized MNIST 29 | --------------- 30 | 31 | .. automodule:: fuel.converters.binarized_mnist 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | CIFAR100 37 | -------- 38 | 39 | .. automodule:: fuel.converters.cifar100 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | CIFAR10 45 | ------- 46 | 47 | .. automodule:: fuel.converters.cifar10 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | IRIS 53 | ---- 54 | 55 | .. automodule:: fuel.converters.iris 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | MNIST 61 | ----- 62 | 63 | .. automodule:: fuel.converters.mnist 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | SVHN 69 | ---- 70 | 71 | .. automodule:: fuel.converters.svhn 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | 76 | -------------------------------------------------------------------------------- /docs/api/data_streams.rst: -------------------------------------------------------------------------------- 1 | Data streams 2 | ============ 3 | 4 | .. automodule:: fuel.streams 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/dataset.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ======== 3 | 4 | Base classes 5 | ------------ 6 | 7 | .. automodule:: fuel.datasets.base 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Adult 13 | ----- 14 | 15 | .. automodule:: fuel.datasets.adult 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | One Billion Word 21 | ---------------- 22 | 23 | .. automodule:: fuel.datasets.billion 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | CalTech 101 Silhouettes 29 | ----------------------- 30 | 31 | .. automodule:: fuel.datasets.caltech101_silhouettes 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Binarized MNIST 37 | --------------- 38 | 39 | .. automodule:: fuel.datasets.binarized_mnist 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | CIFAR100 45 | -------- 46 | 47 | .. automodule:: fuel.datasets.cifar100 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | CIFAR10 53 | ------- 54 | 55 | .. automodule:: fuel.datasets.cifar10 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | IRIS 61 | ---- 62 | 63 | .. automodule:: fuel.datasets.iris 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | MNIST 69 | ----- 70 | 71 | .. automodule:: fuel.datasets.mnist 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | 76 | SVHN 77 | ---- 78 | 79 | .. automodule:: fuel.datasets.svhn 80 | :members: 81 | :undoc-members: 82 | :show-inheritance: 83 | 84 | Text-based datasets 85 | ------------------- 86 | 87 | .. automodule:: fuel.datasets.text 88 | :members: 89 | :undoc-members: 90 | :show-inheritance: 91 | 92 | Toy datasets 93 | ------------ 94 | 95 | .. automodule:: fuel.datasets.toy 96 | :members: 97 | :undoc-members: 98 | :show-inheritance: 99 | 100 | -------------------------------------------------------------------------------- /docs/api/downloaders.rst: -------------------------------------------------------------------------------- 1 | Downloaders 2 | =========== 3 | 4 | Base Classes 5 | ------------ 6 | 7 | .. automodule:: fuel.downloaders.base 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Adult 13 | ----- 14 | 15 | .. automodule:: fuel.downloaders.adult 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | CalTech 101 Silhouettes 21 | ----------------------- 22 | 23 | .. automodule:: fuel.downloaders.caltech101_silhouettes 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Binarized MNIST 29 | --------------- 30 | 31 | .. automodule:: fuel.downloaders.binarized_mnist 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | CIFAR100 37 | -------- 38 | 39 | .. automodule:: fuel.downloaders.cifar100 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | CIFAR10 45 | ------- 46 | 47 | .. automodule:: fuel.downloaders.cifar10 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | IRIS 53 | ---- 54 | 55 | .. automodule:: fuel.downloaders.iris 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | MNIST 61 | ----- 62 | 63 | .. automodule:: fuel.downloaders.mnist 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | SVHN 69 | ---- 70 | 71 | .. automodule:: fuel.downloaders.svhn 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | 76 | -------------------------------------------------------------------------------- /docs/api/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. warning:: 5 | 6 | This API reference is currently nothing but a dump of docstrings, ordered 7 | alphabetically. 8 | 9 | .. toctree:: 10 | :glob: 11 | 12 | * 13 | -------------------------------------------------------------------------------- /docs/api/iteration_schemes.rst: -------------------------------------------------------------------------------- 1 | Iteration schemes 2 | ================= 3 | 4 | .. automodule:: fuel.schemes 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | -------------------------------------------------------------------------------- /docs/api/transformers.rst: -------------------------------------------------------------------------------- 1 | Transformers 2 | ============ 3 | 4 | General transformers 5 | -------------------- 6 | 7 | .. automodule:: fuel.transformers.defaults 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Transformers for image 13 | ---------------------- 14 | 15 | .. automodule:: fuel.transformers.image 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | Transformers for sequences 21 | -------------------------- 22 | 23 | .. automodule:: fuel.transformers.sequences 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Other 29 | ----- 30 | 31 | .. automodule:: fuel.transformers 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | -------------------------------------------------------------------------------- /docs/api/utils.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | .. automodule:: fuel.utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Caching 10 | ------- 11 | 12 | .. automodule:: fuel.utils.cache 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | -------------------------------------------------------------------------------- /docs/built_in_datasets.rst: -------------------------------------------------------------------------------- 1 | Built-in datasets 2 | ================= 3 | 4 | Fuel has a growing number of built-in datasets that simplify working on 5 | standard benchmark datasets, such as MNIST or CIFAR10. 6 | 7 | These datasets are defined in the ``fuel.datasets`` module. Some user 8 | intervention is needed before they're used for the first time: a given 9 | dataset has to be downloaded and converted into a format that is recognized by 10 | its corresponding dataset class. Fortunately, Fuel also has built-in tools 11 | to automate these operations. 12 | 13 | Environment variable 14 | -------------------- 15 | 16 | In order for Fuel to know where to look for its data, the ``data_path`` 17 | configuration variable has to be set inside ``~/.fuelrc``. It's expected to be 18 | a sequence of paths separated by an OS-specific delimiter (``:`` for Linux and 19 | OSX, ``;`` for Windows): 20 | 21 | .. code-block:: yaml 22 | 23 | # ~/.fuelrc 24 | data_path: "/first/path/to/my/data:/second/path/to/my/data" 25 | 26 | When looking for a specific file (e.g. ``mnist.hdf5``), Fuel will search each of 27 | these paths in sequence, using the first matching file that it finds. 28 | 29 | This configuration variable can be overridden by setting the ``FUEL_DATA_PATH`` 30 | environment variable: 31 | 32 | .. code-block:: bash 33 | 34 | $ export FUEL_DATA_PATH="/first/path/to/my/data:/second/path/to/my/data" 35 | 36 | Let's now change directory for the rest of this tutorial: 37 | 38 | .. code-block:: bash 39 | 40 | $ cd $FUEL_DATA_PATH 41 | 42 | Download a built-in dataset 43 | --------------------------- 44 | 45 | We're going to download the raw data files for the MNIST dataset with the 46 | ``fuel-download`` script that was installed with Fuel: 47 | 48 | .. code-block:: bash 49 | 50 | $ fuel-download mnist 51 | 52 | The script is pretty simple: you call it and pass it the name of the dataset 53 | you'd like to download. In order to know which datasets are available to 54 | download via ``fuel-download``, type 55 | 56 | .. code-block:: bash 57 | 58 | $ fuel-download -h 59 | 60 | You can pass dataset-specific arguments to the script. In order to know which 61 | arguments are accepted, append ``-h`` to your dataset choice: 62 | 63 | .. code-block:: bash 64 | 65 | fuel-download mnist -h 66 | 67 | Two arguments are always accepted: 68 | 69 | * ``-d DIRECTORY`` : define where the dataset files will be downloaded. By 70 | default, ``fuel-download`` uses the current working directory. 71 | * ``--clear`` : delete the dataset files instead of downloading them, if they 72 | exist. 73 | 74 | Convert downloaded files 75 | ------------------------ 76 | 77 | You should now have four new files in your directory: 78 | 79 | * ``train-images-idx3-ubyte.gz`` 80 | * ``train-labels-idx1-ubyte.gz`` 81 | * ``t10k-images-idx3-ubyte.gz`` 82 | * ``t10k-labels-idx1-ubyte.gz`` 83 | 84 | Those are the original files that can be downloaded off Yann Lecun's website. 85 | We now need to convert those files into a format that the ``MNIST`` dataset 86 | class will recognize. This is done through the ``fuel-convert`` script: 87 | 88 | .. code-block:: bash 89 | 90 | $ fuel-convert mnist 91 | 92 | This will generate an ``mnist.hdf5`` file in your directory, which the 93 | ``MNIST`` class recognizes. 94 | 95 | Once again, the script accepts dataset-specific arguments which you can discover 96 | by appending ``-h`` to your dataset choice: 97 | 98 | .. code-block:: bash 99 | 100 | fuel-convert mnist -h 101 | 102 | Two arguments are always accepted: 103 | 104 | * ``-d DIRECTORY`` : where ``fuel-convert`` should look for the input files. 105 | * ``-o OUTPUT_FILE`` : where to save the converted dataset. 106 | 107 | Let's delete the raw input files, as we don't need them anymore: 108 | 109 | .. code-block:: bash 110 | 111 | $ fuel-download mnist --clear 112 | 113 | Inspect Fuel-generated dataset files 114 | ------------------------------------ 115 | 116 | Six months from now, you may have a bunch of dataset files lying on disk, each 117 | with slight differences that you can't identify or reproduce. At that time, 118 | you'll be glad that ``fuel-info`` exists. 119 | 120 | When a dataset is generated through ``fuel-convert``, the script tags it with 121 | what command was issued to generate the file and what were the versions of 122 | relevant parts of the library at that time. 123 | 124 | You can inspect this metadata calling ``fuel-info`` and passing an HDF5 file as 125 | argument: 126 | 127 | .. code-block:: bash 128 | 129 | $ fuel-info mnist.hdf5 130 | 131 | .. code-block:: text 132 | 133 | Metadata for mnist.hdf5 134 | ======================= 135 | 136 | The command used to generate this file is 137 | 138 | fuel-convert mnist 139 | 140 | Relevant versions are 141 | 142 | H5PYDataset 0.1 143 | fuel.converters 0.1 144 | 145 | 146 | Working with external packages 147 | ------------------------------ 148 | 149 | By default, Fuel looks for downloaders and converters in the 150 | ``fuel.downloaders`` and ``fuel.converters`` modules, respectively, but you're 151 | not limited to that. 152 | 153 | Fuel can be told to look into additional modules by setting the 154 | ``extra_downloaders`` and ``extra_converters`` configuration variables in 155 | ``~/.fuelrc``. These variables are expected to be lists of module names. 156 | 157 | For instance, suppose you'd like to include the following modules: 158 | 159 | * ``package1.extra_downloaders`` 160 | * ``package2.extra_downloaders`` 161 | * ``package1.extra_converters`` 162 | * ``package2.extra_converters`` 163 | 164 | You should include the following in your ``~/.fuelrc``: 165 | 166 | .. code-block:: yaml 167 | 168 | # ~/.fuelrc 169 | extra_downloaders: 170 | - package1.extra_downloaders 171 | - package2.extra_downloaders 172 | extra_converters: 173 | - package1.extra_converters 174 | - package2.extra_converters 175 | 176 | These configuration variables can be overridden through the 177 | ``FUEL_EXTRA_DOWNLOADERS`` and ``FUEL_EXTRA_CONVERTERS`` environment variables, 178 | which are expected to be strings of space-separated module names, like so: 179 | 180 | .. code-block:: bash 181 | 182 | export FUEL_EXTRA_DOWNLOADERS="package1.extra_downloaders package2.extra_downloaders" 183 | export FUEL_EXTRA_CONVERTERS="package1.extra_converters package2.extra_converters" 184 | 185 | This feature lets external developers define their own Fuel dataset 186 | downloader/converter packages, and also makes working with private datasets more 187 | straightforward. 188 | -------------------------------------------------------------------------------- /docs/caching.rst: -------------------------------------------------------------------------------- 1 | Caching datasets locally 2 | ======================== 3 | 4 | In some use cases, it may be desirable to set Fuel's ``data_path`` to 5 | point to a shared network drive. For example, when configuring multiple 6 | machines in a cluster to work on the same data in parallel. 7 | However, this can easily cause network bandwidth to become saturated. 8 | 9 | To avoid this problem, Fuel provides a second configuration variable 10 | named ``local_data_path``, which can be set in ``~/.fuelrc``. This 11 | variable points to a filesystem directory to be used to act as a local 12 | cache for datasets. 13 | 14 | This variable can also be set through an environment variable as follows: 15 | 16 | .. code-block:: bash 17 | 18 | $ export FUEL_LOCAL_DATA_PATH="/LOCAL_PATH/my_local_cache" 19 | 20 | Please note that currently, caching is only implemented in the ``H5PyDataset``. 21 | In order to add caching to other types of datasets, one should use the 22 | :func:`fuel.utils.cache.cache_file` function. 23 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Fuel's documentation! 2 | ================================ 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | setup 8 | overview 9 | built_in_datasets 10 | h5py_dataset 11 | new_dataset 12 | extending_fuel 13 | server 14 | caching 15 | api/index 16 | 17 | Fuel is a data pipeline framework which provides your machine learning models 18 | with the data they need. It is planned to be used by both the Blocks_ and 19 | Pylearn2_ neural network libraries. 20 | 21 | * Fuel allows you to easily read different types of data (NumPy binary files, 22 | CSV files, HDF5 files, text files) using a single interface which is based on 23 | Python's iterator types. 24 | * Provides a a series of wrappers around frequently used datasets such as 25 | MNIST, CIFAR-10 (vision), the One Billion Word Dataset (text corpus), and 26 | many more. 27 | * Allows you iterate over data in a variety of ways, e.g. in order, shuffled, 28 | sampled, etc. 29 | * Gives you the possibility to process your data on-the-fly through a series of 30 | (chained) transformation procedures. This way you can whiten your data, 31 | noise, rotate, crop, pad, sort or shuffle, cache it, and much more. 32 | * Is pickle-friendly, allowing you to stop and resume long-running experiments 33 | in the middle of a pass over your dataset without losing any training 34 | progress. 35 | 36 | .. warning:: 37 | Fuel is a new project which is still under development. As such, certain 38 | (all) parts of the framework are subject to change. The last stable (but 39 | possibly outdated) release can be found in the ``stable`` branch. 40 | 41 | .. tip:: 42 | 43 | That said, if you are interested in using Fuel and run into any problems, 44 | feel free to ask your question on the `mailing list`_. Also, don't hesitate 45 | to file bug reports and feature requests by `making a GitHub issue`_. 46 | 47 | .. _mailing list: https://groups.google.com/d/forum/fuel-users 48 | .. _making a GitHub issue: https://github.com/mila-udem/fuel/issues/new 49 | .. _Blocks: https://github.com/mila-udem/blocks 50 | .. _Pylearn2: https://github.com/lisa-lab/pylearn2 51 | 52 | Motivation 53 | ---------- 54 | 55 | Fuel was originally factored out of the Blocks_ framework in the hope of being 56 | useful to other frameworks such as Pylearn2_ as well. It shares similarities 57 | with the skdata_ package, but with a much heavier focus on data iteration and 58 | processing. 59 | 60 | .. _skdata: https://github.com/jaberg/skdata 61 | 62 | Quickstart 63 | ========== 64 | 65 | The best way to get started with Fuel is to have a look at the 66 | :doc:`overview ` documentation section. 67 | 68 | Indices and tables 69 | ================== 70 | * :ref:`genindex` 71 | * :ref:`modindex` 72 | -------------------------------------------------------------------------------- /docs/setup.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | The easiest way to install Fuel is using the Python package manager ``pip``. 5 | Fuel isn't listed yet on the Python Package Index (PyPI), so you will 6 | have to grab it directly from GitHub. 7 | 8 | .. code-block:: bash 9 | 10 | $ pip install git+git://github.com/mila-udem/fuel.git 11 | 12 | This will give you the cutting-edge development version. The latest stable 13 | release is in the ``stable`` branch and can be installed as follows. 14 | 15 | .. code-block:: bash 16 | 17 | $ pip install git+git://github.com/mila-udem/fuel.git@stable 18 | 19 | If you don't have administrative rights, add the ``--user`` switch to the 20 | install commands to install the packages in your home folder. If you want to 21 | update Fuel, simply repeat the first command with the ``--upgrade`` switch 22 | added to pull the latest version from GitHub. 23 | 24 | .. warning:: 25 | 26 | Pip may try to install or update NumPy and SciPy if they are not present or 27 | outdated. However, pip's versions might not be linked to an optimized BLAS 28 | implementation. To prevent this from happening make sure you update NumPy 29 | and SciPy using your system's package manager (e.g. ``apt-get`` or 30 | ``yum``), or use a Python distribution like Anaconda_, before installing 31 | Fuel. You can also pass the ``--no-deps`` switch and install all the 32 | requirements manually. 33 | 34 | If the installation crashes with ``ImportError: No module named 35 | numpy.distutils.core``, install NumPy and try again again. 36 | 37 | 38 | Requirements 39 | ------------ 40 | Fuel's requirements are 41 | 42 | * PyYAML_, to parse the configuration file 43 | * six_, to support both Python 2 and 3 with a single codebase 44 | * h5py_ and PyTables_ for the HDF5 storage back-end 45 | * pillow_, providing PIL for image preprocessing 46 | * Cython_, for fast extensions 47 | * pyzmq_, to efficiently send data across processes 48 | * picklable_itertools_, for supporting iterator serialization 49 | * SciPy_, to read from MATLAB's .mat format 50 | * requests_, to download canonical datasets 51 | 52 | nose2_ is an optional requirement, used to run the tests. 53 | 54 | .. _Anaconda: https://store.continuum.io/cshop/anaconda/ 55 | .. _nose2: https://nose2.readthedocs.org/ 56 | .. _PyYAML: http://pyyaml.org/wiki/PyYAML 57 | .. _six: http://pythonhosted.org/six/ 58 | .. _h5py: http://www.h5py.org/ 59 | .. _PyTables: http://www.pytables.org/ 60 | .. _SciPy: http://www.scipy.org/ 61 | .. _pillow: https://python-pillow.github.io/ 62 | .. _Cython: http://cython.org/ 63 | .. _pyzmq: https://zeromq.github.io/pyzmq/ 64 | .. _picklable_itertools: https://github.com/dwf/picklable_itertools 65 | .. _requests: http://docs.python-requests.org/en/latest/ 66 | 67 | Development 68 | ----------- 69 | 70 | If you want to work on Fuel's development, your first step is to `fork Fuel 71 | on GitHub`_. You will now want to install your fork of Fuel in editable mode. 72 | To install in your home directory, use the following command, replacing ``USER`` 73 | with your own GitHub user name: 74 | 75 | .. code-block:: bash 76 | 77 | $ pip install -e git+git@github.com:USER/fuel.git#egg=fuel[test,docs] --src=$HOME 78 | 79 | As with the usual installation, you can use ``--user`` or ``--no-deps`` if you 80 | need to. You can now make changes in the ``fuel`` directory created by pip, 81 | push to your repository and make a pull request. 82 | 83 | If you had already cloned the GitHub repository, you can use the following 84 | command from the folder you cloned Fuel to: 85 | 86 | .. code-block:: bash 87 | 88 | $ pip install -e file:.#egg=fuel[test,docs] 89 | 90 | Fuel contains Cython extensions, which need to be recompiled if you 91 | update the Cython `.pyx` files. Each time these files are modified, you 92 | should run: 93 | 94 | .. code-block:: bash 95 | 96 | $ python setup.py build_ext --inplace 97 | 98 | .. _fork Fuel on GitHub: https://github.com/mila-udem/fuel/fork 99 | 100 | Documentation 101 | ~~~~~~~~~~~~~ 102 | 103 | If you want to build a local copy of the documentation, you can follow 104 | the instructions in the `documentation development guidelines`_. 105 | 106 | .. _documentation development guidelines: 107 | http://blocks.readthedocs.org/en/latest/development/docs.html 108 | -------------------------------------------------------------------------------- /doctests/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | import doctest 4 | import fnmatch 5 | import importlib 6 | import os 7 | import pkgutil 8 | import re 9 | import sys 10 | from doctest import OutputChecker 11 | 12 | import fuel 13 | from tests import skip_if_not_available 14 | 15 | 16 | class Py23DocChecker(OutputChecker): 17 | """Single-source Python 2/3 output checker. 18 | 19 | For more information, see the `original blog post`_. 20 | 21 | .. _original blog post: 22 | https://dirkjan.ochtman.nl/writing/2014/07/06/ 23 | single-source-python-23-doctests.html 24 | 25 | """ 26 | def check_output(self, want, got, optionflags): 27 | if sys.version_info[0] < 3: 28 | got = re.sub("u'(.*?)'", "'\\1'", got) 29 | got = re.sub('u"(.*?)"', '"\\1"', got) 30 | return OutputChecker.check_output(self, want, got, optionflags) 31 | 32 | 33 | def setup(testobj): 34 | skip_if_not_available(modules=['nose2']) 35 | # Not importing unicode_literal because it gives problems 36 | # If needed, see https://dirkjan.ochtman.nl/writing/2014/07/06/ 37 | # single-source-python-23-doctests.html for a solution 38 | testobj.globs['absolute_import'] = absolute_import 39 | testobj.globs['print_function'] = print_function 40 | 41 | 42 | def load_tests(loader, tests, ignore): 43 | # This function loads doctests from all submodules and runs them 44 | # with the __future__ imports necessary for Python 2 45 | for _, module, _ in pkgutil.walk_packages(path=fuel.__path__, 46 | prefix=fuel.__name__ + '.'): 47 | try: 48 | tests.addTests(doctest.DocTestSuite( 49 | module=importlib.import_module(module), setUp=setup, 50 | optionflags=doctest.IGNORE_EXCEPTION_DETAIL, 51 | checker=Py23DocChecker())) 52 | except Exception: 53 | pass 54 | 55 | # This part loads the doctests from the documentation 56 | docs = [] 57 | for root, _, filenames in os.walk(os.path.join(fuel.__path__[0], 58 | '../docs')): 59 | for doc in fnmatch.filter(filenames, '*.rst'): 60 | docs.append(os.path.abspath(os.path.join(root, doc))) 61 | tests.addTests(doctest.DocFileSuite( 62 | *docs, module_relative=False, setUp=setup, 63 | optionflags=doctest.IGNORE_EXCEPTION_DETAIL, 64 | checker=Py23DocChecker())) 65 | 66 | return tests 67 | -------------------------------------------------------------------------------- /fuel/__init__.py: -------------------------------------------------------------------------------- 1 | import fuel.version 2 | from fuel.config_parser import config # noqa 3 | 4 | __version__ = fuel.version.version 5 | -------------------------------------------------------------------------------- /fuel/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mila-iqia/fuel/1d6292dc25e3a115544237e392e61bff6631d23c/fuel/bin/__init__.py -------------------------------------------------------------------------------- /fuel/bin/fuel_convert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Fuel dataset conversion utility.""" 3 | import argparse 4 | import importlib 5 | import os 6 | import sys 7 | 8 | import h5py 9 | 10 | import fuel 11 | from fuel import converters 12 | from fuel.converters.base import MissingInputFiles 13 | from fuel.datasets import H5PYDataset 14 | 15 | 16 | class CheckDirectoryAction(argparse.Action): 17 | def __call__(self, parser, namespace, values, option_string=None): 18 | if os.path.isdir(values): 19 | setattr(namespace, self.dest, values) 20 | else: 21 | raise ValueError('{} is not a existing directory'.format(values)) 22 | 23 | 24 | def main(args=None): 25 | """Entry point for `fuel-convert` script. 26 | 27 | This function can also be imported and used from Python. 28 | 29 | Parameters 30 | ---------- 31 | args : iterable, optional (default: None) 32 | A list of arguments that will be passed to Fuel's conversion 33 | utility. If this argument is not specified, `sys.argv[1:]` will 34 | be used. 35 | 36 | """ 37 | built_in_datasets = dict(converters.all_converters) 38 | if fuel.config.extra_converters: 39 | for name in fuel.config.extra_converters: 40 | extra_datasets = dict( 41 | importlib.import_module(name).all_converters) 42 | if any(key in built_in_datasets for key in extra_datasets.keys()): 43 | raise ValueError('extra converters conflict in name with ' 44 | 'built-in converters') 45 | built_in_datasets.update(extra_datasets) 46 | parser = argparse.ArgumentParser( 47 | description='Conversion script for built-in datasets.') 48 | subparsers = parser.add_subparsers() 49 | parent_parser = argparse.ArgumentParser(add_help=False) 50 | parent_parser.add_argument( 51 | "-d", "--directory", help="directory in which input files reside", 52 | type=str, default=os.getcwd()) 53 | convert_functions = {} 54 | for name, fill_subparser in built_in_datasets.items(): 55 | subparser = subparsers.add_parser( 56 | name, parents=[parent_parser], 57 | help='Convert the {} dataset'.format(name)) 58 | subparser.add_argument( 59 | "-o", "--output-directory", help="where to save the dataset", 60 | type=str, default=os.getcwd(), action=CheckDirectoryAction) 61 | subparser.add_argument( 62 | "-r", "--output_filename", help="new name of the created dataset", 63 | type=str, default=None) 64 | # Allows the parser to know which subparser was called. 65 | subparser.set_defaults(which_=name) 66 | convert_functions[name] = fill_subparser(subparser) 67 | 68 | args = parser.parse_args(args) 69 | args_dict = vars(args) 70 | if args_dict['output_filename'] is not None and\ 71 | os.path.splitext(args_dict['output_filename'])[1] not in\ 72 | ('.hdf5', '.hdf', '.h5'): 73 | args_dict['output_filename'] += '.hdf5' 74 | if args_dict['output_filename'] is None: 75 | args_dict.pop('output_filename') 76 | 77 | convert_function = convert_functions[args_dict.pop('which_')] 78 | try: 79 | output_paths = convert_function(**args_dict) 80 | except MissingInputFiles as e: 81 | intro = "The following required files were not found:\n" 82 | message = "\n".join([intro] + [" * " + f for f in e.filenames]) 83 | message += "\n\nDid you forget to run fuel-download?" 84 | parser.error(message) 85 | 86 | # Tag the newly-created file(s) with H5PYDataset version and command-line 87 | # options 88 | for output_path in output_paths: 89 | h5file = h5py.File(output_path, 'a') 90 | interface_version = H5PYDataset.interface_version.encode('utf-8') 91 | h5file.attrs['h5py_interface_version'] = interface_version 92 | fuel_convert_version = converters.__version__.encode('utf-8') 93 | h5file.attrs['fuel_convert_version'] = fuel_convert_version 94 | command = [os.path.basename(sys.argv[0])] + sys.argv[1:] 95 | h5file.attrs['fuel_convert_command'] = ( 96 | ' '.join(command).encode('utf-8')) 97 | h5file.flush() 98 | h5file.close() 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /fuel/bin/fuel_download.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Fuel dataset downloading utility.""" 3 | import argparse 4 | import importlib 5 | import os 6 | 7 | import fuel 8 | from fuel import downloaders 9 | from fuel.downloaders.base import NeedURLPrefix 10 | 11 | url_prefix_message = """ 12 | Some files for this dataset do not have a download URL. 13 | 14 | Provide a URL prefix with --url-prefix to prepend to the filenames, 15 | e.g. http://path.to/files/ 16 | """.strip() 17 | 18 | 19 | def main(args=None): 20 | """Entry point for `fuel-download` script. 21 | 22 | This function can also be imported and used from Python. 23 | 24 | Parameters 25 | ---------- 26 | args : iterable, optional (default: None) 27 | A list of arguments that will be passed to Fuel's downloading 28 | utility. If this argument is not specified, `sys.argv[1:]` will 29 | be used. 30 | 31 | """ 32 | built_in_datasets = dict(downloaders.all_downloaders) 33 | if fuel.config.extra_downloaders: 34 | for name in fuel.config.extra_downloaders: 35 | extra_datasets = dict( 36 | importlib.import_module(name).all_downloaders) 37 | if any(key in built_in_datasets for key in extra_datasets.keys()): 38 | raise ValueError('extra downloaders conflict in name with ' 39 | 'built-in downloaders') 40 | built_in_datasets.update(extra_datasets) 41 | parser = argparse.ArgumentParser( 42 | description='Download script for built-in datasets.') 43 | parent_parser = argparse.ArgumentParser(add_help=False) 44 | parent_parser.add_argument( 45 | "-d", "--directory", help="where to save the downloaded files", 46 | type=str, default=os.getcwd()) 47 | parent_parser.add_argument( 48 | "--clear", help="clear the downloaded files", action='store_true') 49 | subparsers = parser.add_subparsers() 50 | download_functions = {} 51 | for name, fill_subparser in built_in_datasets.items(): 52 | subparser = subparsers.add_parser( 53 | name, parents=[parent_parser], 54 | help='Download the {} dataset'.format(name)) 55 | # Allows the parser to know which subparser was called. 56 | subparser.set_defaults(which_=name) 57 | download_functions[name] = fill_subparser(subparser) 58 | args = parser.parse_args() 59 | args_dict = vars(args) 60 | download_function = download_functions[args_dict.pop('which_')] 61 | try: 62 | download_function(**args_dict) 63 | except NeedURLPrefix: 64 | parser.error(url_prefix_message) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /fuel/bin/fuel_info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Fuel utility for extracting metadata.""" 3 | import argparse 4 | import os 5 | 6 | import h5py 7 | 8 | message_prefix_template = 'Metadata for {}' 9 | message_body_template = """ 10 | 11 | The command used to generate this file is 12 | 13 | {} 14 | 15 | Relevant versions are 16 | 17 | H5PYDataset {} 18 | fuel.converters {} 19 | """ 20 | 21 | 22 | def main(args=None): 23 | """Entry point for `fuel-info` script. 24 | 25 | This function can also be imported and used from Python. 26 | 27 | Parameters 28 | ---------- 29 | args : iterable, optional (default: None) 30 | A list of arguments that will be passed to Fuel's information 31 | utility. If this argument is not specified, `sys.argv[1:]` will 32 | be used. 33 | 34 | """ 35 | parser = argparse.ArgumentParser( 36 | description='Extracts metadata from a Fuel-converted HDF5 file.') 37 | parser.add_argument("filename", help="HDF5 file to analyze") 38 | args = parser.parse_args() 39 | 40 | with h5py.File(args.filename, 'r') as h5file: 41 | interface_version = h5file.attrs.get('h5py_interface_version', 'N/A') 42 | fuel_convert_version = h5file.attrs.get('fuel_convert_version', 'N/A') 43 | fuel_convert_command = h5file.attrs.get('fuel_convert_command', 'N/A') 44 | 45 | message_prefix = message_prefix_template.format( 46 | os.path.basename(args.filename)) 47 | message_body = message_body_template.format( 48 | fuel_convert_command, interface_version, fuel_convert_version) 49 | message = ''.join(['\n', message_prefix, '\n', '=' * len(message_prefix), 50 | message_body]) 51 | print(message) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /fuel/converters/__init__.py: -------------------------------------------------------------------------------- 1 | """Data conversion modules for built-in datasets. 2 | 3 | Conversion submodules generate an HDF5 file that is compatible with 4 | their corresponding built-in dataset. 5 | 6 | Conversion functions accept a single argument, `subparser`, which is an 7 | `argparse.ArgumentParser` instance that it needs to fill with its own 8 | specific arguments. They should set a `func` default argument for the 9 | subparser with a function that will get called and given the parsed 10 | command-line arguments, and is expected to download the required files. 11 | 12 | """ 13 | from fuel.converters import adult 14 | from fuel.converters import binarized_mnist 15 | from fuel.converters import caltech101_silhouettes 16 | from fuel.converters import celeba 17 | from fuel.converters import cifar10 18 | from fuel.converters import cifar100 19 | from fuel.converters import dogs_vs_cats 20 | from fuel.converters import iris 21 | from fuel.converters import mnist 22 | from fuel.converters import svhn 23 | from fuel.converters import ilsvrc2010 24 | from fuel.converters import ilsvrc2012 25 | from fuel.converters import youtube_audio 26 | 27 | __version__ = '0.2' 28 | all_converters = ( 29 | ('adult', adult.fill_subparser), 30 | ('binarized_mnist', binarized_mnist.fill_subparser), 31 | ('caltech101_silhouettes', caltech101_silhouettes.fill_subparser), 32 | ('celeba', celeba.fill_subparser), 33 | ('cifar10', cifar10.fill_subparser), 34 | ('cifar100', cifar100.fill_subparser), 35 | ('dogs_vs_cats', dogs_vs_cats.fill_subparser), 36 | ('iris', iris.fill_subparser), 37 | ('mnist', mnist.fill_subparser), 38 | ('svhn', svhn.fill_subparser), 39 | ('ilsvrc2010', ilsvrc2010.fill_subparser), 40 | ('ilsvrc2012', ilsvrc2012.fill_subparser), 41 | ('youtube_audio', youtube_audio.fill_subparser)) 42 | -------------------------------------------------------------------------------- /fuel/converters/adult.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy 5 | 6 | from fuel.converters.base import fill_hdf5_file 7 | 8 | 9 | def convert_to_one_hot(y): 10 | """ 11 | converts y into one hot reprsentation. 12 | 13 | Parameters 14 | ---------- 15 | y : list 16 | A list containing continous integer values. 17 | 18 | Returns 19 | ------- 20 | one_hot : numpy.ndarray 21 | A numpy.ndarray object, which is one-hot representation of y. 22 | 23 | """ 24 | max_value = max(y) 25 | min_value = min(y) 26 | length = len(y) 27 | one_hot = numpy.zeros((length, (max_value - min_value + 1))) 28 | one_hot[numpy.arange(length), y] = 1 29 | return one_hot 30 | 31 | 32 | def convert_adult(directory, output_directory, 33 | output_filename='adult.hdf5'): 34 | """ 35 | Convert the Adult dataset to HDF5. 36 | 37 | Converts the Adult dataset to an HDF5 dataset compatible with 38 | :class:`fuel.datasets.Adult`. The converted dataset is saved as 39 | 'adult.hdf5'. 40 | This method assumes the existence of the file `adult.data` and 41 | `adult.test`. 42 | 43 | Parameters 44 | ---------- 45 | directory : str 46 | Directory in which input files reside. 47 | output_directory : str 48 | Directory in which to save the converted dataset. 49 | output_filename : str, optional 50 | Name of the saved dataset. Defaults to `adult.hdf5`. 51 | 52 | Returns 53 | ------- 54 | output_paths : tuple of str 55 | Single-element tuple containing the path to the converted dataset. 56 | 57 | """ 58 | train_path = os.path.join(directory, 'adult.data') 59 | test_path = os.path.join(directory, 'adult.test') 60 | output_path = os.path.join(output_directory, output_filename) 61 | 62 | train_content = open(train_path, 'r').readlines() 63 | test_content = open(test_path, 'r').readlines() 64 | train_content = train_content[:-1] 65 | test_content = test_content[1:-1] 66 | 67 | features_list = [] 68 | targets_list = [] 69 | for content in [train_content, test_content]: 70 | # strip out examples with missing features 71 | content = [line for line in content if line.find('?') == -1] 72 | # strip off endlines, separate entries 73 | content = list(map(lambda l: l[:-1].split(', '), content)) 74 | 75 | features = list(map(lambda l: l[:-1], content)) 76 | targets = list(map(lambda l: l[-1], content)) 77 | del content 78 | y = list(map(lambda l: [l[0] == '>'], targets)) 79 | y = numpy.array(y) 80 | del targets 81 | 82 | # Process features into a matrix 83 | variables = [ 84 | 'age', 'workclass', 'fnlwgt', 'education', 'education-num', 85 | 'marital-status', 'occupation', 'relationship', 'race', 'sex', 86 | 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country' 87 | ] 88 | continuous = set([ 89 | 'age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 90 | 'hours-per-week' 91 | ]) 92 | 93 | pieces = [] 94 | for i, var in enumerate(variables): 95 | data = list(map(lambda l: l[i], features)) 96 | if var in continuous: 97 | data = list(map(lambda l: float(l), data)) 98 | data = numpy.array(data) 99 | data = data.reshape(data.shape[0], 1) 100 | else: 101 | unique_values = list(set(data)) 102 | data = list(map(lambda l: unique_values.index(l), data)) 103 | data = convert_to_one_hot(data) 104 | pieces.append(data) 105 | 106 | X = numpy.concatenate(pieces, axis=1) 107 | 108 | features_list.append(X) 109 | targets_list.append(y) 110 | 111 | # the largets value in the last variable of test set is only 40, thus 112 | # the one hot representation has 40 at the second dimention. While in 113 | # training set it is 41. Since it lies in the last variable, so it is 114 | # safe to simply add a last column with zeros. 115 | features_list[1] = numpy.concatenate( 116 | (features_list[1], 117 | numpy.zeros((features_list[1].shape[0], 1), 118 | dtype=features_list[1].dtype)), 119 | axis=1) 120 | h5file = h5py.File(output_path, mode='w') 121 | data = (('train', 'features', features_list[0]), 122 | ('train', 'targets', targets_list[0]), 123 | ('test', 'features', features_list[1]), 124 | ('test', 'targets', targets_list[1])) 125 | 126 | fill_hdf5_file(h5file, data) 127 | h5file['features'].dims[0].label = 'batch' 128 | h5file['features'].dims[1].label = 'feature' 129 | h5file['targets'].dims[0].label = 'batch' 130 | h5file['targets'].dims[1].label = 'index' 131 | 132 | h5file.flush() 133 | h5file.close() 134 | 135 | return (output_path,) 136 | 137 | 138 | def fill_subparser(subparser): 139 | return convert_adult 140 | -------------------------------------------------------------------------------- /fuel/converters/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from contextlib import contextmanager 4 | from six import wraps 5 | 6 | import numpy 7 | from progressbar import (ProgressBar, Percentage, Bar, ETA) 8 | 9 | from fuel.datasets import H5PYDataset 10 | from ..exceptions import MissingInputFiles 11 | 12 | 13 | def check_exists(required_files): 14 | """Decorator that checks if required files exist before running. 15 | 16 | Parameters 17 | ---------- 18 | required_files : list of str 19 | A list of strings indicating the filenames of regular files 20 | (not directories) that should be found in the input directory 21 | (which is the first argument to the wrapped function). 22 | 23 | Returns 24 | ------- 25 | wrapper : function 26 | A function that takes a function and returns a wrapped function. 27 | The function returned by `wrapper` will include input file 28 | existence verification. 29 | 30 | Notes 31 | ----- 32 | Assumes that the directory in which to find the input files is 33 | provided as the first argument, with the argument name `directory`. 34 | 35 | """ 36 | def function_wrapper(f): 37 | @wraps(f) 38 | def wrapped(directory, *args, **kwargs): 39 | missing = [] 40 | for filename in required_files: 41 | if not os.path.isfile(os.path.join(directory, filename)): 42 | missing.append(filename) 43 | if len(missing) > 0: 44 | raise MissingInputFiles('Required files missing', missing) 45 | return f(directory, *args, **kwargs) 46 | return wrapped 47 | return function_wrapper 48 | 49 | 50 | def fill_hdf5_file(h5file, data): 51 | """Fills an HDF5 file in a H5PYDataset-compatible manner. 52 | 53 | Parameters 54 | ---------- 55 | h5file : :class:`h5py.File` 56 | File handle for an HDF5 file. 57 | data : tuple of tuple 58 | One element per split/source pair. Each element consists of a 59 | tuple of (split_name, source_name, data_array, comment), where 60 | 61 | * 'split_name' is a string identifier for the split name 62 | * 'source_name' is a string identifier for the source name 63 | * 'data_array' is a :class:`numpy.ndarray` containing the data 64 | for this split/source pair 65 | * 'comment' is a comment string for the split/source pair 66 | 67 | The 'comment' element can optionally be omitted. 68 | 69 | """ 70 | # Check that all sources for a split have the same length 71 | split_names = set(split_tuple[0] for split_tuple in data) 72 | for name in split_names: 73 | lengths = [len(split_tuple[2]) for split_tuple in data 74 | if split_tuple[0] == name] 75 | if not all(le == lengths[0] for le in lengths): 76 | raise ValueError("split '{}' has sources that ".format(name) + 77 | "vary in length") 78 | 79 | # Initialize split dictionary 80 | split_dict = dict([(split_name, {}) for split_name in split_names]) 81 | 82 | # Compute total source lengths and check that splits have the same dtype 83 | # across a source 84 | source_names = set(split_tuple[1] for split_tuple in data) 85 | for name in source_names: 86 | splits = [s for s in data if s[1] == name] 87 | indices = numpy.cumsum([0] + [len(s[2]) for s in splits]) 88 | if not all(s[2].dtype == splits[0][2].dtype for s in splits): 89 | raise ValueError("source '{}' has splits that ".format(name) + 90 | "vary in dtype") 91 | if not all(s[2].shape[1:] == splits[0][2].shape[1:] for s in splits): 92 | raise ValueError("source '{}' has splits that ".format(name) + 93 | "vary in shapes") 94 | dataset = h5file.create_dataset( 95 | name, (sum(len(s[2]) for s in splits),) + splits[0][2].shape[1:], 96 | dtype=splits[0][2].dtype) 97 | dataset[...] = numpy.concatenate([s[2] for s in splits], axis=0) 98 | for i, j, s in zip(indices[:-1], indices[1:], splits): 99 | if len(s) == 4: 100 | split_dict[s[0]][name] = (i, j, None, s[3]) 101 | else: 102 | split_dict[s[0]][name] = (i, j) 103 | h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict) 104 | 105 | 106 | @contextmanager 107 | def progress_bar(name, maxval, prefix='Converting'): 108 | """Manages a progress bar for a conversion. 109 | 110 | Parameters 111 | ---------- 112 | name : str 113 | Name of the file being converted. 114 | maxval : int 115 | Total number of steps for the conversion. 116 | 117 | """ 118 | widgets = ['{} {}: '.format(prefix, name), Percentage(), ' ', 119 | Bar(marker='=', left='[', right=']'), ' ', ETA()] 120 | bar = ProgressBar(widgets=widgets, max_value=maxval, fd=sys.stdout).start() 121 | try: 122 | yield bar 123 | finally: 124 | bar.update(maxval) 125 | bar.finish() 126 | -------------------------------------------------------------------------------- /fuel/converters/binarized_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy 5 | 6 | from fuel.converters.base import fill_hdf5_file, check_exists 7 | 8 | 9 | TRAIN_FILE = 'binarized_mnist_train.amat' 10 | VALID_FILE = 'binarized_mnist_valid.amat' 11 | TEST_FILE = 'binarized_mnist_test.amat' 12 | 13 | ALL_FILES = [TRAIN_FILE, VALID_FILE, TEST_FILE] 14 | 15 | 16 | @check_exists(required_files=ALL_FILES) 17 | def convert_binarized_mnist(directory, output_directory, 18 | output_filename='binarized_mnist.hdf5'): 19 | """Converts the binarized MNIST dataset to HDF5. 20 | 21 | Converts the binarized MNIST dataset used in R. Salakhutdinov's DBN 22 | paper [DBN] to an HDF5 dataset compatible with 23 | :class:`fuel.datasets.BinarizedMNIST`. The converted dataset is 24 | saved as 'binarized_mnist.hdf5'. 25 | 26 | This method assumes the existence of the files 27 | `binarized_mnist_{train,valid,test}.amat`, which are accessible 28 | through Hugo Larochelle's website [HUGO]. 29 | 30 | .. [DBN] Ruslan Salakhutdinov and Iain Murray, *On the Quantitative 31 | Analysis of Deep Belief Networks*, Proceedings of the 25th 32 | international conference on Machine learning, 2008, pp. 872-879. 33 | 34 | Parameters 35 | ---------- 36 | directory : str 37 | Directory in which input files reside. 38 | output_directory : str 39 | Directory in which to save the converted dataset. 40 | output_filename : str, optional 41 | Name of the saved dataset. Defaults to 'binarized_mnist.hdf5'. 42 | 43 | Returns 44 | ------- 45 | output_paths : tuple of str 46 | Single-element tuple containing the path to the converted dataset. 47 | 48 | """ 49 | output_path = os.path.join(output_directory, output_filename) 50 | h5file = h5py.File(output_path, mode='w') 51 | 52 | train_set = numpy.loadtxt( 53 | os.path.join(directory, TRAIN_FILE)).reshape( 54 | (-1, 1, 28, 28)).astype('uint8') 55 | valid_set = numpy.loadtxt( 56 | os.path.join(directory, VALID_FILE)).reshape( 57 | (-1, 1, 28, 28)).astype('uint8') 58 | test_set = numpy.loadtxt( 59 | os.path.join(directory, TEST_FILE)).reshape( 60 | (-1, 1, 28, 28)).astype('uint8') 61 | data = (('train', 'features', train_set), 62 | ('valid', 'features', valid_set), 63 | ('test', 'features', test_set)) 64 | fill_hdf5_file(h5file, data) 65 | for i, label in enumerate(('batch', 'channel', 'height', 'width')): 66 | h5file['features'].dims[i].label = label 67 | 68 | h5file.flush() 69 | h5file.close() 70 | 71 | return (output_path,) 72 | 73 | 74 | def fill_subparser(subparser): 75 | """Sets up a subparser to convert the binarized MNIST dataset files. 76 | 77 | Parameters 78 | ---------- 79 | subparser : :class:`argparse.ArgumentParser` 80 | Subparser handling the `binarized_mnist` command. 81 | 82 | """ 83 | return convert_binarized_mnist 84 | -------------------------------------------------------------------------------- /fuel/converters/caltech101_silhouettes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | 4 | from scipy.io import loadmat 5 | 6 | from fuel.converters.base import fill_hdf5_file, MissingInputFiles 7 | 8 | 9 | def convert_silhouettes(size, directory, output_directory, 10 | output_filename=None): 11 | """ Convert the CalTech 101 Silhouettes Datasets. 12 | 13 | Parameters 14 | ---------- 15 | size : {16, 28} 16 | Convert either the 16x16 or 28x28 sized version of the dataset. 17 | directory : str 18 | Directory in which the required input files reside. 19 | output_filename : str 20 | Where to save the converted dataset. 21 | 22 | """ 23 | if size not in (16, 28): 24 | raise ValueError('size must be 16 or 28') 25 | 26 | if output_filename is None: 27 | output_filename = 'caltech101_silhouettes{}.hdf5'.format(size) 28 | output_file = os.path.join(output_directory, output_filename) 29 | 30 | input_file = 'caltech101_silhouettes_{}_split1.mat'.format(size) 31 | input_file = os.path.join(directory, input_file) 32 | 33 | if not os.path.isfile(input_file): 34 | raise MissingInputFiles('Required files missing', [input_file]) 35 | 36 | with h5py.File(output_file, mode="w") as h5file: 37 | mat = loadmat(input_file) 38 | 39 | train_features = mat['train_data'].reshape([-1, 1, size, size]) 40 | train_targets = mat['train_labels'] 41 | valid_features = mat['val_data'].reshape([-1, 1, size, size]) 42 | valid_targets = mat['val_labels'] 43 | test_features = mat['test_data'].reshape([-1, 1, size, size]) 44 | test_targets = mat['test_labels'] 45 | 46 | data = ( 47 | ('train', 'features', train_features), 48 | ('train', 'targets', train_targets), 49 | ('valid', 'features', valid_features), 50 | ('valid', 'targets', valid_targets), 51 | ('test', 'features', test_features), 52 | ('test', 'targets', test_targets), 53 | ) 54 | fill_hdf5_file(h5file, data) 55 | 56 | for i, label in enumerate(('batch', 'channel', 'height', 'width')): 57 | h5file['features'].dims[i].label = label 58 | 59 | for i, label in enumerate(('batch', 'index')): 60 | h5file['targets'].dims[i].label = label 61 | return (output_file,) 62 | 63 | 64 | def fill_subparser(subparser): 65 | """Sets up a subparser to convert CalTech101 Silhouettes Database files. 66 | 67 | Parameters 68 | ---------- 69 | subparser : :class:`argparse.ArgumentParser` 70 | Subparser handling the `caltech101_silhouettes` command. 71 | 72 | """ 73 | subparser.add_argument( 74 | "size", type=int, choices=(16, 28), 75 | help="height/width of the datapoints") 76 | return convert_silhouettes 77 | -------------------------------------------------------------------------------- /fuel/converters/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | 4 | import h5py 5 | import numpy 6 | import six 7 | from six.moves import range, cPickle 8 | 9 | from fuel.converters.base import fill_hdf5_file, check_exists 10 | 11 | DISTRIBUTION_FILE = 'cifar-10-python.tar.gz' 12 | 13 | 14 | @check_exists(required_files=[DISTRIBUTION_FILE]) 15 | def convert_cifar10(directory, output_directory, 16 | output_filename='cifar10.hdf5'): 17 | """Converts the CIFAR-10 dataset to HDF5. 18 | 19 | Converts the CIFAR-10 dataset to an HDF5 dataset compatible with 20 | :class:`fuel.datasets.CIFAR10`. The converted dataset is saved as 21 | 'cifar10.hdf5'. 22 | 23 | It assumes the existence of the following file: 24 | 25 | * `cifar-10-python.tar.gz` 26 | 27 | Parameters 28 | ---------- 29 | directory : str 30 | Directory in which input files reside. 31 | output_directory : str 32 | Directory in which to save the converted dataset. 33 | output_filename : str, optional 34 | Name of the saved dataset. Defaults to 'cifar10.hdf5'. 35 | 36 | Returns 37 | ------- 38 | output_paths : tuple of str 39 | Single-element tuple containing the path to the converted dataset. 40 | 41 | """ 42 | output_path = os.path.join(output_directory, output_filename) 43 | h5file = h5py.File(output_path, mode='w') 44 | input_file = os.path.join(directory, DISTRIBUTION_FILE) 45 | tar_file = tarfile.open(input_file, 'r:gz') 46 | 47 | train_batches = [] 48 | for batch in range(1, 6): 49 | file = tar_file.extractfile( 50 | 'cifar-10-batches-py/data_batch_%d' % batch) 51 | try: 52 | if six.PY3: 53 | array = cPickle.load(file, encoding='latin1') 54 | else: 55 | array = cPickle.load(file) 56 | train_batches.append(array) 57 | finally: 58 | file.close() 59 | 60 | train_features = numpy.concatenate( 61 | [batch['data'].reshape(batch['data'].shape[0], 3, 32, 32) 62 | for batch in train_batches]) 63 | train_labels = numpy.concatenate( 64 | [numpy.array(batch['labels'], dtype=numpy.uint8) 65 | for batch in train_batches]) 66 | train_labels = numpy.expand_dims(train_labels, 1) 67 | 68 | file = tar_file.extractfile('cifar-10-batches-py/test_batch') 69 | try: 70 | if six.PY3: 71 | test = cPickle.load(file, encoding='latin1') 72 | else: 73 | test = cPickle.load(file) 74 | finally: 75 | file.close() 76 | 77 | test_features = test['data'].reshape(test['data'].shape[0], 78 | 3, 32, 32) 79 | test_labels = numpy.array(test['labels'], dtype=numpy.uint8) 80 | test_labels = numpy.expand_dims(test_labels, 1) 81 | 82 | data = (('train', 'features', train_features), 83 | ('train', 'targets', train_labels), 84 | ('test', 'features', test_features), 85 | ('test', 'targets', test_labels)) 86 | fill_hdf5_file(h5file, data) 87 | h5file['features'].dims[0].label = 'batch' 88 | h5file['features'].dims[1].label = 'channel' 89 | h5file['features'].dims[2].label = 'height' 90 | h5file['features'].dims[3].label = 'width' 91 | h5file['targets'].dims[0].label = 'batch' 92 | h5file['targets'].dims[1].label = 'index' 93 | 94 | h5file.flush() 95 | h5file.close() 96 | 97 | return (output_path,) 98 | 99 | 100 | def fill_subparser(subparser): 101 | """Sets up a subparser to convert the CIFAR10 dataset files. 102 | 103 | Parameters 104 | ---------- 105 | subparser : :class:`argparse.ArgumentParser` 106 | Subparser handling the `cifar10` command. 107 | 108 | """ 109 | return convert_cifar10 110 | -------------------------------------------------------------------------------- /fuel/converters/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | 4 | import h5py 5 | import numpy 6 | import six 7 | from six.moves import cPickle 8 | 9 | from fuel.converters.base import fill_hdf5_file, check_exists 10 | 11 | DISTRIBUTION_FILE = 'cifar-100-python.tar.gz' 12 | 13 | 14 | @check_exists(required_files=[DISTRIBUTION_FILE]) 15 | def convert_cifar100(directory, output_directory, 16 | output_filename='cifar100.hdf5'): 17 | """Converts the CIFAR-100 dataset to HDF5. 18 | 19 | Converts the CIFAR-100 dataset to an HDF5 dataset compatible with 20 | :class:`fuel.datasets.CIFAR100`. The converted dataset is saved as 21 | 'cifar100.hdf5'. 22 | 23 | This method assumes the existence of the following file: 24 | `cifar-100-python.tar.gz` 25 | 26 | Parameters 27 | ---------- 28 | directory : str 29 | Directory in which the required input files reside. 30 | output_directory : str 31 | Directory in which to save the converted dataset. 32 | output_filename : str, optional 33 | Name of the saved dataset. Defaults to 'cifar100.hdf5'. 34 | 35 | Returns 36 | ------- 37 | output_paths : tuple of str 38 | Single-element tuple containing the path to the converted dataset. 39 | 40 | """ 41 | output_path = os.path.join(output_directory, output_filename) 42 | h5file = h5py.File(output_path, mode="w") 43 | input_file = os.path.join(directory, 'cifar-100-python.tar.gz') 44 | tar_file = tarfile.open(input_file, 'r:gz') 45 | 46 | file = tar_file.extractfile('cifar-100-python/train') 47 | try: 48 | if six.PY3: 49 | train = cPickle.load(file, encoding='latin1') 50 | else: 51 | train = cPickle.load(file) 52 | finally: 53 | file.close() 54 | 55 | train_features = train['data'].reshape(train['data'].shape[0], 56 | 3, 32, 32) 57 | train_coarse_labels = numpy.array(train['coarse_labels'], 58 | dtype=numpy.uint8) 59 | train_fine_labels = numpy.array(train['fine_labels'], 60 | dtype=numpy.uint8) 61 | 62 | file = tar_file.extractfile('cifar-100-python/test') 63 | try: 64 | if six.PY3: 65 | test = cPickle.load(file, encoding='latin1') 66 | else: 67 | test = cPickle.load(file) 68 | finally: 69 | file.close() 70 | 71 | test_features = test['data'].reshape(test['data'].shape[0], 72 | 3, 32, 32) 73 | test_coarse_labels = numpy.array(test['coarse_labels'], dtype=numpy.uint8) 74 | test_fine_labels = numpy.array(test['fine_labels'], dtype=numpy.uint8) 75 | 76 | data = (('train', 'features', train_features), 77 | ('train', 'coarse_labels', train_coarse_labels.reshape((-1, 1))), 78 | ('train', 'fine_labels', train_fine_labels.reshape((-1, 1))), 79 | ('test', 'features', test_features), 80 | ('test', 'coarse_labels', test_coarse_labels.reshape((-1, 1))), 81 | ('test', 'fine_labels', test_fine_labels.reshape((-1, 1)))) 82 | fill_hdf5_file(h5file, data) 83 | h5file['features'].dims[0].label = 'batch' 84 | h5file['features'].dims[1].label = 'channel' 85 | h5file['features'].dims[2].label = 'height' 86 | h5file['features'].dims[3].label = 'width' 87 | h5file['coarse_labels'].dims[0].label = 'batch' 88 | h5file['coarse_labels'].dims[1].label = 'index' 89 | h5file['fine_labels'].dims[0].label = 'batch' 90 | h5file['fine_labels'].dims[1].label = 'index' 91 | 92 | h5file.flush() 93 | h5file.close() 94 | 95 | return (output_path,) 96 | 97 | 98 | def fill_subparser(subparser): 99 | """Sets up a subparser to convert the CIFAR100 dataset files. 100 | 101 | Parameters 102 | ---------- 103 | subparser : :class:`argparse.ArgumentParser` 104 | Subparser handling the `cifar100` command. 105 | 106 | """ 107 | return convert_cifar100 108 | -------------------------------------------------------------------------------- /fuel/converters/dogs_vs_cats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | 4 | import h5py 5 | import numpy 6 | from PIL import Image 7 | 8 | from fuel.converters.base import check_exists, progress_bar 9 | from fuel.datasets.hdf5 import H5PYDataset 10 | 11 | TRAIN = 'dogs_vs_cats.train.zip' 12 | TEST = 'dogs_vs_cats.test1.zip' 13 | 14 | 15 | @check_exists(required_files=[TRAIN, TEST]) 16 | def convert_dogs_vs_cats(directory, output_directory, 17 | output_filename='dogs_vs_cats.hdf5'): 18 | """Converts the Dogs vs. Cats dataset to HDF5. 19 | 20 | Converts the Dogs vs. Cats dataset to an HDF5 dataset compatible with 21 | :class:`fuel.datasets.dogs_vs_cats`. The converted dataset is saved as 22 | 'dogs_vs_cats.hdf5'. 23 | 24 | It assumes the existence of the following files: 25 | 26 | * `dogs_vs_cats.train.zip` 27 | * `dogs_vs_cats.test1.zip` 28 | 29 | Parameters 30 | ---------- 31 | directory : str 32 | Directory in which input files reside. 33 | output_directory : str 34 | Directory in which to save the converted dataset. 35 | output_filename : str, optional 36 | Name of the saved dataset. Defaults to 'dogs_vs_cats.hdf5'. 37 | 38 | Returns 39 | ------- 40 | output_paths : tuple of str 41 | Single-element tuple containing the path to the converted dataset. 42 | 43 | """ 44 | # Prepare output file 45 | output_path = os.path.join(output_directory, output_filename) 46 | h5file = h5py.File(output_path, mode='w') 47 | dtype = h5py.special_dtype(vlen=numpy.dtype('uint8')) 48 | hdf_features = h5file.create_dataset('image_features', (37500,), 49 | dtype=dtype) 50 | hdf_shapes = h5file.create_dataset('image_features_shapes', (37500, 3), 51 | dtype='int32') 52 | hdf_labels = h5file.create_dataset('targets', (25000, 1), dtype='uint8') 53 | 54 | # Attach shape annotations and scales 55 | hdf_features.dims.create_scale(hdf_shapes, 'shapes') 56 | hdf_features.dims[0].attach_scale(hdf_shapes) 57 | 58 | hdf_shapes_labels = h5file.create_dataset('image_features_shapes_labels', 59 | (3,), dtype='S7') 60 | hdf_shapes_labels[...] = ['channel'.encode('utf8'), 61 | 'height'.encode('utf8'), 62 | 'width'.encode('utf8')] 63 | hdf_features.dims.create_scale(hdf_shapes_labels, 'shape_labels') 64 | hdf_features.dims[0].attach_scale(hdf_shapes_labels) 65 | 66 | # Add axis annotations 67 | hdf_features.dims[0].label = 'batch' 68 | hdf_labels.dims[0].label = 'batch' 69 | hdf_labels.dims[1].label = 'index' 70 | 71 | # Convert 72 | i = 0 73 | for split, split_size in zip([TRAIN, TEST], [25000, 12500]): 74 | # Open the ZIP file 75 | filename = os.path.join(directory, split) 76 | zip_file = zipfile.ZipFile(filename, 'r') 77 | image_names = zip_file.namelist()[1:] # Discard the directory name 78 | 79 | # Shuffle the examples 80 | if split == TRAIN: 81 | rng = numpy.random.RandomState(123522) 82 | rng.shuffle(image_names) 83 | else: 84 | image_names.sort(key=lambda fn: int(os.path.splitext(fn[6:])[0])) 85 | 86 | # Convert from JPEG to NumPy arrays 87 | with progress_bar(filename, split_size) as bar: 88 | for image_name in image_names: 89 | # Save image 90 | image = numpy.array(Image.open(zip_file.open(image_name))) 91 | image = image.transpose(2, 0, 1) 92 | hdf_features[i] = image.flatten() 93 | hdf_shapes[i] = image.shape 94 | 95 | # Cats are 0, Dogs are 1 96 | if split == TRAIN: 97 | hdf_labels[i] = 0 if 'cat' in image_name else 1 98 | 99 | # Update progress 100 | i += 1 101 | bar.update(i if split == TRAIN else i - 25000) 102 | 103 | # Add the labels 104 | split_dict = {} 105 | sources = ['image_features', 'targets'] 106 | split_dict['train'] = dict(zip(sources, [(0, 25000)] * 2)) 107 | split_dict['test'] = {sources[0]: (25000, 37500)} 108 | h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict) 109 | 110 | h5file.flush() 111 | h5file.close() 112 | 113 | return (output_path,) 114 | 115 | 116 | def fill_subparser(subparser): 117 | """Sets up a subparser to convert the dogs_vs_cats dataset files. 118 | 119 | Parameters 120 | ---------- 121 | subparser : :class:`argparse.ArgumentParser` 122 | Subparser handling the `dogs_vs_cats` command. 123 | 124 | """ 125 | return convert_dogs_vs_cats 126 | -------------------------------------------------------------------------------- /fuel/converters/iris.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy 5 | 6 | from fuel.converters.base import fill_hdf5_file 7 | 8 | 9 | def convert_iris(directory, output_directory, output_filename='iris.hdf5'): 10 | """Convert the Iris dataset to HDF5. 11 | 12 | Converts the Iris dataset to an HDF5 dataset compatible with 13 | :class:`fuel.datasets.Iris`. The converted dataset is 14 | saved as 'iris.hdf5'. 15 | This method assumes the existence of the file `iris.data`. 16 | 17 | Parameters 18 | ---------- 19 | directory : str 20 | Directory in which input files reside. 21 | output_directory : str 22 | Directory in which to save the converted dataset. 23 | output_filename : str, optional 24 | Name of the saved dataset. Defaults to `None`, in which case a name 25 | based on `dtype` will be used. 26 | 27 | Returns 28 | ------- 29 | output_paths : tuple of str 30 | Single-element tuple containing the path to the converted dataset. 31 | 32 | """ 33 | classes = {b'Iris-setosa': 0, b'Iris-versicolor': 1, b'Iris-virginica': 2} 34 | data = numpy.loadtxt( 35 | os.path.join(directory, 'iris.data'), 36 | converters={4: lambda x: classes[x]}, 37 | delimiter=',') 38 | features = data[:, :-1].astype('float32') 39 | targets = data[:, -1].astype('uint8').reshape((-1, 1)) 40 | data = (('all', 'features', features), 41 | ('all', 'targets', targets)) 42 | 43 | output_path = os.path.join(output_directory, output_filename) 44 | h5file = h5py.File(output_path, mode='w') 45 | fill_hdf5_file(h5file, data) 46 | h5file['features'].dims[0].label = 'batch' 47 | h5file['features'].dims[1].label = 'feature' 48 | h5file['targets'].dims[0].label = 'batch' 49 | h5file['targets'].dims[1].label = 'index' 50 | 51 | h5file.flush() 52 | h5file.close() 53 | 54 | return (output_path,) 55 | 56 | 57 | def fill_subparser(subparser): 58 | """Sets up a subparser to convert the Iris dataset file. 59 | 60 | Parameters 61 | ---------- 62 | subparser : :class:`argparse.ArgumentParser` 63 | Subparser handling the `iris` command. 64 | 65 | """ 66 | return convert_iris 67 | -------------------------------------------------------------------------------- /fuel/converters/mnist.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | import struct 4 | 5 | import h5py 6 | import numpy 7 | 8 | from fuel.converters.base import fill_hdf5_file, check_exists 9 | 10 | MNIST_IMAGE_MAGIC = 2051 11 | MNIST_LABEL_MAGIC = 2049 12 | 13 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 14 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 15 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 16 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 17 | 18 | ALL_FILES = [TRAIN_IMAGES, TRAIN_LABELS, TEST_IMAGES, TEST_LABELS] 19 | 20 | 21 | @check_exists(required_files=ALL_FILES) 22 | def convert_mnist(directory, output_directory, output_filename=None, 23 | dtype=None): 24 | """Converts the MNIST dataset to HDF5. 25 | 26 | Converts the MNIST dataset to an HDF5 dataset compatible with 27 | :class:`fuel.datasets.MNIST`. The converted dataset is 28 | saved as 'mnist.hdf5'. 29 | 30 | This method assumes the existence of the following files: 31 | `train-images-idx3-ubyte.gz`, `train-labels-idx1-ubyte.gz` 32 | `t10k-images-idx3-ubyte.gz`, `t10k-labels-idx1-ubyte.gz` 33 | 34 | It assumes the existence of the following files: 35 | 36 | * `train-images-idx3-ubyte.gz` 37 | * `train-labels-idx1-ubyte.gz` 38 | * `t10k-images-idx3-ubyte.gz` 39 | * `t10k-labels-idx1-ubyte.gz` 40 | 41 | Parameters 42 | ---------- 43 | directory : str 44 | Directory in which input files reside. 45 | output_directory : str 46 | Directory in which to save the converted dataset. 47 | output_filename : str, optional 48 | Name of the saved dataset. Defaults to `None`, in which case a name 49 | based on `dtype` will be used. 50 | dtype : str, optional 51 | Either 'float32', 'float64', or 'bool'. Defaults to `None`, 52 | in which case images will be returned in their original 53 | unsigned byte format. 54 | 55 | Returns 56 | ------- 57 | output_paths : tuple of str 58 | Single-element tuple containing the path to the converted dataset. 59 | 60 | """ 61 | if not output_filename: 62 | if dtype: 63 | output_filename = 'mnist_{}.hdf5'.format(dtype) 64 | else: 65 | output_filename = 'mnist.hdf5' 66 | output_path = os.path.join(output_directory, output_filename) 67 | h5file = h5py.File(output_path, mode='w') 68 | 69 | train_feat_path = os.path.join(directory, TRAIN_IMAGES) 70 | train_features = read_mnist_images(train_feat_path, dtype) 71 | train_lab_path = os.path.join(directory, TRAIN_LABELS) 72 | train_labels = read_mnist_labels(train_lab_path) 73 | test_feat_path = os.path.join(directory, TEST_IMAGES) 74 | test_features = read_mnist_images(test_feat_path, dtype) 75 | test_lab_path = os.path.join(directory, TEST_LABELS) 76 | test_labels = read_mnist_labels(test_lab_path) 77 | data = (('train', 'features', train_features), 78 | ('train', 'targets', train_labels), 79 | ('test', 'features', test_features), 80 | ('test', 'targets', test_labels)) 81 | fill_hdf5_file(h5file, data) 82 | h5file['features'].dims[0].label = 'batch' 83 | h5file['features'].dims[1].label = 'channel' 84 | h5file['features'].dims[2].label = 'height' 85 | h5file['features'].dims[3].label = 'width' 86 | h5file['targets'].dims[0].label = 'batch' 87 | h5file['targets'].dims[1].label = 'index' 88 | 89 | h5file.flush() 90 | h5file.close() 91 | 92 | return (output_path,) 93 | 94 | 95 | def fill_subparser(subparser): 96 | """Sets up a subparser to convert the MNIST dataset files. 97 | 98 | Parameters 99 | ---------- 100 | subparser : :class:`argparse.ArgumentParser` 101 | Subparser handling the `mnist` command. 102 | 103 | """ 104 | subparser.add_argument( 105 | "--dtype", help="dtype to save to; by default, images will be " + 106 | "returned in their original unsigned byte format", 107 | choices=('float32', 'float64', 'bool'), type=str, default=None) 108 | return convert_mnist 109 | 110 | 111 | def read_mnist_images(filename, dtype=None): 112 | """Read MNIST images from the original ubyte file format. 113 | 114 | Parameters 115 | ---------- 116 | filename : str 117 | Filename/path from which to read images. 118 | 119 | dtype : 'float32', 'float64', or 'bool' 120 | If unspecified, images will be returned in their original 121 | unsigned byte format. 122 | 123 | Returns 124 | ------- 125 | images : :class:`~numpy.ndarray`, shape (n_images, 1, n_rows, n_cols) 126 | An image array, with individual examples indexed along the 127 | first axis and the image dimensions along the second and 128 | third axis. 129 | 130 | Notes 131 | ----- 132 | If the dtype provided was Boolean, the resulting array will 133 | be Boolean with `True` if the corresponding pixel had a value 134 | greater than or equal to 128, `False` otherwise. 135 | 136 | If the dtype provided was a float dtype, the values will be mapped to 137 | the unit interval [0, 1], with pixel values that were 255 in the 138 | original unsigned byte representation equal to 1.0. 139 | 140 | """ 141 | with gzip.open(filename, 'rb') as f: 142 | magic, number, rows, cols = struct.unpack('>iiii', f.read(16)) 143 | if magic != MNIST_IMAGE_MAGIC: 144 | raise ValueError("Wrong magic number reading MNIST image file") 145 | array = numpy.frombuffer(f.read(), dtype='uint8') 146 | array = array.reshape((number, 1, rows, cols)) 147 | if dtype: 148 | dtype = numpy.dtype(dtype) 149 | 150 | if dtype.kind == 'b': 151 | # If the user wants Booleans, threshold at half the range. 152 | array = array >= 128 153 | elif dtype.kind == 'f': 154 | # Otherwise, just convert. 155 | array = array.astype(dtype) 156 | array /= 255. 157 | else: 158 | raise ValueError("Unknown dtype to convert MNIST to") 159 | return array 160 | 161 | 162 | def read_mnist_labels(filename): 163 | """Read MNIST labels from the original ubyte file format. 164 | 165 | Parameters 166 | ---------- 167 | filename : str 168 | Filename/path from which to read labels. 169 | 170 | Returns 171 | ------- 172 | labels : :class:`~numpy.ndarray`, shape (nlabels, 1) 173 | A one-dimensional unsigned byte array containing the 174 | labels as integers. 175 | 176 | """ 177 | with gzip.open(filename, 'rb') as f: 178 | magic, _ = struct.unpack('>ii', f.read(8)) 179 | if magic != MNIST_LABEL_MAGIC: 180 | raise ValueError("Wrong magic number reading MNIST label file") 181 | array = numpy.frombuffer(f.read(), dtype='uint8') 182 | array = array.reshape(array.size, 1) 183 | return array 184 | -------------------------------------------------------------------------------- /fuel/converters/youtube_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | 5 | import h5py 6 | import scipy.io.wavfile 7 | 8 | from fuel.converters.base import fill_hdf5_file 9 | 10 | 11 | def convert_youtube_audio(directory, output_directory, youtube_id, channels, 12 | sample, output_filename=None): 13 | """Converts downloaded YouTube audio to HDF5 format. 14 | 15 | Requires `ffmpeg` to be installed and available on the command line 16 | (i.e. available on your `PATH`). 17 | 18 | Parameters 19 | ---------- 20 | directory : str 21 | Directory in which input files reside. 22 | output_directory : str 23 | Directory in which to save the converted dataset. 24 | youtube_id : str 25 | 11-character video ID (taken from YouTube URL) 26 | channels : int 27 | The number of audio channels to use in the PCM Wave file. 28 | sample : int 29 | The sampling rate to use in Hz, e.g. 44100 or 16000. 30 | output_filename : str, optional 31 | Name of the saved dataset. If `None` (the default), 32 | `youtube_id.hdf5` is used. 33 | 34 | """ 35 | input_file = os.path.join(directory, '{}.m4a'.format(youtube_id)) 36 | wav_filename = '{}.wav'.format(youtube_id) 37 | wav_file = os.path.join(directory, wav_filename) 38 | ffmpeg_not_available = subprocess.call(['ffmpeg', '-version']) 39 | if ffmpeg_not_available: 40 | raise RuntimeError('conversion requires ffmpeg') 41 | subprocess.check_call(['ffmpeg', '-y', '-i', input_file, '-ac', 42 | str(channels), '-ar', str(sample), wav_file], 43 | stdout=sys.stdout) 44 | 45 | # Load WAV into array 46 | _, data = scipy.io.wavfile.read(wav_file) 47 | if data.ndim == 1: 48 | data = data[:, None] 49 | data = data[None, :] 50 | 51 | # Store in HDF5 52 | if output_filename is None: 53 | output_filename = '{}.hdf5'.format(youtube_id) 54 | output_file = os.path.join(output_directory, output_filename) 55 | 56 | with h5py.File(output_file, 'w') as h5file: 57 | fill_hdf5_file(h5file, (('train', 'features', data),)) 58 | h5file['features'].dims[0].label = 'batch' 59 | h5file['features'].dims[1].label = 'time' 60 | h5file['features'].dims[2].label = 'feature' 61 | 62 | return (output_file,) 63 | 64 | 65 | def fill_subparser(subparser): 66 | """Sets up a subparser to convert YouTube audio files. 67 | 68 | Adds the compulsory `--youtube-id` flag as well as the optional 69 | `sample` and `channels` flags. 70 | 71 | Parameters 72 | ---------- 73 | subparser : :class:`argparse.ArgumentParser` 74 | Subparser handling the `youtube_audio` command. 75 | 76 | """ 77 | subparser.add_argument( 78 | '--youtube-id', type=str, required=True, 79 | help=("The YouTube ID of the video from which to extract audio, " 80 | "usually an 11-character string.") 81 | ) 82 | subparser.add_argument( 83 | '--channels', type=int, default=1, 84 | help=("The number of audio channels to convert to. The default of 1" 85 | "means audio is converted to mono.") 86 | ) 87 | subparser.add_argument( 88 | '--sample', type=int, default=16000, 89 | help=("The sampling rate in Hz. The default of 16000 is " 90 | "significantly downsampled compared to normal WAVE files; " 91 | "pass 44100 for the usual sampling rate.") 92 | ) 93 | return convert_youtube_audio 94 | -------------------------------------------------------------------------------- /fuel/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from fuel.datasets.base import (Dataset, IterableDataset, 3 | IndexableDataset) 4 | 5 | from fuel.datasets.hdf5 import H5PYDataset 6 | from fuel.datasets.adult import Adult 7 | from fuel.datasets.binarized_mnist import BinarizedMNIST 8 | from fuel.datasets.celeba import CelebA 9 | from fuel.datasets.cifar10 import CIFAR10 10 | from fuel.datasets.cifar100 import CIFAR100 11 | from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes 12 | from fuel.datasets.dogs_vs_cats import DogsVsCats 13 | from fuel.datasets.iris import Iris 14 | from fuel.datasets.mnist import MNIST 15 | from fuel.datasets.svhn import SVHN 16 | from fuel.datasets.text import TextFile 17 | from fuel.datasets.billion import OneBillionWord 18 | -------------------------------------------------------------------------------- /fuel/datasets/adult.py: -------------------------------------------------------------------------------- 1 | from fuel.datasets import H5PYDataset 2 | from fuel.utils import find_in_data_path 3 | 4 | 5 | class Adult(H5PYDataset): 6 | filename = 'adult.hdf5' 7 | 8 | def __init__(self, which_sets, **kwargs): 9 | kwargs.setdefault('load_in_memory', True) 10 | super(Adult, self).__init__( 11 | file_or_path=find_in_data_path(self.filename), 12 | which_sets=which_sets, **kwargs 13 | ) 14 | -------------------------------------------------------------------------------- /fuel/datasets/billion.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from fuel.datasets import TextFile 4 | from fuel.utils import find_in_data_path 5 | 6 | 7 | class OneBillionWord(TextFile): 8 | """Google's One Billion Word benchmark. 9 | 10 | This monolingual corpus contains 829,250,940 tokens (including sentence 11 | boundary markers). The data is split into 100 partitions, one of which 12 | is the held-out set. This held-out set is further divided into 50 13 | partitions. More information about the dataset can be found in 14 | [CMSG14]. 15 | 16 | .. [CSMG14] Ciprian Chelba, Tomas Mikolov, Mike Schuster, Qi Ge, and 17 | Thorsten Brants, *One Billion Word Benchmark for Measuring Progress 18 | in Statistical Language Modeling*, `arXiv:1312.3005 [cs.CL] 19 | `. 20 | 21 | Parameters 22 | ---------- 23 | which_set : 'training' or 'heldout' 24 | Which dataset to load. 25 | which_partitions : list of ints 26 | For the training set, valid values must lie in [1, 99]. For the 27 | heldout set they must be in [0, 49]. 28 | vocabulary : dict 29 | A dictionary mapping tokens to integers. This dictionary is 30 | expected to contain the tokens ````, ```` and ````, 31 | representing "start of sentence", "end of sentence", and 32 | "out-of-vocabulary" (OoV). The latter will be used whenever a token 33 | cannot be found in the vocabulary. 34 | preprocess : function, optional 35 | A function that takes a string (a sentence including new line) as 36 | an input and returns a modified string. A useful function to pass 37 | could be ``str.lower``. 38 | 39 | See :class:`TextFile` for remaining keyword arguments. 40 | 41 | """ 42 | def __init__(self, which_set, which_partitions, dictionary, **kwargs): 43 | if which_set not in ('training', 'heldout'): 44 | raise ValueError 45 | if which_set == 'training': 46 | if not all(partition in range(1, 100) 47 | for partition in which_partitions): 48 | raise ValueError 49 | files = [find_in_data_path(os.path.join( 50 | '1-billion-word', 'training-monolingual.tokenized.shuffled', 51 | 'news.en-{:05d}-of-00100'.format(partition))) 52 | for partition in which_partitions] 53 | else: 54 | if not all(partition in range(50) 55 | for partition in which_partitions): 56 | raise ValueError 57 | files = [find_in_data_path(os.path.join( 58 | '1-billion-word', 'heldout-monolingual.tokenized.shuffled', 59 | 'news.en.heldout-{:05d}-of-00050'.format(partition))) 60 | for partition in which_partitions] 61 | super(OneBillionWord, self).__init__(files, dictionary, **kwargs) 62 | -------------------------------------------------------------------------------- /fuel/datasets/binarized_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from fuel.datasets import H5PYDataset 3 | from fuel.utils import find_in_data_path 4 | 5 | 6 | class BinarizedMNIST(H5PYDataset): 7 | u"""Binarized, unlabeled MNIST dataset. 8 | 9 | MNIST (Mixed National Institute of Standards and Technology) [LBBH] is 10 | a database of handwritten digits. It is one of the most famous datasets 11 | in machine learning and consists of 60,000 training images and 10,000 12 | testing images. The images are grayscale and 28 x 28 pixels large. 13 | 14 | This particular version of the dataset is the one used in R. 15 | Salakhutdinov's DBN paper [DBN] as well as the VAE and NADE papers, and 16 | is accessible through Hugo Larochelle's public website [HUGO]. 17 | 18 | The training set has further been split into a training and a 19 | validation set. All examples were binarized by sampling from a binomial 20 | distribution defined by the pixel values. 21 | 22 | .. [LBBH] Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner, 23 | *Gradient-based learning applied to document recognition*, 24 | Proceedings of the IEEE, November 1998, 86(11):2278-2324. 25 | 26 | Parameters 27 | ---------- 28 | which_sets : tuple of str 29 | Which split to load. Valid values are 'train', 'valid' and 'test', 30 | corresponding to the training set (50,000 examples), the validation 31 | set (10,000 samples) and the test set (10,000 examples). 32 | 33 | """ 34 | filename = 'binarized_mnist.hdf5' 35 | 36 | def __init__(self, which_sets, load_in_memory=True, **kwargs): 37 | super(BinarizedMNIST, self).__init__( 38 | file_or_path=find_in_data_path(self.filename), 39 | which_sets=which_sets, 40 | load_in_memory=load_in_memory, **kwargs) 41 | -------------------------------------------------------------------------------- /fuel/datasets/caltech101_silhouettes.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from fuel.utils import find_in_data_path 3 | from fuel.datasets import H5PYDataset 4 | 5 | 6 | class CalTech101Silhouettes(H5PYDataset): 7 | u"""CalTech 101 Silhouettes dataset. 8 | 9 | This dataset provides the `split1` train/validation/test split of the 10 | CalTech101 Silhouette dataset prepared by Benjamin M. Marlin [MARLIN]. 11 | 12 | This class provides both the 16x16 and the 28x28 pixel sized version. 13 | The 16x16 version contains 4082 examples in the training set, 2257 14 | examples in the validation set and 2302 examples in the test set. The 15 | 28x28 version contains 4100, 2264 and 2307 examples in the train, valid 16 | and test set. 17 | 18 | Parameters 19 | ---------- 20 | which_sets : tuple of str 21 | Which split to load. Valid values are 'train', 'valid' and 'test'. 22 | size : {16, 28} 23 | Either 16 or 28 to select the 16x16 or 28x28 pixels version 24 | of the dataset (default: 28). 25 | 26 | """ 27 | def __init__(self, which_sets, size=28, load_in_memory=True, **kwargs): 28 | if size not in (16, 28): 29 | raise ValueError('size must be 16 or 28') 30 | 31 | self.filename = 'caltech101_silhouettes{}.hdf5'.format(size) 32 | super(CalTech101Silhouettes, self).__init__( 33 | self.data_path, which_sets=which_sets, 34 | load_in_memory=load_in_memory, **kwargs) 35 | 36 | @property 37 | def data_path(self): 38 | return find_in_data_path(self.filename) 39 | -------------------------------------------------------------------------------- /fuel/datasets/celeba.py: -------------------------------------------------------------------------------- 1 | from fuel.datasets import H5PYDataset 2 | from fuel.transformers.defaults import uint8_pixels_to_floatX 3 | from fuel.utils import find_in_data_path 4 | 5 | 6 | class CelebA(H5PYDataset): 7 | """The CelebFaces Attributes Dataset (CelebA) dataset. 8 | 9 | CelebA is a large-scale face 10 | attributes dataset with more than 200K celebrity images, each 11 | with 40 attribute annotations. The images in this dataset cover 12 | large pose variations and background clutter. CelebA has large 13 | diversities, large quantities, and rich annotations, including: 14 | 15 | * 10,177 number of identities 16 | * 202,599 number of face images 17 | * 5 landmark locations per image 18 | * 40 binary attributes annotations per image. 19 | 20 | The dataset can be employed as the training and test sets for 21 | the following computer vision tasks: 22 | 23 | * face attribute recognition 24 | * face detection 25 | * landmark (or facial part) localization 26 | 27 | Parameters 28 | ---------- 29 | which_format : {'aligned_cropped, '64'} 30 | Either the aligned and cropped version of CelebA, or 31 | a 64x64 version of it. 32 | which_sets : tuple of str 33 | Which split to load. Valid values are 'train', 'valid' and 34 | 'test' corresponding to the training set (162,770 examples), the 35 | validation set (19,867 examples) and the test set (19,962 36 | examples). 37 | 38 | """ 39 | _filename = 'celeba_{}.hdf5' 40 | default_transformers = uint8_pixels_to_floatX(('features',)) 41 | 42 | def __init__(self, which_format, which_sets, **kwargs): 43 | self.which_format = which_format 44 | super(CelebA, self).__init__( 45 | file_or_path=find_in_data_path(self.filename), 46 | which_sets=which_sets, **kwargs) 47 | 48 | @property 49 | def filename(self): 50 | return self._filename.format(self.which_format) 51 | -------------------------------------------------------------------------------- /fuel/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from fuel.datasets import H5PYDataset 2 | from fuel.transformers.defaults import uint8_pixels_to_floatX 3 | from fuel.utils import find_in_data_path 4 | 5 | 6 | class CIFAR10(H5PYDataset): 7 | """The CIFAR10 dataset of natural images. 8 | 9 | This dataset is a labeled subset of the ``80 million tiny images`` 10 | dataset [TINY]. It consists of 60,000 32 x 32 colour images in 10 11 | classes, with 6,000 images per class. There are 50,000 training 12 | images and 10,000 test images [CIFAR10]. 13 | 14 | .. [CIFAR10] Alex Krizhevsky, *Learning Multiple Layers of Features 15 | from Tiny Images*, technical report, 2009. 16 | 17 | Parameters 18 | ---------- 19 | which_sets : tuple of str 20 | Which split to load. Valid values are 'train' and 'test', 21 | corresponding to the training set (50,000 examples) and the test 22 | set (10,000 examples). Note that CIFAR10 does not have a 23 | validation set; usually you will create your own 24 | training/validation split using the `subset` argument. 25 | 26 | """ 27 | filename = 'cifar10.hdf5' 28 | default_transformers = uint8_pixels_to_floatX(('features',)) 29 | 30 | def __init__(self, which_sets, **kwargs): 31 | kwargs.setdefault('load_in_memory', True) 32 | super(CIFAR10, self).__init__( 33 | file_or_path=find_in_data_path(self.filename), 34 | which_sets=which_sets, **kwargs) 35 | -------------------------------------------------------------------------------- /fuel/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | from fuel.datasets import H5PYDataset 2 | from fuel.transformers.defaults import uint8_pixels_to_floatX 3 | from fuel.utils import find_in_data_path 4 | 5 | 6 | class CIFAR100(H5PYDataset): 7 | """The CIFAR100 dataset of natural images. 8 | 9 | This dataset is a labeled subset of the ``80 million tiny images`` 10 | dataset [TINY]. It consists of 60,000 32 x 32 colour images labelled 11 | into 100 fine-grained classes and 20 super-classes. There are 12 | 600 images per fine-grained class. There are 50,000 training 13 | images and 10,000 test images [CIFAR100]. 14 | 15 | The dataset contains three sources: 16 | - features: the images themselves, 17 | - coarse_labels: the superclasses 1-20, 18 | - fine_labels: the fine-grained classes 1-100. 19 | 20 | .. [TINY] Antonio Torralba, Rob Fergus and William T. Freeman, 21 | *80 million tiny images: a large dataset for non-parametric 22 | object and scene recognition*, Pattern Analysis and Machine 23 | Intelligence, IEEE Transactions on 30.11 (2008): 1958-1970. 24 | 25 | .. [CIFAR100] Alex Krizhevsky, *Learning Multiple Layers of Features 26 | from Tiny Images*, technical report, 2009. 27 | 28 | Parameters 29 | ---------- 30 | which_sets : tuple of str 31 | Which split to load. Valid values are 'train' and 'test', 32 | corresponding to the training set (50,000 examples) and the test 33 | set (10,000 examples). Note that CIFAR100 does not have a 34 | validation set; usually you will create your own 35 | training/validation split using the `subset` argument. 36 | 37 | """ 38 | filename = 'cifar100.hdf5' 39 | default_transformers = uint8_pixels_to_floatX(('features',)) 40 | 41 | def __init__(self, which_sets, **kwargs): 42 | kwargs.setdefault('load_in_memory', True) 43 | super(CIFAR100, self).__init__( 44 | file_or_path=find_in_data_path(self.filename), 45 | which_sets=which_sets, **kwargs) 46 | -------------------------------------------------------------------------------- /fuel/datasets/dogs_vs_cats.py: -------------------------------------------------------------------------------- 1 | from fuel.datasets import H5PYDataset 2 | from fuel.transformers import ScaleAndShift 3 | from fuel.utils import find_in_data_path 4 | 5 | 6 | class DogsVsCats(H5PYDataset): 7 | """The Kaggle Dogs vs. Cats dataset of cats and dogs images. 8 | 9 | Parameters 10 | ---------- 11 | which_sets : tuple of str 12 | Which split to load. Valid values are 'train' and 'test'. 13 | The test set is the one released on Kaggle. 14 | 15 | Notes 16 | ----- 17 | The Dogs vs. Cats dataset does not provide an official 18 | validation split. Users need to create their own 19 | training / validation split using the `subset` argument. 20 | 21 | """ 22 | filename = 'dogs_vs_cats.hdf5' 23 | 24 | default_transformers = ((ScaleAndShift, [1 / 255.0, 0], 25 | {'which_sources': ('image_features',)}),) 26 | 27 | def __init__(self, which_sets, **kwargs): 28 | super(DogsVsCats, self).__init__( 29 | file_or_path=find_in_data_path(self.filename), 30 | which_sets=which_sets, **kwargs) 31 | -------------------------------------------------------------------------------- /fuel/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from fuel.datasets import H5PYDataset 3 | from fuel.transformers.defaults import rgb_images_from_encoded_bytes 4 | from fuel.utils import find_in_data_path 5 | 6 | 7 | class ILSVRC2010(H5PYDataset): 8 | u"""The ILSVRC2010 Dataset. 9 | 10 | The ImageNet Large-Scale Visual Recognition Challenge [ILSVRC] 11 | is an annual computer vision competition testing object classification 12 | and detection at large-scale. This is a wrapper around the data for 13 | the 2010 competition, which is (as of 2015) the only year for which 14 | test data groundtruth is available. 15 | 16 | Note that the download site for the images is not publicly 17 | accessible. To download the images, you may sign up for an account 18 | at [SIGNUP]. 19 | 20 | .. [ILSVRC] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, 21 | Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya 22 | Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei. 23 | *ImageNet Large Scale Visual Recognition Challenge*. IJCV, 2015. 24 | 25 | .. [SIGNUP] http://www.image-net.org/signup 26 | 27 | Parameters 28 | ---------- 29 | which_sets : tuple of str 30 | Which split to load. Valid values are 'train' (1.2M examples) 31 | 'valid' (150,000 examples), and 'test' (50,000 examples). 32 | 33 | """ 34 | filename = 'ilsvrc2010.hdf5' 35 | default_transformers = rgb_images_from_encoded_bytes(('encoded_images',)) 36 | 37 | def __init__(self, which_sets, **kwargs): 38 | kwargs.setdefault('load_in_memory', False) 39 | super(ILSVRC2010, self).__init__( 40 | file_or_path=find_in_data_path(self.filename), 41 | which_sets=which_sets, **kwargs) 42 | 43 | 44 | class ILSVRC2012(H5PYDataset): 45 | u"""The ILSVRC2012 Dataset. 46 | 47 | The ImageNet Large-Scale Visual Recognition Challenge [ILSVRC] 48 | is an annual computer vision competition testing object classification 49 | and detection at large-scale. This is a wrapper around the data for 50 | the 2012 competition. 51 | 52 | Note that the download site for the images is not publicly 53 | accessible. To downlaod the images, you may sign up for an account 54 | at [SIGNUP]. 55 | 56 | .. [ILSVRC] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, 57 | Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya 58 | Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei. 59 | *ImageNet Large Scale Visual Recognition Challenge*. IJCV, 2015. 60 | 61 | .. [SIGNUP] http://www.image-net.org/signup 62 | 63 | Parameters 64 | ---------- 65 | which_sets : tuple of str 66 | Which split to load. Valid values are 'train' (1,281,167 examples) 67 | 'valid' (50,000 examples), and 'test' (100,000 examples). 68 | 69 | """ 70 | filename = 'ilsvrc2012.hdf5' 71 | default_transformers = rgb_images_from_encoded_bytes(('encoded_images',)) 72 | 73 | def __init__(self, which_sets, **kwargs): 74 | kwargs.setdefault('load_in_memory', False) 75 | super(ILSVRC2012, self).__init__( 76 | file_or_path=find_in_data_path(self.filename), 77 | which_sets=which_sets, **kwargs) 78 | -------------------------------------------------------------------------------- /fuel/datasets/iris.py: -------------------------------------------------------------------------------- 1 | from fuel.datasets import H5PYDataset 2 | from fuel.utils import find_in_data_path 3 | 4 | 5 | class Iris(H5PYDataset): 6 | u"""Iris dataset. 7 | 8 | Iris [IRIS] is a simple pattern recognition dataset, which consist of 9 | 3 classes of 50 examples each having 4 real-valued features each, where 10 | each class refers to a type of iris plant. It is accessible through the 11 | UCI Machine Learning repository [UCIIRIS]. 12 | 13 | .. [IRIS] Ronald A. Fisher, *The use of multiple measurements in 14 | taxonomic problems*, Annual Eugenics, 7, Part II, 179-188, 15 | September 1936. 16 | .. [UCIIRIS] https://archive.ics.uci.edu/ml/datasets/Iris 17 | 18 | Parameters 19 | ---------- 20 | which_sets : tuple of str 21 | Which split to load. Valid value is 'all' 22 | corresponding to 150 examples. 23 | 24 | """ 25 | filename = 'iris.hdf5' 26 | 27 | def __init__(self, which_sets, **kwargs): 28 | kwargs.setdefault('load_in_memory', True) 29 | super(Iris, self).__init__( 30 | file_or_path=find_in_data_path(self.filename), 31 | which_sets=which_sets, **kwargs) 32 | -------------------------------------------------------------------------------- /fuel/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from fuel.datasets import H5PYDataset 3 | from fuel.transformers.defaults import uint8_pixels_to_floatX 4 | from fuel.utils import find_in_data_path 5 | 6 | 7 | class MNIST(H5PYDataset): 8 | u"""MNIST dataset. 9 | 10 | MNIST (Mixed National Institute of Standards and Technology) [LBBH] is 11 | a database of handwritten digits. It is one of the most famous 12 | datasets in machine learning and consists of 60,000 training images 13 | and 10,000 testing images. The images are grayscale and 28 x 28 pixels 14 | large. It is accessible through Yann LeCun's website [LECUN]. 15 | 16 | .. [LECUN] http://yann.lecun.com/exdb/mnist/ 17 | 18 | Parameters 19 | ---------- 20 | which_sets : tuple of str 21 | Which split to load. Valid values are 'train' and 'test', 22 | corresponding to the training set (60,000 examples) and the test 23 | set (10,000 examples). 24 | 25 | """ 26 | filename = 'mnist.hdf5' 27 | default_transformers = uint8_pixels_to_floatX(('features',)) 28 | 29 | def __init__(self, which_sets, **kwargs): 30 | kwargs.setdefault('load_in_memory', True) 31 | super(MNIST, self).__init__( 32 | file_or_path=find_in_data_path(self.filename), 33 | which_sets=which_sets, **kwargs) 34 | -------------------------------------------------------------------------------- /fuel/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from fuel.datasets import H5PYDataset 3 | from fuel.transformers.defaults import uint8_pixels_to_floatX 4 | from fuel.utils import find_in_data_path 5 | 6 | 7 | class SVHN(H5PYDataset): 8 | """The Street View House Numbers (SVHN) dataset. 9 | 10 | SVHN [SVHN] is a real-world image dataset for developing machine 11 | learning and object recognition algorithms with minimal requirement 12 | on data preprocessing and formatting. It can be seen as similar in 13 | flavor to MNIST [LBBH] (e.g., the images are of small cropped 14 | digits), but incorporates an order of magnitude more labeled data 15 | (over 600,000 digit images) and comes from a significantly harder, 16 | unsolved, real world problem (recognizing digits and numbers in 17 | natural scene images). SVHN is obtained from house numbers in 18 | Google Street View images. 19 | 20 | Parameters 21 | ---------- 22 | which_format : {1, 2} 23 | SVHN format 1 contains the full numbers, whereas SVHN format 2 24 | contains cropped digits. 25 | which_sets : tuple of str 26 | Which split to load. Valid values are 'train', 'test' and 'extra', 27 | corresponding to the training set (73,257 examples), the test 28 | set (26,032 examples) and the extra set (531,131 examples). 29 | Note that SVHN does not have a validation set; usually you will 30 | create your own training/validation split using the `subset` 31 | argument. 32 | 33 | """ 34 | _filename = 'svhn_format_{}.hdf5' 35 | default_transformers = uint8_pixels_to_floatX(('features',)) 36 | 37 | def __init__(self, which_format, which_sets, **kwargs): 38 | self.which_format = which_format 39 | super(SVHN, self).__init__( 40 | file_or_path=find_in_data_path(self.filename), 41 | which_sets=which_sets, **kwargs) 42 | 43 | @property 44 | def filename(self): 45 | return self._filename.format(self.which_format) 46 | -------------------------------------------------------------------------------- /fuel/datasets/text.py: -------------------------------------------------------------------------------- 1 | from picklable_itertools import iter_, chain 2 | 3 | from fuel.datasets import Dataset 4 | from fuel.utils.formats import open_ 5 | 6 | 7 | class TextFile(Dataset): 8 | r"""Reads text files and numberizes them given a dictionary. 9 | 10 | Parameters 11 | ---------- 12 | files : list of str 13 | The names of the files in order which they should be read. Each 14 | file is expected to have a sentence per line. If the filename ends 15 | with `.gz` it will be opened using `gzip`. Note however that `gzip` 16 | file handles aren't picklable on legacy Python. 17 | dictionary : str or dict 18 | Either the path to a Pickled dictionary mapping tokens to integers, 19 | or the dictionary itself. At the very least this dictionary must 20 | map the unknown word-token to an integer. 21 | bos_token : str or None, optional 22 | The beginning-of-sentence (BOS) token in the dictionary that 23 | denotes the beginning of a sentence. Is ```` by default. If 24 | passed ``None`` no beginning of sentence markers will be added. 25 | eos_token : str or None, optional 26 | The end-of-sentence (EOS) token is ```` by default, see 27 | ``bos_taken``. 28 | unk_token : str, optional 29 | The token in the dictionary to fall back on when a token could not 30 | be found in the dictionary. ```` by default. Pass ``None`` if 31 | the dataset doesn't contain any out-of-vocabulary words/characters 32 | (the data request is going to crash if meets an unknown symbol). 33 | 34 | level : 'word' or 'character', optional 35 | If 'word' the dictionary is expected to contain full words. The 36 | sentences in the text file will be split at the spaces, and each 37 | word replaced with its number as given by the dictionary, resulting 38 | in each example being a single list of numbers. If 'character' the 39 | dictionary is expected to contain single letters as keys. A single 40 | example will be a list of character numbers, starting with the 41 | first non-whitespace character and finishing with the last one. The 42 | default is 'word'. 43 | preprocess : function, optional 44 | A function which takes a sentence (string) as an input and returns 45 | a modified string. For example ``str.lower`` in order to lowercase 46 | the sentence before numberizing. 47 | encoding : str, optional 48 | The encoding to use to read the file. Defaults to ``None``. Use 49 | UTF-8 if the dictionary you pass contains UTF-8 characters, but 50 | note that this makes the dataset unpicklable on legacy Python. 51 | 52 | Examples 53 | -------- 54 | >>> with open('sentences.txt', 'w') as f: 55 | ... _ = f.write("This is a sentence\n") 56 | ... _ = f.write("This another one") 57 | >>> dictionary = {'': 0, '': 1, 'this': 2, 'a': 3, 'one': 4} 58 | >>> def lower(s): 59 | ... return s.lower() 60 | >>> text_data = TextFile(files=['sentences.txt'], 61 | ... dictionary=dictionary, bos_token=None, 62 | ... preprocess=lower) 63 | >>> from fuel.streams import DataStream 64 | >>> for data in DataStream(text_data).get_epoch_iterator(): 65 | ... print(data) 66 | ([2, 0, 3, 0, 1],) 67 | ([2, 0, 4, 1],) 68 | >>> full_dictionary = {'this': 0, 'a': 3, 'is': 4, 'sentence': 5, 69 | ... 'another': 6, 'one': 7} 70 | >>> text_data = TextFile(files=['sentences.txt'], 71 | ... dictionary=full_dictionary, bos_token=None, 72 | ... eos_token=None, unk_token=None, 73 | ... preprocess=lower) 74 | >>> for data in DataStream(text_data).get_epoch_iterator(): 75 | ... print(data) 76 | ([0, 4, 3, 5],) 77 | ([0, 6, 7],) 78 | 79 | .. doctest:: 80 | :hide: 81 | 82 | >>> import os 83 | >>> os.remove('sentences.txt') 84 | 85 | """ 86 | provides_sources = ('features',) 87 | example_iteration_scheme = None 88 | 89 | def __init__(self, files, dictionary, bos_token='', eos_token='', 90 | unk_token='', level='word', preprocess=None, 91 | encoding=None): 92 | self.files = files 93 | self.dictionary = dictionary 94 | if bos_token is not None and bos_token not in dictionary: 95 | raise ValueError( 96 | "BOS token '{}' is not in the dictionary".format(bos_token)) 97 | self.bos_token = bos_token 98 | if eos_token is not None and eos_token not in dictionary: 99 | raise ValueError( 100 | "EOS token '{}' is not in the dictionary".format(eos_token)) 101 | self.eos_token = eos_token 102 | if unk_token is not None and unk_token not in dictionary: 103 | raise ValueError( 104 | "UNK token '{}' is not in the dictionary".format(unk_token)) 105 | self.unk_token = unk_token 106 | if level not in ('word', 'character'): 107 | raise ValueError( 108 | "level should be 'word' or 'character', not '{}'" 109 | .format(level)) 110 | self.level = level 111 | self.preprocess = preprocess 112 | self.encoding = encoding 113 | super(TextFile, self).__init__() 114 | 115 | def open(self): 116 | return chain(*[iter_(open_(f, encoding=self.encoding)) 117 | for f in self.files]) 118 | 119 | def _get_from_dictionary(self, symbol): 120 | value = self.dictionary.get(symbol) 121 | if value is not None: 122 | return value 123 | else: 124 | if self.unk_token is None: 125 | raise KeyError("token '{}' not found in dictionary and no " 126 | "`unk_token` given".format(symbol)) 127 | return self.dictionary[self.unk_token] 128 | 129 | def get_data(self, state=None, request=None): 130 | if request is not None: 131 | raise ValueError 132 | sentence = next(state) 133 | if self.preprocess is not None: 134 | sentence = self.preprocess(sentence) 135 | data = [self.dictionary[self.bos_token]] if self.bos_token else [] 136 | if self.level == 'word': 137 | data.extend(self._get_from_dictionary(word) 138 | for word in sentence.split()) 139 | else: 140 | data.extend(self._get_from_dictionary(char) 141 | for char in sentence.strip()) 142 | if self.eos_token: 143 | data.append(self.dictionary[self.eos_token]) 144 | return (data,) 145 | -------------------------------------------------------------------------------- /fuel/datasets/toy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy 4 | 5 | from collections import OrderedDict 6 | 7 | from fuel import config 8 | from fuel.datasets import IndexableDataset 9 | 10 | 11 | class Spiral(IndexableDataset): 12 | u"""Toy dataset containing points sampled from spirals on a 2d plane. 13 | 14 | The dataset contains 3 sources: 15 | 16 | * features -- the (x, y) position of the datapoints 17 | * position -- the relative position on the spiral arm 18 | * label -- the class labels (spiral arm) 19 | 20 | .. plot:: 21 | 22 | from fuel.datasets.toy import Spiral 23 | 24 | ds = Spiral(classes=3) 25 | features, position, label = ds.get_data(None, slice(0, 500)) 26 | 27 | plt.title("Datapoints drawn from Spiral(classes=3)") 28 | for l, m in enumerate(['o', '^', 'v']): 29 | mask = label == l 30 | plt.scatter(features[mask,0], features[mask,1], 31 | c=position[mask], marker=m, label="label==%d"%l) 32 | plt.xlim(-1.2, 1.2) 33 | plt.ylim(-1.2, 1.2) 34 | plt.legend() 35 | plt.colorbar() 36 | plt.xlabel("features[:,0]") 37 | plt.ylabel("features[:,1]") 38 | plt.show() 39 | 40 | Parameters 41 | ---------- 42 | num_examples : int 43 | Number of datapoints to create. 44 | classes : int 45 | Number of spiral arms. 46 | cycles : float 47 | Number of turns the arms take. 48 | noise : float 49 | Add normal distributed noise with standard deviation *noise*. 50 | 51 | """ 52 | def __init__(self, num_examples=1000, classes=1, cycles=1., noise=0.0, 53 | **kwargs): 54 | seed = kwargs.pop('seed', config.default_seed) 55 | rng = numpy.random.RandomState(seed) 56 | # Create dataset 57 | pos = rng.uniform(size=num_examples, low=0, high=cycles) 58 | label = rng.randint(size=num_examples, low=0, high=classes) 59 | radius = (2 * pos + 1) / 3. 60 | phase_offset = label * (2*numpy.pi) / classes 61 | 62 | features = numpy.zeros(shape=(num_examples, 2), dtype='float32') 63 | 64 | features[:, 0] = radius * numpy.sin(2*numpy.pi*pos + phase_offset) 65 | features[:, 1] = radius * numpy.cos(2*numpy.pi*pos + phase_offset) 66 | features += noise * rng.normal(size=(num_examples, 2)) 67 | 68 | data = OrderedDict([ 69 | ('features', features), 70 | ('position', pos), 71 | ('label', label), 72 | ]) 73 | 74 | super(Spiral, self).__init__(data, **kwargs) 75 | 76 | 77 | class SwissRoll(IndexableDataset): 78 | """Dataset containing points from a 3-dimensional Swiss roll. 79 | 80 | The dataset contains 2 sources: 81 | 82 | * features -- the x, y and z position of the datapoints 83 | * position -- radial and z position on the manifold 84 | 85 | .. plot:: 86 | 87 | from fuel.datasets.toy import SwissRoll 88 | import mpl_toolkits.mplot3d.axes3d as p3 89 | import numpy as np 90 | 91 | ds = SwissRoll() 92 | features, pos = ds.get_data(None, slice(0, 1000)) 93 | 94 | color = pos[:,0] 95 | color -= color.min() 96 | color /= color.max() 97 | 98 | fig = plt.figure() 99 | ax = fig.gca(projection="3d") 100 | ax.scatter(features[:,0], features[:,1], features[:,2], 101 | 'x', c=color) 102 | ax.set_xlim(-1, 1) 103 | ax.set_ylim(-1, 1) 104 | ax.set_zlim(-1, 1) 105 | ax.view_init(10., 10.) 106 | plt.show() 107 | 108 | Parameters 109 | ---------- 110 | num_examples : int 111 | Number of datapoints to create. 112 | noise : float 113 | Add normal distributed noise with standard deviation *noise*. 114 | 115 | """ 116 | def __init__(self, num_examples=1000, noise=0.0, **kwargs): 117 | cycles = 1.5 118 | seed = kwargs.pop('seed', config.default_seed) 119 | rng = numpy.random.RandomState(seed) 120 | pos = rng.uniform(size=num_examples, low=0, high=1) 121 | phi = cycles * numpy.pi * (1 + 2*pos) 122 | radius = (1 + 2 * pos) / 3 123 | 124 | x = radius * numpy.cos(phi) 125 | y = radius * numpy.sin(phi) 126 | z = rng.uniform(size=num_examples, low=-1, high=1) 127 | 128 | features = numpy.zeros(shape=(num_examples, 3), dtype='float32') 129 | features[:, 0] = x 130 | features[:, 1] = y 131 | features[:, 2] = z 132 | features += noise * rng.normal(size=(num_examples, 3)) 133 | 134 | position = numpy.zeros(shape=(num_examples, 2), dtype='float32') 135 | position[:, 0] = pos 136 | position[:, 1] = z 137 | 138 | data = OrderedDict([ 139 | ('features', features), 140 | ('position', position), 141 | ]) 142 | 143 | super(SwissRoll, self).__init__(data, **kwargs) 144 | -------------------------------------------------------------------------------- /fuel/datasets/youtube_audio.py: -------------------------------------------------------------------------------- 1 | from fuel.datasets.hdf5 import H5PYDataset 2 | from fuel.utils import find_in_data_path 3 | 4 | 5 | class YouTubeAudio(H5PYDataset): 6 | r"""Dataset of audio from YouTube video. 7 | 8 | Assumes the existence of a dataset file with the name 9 | `youtube_id.hdf5`. These datasets don't have any split; the entire 10 | audio sequence is considered training. 11 | 12 | Note that the data structured in the form `(batch, time, features)` 13 | where `features` are the audio channels (dimension 1 or 2) and batch is 14 | equal to 1 in this case (since there is only one audiotrack). 15 | 16 | Parameters 17 | ---------- 18 | youtube_id : str 19 | 11-character video ID (taken from YouTube URL) 20 | \*\*kwargs 21 | Passed to the `H5PYDataset` class. 22 | 23 | """ 24 | def __init__(self, youtube_id, **kwargs): 25 | super(YouTubeAudio, self).__init__( 26 | file_or_path=find_in_data_path('{}.hdf5'.format(youtube_id)), 27 | which_sets=('train',), **kwargs 28 | ) 29 | -------------------------------------------------------------------------------- /fuel/downloaders/__init__.py: -------------------------------------------------------------------------------- 1 | """Download modules for built-in datasets. 2 | 3 | Download functions accept two arguments: 4 | 5 | * `save_directory` : Where to save the downloaded files 6 | * `clear` : If `True`, clear the downloaded files. Defaults to `False`. 7 | 8 | """ 9 | from fuel.downloaders import adult 10 | from fuel.downloaders import binarized_mnist 11 | from fuel.downloaders import caltech101_silhouettes 12 | from fuel.downloaders import celeba 13 | from fuel.downloaders import cifar10 14 | from fuel.downloaders import cifar100 15 | from fuel.downloaders import dogs_vs_cats 16 | from fuel.downloaders import iris 17 | from fuel.downloaders import mnist 18 | from fuel.downloaders import svhn 19 | from fuel.downloaders import ilsvrc2010 20 | from fuel.downloaders import ilsvrc2012 21 | from fuel.downloaders import youtube_audio 22 | 23 | all_downloaders = ( 24 | ('adult', adult.fill_subparser), 25 | ('binarized_mnist', binarized_mnist.fill_subparser), 26 | ('caltech101_silhouettes', caltech101_silhouettes.fill_subparser), 27 | ('celeba', celeba.fill_subparser), 28 | ('cifar10', cifar10.fill_subparser), 29 | ('cifar100', cifar100.fill_subparser), 30 | ('iris', iris.fill_subparser), 31 | ('mnist', mnist.fill_subparser), 32 | ('svhn', svhn.fill_subparser), 33 | ('ilsvrc2010', ilsvrc2010.fill_subparser), 34 | ('ilsvrc2012', ilsvrc2012.fill_subparser), 35 | ('dogs_vs_cats', dogs_vs_cats.fill_subparser), 36 | ('youtube_audio', youtube_audio.fill_subparser)) 37 | -------------------------------------------------------------------------------- /fuel/downloaders/adult.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def fill_subparser(subparser): 5 | """Set up a subparser to download the adult dataset file. 6 | 7 | The Adult dataset file `adult.data` and `adult.test` is downloaded from 8 | the UCI Machine Learning Repository [UCIADULT]. 9 | 10 | .. [UCIADULT] https://archive.ics.uci.edu/ml/datasets/Adult 11 | 12 | Parameters 13 | ---------- 14 | subparser : :class:`argparse.ArgumentParser` 15 | Subparser handling the adult command. 16 | 17 | """ 18 | subparser.set_defaults( 19 | urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/' 20 | 'adult/adult.data', 21 | 'https://archive.ics.uci.edu/ml/machine-learning-databases/' 22 | 'adult/adult.test'], 23 | filenames=['adult.data', 'adult.test']) 24 | return default_downloader 25 | -------------------------------------------------------------------------------- /fuel/downloaders/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | from contextlib import contextmanager 5 | 6 | import requests 7 | from progressbar import (ProgressBar, Percentage, Bar, ETA, FileTransferSpeed, 8 | Timer, UnknownLength) 9 | from six.moves import zip, urllib 10 | from ..exceptions import NeedURLPrefix 11 | 12 | 13 | @contextmanager 14 | def progress_bar(name, maxval): 15 | """Manages a progress bar for a download. 16 | 17 | Parameters 18 | ---------- 19 | name : str 20 | Name of the downloaded file. 21 | maxval : int 22 | Total size of the download, in bytes. 23 | 24 | """ 25 | if maxval is not UnknownLength: 26 | widgets = ['{}: '.format(name), Percentage(), ' ', 27 | Bar(marker='=', left='[', right=']'), ' ', ETA(), ' ', 28 | FileTransferSpeed()] 29 | else: 30 | widgets = ['{}: '.format(name), ' ', Timer(), ' ', FileTransferSpeed()] 31 | bar = ProgressBar(widgets=widgets, max_value=maxval, fd=sys.stdout).start() 32 | try: 33 | yield bar 34 | finally: 35 | bar.update(maxval) 36 | bar.finish() 37 | 38 | 39 | def filename_from_url(url, path=None): 40 | """Parses a URL to determine a file name. 41 | 42 | Parameters 43 | ---------- 44 | url : str 45 | URL to parse. 46 | 47 | """ 48 | r = requests.get(url, stream=True) 49 | if 'Content-Disposition' in r.headers: 50 | filename = re.findall(r'filename=([^;]+)', 51 | r.headers['Content-Disposition'])[0].strip('"\"') 52 | else: 53 | filename = os.path.basename(urllib.parse.urlparse(url).path) 54 | return filename 55 | 56 | 57 | def download(url, file_handle, chunk_size=1024): 58 | """Downloads a given URL to a specific file. 59 | 60 | Parameters 61 | ---------- 62 | url : str 63 | URL to download. 64 | file_handle : file 65 | Where to save the downloaded URL. 66 | 67 | """ 68 | r = requests.get(url, stream=True) 69 | total_length = r.headers.get('content-length') 70 | if total_length is None: 71 | maxval = UnknownLength 72 | else: 73 | maxval = int(total_length) 74 | name = file_handle.name 75 | with progress_bar(name=name, maxval=maxval) as bar: 76 | for i, chunk in enumerate(r.iter_content(chunk_size)): 77 | if total_length: 78 | bar.update(i * chunk_size) 79 | file_handle.write(chunk) 80 | 81 | 82 | def ensure_directory_exists(directory): 83 | """Create directory (with parents) if does not exist, raise on failure. 84 | 85 | Parameters 86 | ---------- 87 | directory : str 88 | The directory to create 89 | 90 | """ 91 | if os.path.isdir(directory): 92 | return 93 | os.makedirs(directory) 94 | 95 | 96 | def default_downloader(directory, urls, filenames, url_prefix=None, 97 | clear=False): 98 | """Downloads or clears files from URLs and filenames. 99 | 100 | Parameters 101 | ---------- 102 | directory : str 103 | The directory in which downloaded files are saved. 104 | urls : list 105 | A list of URLs to download. 106 | filenames : list 107 | A list of file names for the corresponding URLs. 108 | url_prefix : str, optional 109 | If provided, this is prepended to filenames that 110 | lack a corresponding URL. 111 | clear : bool, optional 112 | If `True`, delete the given filenames from the given 113 | directory rather than download them. 114 | 115 | """ 116 | # Parse file names from URL if not provided 117 | for i, url in enumerate(urls): 118 | filename = filenames[i] 119 | if not filename: 120 | filename = filename_from_url(url) 121 | if not filename: 122 | raise ValueError("no filename available for URL '{}'".format(url)) 123 | filenames[i] = filename 124 | files = [os.path.join(directory, f) for f in filenames] 125 | 126 | if clear: 127 | for f in files: 128 | if os.path.isfile(f): 129 | os.remove(f) 130 | else: 131 | print('Downloading ' + ', '.join(filenames) + '\n') 132 | ensure_directory_exists(directory) 133 | 134 | for url, f, n in zip(urls, files, filenames): 135 | if not url: 136 | if url_prefix is None: 137 | raise NeedURLPrefix 138 | url = url_prefix + n 139 | with open(f, 'wb') as file_handle: 140 | download(url, file_handle) 141 | -------------------------------------------------------------------------------- /fuel/downloaders/binarized_mnist.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def fill_subparser(subparser): 5 | """Sets up a subparser to download the binarized MNIST dataset files. 6 | 7 | The binarized MNIST dataset files 8 | (`binarized_mnist_{train,valid,test}.amat`) are downloaded from 9 | Hugo Larochelle's website [HUGO]. 10 | 11 | .. [HUGO] http://www.cs.toronto.edu/~larocheh/public/datasets/ 12 | binarized_mnist/binarized_mnist_{train,valid,test}.amat 13 | 14 | Parameters 15 | ---------- 16 | subparser : :class:`argparse.ArgumentParser` 17 | Subparser handling the `binarized_mnist` command. 18 | 19 | """ 20 | sets = ['train', 'valid', 'test'] 21 | urls = ['http://www.cs.toronto.edu/~larocheh/public/datasets/' + 22 | 'binarized_mnist/binarized_mnist_{}.amat'.format(s) for s in sets] 23 | filenames = ['binarized_mnist_{}.amat'.format(s) for s in sets] 24 | subparser.set_defaults(urls=urls, filenames=filenames) 25 | return default_downloader 26 | -------------------------------------------------------------------------------- /fuel/downloaders/caltech101_silhouettes.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | BASE_URL = 'https://people.cs.umass.edu/~marlin/data/' 5 | FILENAME = 'caltech101_silhouettes_{}_split1.mat' 6 | 7 | 8 | def silhouettes_downloader(size, **kwargs): 9 | if size not in (16, 28): 10 | raise ValueError("size must be 16 or 28") 11 | 12 | actual_filename = FILENAME.format(size) 13 | actual_url = BASE_URL + actual_filename 14 | default_downloader(urls=[actual_url], 15 | filenames=[actual_filename], **kwargs) 16 | 17 | 18 | def fill_subparser(subparser): 19 | """Sets up a subparser to download the Silhouettes dataset files. 20 | 21 | The following CalTech 101 Silhouette dataset files can be downloaded 22 | from Benjamin M. Marlin's website [MARLIN]: 23 | `caltech101_silhouettes_16_split1.mat` and 24 | `caltech101_silhouettes_28_split1.mat`. 25 | 26 | .. [MARLIN] https://people.cs.umass.edu/~marlin/data.shtml 27 | 28 | Parameters 29 | ---------- 30 | subparser : :class:`argparse.ArgumentParser` 31 | Subparser handling the `caltech101_silhouettes` command. 32 | 33 | """ 34 | subparser.add_argument( 35 | "size", type=int, choices=(16, 28), 36 | help="height/width of the datapoints") 37 | return silhouettes_downloader 38 | -------------------------------------------------------------------------------- /fuel/downloaders/celeba.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def fill_subparser(subparser): 5 | """Sets up a subparser to download the CelebA dataset file. 6 | 7 | Parameters 8 | ---------- 9 | subparser : :class:`argparse.ArgumentParser` 10 | Subparser handling the `celeba` command. 11 | 12 | """ 13 | urls = ['https://www.dropbox.com/sh/8oqt9vytwxb3s4r/' 14 | 'AAC7-uCaJkmPmvLX2_P5qy0ga/Anno/list_attr_celeba.txt?dl=1', 15 | 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/' 16 | 'AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1'] 17 | filenames = ['list_attr_celeba.txt', 'img_align_celeba.zip'] 18 | subparser.set_defaults(urls=urls, filenames=filenames) 19 | return default_downloader 20 | -------------------------------------------------------------------------------- /fuel/downloaders/cifar10.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def fill_subparser(subparser): 5 | """Sets up a subparser to download the CIFAR-10 dataset file. 6 | 7 | The CIFAR-10 dataset file is downloaded from Alex Krizhevsky's 8 | website [ALEX]. 9 | 10 | Parameters 11 | ---------- 12 | subparser : :class:`argparse.ArgumentParser` 13 | Subparser handling the `cifar10` command. 14 | 15 | """ 16 | url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 17 | filename = 'cifar-10-python.tar.gz' 18 | subparser.set_defaults(urls=[url], filenames=[filename]) 19 | return default_downloader 20 | -------------------------------------------------------------------------------- /fuel/downloaders/cifar100.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def fill_subparser(subparser): 5 | """Sets up a subparser to download the CIFAR-100 dataset file. 6 | 7 | The CIFAR-100 dataset file is downloaded from Alex Krizhevsky's 8 | website [ALEX]. 9 | 10 | .. [ALEX] http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz 11 | 12 | Parameters 13 | ---------- 14 | subparser : :class:`argparse.ArgumentParser` 15 | Subparser handling the `cifar100` command. 16 | 17 | """ 18 | url = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' 19 | filename = 'cifar-100-python.tar.gz' 20 | subparser.set_defaults(urls=[url], filenames=[filename]) 21 | return default_downloader 22 | -------------------------------------------------------------------------------- /fuel/downloaders/dogs_vs_cats.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def fill_subparser(subparser): 5 | """Sets up a subparser to download the Dogs vs. Cats dataset file. 6 | 7 | Kaggle's Dogs vs. Cats [KAGGLE] dataset is downloaded from Dropbox 8 | since Kaggle requires user authentication. 9 | 10 | .. [KAGGLE] https://www.kaggle.com/c/dogs-vs-cats 11 | 12 | Parameters 13 | ---------- 14 | subparser : :class:`argparse.ArgumentParser` 15 | Subparser handling the `dogs_vs_cats` command. 16 | 17 | """ 18 | urls = ['https://www.dropbox.com/s/s3u30quvpxqdbz6/train.zip?dl=1', 19 | 'https://www.dropbox.com/s/21rwu6drnplsbkb/test1.zip?dl=1'] 20 | filenames = ['dogs_vs_cats.train.zip', 'dogs_vs_cats.test1.zip'] 21 | subparser.set_defaults(urls=urls, filenames=filenames) 22 | return default_downloader 23 | -------------------------------------------------------------------------------- /fuel/downloaders/ilsvrc2010.py: -------------------------------------------------------------------------------- 1 | from fuel.converters.ilsvrc2010 import IMAGE_TARS 2 | from fuel.downloaders.base import default_downloader 3 | 4 | 5 | def fill_subparser(subparser): 6 | """Sets up a subparser to download the ILSVRC2010 dataset files. 7 | 8 | Note that you will need to use `--url-prefix` to download the 9 | non-public files (namely, the TARs of images). This is a single 10 | prefix that is common to all distributed files, which you can 11 | obtain by registering at the ImageNet website [DOWNLOAD]. 12 | 13 | Note that these files are quite large and you may be better off 14 | simply downloading them separately and running ``fuel-convert``. 15 | 16 | .. [DOWNLOAD] http://www.image-net.org/download-images 17 | 18 | 19 | Parameters 20 | ---------- 21 | subparser : :class:`argparse.ArgumentParser` 22 | Subparser handling the `ilsvrc2010` command. 23 | 24 | """ 25 | urls = [ 26 | ('http://www.image-net.org/challenges/LSVRC/2010/' 27 | 'ILSVRC2010_test_ground_truth.txt'), 28 | ('http://www.image-net.org/challenges/LSVRC/2010/' 29 | 'download/ILSVRC2010_devkit-1.0.tar.gz'), 30 | ] + ([None] * len(IMAGE_TARS)) 31 | filenames = [None, None] + list(IMAGE_TARS) 32 | subparser.set_defaults(urls=urls, filenames=filenames) 33 | subparser.add_argument('-P', '--url-prefix', type=str, default=None, 34 | help="URL prefix to prepend to the filenames of " 35 | "non-public files, in order to download them. " 36 | "Be sure to include the trailing slash.") 37 | return default_downloader 38 | -------------------------------------------------------------------------------- /fuel/downloaders/ilsvrc2012.py: -------------------------------------------------------------------------------- 1 | from fuel.converters.ilsvrc2012 import ALL_FILES 2 | from fuel.downloaders.base import default_downloader 3 | 4 | 5 | def fill_subparser(subparser): 6 | """Sets up a subparser to download the ILSVRC2012 dataset files. 7 | 8 | Note that you will need to use `--url-prefix` to download the 9 | non-public files (namely, the TARs of images). This is a single 10 | prefix that is common to all distributed files, which you can 11 | obtain by registering at the ImageNet website [DOWNLOAD]. 12 | 13 | Note that these files are quite large and you may be better off 14 | simply downloading them separately and running ``fuel-convert``. 15 | 16 | .. [DOWNLOAD] http://www.image-net.org/download-images 17 | 18 | 19 | Parameters 20 | ---------- 21 | subparser : :class:`argparse.ArgumentParser` 22 | Subparser handling the `ilsvrc2012` command. 23 | 24 | """ 25 | urls = ([None] * len(ALL_FILES)) 26 | filenames = list(ALL_FILES) 27 | subparser.set_defaults(urls=urls, filenames=filenames) 28 | subparser.add_argument('-P', '--url-prefix', type=str, default=None, 29 | help="URL prefix to prepend to the filenames of " 30 | "non-public files, in order to download them. " 31 | "Be sure to include the trailing slash.") 32 | return default_downloader 33 | -------------------------------------------------------------------------------- /fuel/downloaders/iris.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def fill_subparser(subparser): 5 | """Set up a subparser to download the Iris dataset file. 6 | 7 | The Iris dataset file `iris.data` is downloaded from the UCI 8 | Machine Learning Repository [UCIIRIS]. 9 | 10 | Parameters 11 | ---------- 12 | subparser : :class:`argparse.ArgumentParser` 13 | Subparser handling the iris command. 14 | 15 | """ 16 | subparser.set_defaults( 17 | urls=['https://archive.ics.uci.edu/ml/machine-learning-databases/' 18 | 'iris/iris.data'], 19 | filenames=['iris.data']) 20 | return default_downloader 21 | -------------------------------------------------------------------------------- /fuel/downloaders/mnist.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def fill_subparser(subparser): 5 | """Sets up a subparser to download the MNIST dataset files. 6 | 7 | The following MNIST dataset files are downloaded from Yann LeCun's 8 | website [LECUN]: 9 | `train-images-idx3-ubyte.gz`, `train-labels-idx1-ubyte.gz`, 10 | `t10k-images-idx3-ubyte.gz`, `t10k-labels-idx1-ubyte.gz`. 11 | 12 | Parameters 13 | ---------- 14 | subparser : :class:`argparse.ArgumentParser` 15 | Subparser handling the `mnist` command. 16 | 17 | """ 18 | filenames = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 19 | 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'] 20 | urls = ['http://yann.lecun.com/exdb/mnist/' + f for f in filenames] 21 | subparser.set_defaults(urls=urls, filenames=filenames) 22 | return default_downloader 23 | -------------------------------------------------------------------------------- /fuel/downloaders/svhn.py: -------------------------------------------------------------------------------- 1 | from fuel.downloaders.base import default_downloader 2 | 3 | 4 | def svhn_downloader(which_format, directory, clear=False): 5 | suffix = {1: '.tar.gz', 2: '_32x32.mat'}[which_format] 6 | sets = ['train', 'test', 'extra'] 7 | default_downloader( 8 | directory=directory, 9 | urls=[None for f in sets], 10 | filenames=['{}{}'.format(s, suffix) for s in sets], 11 | url_prefix='http://ufldl.stanford.edu/housenumbers/', 12 | clear=clear) 13 | 14 | 15 | def fill_subparser(subparser): 16 | """Sets up a subparser to download the SVHN dataset files. 17 | 18 | The SVHN dataset files (`{train,test,extra}{.tar.gz,_32x32.mat}`) 19 | are downloaded from the official website [SVHNSITE]. 20 | 21 | Parameters 22 | ---------- 23 | subparser : :class:`argparse.ArgumentParser` 24 | Subparser handling the `svhn` command. 25 | 26 | """ 27 | subparser.add_argument( 28 | "which_format", help="which dataset format", type=int, choices=(1, 2)) 29 | return svhn_downloader 30 | -------------------------------------------------------------------------------- /fuel/downloaders/youtube_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | try: 4 | import pafy 5 | PAFY_AVAILABLE = True 6 | except ImportError: 7 | PAFY_AVAILABLE = False 8 | 9 | 10 | def download(directory, youtube_id, clear=False): 11 | """Download the audio of a YouTube video. 12 | 13 | The audio is downloaded in the highest available quality. Progress is 14 | printed to `stdout`. The file is named `youtube_id.m4a`, where 15 | `youtube_id` is the 11-character code identifiying the YouTube video 16 | (can be determined from the URL). 17 | 18 | Parameters 19 | ---------- 20 | directory : str 21 | The directory in which to save the downloaded audio file. 22 | youtube_id : str 23 | 11-character video ID (taken from YouTube URL) 24 | clear : bool 25 | If `True`, it deletes the downloaded video. Otherwise it downloads 26 | it. Defaults to `False`. 27 | 28 | """ 29 | filepath = os.path.join(directory, '{}.m4a'.format(youtube_id)) 30 | if clear: 31 | os.remove(filepath) 32 | return 33 | if not PAFY_AVAILABLE: 34 | raise ImportError("pafy is required to download YouTube videos") 35 | url = 'https://www.youtube.com/watch?v={}'.format(youtube_id) 36 | video = pafy.new(url) 37 | audio = video.getbestaudio() 38 | audio.download(quiet=False, filepath=filepath) 39 | 40 | 41 | def fill_subparser(subparser): 42 | """Sets up a subparser to download audio of YouTube videos. 43 | 44 | Adds the compulsory `--youtube-id` flag. 45 | 46 | Parameters 47 | ---------- 48 | subparser : :class:`argparse.ArgumentParser` 49 | Subparser handling the `youtube_audio` command. 50 | 51 | """ 52 | subparser.add_argument( 53 | '--youtube-id', type=str, required=True, 54 | help=("The YouTube ID of the video from which to extract audio, " 55 | "usually an 11-character string.") 56 | ) 57 | return download 58 | -------------------------------------------------------------------------------- /fuel/exceptions.py: -------------------------------------------------------------------------------- 1 | class AxisLabelsMismatchError(ValueError): 2 | """Raised when a pair of axis labels tuples do not match.""" 3 | 4 | 5 | class ConfigurationError(Exception): 6 | """Error raised when a configuration value is requested but not set.""" 7 | 8 | 9 | class MissingInputFiles(Exception): 10 | """Exception raised by a converter when input files are not found. 11 | 12 | Parameters 13 | ---------- 14 | message : str 15 | The error message to be associated with this exception. 16 | filenames : list 17 | A list of filenames that were not found. 18 | 19 | """ 20 | def __init__(self, message, filenames): 21 | self.filenames = filenames 22 | super(MissingInputFiles, self).__init__(message, filenames) 23 | 24 | 25 | class NeedURLPrefix(Exception): 26 | """Raised when a URL is not provided for a file.""" 27 | -------------------------------------------------------------------------------- /fuel/iterator.py: -------------------------------------------------------------------------------- 1 | import six 2 | 3 | 4 | class DataIterator(six.Iterator): 5 | """An iterator over data, representing a single epoch. 6 | 7 | Parameters 8 | ---------- 9 | data_stream : :class:`DataStream` or :class:`Transformer` 10 | The data stream over which to iterate. 11 | request_iterator : iterator 12 | An iterator which returns the request to pass to the data stream 13 | for each step. 14 | as_dict : bool, optional 15 | If `True`, return dictionaries mapping source names to data 16 | from each source. If `False` (default), return tuples in the 17 | same order as `data_stream.sources`. 18 | 19 | """ 20 | def __init__(self, data_stream, request_iterator=None, as_dict=False): 21 | self.data_stream = data_stream 22 | self.request_iterator = request_iterator 23 | self.as_dict = as_dict 24 | 25 | def __iter__(self): 26 | return self 27 | 28 | def __next__(self): 29 | if self.request_iterator is not None: 30 | data = self.data_stream.get_data(next(self.request_iterator)) 31 | else: 32 | data = self.data_stream.get_data() 33 | if self.as_dict: 34 | return dict(zip(self.data_stream.sources, data)) 35 | else: 36 | return data 37 | -------------------------------------------------------------------------------- /fuel/server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy 4 | import zmq 5 | from numpy.lib.format import header_data_from_array_1_0 6 | 7 | from fuel.utils import buffer_ 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def send_arrays(socket, arrays, stop=False): 13 | """Send NumPy arrays using the buffer interface and some metadata. 14 | 15 | Parameters 16 | ---------- 17 | socket : :class:`zmq.Socket` 18 | The socket to send data over. 19 | arrays : list 20 | A list of :class:`numpy.ndarray` to transfer. 21 | stop : bool, optional 22 | Instead of sending a series of NumPy arrays, send a JSON object 23 | with a single `stop` key. The :func:`recv_arrays` will raise 24 | ``StopIteration`` when it receives this. 25 | 26 | Notes 27 | ----- 28 | The protocol is very simple: A single JSON object describing the array 29 | format (using the same specification as ``.npy`` files) is sent first. 30 | Subsequently the arrays are sent as bytestreams (through NumPy's 31 | support of the buffering protocol). 32 | 33 | """ 34 | if arrays: 35 | # The buffer protocol only works on contiguous arrays 36 | arrays = [numpy.ascontiguousarray(array) for array in arrays] 37 | if stop: 38 | headers = {'stop': True} 39 | socket.send_json(headers) 40 | else: 41 | headers = [header_data_from_array_1_0(array) for array in arrays] 42 | socket.send_json(headers, zmq.SNDMORE) 43 | for array in arrays[:-1]: 44 | socket.send(array, zmq.SNDMORE) 45 | socket.send(arrays[-1]) 46 | 47 | 48 | def recv_arrays(socket): 49 | """Receive a list of NumPy arrays. 50 | 51 | Parameters 52 | ---------- 53 | socket : :class:`zmq.Socket` 54 | The socket to receive the arrays on. 55 | 56 | Returns 57 | ------- 58 | list 59 | A list of :class:`numpy.ndarray` objects. 60 | 61 | Raises 62 | ------ 63 | StopIteration 64 | If the first JSON object received contains the key `stop`, 65 | signifying that the server has finished a single epoch. 66 | 67 | """ 68 | headers = socket.recv_json() 69 | if 'stop' in headers: 70 | raise StopIteration 71 | arrays = [] 72 | for header in headers: 73 | data = socket.recv(copy=False) 74 | buf = buffer_(data) 75 | array = numpy.frombuffer(buf, dtype=numpy.dtype(header['descr'])) 76 | array.shape = header['shape'] 77 | if header['fortran_order']: 78 | array.shape = header['shape'][::-1] 79 | array = array.transpose() 80 | arrays.append(array) 81 | return arrays 82 | 83 | 84 | def start_server(data_stream, port=5557, hwm=10): 85 | """Start a data processing server. 86 | 87 | This command starts a server in the current process that performs the 88 | actual data processing (by retrieving data from the given data stream). 89 | It also starts a second process, the broker, which mediates between the 90 | server and the client. The broker also keeps a buffer of batches in 91 | memory. 92 | 93 | Parameters 94 | ---------- 95 | data_stream : :class:`.DataStream` 96 | The data stream to return examples from. 97 | port : int, optional 98 | The port the server and the client (training loop) will use to 99 | communicate. Defaults to 5557. 100 | hwm : int, optional 101 | The `ZeroMQ high-water mark (HWM) 102 | `_ on the 103 | sending socket. Increasing this increases the buffer, which can be 104 | useful if your data preprocessing times are very random. However, 105 | it will increase memory usage. There is no easy way to tell how 106 | many batches will actually be queued with a particular HWM. 107 | Defaults to 10. Be sure to set the corresponding HWM on the 108 | receiving end as well. 109 | 110 | """ 111 | logging.basicConfig(level='INFO') 112 | 113 | context = zmq.Context() 114 | socket = context.socket(zmq.PUSH) 115 | socket.set_hwm(hwm) 116 | socket.bind('tcp://*:{}'.format(port)) 117 | 118 | it = data_stream.get_epoch_iterator() 119 | 120 | logger.info('server started') 121 | while True: 122 | try: 123 | data = next(it) 124 | stop = False 125 | logger.debug("sending {} arrays".format(len(data))) 126 | except StopIteration: 127 | it = data_stream.get_epoch_iterator() 128 | data = None 129 | stop = True 130 | logger.debug("sending StopIteration") 131 | send_arrays(socket, data, stop=stop) 132 | -------------------------------------------------------------------------------- /fuel/transformers/_image.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel cimport prange 3 | 4 | 5 | ctypedef long Py_intptr_t 6 | 7 | ctypedef fused image_dtype: 8 | float 9 | double 10 | unsigned char 11 | 12 | 13 | @cython.boundscheck(False) 14 | @cython.wraparound(False) 15 | cpdef window_batch_bchw(image_dtype[:, :, :, :] batch, 16 | long[:] height_offsets, long[:] width_offsets, 17 | image_dtype[:, :, :, :] out): 18 | """window_batch_bchw(batch, window_height, window_width, 19 | height_offsets, width_offsets, out) 20 | 21 | Perform windowing on a (batch, channels, height, width) image tensor. 22 | 23 | Parameters 24 | ---------- 25 | batch : memoryview, 4-dimensional 26 | A 4-d tensor containing a batch of images in the expected 27 | format above. 28 | height_offsets : memoryview, integer, 1-dimensional 29 | An array of offsets for the height dimension of each image. 30 | Assumed that batch.shape[0] <= height_offsets.shape[0]. 31 | width_offsets : memoryview, integer, 1-dimensional 32 | An array of offsets for the width dimension of each image. 33 | Assumed that batch.shape[0] <= width_offsets.shape[0]. 34 | out : memoryview 35 | The array to which to write output. It is assumed that 36 | `out.shape[2] + height_offsets[i] <= batch.shape[2]` and 37 | `out.shape[3] + width_offsets[i] <= batch.shape[3]`, for 38 | all values of `i`. 39 | 40 | Notes 41 | ----- 42 | Operates on a batch in parallel via OpenMP. Set `OMP_NUM_THREADS` 43 | to benefit from this parallelism. 44 | 45 | This is a low-level utility that, for the sake of speed, does 46 | not check its input for validity. Some amount of protection is 47 | provided by Cython memoryview objects. 48 | 49 | """ 50 | cdef Py_intptr_t index 51 | cdef Py_intptr_t window_width = out.shape[3] 52 | cdef Py_intptr_t window_height = out.shape[2] 53 | cdef Py_intptr_t h_off, w_off, h_extent, w_extent 54 | with nogil: 55 | for index in prange(batch.shape[0]): 56 | h_off = height_offsets[index] 57 | w_off = width_offsets[index] 58 | h_extent = h_off + window_height 59 | w_extent = w_off + window_width 60 | out[index] = batch[index, :, h_off:h_extent, w_off:w_extent] 61 | -------------------------------------------------------------------------------- /fuel/transformers/defaults.py: -------------------------------------------------------------------------------- 1 | """Commonly-used default transformers.""" 2 | from fuel.transformers import ScaleAndShift, Cast, SourcewiseTransformer 3 | from fuel.transformers.image import ImagesFromBytes 4 | 5 | 6 | def uint8_pixels_to_floatX(which_sources): 7 | return ( 8 | (ScaleAndShift, [1 / 255.0, 0], {'which_sources': which_sources}), 9 | (Cast, ['floatX'], {'which_sources': which_sources})) 10 | 11 | 12 | class ToBytes(SourcewiseTransformer): 13 | """Transform a stream of ndarray examples to bytes. 14 | 15 | Notes 16 | ----- 17 | Used for retrieving variable-length byte data stored as, e.g. a uint8 18 | ragged array. 19 | 20 | """ 21 | def __init__(self, stream, **kwargs): 22 | kwargs.setdefault('produces_examples', stream.produces_examples) 23 | axis_labels = (stream.axis_labels 24 | if stream.axis_labels is not None 25 | else {}) 26 | for source in kwargs.get('which_sources', stream.sources): 27 | axis_labels[source] = (('batch', 'bytes') 28 | if 'batch' in axis_labels.get(source, ()) 29 | else ('bytes',)) 30 | kwargs.setdefault('axis_labels', axis_labels) 31 | super(ToBytes, self).__init__(stream, **kwargs) 32 | 33 | def transform_source_example(self, example, _): 34 | return example.tostring() 35 | 36 | def transform_source_batch(self, batch, _): 37 | return [example.tostring() for example in batch] 38 | 39 | 40 | def rgb_images_from_encoded_bytes(which_sources): 41 | return ((ToBytes, [], {'which_sources': which_sources}), 42 | (ImagesFromBytes, [], {'which_sources': which_sources})) 43 | -------------------------------------------------------------------------------- /fuel/transformers/sequences.py: -------------------------------------------------------------------------------- 1 | from fuel.transformers import Transformer 2 | 3 | 4 | class Window(Transformer): 5 | """Return pairs of source and target windows from a stream. 6 | 7 | This data stream wrapper takes as an input a data stream outputting 8 | sequences of potentially varying lengths (e.g. sentences, audio tracks, 9 | etc.). It then returns two sliding windows (source and target) over 10 | these sequences. 11 | 12 | For example, to train an n-gram model set `source_window` to n, 13 | `target_window` to 1, no offset, and `overlapping` to false. This will 14 | give chunks [1, N] and [N + 1]. To train an RNN you often want to set 15 | the source and target window to the same size and use an offset of 1 16 | with overlap, this would give you chunks [1, N] and [2, N + 1]. 17 | 18 | Parameters 19 | ---------- 20 | offset : int 21 | The offset from the source window where the target window starts. 22 | source_window : int 23 | The size of the source window. 24 | target_window : int 25 | The size of the target window. 26 | overlapping : bool 27 | If true, the source and target windows overlap i.e. the offset of 28 | the target window is taken to be from the beginning of the source 29 | window. If false, the target window offset is taken to be from the 30 | end of the source window. 31 | data_stream : :class:`.DataStream` instance 32 | The data stream providing sequences. Each example is assumed to be 33 | an object that supports slicing. 34 | target_source : str, optional 35 | This data stream adds a new source for the target words. By default 36 | this source is 'targets'. 37 | 38 | """ 39 | def __init__(self, offset, source_window, target_window, 40 | overlapping, data_stream, target_source='targets', **kwargs): 41 | if not data_stream.produces_examples: 42 | raise ValueError('the wrapped data stream must produce examples, ' 43 | 'not batches of examples.') 44 | if len(data_stream.sources) > 1: 45 | raise ValueError('{} expects only one source' 46 | .format(self.__class__.__name__)) 47 | 48 | super(Window, self).__init__(data_stream, produces_examples=True, 49 | **kwargs) 50 | self.sources = self.sources + (target_source,) 51 | 52 | self.offset = offset 53 | self.source_window = source_window 54 | self.target_window = target_window 55 | self.overlapping = overlapping 56 | 57 | self.sentence = [] 58 | self._set_index() 59 | 60 | def _set_index(self): 61 | """Set the starting index of the source window.""" 62 | self.index = 0 63 | # If offset is negative, target window might start before 0 64 | self.index = -min(0, self._get_target_index()) 65 | 66 | def _get_target_index(self): 67 | """Return the index where the target window starts.""" 68 | return (self.index + self.source_window * (not self.overlapping) + 69 | self.offset) 70 | 71 | def _get_end_index(self): 72 | """Return the end of both windows.""" 73 | return max(self.index + self.source_window, 74 | self._get_target_index() + self.target_window) 75 | 76 | def get_data(self, request=None): 77 | if request is not None: 78 | raise ValueError 79 | while not self._get_end_index() <= len(self.sentence): 80 | self.sentence, = next(self.child_epoch_iterator) 81 | self._set_index() 82 | source = self.sentence[self.index:self.index + self.source_window] 83 | target = self.sentence[self._get_target_index(): 84 | self._get_target_index() + self.target_window] 85 | self.index += 1 86 | return (source, target) 87 | 88 | 89 | class NGrams(Window): 90 | """Return n-grams from a stream. 91 | 92 | This data stream wrapper takes as an input a data stream outputting 93 | sentences. From these sentences n-grams of a fixed order (e.g. bigrams, 94 | trigrams, etc.) are extracted and returned. It also creates a 95 | ``targets`` data source. For each example, the target is the word 96 | immediately following that n-gram. It is normally used for language 97 | modeling, where we try to predict the next word from the previous *n* 98 | words. 99 | 100 | .. note:: 101 | 102 | Unlike the :class:`Window` stream, the target returned by 103 | :class:`NGrams` is a single element instead of a window. 104 | 105 | Parameters 106 | ---------- 107 | ngram_order : int 108 | The order of the n-grams to output e.g. 3 for trigrams. 109 | data_stream : :class:`.DataStream` instance 110 | The data stream providing sentences. Each example is assumed to be 111 | a list of integers. 112 | target_source : str, optional 113 | This data stream adds a new source for the target words. By default 114 | this source is 'targets'. 115 | 116 | """ 117 | def __init__(self, ngram_order, *args, **kwargs): 118 | super(NGrams, self).__init__( 119 | 0, ngram_order, 1, False, *args, **kwargs) 120 | 121 | def get_data(self, *args, **kwargs): 122 | source, target = super(NGrams, self).get_data(*args, **kwargs) 123 | return (source, target[0]) 124 | -------------------------------------------------------------------------------- /fuel/utils/disk.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Some of code below is taken from 4 | # [pylearn2](https://github.com/lisa-lab/pylearn2) framework developed under 5 | # the copyright: 6 | # 7 | # Copyright (c) 2011--2014, Université de Montréal 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # 1. Redistributions of source code must retain the above copyright notice, 14 | # this list of conditions and the following disclaimer. 15 | # 16 | # 2. Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # 3. Neither the name of the copyright holder nor the names of its contributors 21 | # may be used to endorse or promote products derived from this software 22 | # without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 27 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 28 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 29 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 30 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 31 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 32 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 33 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 34 | # POSSIBILITY OF SUCH DAMAGE. 35 | """Filesystem utility code.""" 36 | import os 37 | 38 | 39 | def disk_usage(path): 40 | """Return free usage about the given path, in bytes. 41 | 42 | Parameters 43 | ---------- 44 | path : str 45 | Folder for which to return disk usage 46 | 47 | Returns 48 | ------- 49 | output : tuple 50 | Tuple containing total space in the folder and currently 51 | used space in the folder 52 | 53 | """ 54 | st = os.statvfs(path) 55 | total = st.f_blocks * st.f_frsize 56 | used = (st.f_blocks - st.f_bfree) * st.f_frsize 57 | return total, used 58 | 59 | 60 | def safe_mkdir(folder_name, force_perm=None): 61 | """Create the specified folder. 62 | 63 | If the parent folders do not exist, they are also created. 64 | If the folder already exists, nothing is done. 65 | 66 | Parameters 67 | ---------- 68 | folder_name : str 69 | Name of the folder to create. 70 | force_perm : str 71 | Mode to use for folder creation. 72 | 73 | """ 74 | if os.path.exists(folder_name): 75 | return 76 | intermediary_folders = folder_name.split(os.path.sep) 77 | 78 | # Remove invalid elements from intermediary_folders 79 | if intermediary_folders[-1] == "": 80 | intermediary_folders = intermediary_folders[:-1] 81 | if force_perm: 82 | force_perm_path = folder_name.split(os.path.sep) 83 | if force_perm_path[-1] == "": 84 | force_perm_path = force_perm_path[:-1] 85 | 86 | for i in range(1, len(intermediary_folders)): 87 | folder_to_create = os.path.sep.join(intermediary_folders[:i + 1]) 88 | 89 | if os.path.exists(folder_to_create): 90 | continue 91 | os.mkdir(folder_to_create) 92 | if force_perm: 93 | os.chmod(folder_to_create, force_perm) 94 | 95 | 96 | def check_enough_space(dataset_local_dir, remote_fname, local_fname, 97 | max_disk_usage=0.9): 98 | """Check if the given local folder has enough space. 99 | 100 | Check if the given local folder has enough space to store 101 | the specified remote file. 102 | 103 | Parameters 104 | ---------- 105 | remote_fname : str 106 | Path to the remote file 107 | remote_fname : str 108 | Path to the local folder 109 | max_disk_usage : float 110 | Fraction indicating how much of the total space in the 111 | local folder can be used before the local cache must stop 112 | adding to it. 113 | 114 | Returns 115 | ------- 116 | output : boolean 117 | True if there is enough space to store the remote file. 118 | 119 | """ 120 | storage_need = os.path.getsize(remote_fname) 121 | storage_total, storage_used = disk_usage(dataset_local_dir) 122 | 123 | # Instead of only looking if there's enough space, we ensure we do not 124 | # go over max disk usage level to avoid filling the disk/partition 125 | return ((storage_used + storage_need) < 126 | (storage_total * max_disk_usage)) 127 | -------------------------------------------------------------------------------- /fuel/utils/formats.py: -------------------------------------------------------------------------------- 1 | """Low-level utilities for reading a variety of source formats.""" 2 | import codecs 3 | import gzip 4 | import io 5 | import tarfile 6 | import six 7 | 8 | 9 | def open_(filename, mode='r', encoding=None): 10 | """Open a text file with encoding and optional gzip compression. 11 | 12 | Note that on legacy Python any encoding other than ``None`` or opening 13 | GZipped files will return an unpicklable file-like object. 14 | 15 | Parameters 16 | ---------- 17 | filename : str 18 | The filename to read. 19 | mode : str, optional 20 | The mode with which to open the file. Defaults to `r`. 21 | encoding : str, optional 22 | The encoding to use (see the codecs documentation_ for supported 23 | values). Defaults to ``None``. 24 | 25 | .. _documentation: 26 | https://docs.python.org/3/library/codecs.html#standard-encodings 27 | 28 | """ 29 | if filename.endswith('.gz'): 30 | if six.PY2: 31 | zf = io.BufferedReader(gzip.open(filename, mode)) 32 | if encoding: 33 | return codecs.getreader(encoding)(zf) 34 | else: 35 | return zf 36 | else: 37 | return io.BufferedReader(gzip.open(filename, mode, 38 | encoding=encoding)) 39 | if six.PY2: 40 | if encoding: 41 | return codecs.open(filename, mode, encoding=encoding) 42 | else: 43 | return open(filename, mode) 44 | else: 45 | return open(filename, mode, encoding=encoding) 46 | 47 | 48 | def tar_open(f): 49 | """Open either a filename or a file-like object as a TarFile. 50 | 51 | Parameters 52 | ---------- 53 | f : str or file-like object 54 | The filename or file-like object from which to read. 55 | 56 | Returns 57 | ------- 58 | TarFile 59 | A `TarFile` instance. 60 | 61 | """ 62 | if isinstance(f, six.string_types): 63 | return tarfile.open(name=f) 64 | else: 65 | return tarfile.open(fileobj=f) 66 | -------------------------------------------------------------------------------- /fuel/utils/parallel.py: -------------------------------------------------------------------------------- 1 | """Utilities for speeding things up through parallelism. 2 | 3 | Currently including: 4 | 5 | * A very simple PUSH-PULL reusable producer-consumer pattern 6 | using a ZeroMQ socket instead of the (slow, unnecessarily 7 | copying) multiprocessing.Queue. See :func:`producer_consumer`. 8 | 9 | """ 10 | from multiprocessing import Process 11 | import zmq 12 | 13 | 14 | def _producer_wrapper(f, port, addr='tcp://127.0.0.1'): 15 | """A shim that sets up a socket and starts the producer callable. 16 | 17 | Parameters 18 | ---------- 19 | f : callable 20 | Callable that takes a single argument, a handle 21 | for a ZeroMQ PUSH socket. Must be picklable. 22 | port : int 23 | The port on which the socket should connect. 24 | addr : str, optional 25 | Address to which the socket should connect. Defaults 26 | to localhost ('tcp://127.0.0.1'). 27 | 28 | """ 29 | try: 30 | context = zmq.Context() 31 | socket = context.socket(zmq.PUSH) 32 | socket.connect(':'.join([addr, str(port)])) 33 | f(socket) 34 | finally: 35 | # Works around a Python 3.x bug. 36 | context.destroy() 37 | 38 | 39 | def _spawn_producer(f, port, addr='tcp://127.0.0.1'): 40 | """Start a process that sends results on a PUSH socket. 41 | 42 | Parameters 43 | ---------- 44 | f : callable 45 | Callable that takes a single argument, a handle 46 | for a ZeroMQ PUSH socket. Must be picklable. 47 | 48 | Returns 49 | ------- 50 | process : multiprocessing.Process 51 | The process handle of the created producer process. 52 | 53 | """ 54 | process = Process(target=_producer_wrapper, args=(f, port, addr)) 55 | process.start() 56 | return process 57 | 58 | 59 | def producer_consumer(producer, consumer, addr='tcp://127.0.0.1', 60 | port=None, context=None): 61 | """A producer-consumer pattern. 62 | 63 | Parameters 64 | ---------- 65 | producer : callable 66 | Callable that takes a single argument, a handle 67 | for a ZeroMQ PUSH socket. Must be picklable. 68 | consumer : callable 69 | Callable that takes a single argument, a handle 70 | for a ZeroMQ PULL socket. 71 | addr : str, optional 72 | Address to which the socket should connect. Defaults 73 | to localhost ('tcp://127.0.0.1'). 74 | port : int, optional 75 | The port on which the consumer should listen. 76 | context : zmq.Context, optional 77 | The ZeroMQ Context to use. One will be created otherwise. 78 | 79 | Returns 80 | ------- 81 | result 82 | Passes along whatever `consumer` returns. 83 | 84 | Notes 85 | ----- 86 | This sets up a PULL socket in the calling process and forks 87 | a process that calls `producer` on a PUSH socket. When the 88 | consumer returns, the producer process is terminated. 89 | 90 | Wrap `consumer` or `producer` in a `functools.partial` object 91 | in order to send additional arguments; the callables passed in 92 | should expect only one required, positional argument, the socket 93 | handle. 94 | 95 | """ 96 | context_created = False 97 | if context is None: 98 | context_created = True 99 | context = zmq.Context() 100 | try: 101 | consumer_socket = context.socket(zmq.PULL) 102 | if port is None: 103 | port = consumer_socket.bind_to_random_port(addr) 104 | try: 105 | process = _spawn_producer(producer, port) 106 | result = consumer(consumer_socket) 107 | finally: 108 | process.terminate() 109 | return result 110 | finally: 111 | # Works around a Python 3.x bug. 112 | if context_created: 113 | context.destroy() 114 | -------------------------------------------------------------------------------- /fuel/version.py: -------------------------------------------------------------------------------- 1 | version = '0.2.0' 2 | -------------------------------------------------------------------------------- /req-rtd.txt: -------------------------------------------------------------------------------- 1 | picklable-itertools==0.1.1 2 | progressbar2==3.6.2 3 | pyyaml==3.11 4 | requests==2.20.0 5 | six==1.10.0 6 | -------------------------------------------------------------------------------- /req-travis-conda.txt: -------------------------------------------------------------------------------- 1 | coverage==4.0.3 2 | h5py==2.8.0 3 | mock==1.3.0 4 | nose==1.3.7 5 | numpy==1.10.4 6 | pillow==5.2.0 7 | pytables==3.4.4 8 | pyyaml==3.13 9 | pyzmq==15.2.0 10 | scipy==0.17.0 11 | six==1.10.0 12 | progressbar2==3.10.0 13 | -------------------------------------------------------------------------------- /req-travis-pip.txt: -------------------------------------------------------------------------------- 1 | nose2[coverage_plugin]==0.8.0 2 | coveralls==1.1 3 | picklable-itertools==0.1.1 4 | requests==2.20.0 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.8.0 2 | numpy==1.10.4 3 | tables==3.4.4 4 | picklable-itertools==0.1.1 5 | pillow==5.2.0 6 | progressbar2==3.6.2 7 | pyyaml==3.13 8 | pyzmq==15.2.0 9 | requests==2.20.0 10 | scipy==0.17.0 11 | six==1.10.0 12 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.rst 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Installation script.""" 2 | from os import path 3 | import sys 4 | from io import open 5 | from setuptools import find_packages, setup 6 | from distutils.extension import Extension 7 | 8 | HERE = path.abspath(path.dirname(__file__)) 9 | 10 | with open(path.join(HERE, 'README.rst'), encoding='utf-8') as f: 11 | LONG_DESCRIPTION = f.read().strip() 12 | 13 | # Visual C++ apparently doesn't respect/know what to do with this flag. 14 | # Windows users may thus see unused function warnings. Oh well. 15 | if sys.platform != 'win32': 16 | extra_compile_args = ['-Wno-unused-function'] 17 | else: 18 | extra_compile_args = [] 19 | 20 | exec_results = {} 21 | with open(path.join(path.dirname(__file__), 'fuel/version.py')) as file_: 22 | exec(file_.read(), exec_results) 23 | version = exec_results['version'] 24 | 25 | setup( 26 | name='fuel', 27 | version=version, # PEP 440 compliant 28 | description='Data pipeline framework for machine learning', 29 | long_description=LONG_DESCRIPTION, 30 | url='https://github.com/mila-udem/fuel.git', 31 | download_url='https://github.com/mila-udem/fuel/tarball/v' + version, 32 | author='Universite de Montreal', 33 | license='MIT', 34 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 35 | classifiers=[ 36 | 'Development Status :: 3 - Alpha', 37 | 'Intended Audience :: Developers', 38 | 'Topic :: Utilities', 39 | 'Topic :: Scientific/Engineering', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Programming Language :: Python :: 2', 42 | 'Programming Language :: Python :: 2.7', 43 | 'Programming Language :: Python :: 3', 44 | 'Programming Language :: Python :: 3.4', 45 | ], 46 | keywords='dataset data iteration pipeline processing', 47 | packages=find_packages(exclude=['tests']), 48 | install_requires=['numpy', 'six', 'picklable_itertools', 'pyyaml', 49 | 'h5py', 'tables', 50 | 'progressbar2', 'pyzmq', 'scipy', 'pillow>=3.3.2', 51 | 'requests'], 52 | extras_require={ 53 | 'test': ['mock', 'nose', 'nose2'], 54 | 'docs': ['sphinx', 'sphinx-rtd-theme'] 55 | }, 56 | entry_points={ 57 | 'console_scripts': ['fuel-convert = fuel.bin.fuel_convert:main', 58 | 'fuel-download = fuel.bin.fuel_download:main', 59 | 'fuel-info = fuel.bin.fuel_info:main'] 60 | }, 61 | ext_modules=[Extension("fuel.transformers._image", 62 | ["fuel/transformers/_image.c"], 63 | extra_compile_args=extra_compile_args)] 64 | ) 65 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from unittest.case import SkipTest 3 | 4 | from fuel.utils import find_in_data_path 5 | from fuel import config 6 | 7 | 8 | def skip_if_not_available(modules=None, datasets=None, configurations=None): 9 | """Raises a SkipTest exception when requirements are not met. 10 | 11 | Parameters 12 | ---------- 13 | modules : list 14 | A list of strings of module names. If one of the modules fails to 15 | import, the test will be skipped. 16 | datasets : list 17 | A list of strings of folder names. If the data path is not 18 | configured, or the folder does not exist, the test is skipped. 19 | configurations : list 20 | A list of of strings of configuration names. If this configuration 21 | is not set and does not have a default, the test will be skipped. 22 | 23 | """ 24 | if modules is None: 25 | modules = [] 26 | if datasets is None: 27 | datasets = [] 28 | if configurations is None: 29 | configurations = [] 30 | for module in modules: 31 | try: 32 | import_module(module) 33 | except Exception: 34 | raise SkipTest 35 | if datasets and not hasattr(config, 'data_path'): 36 | raise SkipTest 37 | for dataset in datasets: 38 | try: 39 | find_in_data_path(dataset) 40 | except IOError: 41 | raise SkipTest 42 | for configuration in configurations: 43 | if not hasattr(config, configuration): 44 | raise SkipTest 45 | -------------------------------------------------------------------------------- /tests/converters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mila-iqia/fuel/1d6292dc25e3a115544237e392e61bff6631d23c/tests/converters/__init__.py -------------------------------------------------------------------------------- /tests/converters/test_convert_ilsvrc2012.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | 3 | import numpy 4 | import six 5 | 6 | from .test_convert_ilsvrc2010 import MockH5PYFile 7 | # from fuel.server import recv_arrays, send_arrays 8 | from fuel.converters.ilsvrc2012 import (prepare_hdf5_file, 9 | prepare_metadata, 10 | read_devkit, 11 | read_metadata_mat_file, 12 | DEVKIT_META_PATH, 13 | DEVKIT_ARCHIVE, 14 | TEST_IMAGES_TAR) 15 | from fuel.utils import find_in_data_path 16 | from tests import skip_if_not_available 17 | 18 | 19 | def test_prepare_metadata(): 20 | skip_if_not_available(datasets=[DEVKIT_ARCHIVE, TEST_IMAGES_TAR]) 21 | devkit_path = find_in_data_path(DEVKIT_ARCHIVE) 22 | n_train, v_gt, n_test, wnid_map = prepare_metadata(devkit_path) 23 | assert n_train == 1281167 24 | assert len(v_gt) == 50000 25 | assert n_test == 100000 26 | assert sorted(wnid_map.values()) == list(range(1000)) 27 | assert all(isinstance(k, six.string_types) and len(k) == 9 28 | for k in wnid_map) 29 | 30 | 31 | def test_prepare_hdf5_file(): 32 | hdf5_file = MockH5PYFile() 33 | prepare_hdf5_file(hdf5_file, 10, 5, 2) 34 | 35 | def get_start_stop(hdf5_file, split): 36 | rows = [r for r in hdf5_file.attrs['split'] if 37 | (r['split'].decode('utf8') == split)] 38 | return dict([(r['source'].decode('utf8'), (r['start'], r['stop'])) 39 | for r in rows if r['stop'] - r['start'] > 0]) 40 | 41 | # Verify properties of the train split. 42 | train_splits = get_start_stop(hdf5_file, 'train') 43 | assert all(v == (0, 10) for v in train_splits.values()) 44 | assert set(train_splits.keys()) == set([u'encoded_images', u'targets', 45 | u'filenames']) 46 | 47 | # Verify properties of the valid split. 48 | valid_splits = get_start_stop(hdf5_file, 'valid') 49 | assert all(v == (10, 15) for v in valid_splits.values()) 50 | assert set(valid_splits.keys()) == set([u'encoded_images', u'targets', 51 | u'filenames']) 52 | 53 | # Verify properties of the test split. 54 | test_splits = get_start_stop(hdf5_file, 'test') 55 | assert all(v == (15, 17) for v in test_splits.values()) 56 | assert set(test_splits.keys()) == set([u'encoded_images', u'filenames']) 57 | 58 | from numpy import dtype 59 | 60 | # Verify properties of the encoded_images HDF5 dataset. 61 | assert hdf5_file['encoded_images'].shape[0] == 17 62 | assert len(hdf5_file['encoded_images'].shape) == 1 63 | assert hdf5_file['encoded_images'].dtype.kind == 'O' 64 | assert hdf5_file['encoded_images'].dtype.metadata['vlen'] == dtype('uint8') 65 | 66 | # Verify properties of the filenames dataset. 67 | assert hdf5_file['filenames'].shape[0] == 17 68 | assert len(hdf5_file['filenames'].shape) == 2 69 | assert hdf5_file['filenames'].dtype == dtype('S32') 70 | 71 | # Verify properties of the targets dataset. 72 | assert hdf5_file['targets'].shape[0] == 15 73 | assert hdf5_file['targets'].shape[1] == 1 74 | assert len(hdf5_file['targets'].shape) == 2 75 | assert hdf5_file['targets'].dtype == dtype('int16') 76 | 77 | 78 | def test_read_devkit(): 79 | skip_if_not_available(datasets=[DEVKIT_ARCHIVE]) 80 | synsets, raw_valid_gt = read_devkit(find_in_data_path(DEVKIT_ARCHIVE)) 81 | # synset sanity tests appear in test_read_metadata_mat_file 82 | assert raw_valid_gt.min() == 1 83 | assert raw_valid_gt.max() == 1000 84 | assert raw_valid_gt.dtype.kind == 'i' 85 | assert raw_valid_gt.shape == (50000,) 86 | 87 | 88 | def test_read_metadata_mat_file(): 89 | skip_if_not_available(datasets=[DEVKIT_ARCHIVE]) 90 | with tarfile.open(find_in_data_path(DEVKIT_ARCHIVE)) as tar: 91 | meta_mat = tar.extractfile(DEVKIT_META_PATH) 92 | synsets = read_metadata_mat_file(meta_mat) 93 | assert (synsets['ILSVRC2012_ID'] == 94 | numpy.arange(1, len(synsets) + 1)).all() 95 | assert synsets['num_train_images'][1000:].sum() == 0 96 | assert (synsets['num_train_images'][:1000] > 0).all() 97 | assert synsets.ndim == 1 98 | assert synsets['wordnet_height'].min() == 0 99 | assert synsets['wordnet_height'].max() == 19 100 | assert synsets['WNID'].dtype == numpy.dtype('S9') 101 | assert (synsets['num_children'][:1000] == 0).all() 102 | assert (synsets['children'][:1000] == -1).all() 103 | -------------------------------------------------------------------------------- /tests/test_adult.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from numpy.testing import assert_raises, assert_equal, assert_allclose 4 | 5 | from fuel.datasets import Adult 6 | from tests import skip_if_not_available 7 | 8 | 9 | def test_adult_test(): 10 | skip_if_not_available(datasets=['adult.hdf5']) 11 | 12 | dataset = Adult(('test',), load_in_memory=False) 13 | handle = dataset.open() 14 | data, labels = dataset.get_data(handle, slice(0, 10)) 15 | 16 | assert data.shape == (10, 104) 17 | assert labels.shape == (10, 1) 18 | known = numpy.array( 19 | [25., 38., 28., 44., 34., 63., 24., 55., 65., 36.]) 20 | assert_allclose(data[:, 0], known) 21 | assert dataset.num_examples == 15060 22 | dataset.close(handle) 23 | 24 | dataset = Adult(('train',), load_in_memory=False) 25 | handle = dataset.open() 26 | data, labels = dataset.get_data(handle, slice(0, 10)) 27 | 28 | assert data.shape == (10, 104) 29 | assert labels.shape == (10, 1) 30 | known = numpy.array( 31 | [39., 50., 38., 53., 28., 37., 49., 52., 31., 42.]) 32 | assert_allclose(data[:, 0], known) 33 | assert dataset.num_examples == 30162 34 | dataset.close(handle) 35 | 36 | 37 | def test_adult_axes(): 38 | skip_if_not_available(datasets=['adult.hdf5']) 39 | 40 | dataset = Adult(('test',), load_in_memory=False) 41 | assert_equal(dataset.axis_labels['features'], 42 | ('batch', 'feature')) 43 | 44 | dataset = Adult(('train',), load_in_memory=False) 45 | assert_equal(dataset.axis_labels['features'], 46 | ('batch', 'feature')) 47 | 48 | 49 | def test_adult_invalid_split(): 50 | skip_if_not_available(datasets=['adult.hdf5']) 51 | 52 | assert_raises(ValueError, Adult, ('dummy',)) 53 | -------------------------------------------------------------------------------- /tests/test_billion.py: -------------------------------------------------------------------------------- 1 | from numpy.testing import assert_raises 2 | 3 | from fuel.datasets.billion import OneBillionWord 4 | 5 | 6 | class TestOneBillionWord(object): 7 | def setUp(self): 8 | all_chars = ([chr(ord('a') + i) for i in range(26)] + 9 | [chr(ord('0') + i) for i in range(10)] + 10 | [',', '.', '!', '?', ''] + 11 | [' ', '', '']) 12 | code2char = dict(enumerate(all_chars)) 13 | self.char2code = {v: k for k, v in code2char.items()} 14 | 15 | def test_value_error_wrong_set(self): 16 | assert_raises( 17 | ValueError, OneBillionWord, 'dummy', [0, 1], self.char2code) 18 | 19 | def test_value_error_training_partition(self): 20 | assert_raises( 21 | ValueError, OneBillionWord, 'training', [101], self.char2code) 22 | 23 | def test_value_error_heldout_partition(self): 24 | assert_raises( 25 | ValueError, OneBillionWord, 'heldout', [101], self.char2code) 26 | -------------------------------------------------------------------------------- /tests/test_binarized_mnist.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | from numpy.testing import assert_raises, assert_equal 4 | 5 | from fuel.datasets import BinarizedMNIST 6 | from tests import skip_if_not_available 7 | 8 | 9 | def test_binarized_mnist_train(): 10 | skip_if_not_available(datasets=['binarized_mnist.hdf5']) 11 | 12 | dataset = BinarizedMNIST(('train',), load_in_memory=False) 13 | handle = dataset.open() 14 | data, = dataset.get_data(handle, slice(0, 10)) 15 | assert data.dtype == 'uint8' 16 | assert data.shape == (10, 1, 28, 28) 17 | assert hashlib.md5(data).hexdigest() == '0922fefc9a9d097e3b086b89107fafce' 18 | assert dataset.num_examples == 50000 19 | dataset.close(handle) 20 | 21 | 22 | def test_binarized_mnist_valid(): 23 | skip_if_not_available(datasets=['binarized_mnist.hdf5']) 24 | 25 | dataset = BinarizedMNIST(('valid',), load_in_memory=False) 26 | handle = dataset.open() 27 | data, = dataset.get_data(handle, slice(0, 10)) 28 | assert data.dtype == 'uint8' 29 | assert data.shape == (10, 1, 28, 28) 30 | assert hashlib.md5(data).hexdigest() == '65e8099613162b3110a7618037011617' 31 | assert dataset.num_examples == 10000 32 | dataset.close(handle) 33 | 34 | 35 | def test_binarized_mnist_test(): 36 | skip_if_not_available(datasets=['binarized_mnist.hdf5']) 37 | 38 | dataset = BinarizedMNIST(('test',), load_in_memory=False) 39 | handle = dataset.open() 40 | data, = dataset.get_data(handle, slice(0, 10)) 41 | assert data.dtype == 'uint8' 42 | assert data.shape == (10, 1, 28, 28) 43 | assert hashlib.md5(data).hexdigest() == '0fa539ed8cb008880a61be77f744f06a' 44 | assert dataset.num_examples == 10000 45 | dataset.close(handle) 46 | 47 | 48 | def test_binarized_mnist_axes(): 49 | skip_if_not_available(datasets=['binarized_mnist.hdf5']) 50 | 51 | dataset = BinarizedMNIST(('train',), load_in_memory=False) 52 | assert_equal(dataset.axis_labels['features'], 53 | ('batch', 'channel', 'height', 'width')) 54 | 55 | 56 | def test_binarized_mnist_invalid_split(): 57 | assert_raises(ValueError, BinarizedMNIST, ('dummy',)) 58 | -------------------------------------------------------------------------------- /tests/test_caltech101_silhouettes.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from numpy.testing import assert_raises 3 | 4 | from fuel.datasets import CalTech101Silhouettes 5 | from tests import skip_if_not_available 6 | 7 | 8 | def test_caltech101_silhouettes16(): 9 | skip_if_not_available(datasets=['caltech101_silhouettes16.hdf5']) 10 | for which_set, size, num_examples in ( 11 | ('train', 16, 4082), ('valid', 16, 2257), ('test', 16, 2302)): 12 | ds = CalTech101Silhouettes(which_sets=[which_set], size=size, 13 | load_in_memory=False) 14 | 15 | assert ds.num_examples == num_examples 16 | 17 | handle = ds.open() 18 | features, targets = ds.get_data(handle, slice(0, 10)) 19 | 20 | assert features.shape == (10, 1, size, size) 21 | assert targets.shape == (10, 1) 22 | 23 | assert features.dtype == numpy.uint8 24 | assert targets.dtype == numpy.uint8 25 | 26 | 27 | def test_caltech101_silhouettes_unkn_size(): 28 | assert_raises(ValueError, CalTech101Silhouettes, 29 | which_sets=['test'], size=10) 30 | 31 | 32 | def test_caltech101_silhouettes28(): 33 | skip_if_not_available(datasets=['caltech101_silhouettes28.hdf5']) 34 | for which_set, size, num_examples in ( 35 | ('train', 28, 4100), ('valid', 28, 2264), ('test', 28, 2307)): 36 | ds = CalTech101Silhouettes(which_sets=[which_set], size=size, 37 | load_in_memory=False) 38 | 39 | assert ds.num_examples == num_examples 40 | 41 | handle = ds.open() 42 | features, targets = ds.get_data(handle, slice(0, 10)) 43 | 44 | assert features.shape == (10, 1, size, size) 45 | assert targets.shape == (10, 1) 46 | 47 | assert features.dtype == numpy.uint8 48 | assert targets.dtype == numpy.uint8 49 | -------------------------------------------------------------------------------- /tests/test_celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy 5 | from numpy.testing import assert_equal 6 | 7 | from fuel import config 8 | from fuel.datasets import H5PYDataset, CelebA 9 | 10 | 11 | def test_celeba(): 12 | data_path = config.data_path 13 | try: 14 | config.data_path = '.' 15 | f = h5py.File('celeba_64.hdf5', 'w') 16 | f['features'] = numpy.arange( 17 | 10 * 3 * 64 * 64, dtype='uint8').reshape((10, 3, 64, 64)) 18 | f['targets'] = numpy.arange( 19 | 10 * 40, dtype='uint8').reshape((10, 40)) 20 | split_dict = {'train': {'features': (0, 6), 'targets': (0, 6)}, 21 | 'valid': {'features': (6, 8), 'targets': (6, 8)}, 22 | 'test': {'features': (8, 10), 'targets': (8, 10)}} 23 | f.attrs['split'] = H5PYDataset.create_split_array(split_dict) 24 | f.close() 25 | dataset = CelebA(which_format='64', which_sets=('train',)) 26 | assert_equal(dataset.filename, 'celeba_64.hdf5') 27 | finally: 28 | config.data_path = data_path 29 | os.remove('celeba_64.hdf5') 30 | -------------------------------------------------------------------------------- /tests/test_cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from numpy.testing import assert_raises 3 | 4 | from fuel import config 5 | from fuel.datasets import CIFAR10 6 | from fuel.streams import DataStream 7 | from fuel.schemes import SequentialScheme 8 | 9 | 10 | def test_cifar10(): 11 | train = CIFAR10(('train',), load_in_memory=False) 12 | assert train.num_examples == 50000 13 | handle = train.open() 14 | features, targets = train.get_data(handle, slice(49990, 50000)) 15 | assert features.shape == (10, 3, 32, 32) 16 | assert targets.shape == (10, 1) 17 | train.close(handle) 18 | 19 | test = CIFAR10(('test',), load_in_memory=False) 20 | handle = test.open() 21 | features, targets = test.get_data(handle, slice(0, 10)) 22 | assert features.shape == (10, 3, 32, 32) 23 | assert targets.shape == (10, 1) 24 | assert features.dtype == numpy.uint8 25 | assert targets.dtype == numpy.uint8 26 | test.close(handle) 27 | 28 | stream = DataStream.default_stream( 29 | test, iteration_scheme=SequentialScheme(10, 10)) 30 | data = next(stream.get_epoch_iterator())[0] 31 | assert data.min() >= 0.0 and data.max() <= 1.0 32 | assert data.dtype == config.floatX 33 | 34 | assert_raises(ValueError, CIFAR10, ('valid',)) 35 | 36 | assert_raises(ValueError, CIFAR10, 37 | ('train',), subset=slice(50000, 60000)) 38 | -------------------------------------------------------------------------------- /tests/test_cifar100.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from numpy.testing import assert_raises 3 | 4 | from fuel import config 5 | from fuel.datasets import CIFAR100 6 | from fuel.streams import DataStream 7 | from fuel.schemes import SequentialScheme 8 | 9 | 10 | def test_cifar100(): 11 | train = CIFAR100(('train',), load_in_memory=False) 12 | assert train.num_examples == 50000 13 | handle = train.open() 14 | coarse_labels, features, fine_labels = train.get_data(handle, 15 | slice(49990, 50000)) 16 | 17 | assert features.shape == (10, 3, 32, 32) 18 | assert coarse_labels.shape == (10, 1) 19 | assert fine_labels.shape == (10, 1) 20 | train.close(handle) 21 | 22 | test = CIFAR100(('test',), load_in_memory=False) 23 | handle = test.open() 24 | coarse_labels, features, fine_labels = test.get_data(handle, 25 | slice(0, 10)) 26 | 27 | assert features.shape == (10, 3, 32, 32) 28 | assert coarse_labels.shape == (10, 1) 29 | assert fine_labels.shape == (10, 1) 30 | 31 | assert features.dtype == numpy.uint8 32 | assert coarse_labels.dtype == numpy.uint8 33 | assert fine_labels.dtype == numpy.uint8 34 | 35 | test.close(handle) 36 | 37 | stream = DataStream.default_stream( 38 | test, iteration_scheme=SequentialScheme(10, 10)) 39 | data = next(stream.get_epoch_iterator())[1] 40 | 41 | assert data.min() >= 0.0 and data.max() <= 1.0 42 | assert data.dtype == config.floatX 43 | 44 | assert_raises(ValueError, CIFAR100, ('valid',)) 45 | -------------------------------------------------------------------------------- /tests/test_config_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from numpy.testing import assert_equal, assert_raises 5 | 6 | from fuel.config_parser import (Configuration, ConfigurationError, 7 | extra_downloader_converter) 8 | 9 | 10 | class TestExtraDownloaderConverter(object): 11 | def test_iterable(self): 12 | assert_equal(extra_downloader_converter(['a.b.c', 'd.e.f']), 13 | ['a.b.c', 'd.e.f']) 14 | 15 | def test_str(self): 16 | assert_equal(extra_downloader_converter("a.b.c d.e.f"), 17 | ['a.b.c', 'd.e.f']) 18 | 19 | def test_str_one_element(self): 20 | assert_equal(extra_downloader_converter("a.b.c"), ['a.b.c']) 21 | 22 | 23 | def test_config_parser(): 24 | _environ = dict(os.environ) 25 | try: 26 | 27 | with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: 28 | f.write('data_path: yaml_path') 29 | filename = f.name 30 | os.environ['FUEL_CONFIG'] = filename 31 | if 'FUEL_DATA_PATH' in os.environ: 32 | del os.environ['FUEL_DATA_PATH'] 33 | config = Configuration() 34 | config.add_config('data_path', str, env_var='FUEL_DATA_PATH') 35 | config.add_config('config_with_default', int, default='1', 36 | env_var='FUEL_CONFIG_TEST') 37 | config.add_config('config_without_default', str) 38 | config.load_yaml() 39 | assert config.data_path == 'yaml_path' 40 | os.environ['FUEL_DATA_PATH'] = 'env_path' 41 | assert config.data_path == 'env_path' 42 | assert config.config_with_default == 1 43 | os.environ['FUEL_CONFIG_TEST'] = '2' 44 | assert config.config_with_default == 2 45 | assert_raises(AttributeError, getattr, config, 46 | 'non_existing_config') 47 | assert_raises(ConfigurationError, getattr, config, 48 | 'config_without_default') 49 | config.data_path = 'manual_path' 50 | assert config.data_path == 'manual_path' 51 | config.new_config = 'new_config' 52 | assert config.new_config == 'new_config' 53 | finally: 54 | os.environ.clear() 55 | os.environ.update(_environ) 56 | -------------------------------------------------------------------------------- /tests/test_dogs_vs_cats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import zipfile 5 | 6 | import h5py 7 | import numpy 8 | import six 9 | from PIL import Image 10 | from numpy.testing import assert_raises 11 | 12 | from fuel import config 13 | from fuel.converters.dogs_vs_cats import convert_dogs_vs_cats 14 | from fuel.datasets.dogs_vs_cats import DogsVsCats 15 | from fuel.streams import DataStream 16 | from fuel.schemes import SequentialScheme 17 | 18 | 19 | def setup(): 20 | config._old_data_path = config.data_path 21 | config.data_path = tempfile.mkdtemp() 22 | _make_dummy_data(config.data_path[0]) 23 | 24 | 25 | def _make_dummy_data(output_directory): 26 | data = six.BytesIO() 27 | Image.new('RGB', (1, 1)).save(data, 'JPEG') 28 | image = data.getvalue() 29 | 30 | output_files = [os.path.join(output_directory, 31 | 'dogs_vs_cats.{}.zip'.format(set_)) 32 | for set_ in ['train', 'test1']] 33 | with zipfile.ZipFile(output_files[0], 'w') as zip_file: 34 | zif = zipfile.ZipInfo('train/') 35 | zip_file.writestr(zif, "") 36 | for i in range(25000): 37 | zip_file.writestr('train/cat.{}.jpeg'.format(i), image) 38 | with zipfile.ZipFile(output_files[1], 'w') as zip_file: 39 | zif = zipfile.ZipInfo('test1/') 40 | zip_file.writestr(zif, "") 41 | for i in range(12500): 42 | zip_file.writestr('test1/{}.jpeg'.format(i), image) 43 | 44 | 45 | def teardown(): 46 | shutil.rmtree(config.data_path[0]) 47 | config.data_path = config._old_data_path 48 | del config._old_data_path 49 | 50 | 51 | def test_dogs_vs_cats(): 52 | _test_conversion() 53 | _test_dataset() 54 | 55 | 56 | def _test_conversion(): 57 | convert_dogs_vs_cats(config.data_path[0], config.data_path[0]) 58 | output_file = "dogs_vs_cats.hdf5" 59 | output_file = os.path.join(config.data_path[0], output_file) 60 | with h5py.File(output_file, 'r') as h5: 61 | assert numpy.all(h5['targets'][:25000] == 0) 62 | assert numpy.all(h5['targets'][25000:] == 1) 63 | assert numpy.all(numpy.array( 64 | [img for img in h5['image_features'][:]]) == 0) 65 | assert numpy.all(h5['image_features_shapes'][:, 0] == 3) 66 | assert numpy.all(h5['image_features_shapes'][:, 1:] == 1) 67 | 68 | 69 | def _test_dataset(): 70 | train = DogsVsCats(('train',)) 71 | assert train.num_examples == 25000 72 | assert_raises(ValueError, DogsVsCats, ('valid',)) 73 | 74 | test = DogsVsCats(('test',)) 75 | stream = DataStream.default_stream( 76 | test, iteration_scheme=SequentialScheme(10, 10)) 77 | data = next(stream.get_epoch_iterator())[0][0] 78 | assert data.dtype.kind == 'f' 79 | 80 | 81 | test_dogs_vs_cats.setup = setup 82 | test_dogs_vs_cats.teardown = teardown 83 | -------------------------------------------------------------------------------- /tests/test_iris.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from numpy.testing import assert_raises, assert_equal, assert_allclose 4 | 5 | from fuel.datasets import Iris 6 | from tests import skip_if_not_available 7 | 8 | 9 | def test_iris_all(): 10 | skip_if_not_available(datasets=['iris.hdf5']) 11 | 12 | dataset = Iris(('all',), load_in_memory=False) 13 | handle = dataset.open() 14 | data, labels = dataset.get_data(handle, slice(0, 10)) 15 | assert data.dtype == 'float32' 16 | assert data.shape == (10, 4) 17 | assert labels.shape == (10, 1) 18 | known = numpy.array([5.1, 3.5, 1.4, 0.2]) 19 | assert_allclose(data[0], known) 20 | assert labels[0][0] == 0 21 | assert dataset.num_examples == 150 22 | dataset.close(handle) 23 | 24 | 25 | def test_iris_axes(): 26 | skip_if_not_available(datasets=['iris.hdf5']) 27 | 28 | dataset = Iris(('all',), load_in_memory=False) 29 | assert_equal(dataset.axis_labels['features'], 30 | ('batch', 'feature')) 31 | 32 | 33 | def test_iris_invalid_split(): 34 | skip_if_not_available(datasets=['iris.hdf5']) 35 | 36 | assert_raises(ValueError, Iris, ('dummy',)) 37 | -------------------------------------------------------------------------------- /tests/test_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from numpy.testing import assert_raises, assert_equal, assert_allclose 4 | 5 | from fuel import config 6 | from fuel.datasets import MNIST 7 | from fuel.streams import DataStream 8 | from fuel.schemes import SequentialScheme 9 | from tests import skip_if_not_available 10 | 11 | 12 | def test_mnist_train(): 13 | skip_if_not_available(datasets=['mnist.hdf5']) 14 | 15 | dataset = MNIST(('train',), load_in_memory=False) 16 | handle = dataset.open() 17 | data, labels = dataset.get_data(handle, slice(0, 10)) 18 | assert data.dtype == 'uint8' 19 | assert data.shape == (10, 1, 28, 28) 20 | assert labels.shape == (10, 1) 21 | known = numpy.array([0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 22 | 253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 0, 23 | 0, 0, 0]) 24 | assert_allclose(data[0][0][6], known) 25 | assert labels[0][0] == 5 26 | assert dataset.num_examples == 60000 27 | dataset.close(handle) 28 | 29 | stream = DataStream.default_stream( 30 | dataset, iteration_scheme=SequentialScheme(10, 10)) 31 | data = next(stream.get_epoch_iterator())[0] 32 | assert data.min() >= 0.0 and data.max() <= 1.0 33 | assert data.dtype == config.floatX 34 | 35 | 36 | def test_mnist_test(): 37 | skip_if_not_available(datasets=['mnist.hdf5']) 38 | 39 | dataset = MNIST(('test',), load_in_memory=False) 40 | handle = dataset.open() 41 | data, labels = dataset.get_data(handle, slice(0, 10)) 42 | assert data.dtype == 'uint8' 43 | assert data.shape == (10, 1, 28, 28) 44 | assert labels.shape == (10, 1) 45 | known = numpy.array([0, 0, 0, 0, 0, 0, 84, 185, 159, 151, 60, 36, 0, 0, 0, 46 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 47 | assert_allclose(data[0][0][7], known) 48 | assert labels[0][0] == 7 49 | assert dataset.num_examples == 10000 50 | dataset.close(handle) 51 | 52 | stream = DataStream.default_stream( 53 | dataset, iteration_scheme=SequentialScheme(10, 10)) 54 | data = next(stream.get_epoch_iterator())[0] 55 | assert data.min() >= 0.0 and data.max() <= 1.0 56 | assert data.dtype == config.floatX 57 | 58 | 59 | def test_mnist_axes(): 60 | skip_if_not_available(datasets=['mnist.hdf5']) 61 | 62 | dataset = MNIST(('train',), load_in_memory=False) 63 | assert_equal(dataset.axis_labels['features'], 64 | ('batch', 'channel', 'height', 'width')) 65 | 66 | 67 | def test_mnist_invalid_split(): 68 | skip_if_not_available(datasets=['mnist.hdf5']) 69 | 70 | assert_raises(ValueError, MNIST, ('dummy',)) 71 | -------------------------------------------------------------------------------- /tests/test_sequences.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import numpy 4 | from numpy.testing import assert_raises 5 | from six import BytesIO 6 | from six.moves import cPickle 7 | 8 | from fuel.datasets import TextFile, IterableDataset, IndexableDataset 9 | from fuel.schemes import SequentialScheme 10 | from fuel.streams import DataStream 11 | from fuel.transformers.sequences import Window, NGrams 12 | 13 | 14 | def lower(s): 15 | return s.lower() 16 | 17 | 18 | def test_text(): 19 | # Test word level and epochs. 20 | with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: 21 | sentences1 = f.name 22 | f.write("This is a sentence\n") 23 | f.write("This another one") 24 | with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: 25 | sentences2 = f.name 26 | f.write("More sentences\n") 27 | f.write("The last one") 28 | dictionary = {'': 0, '': 1, 'this': 2, 'a': 3, 'one': 4} 29 | text_data = TextFile(files=[sentences1, sentences2], 30 | dictionary=dictionary, bos_token=None, 31 | preprocess=lower) 32 | stream = DataStream(text_data) 33 | epoch = stream.get_epoch_iterator() 34 | assert len(list(epoch)) == 4 35 | epoch = stream.get_epoch_iterator() 36 | for sentence in zip(range(3), epoch): 37 | pass 38 | f = BytesIO() 39 | cPickle.dump(epoch, f) 40 | sentence = next(epoch) 41 | f.seek(0) 42 | epoch = cPickle.load(f) 43 | assert next(epoch) == sentence 44 | assert_raises(StopIteration, next, epoch) 45 | 46 | # Test character level. 47 | dictionary = dict([(chr(ord('a') + i), i) for i in range(26)] + 48 | [(' ', 26)] + [('', 27)] + 49 | [('', 28)] + [('', 29)]) 50 | text_data = TextFile(files=[sentences1, sentences2], 51 | dictionary=dictionary, preprocess=lower, 52 | level="character") 53 | sentence = next(DataStream(text_data).get_epoch_iterator())[0] 54 | assert sentence[:3] == [27, 19, 7] 55 | assert sentence[-3:] == [2, 4, 28] 56 | 57 | 58 | def test_ngram_stream(): 59 | sentences = [list(numpy.random.randint(10, size=sentence_length)) 60 | for sentence_length in [3, 5, 7]] 61 | stream = DataStream(IterableDataset(sentences)) 62 | ngrams = NGrams(4, stream) 63 | assert len(list(ngrams.get_epoch_iterator())) == 4 64 | 65 | 66 | def test_window_stream(): 67 | sentences = [list(numpy.random.randint(10, size=sentence_length)) 68 | for sentence_length in [3, 5, 7]] 69 | stream = DataStream(IterableDataset(sentences)) 70 | windows = Window(0, 4, 4, True, stream) 71 | for i, (source, target) in enumerate(windows.get_epoch_iterator()): 72 | assert source == target 73 | assert i == 5 # Total of 6 windows 74 | 75 | # Make sure that negative indices work 76 | windows = Window(-2, 4, 4, False, stream) 77 | for i, (source, target) in enumerate(windows.get_epoch_iterator()): 78 | assert source[-2:] == target[:2] 79 | assert i == 1 # Should get 2 examples 80 | 81 | # Even for overlapping negative indices should work 82 | windows = Window(-2, 4, 4, True, stream) 83 | for i, (source, target) in enumerate(windows.get_epoch_iterator()): 84 | assert source[:2] == target[-2:] 85 | assert i == 1 # Should get 2 examples 86 | 87 | 88 | def test_ngram_stream_error_on_multiple_sources(): 89 | # Check that NGram accepts only data streams with one source 90 | sentences = [list(numpy.random.randint(10, size=sentence_length)) 91 | for sentence_length in [3, 5, 7]] 92 | stream = DataStream(IterableDataset(sentences)) 93 | stream.sources = ('1', '2') 94 | assert_raises(ValueError, NGrams, 4, stream) 95 | 96 | 97 | def test_ngram_stream_raises_error_on_batch_stream(): 98 | sentences = [list(numpy.random.randint(10, size=sentence_length)) 99 | for sentence_length in [3, 5, 7]] 100 | stream = DataStream( 101 | IndexableDataset(sentences), iteration_scheme=SequentialScheme(3, 1)) 102 | assert_raises(ValueError, NGrams, 4, stream) 103 | 104 | 105 | def test_ngram_stream_raises_error_on_request(): 106 | sentences = [list(numpy.random.randint(10, size=sentence_length)) 107 | for sentence_length in [3, 5, 7]] 108 | stream = DataStream(IterableDataset(sentences)) 109 | ngrams = NGrams(4, stream) 110 | assert_raises(ValueError, ngrams.get_data, [0, 1]) 111 | -------------------------------------------------------------------------------- /tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import numpy 5 | from six.moves import cPickle 6 | 7 | from fuel.streams import DataStream 8 | from fuel.datasets import MNIST 9 | from fuel.schemes import SequentialScheme 10 | from tests import skip_if_not_available 11 | 12 | 13 | def test_in_memory(): 14 | skip_if_not_available(datasets=['mnist.hdf5']) 15 | # Load MNIST and get two batches 16 | mnist = MNIST(('train',), load_in_memory=True) 17 | data_stream = DataStream(mnist, iteration_scheme=SequentialScheme( 18 | examples=mnist.num_examples, batch_size=256)) 19 | epoch = data_stream.get_epoch_iterator() 20 | for i, (features, targets) in enumerate(epoch): 21 | if i == 1: 22 | break 23 | handle = mnist.open() 24 | known_features, _ = mnist.get_data(handle, slice(256, 512)) 25 | mnist.close(handle) 26 | assert numpy.all(features == known_features) 27 | 28 | # Pickle the epoch and make sure that the data wasn't dumped 29 | with tempfile.NamedTemporaryFile(delete=False) as f: 30 | filename = f.name 31 | cPickle.dump(epoch, f) 32 | assert os.path.getsize(filename) < 1024 * 1024 # Less than 1MB 33 | 34 | # Reload the epoch and make sure that the state was maintained 35 | del epoch 36 | with open(filename, 'rb') as f: 37 | epoch = cPickle.load(f) 38 | features, targets = next(epoch) 39 | handle = mnist.open() 40 | known_features, _ = mnist.get_data(handle, slice(512, 768)) 41 | mnist.close(handle) 42 | assert numpy.all(features == known_features) 43 | -------------------------------------------------------------------------------- /tests/test_server.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process 2 | 3 | from numpy.testing import assert_allclose, assert_raises 4 | from six.moves import cPickle 5 | from nose.exc import SkipTest 6 | 7 | from fuel.datasets import MNIST 8 | from fuel.schemes import SequentialScheme 9 | from fuel.server import start_server 10 | from fuel.streams import DataStream, ServerDataStream 11 | 12 | 13 | def get_stream(): 14 | return DataStream( 15 | MNIST(('train',)), iteration_scheme=SequentialScheme(1500, 500)) 16 | 17 | 18 | class TestServer(object): 19 | def setUp(self): 20 | self.server_process = Process( 21 | target=start_server, args=(get_stream(),)) 22 | self.server_process.start() 23 | self.stream = ServerDataStream(('f', 't'), False) 24 | 25 | def tearDown(self): 26 | self.server_process.terminate() 27 | self.stream = None 28 | 29 | def test_server(self): 30 | server_data = self.stream.get_epoch_iterator() 31 | expected_data = get_stream().get_epoch_iterator() 32 | for _, s, e in zip(range(3), server_data, expected_data): 33 | for data in zip(s, e): 34 | assert_allclose(*data) 35 | assert_raises(StopIteration, next, server_data) 36 | 37 | def test_pickling(self): 38 | try: 39 | self.stream = cPickle.loads(cPickle.dumps(self.stream)) 40 | # regression test: pickling of an unpickled stream used it fail 41 | cPickle.dumps(self.stream) 42 | server_data = self.stream.get_epoch_iterator() 43 | expected_data = get_stream().get_epoch_iterator() 44 | for _, s, e in zip(range(3), server_data, expected_data): 45 | for data in zip(s, e): 46 | assert_allclose(*data, rtol=1e-3) 47 | except AssertionError as e: 48 | raise SkipTest("Skip test_that failed with: {}".format(e)) 49 | assert_raises(StopIteration, next, server_data) 50 | 51 | def test_value_error_on_request(self): 52 | assert_raises(ValueError, self.stream.get_data, [0, 1]) 53 | 54 | def test_close(self): 55 | self.stream.close() 56 | 57 | def test_next_epoch(self): 58 | self.stream.next_epoch() 59 | 60 | def test_reset(self): 61 | self.stream.reset() 62 | -------------------------------------------------------------------------------- /tests/test_streams.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from numpy.testing import assert_equal, assert_raises 3 | 4 | from fuel.datasets import IterableDataset, IndexableDataset 5 | from fuel.schemes import SequentialExampleScheme, SequentialScheme 6 | from fuel.streams import AbstractDataStream, DataStream 7 | 8 | 9 | class DummyDataStream(AbstractDataStream): 10 | def reset(self): 11 | pass 12 | 13 | def close(self): 14 | pass 15 | 16 | def next_epoch(self): 17 | pass 18 | 19 | def get_epoch_iterator(self, as_dict=False): 20 | pass 21 | 22 | def get_data(self, request=None): 23 | pass 24 | 25 | 26 | class TestAbstractDataStream(object): 27 | def test_raises_value_error_on_no_scheme_no_produces_examples(self): 28 | stream = DummyDataStream() 29 | assert_raises(ValueError, getattr, stream, 'produces_examples') 30 | 31 | def test_raises_value_error_when_setting_produces_examples_if_scheme(self): 32 | stream = DummyDataStream(SequentialExampleScheme(2)) 33 | assert_raises(ValueError, setattr, stream, 'produces_examples', True) 34 | 35 | 36 | class TestDataStream(object): 37 | def setUp(self): 38 | self.dataset = IterableDataset(numpy.eye(2)) 39 | 40 | def test_sources_setter(self): 41 | stream = DataStream(self.dataset) 42 | stream.sources = ('features',) 43 | assert_equal(stream.sources, ('features',)) 44 | 45 | def test_no_axis_labels(self): 46 | stream = DataStream(self.dataset) 47 | assert stream.axis_labels is None 48 | 49 | def test_axis_labels_on_produces_examples(self): 50 | axis_labels = {'data': ('batch', 'features')} 51 | self.dataset.axis_labels = axis_labels 52 | stream = DataStream(self.dataset) 53 | assert_equal(stream.axis_labels, {'data': ('features',)}) 54 | 55 | def test_axis_labels_on_produces_batches(self): 56 | dataset = IndexableDataset(numpy.eye(2)) 57 | axis_labels = {'data': ('batch', 'features')} 58 | dataset.axis_labels = axis_labels 59 | stream = DataStream(dataset, iteration_scheme=SequentialScheme(2, 2)) 60 | assert_equal(stream.axis_labels, axis_labels) 61 | 62 | def test_produces_examples(self): 63 | stream = DataStream(self.dataset, 64 | iteration_scheme=SequentialExampleScheme(2)) 65 | assert stream.produces_examples 66 | -------------------------------------------------------------------------------- /tests/test_svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy 5 | from numpy.testing import assert_equal 6 | 7 | from fuel import config 8 | from fuel.datasets import H5PYDataset, SVHN 9 | 10 | 11 | def test_svhn(): 12 | data_path = config.data_path 13 | try: 14 | config.data_path = '.' 15 | f = h5py.File('svhn_format_2.hdf5', 'w') 16 | f['features'] = numpy.arange(100, dtype='uint8').reshape((10, 10)) 17 | f['targets'] = numpy.arange(10, dtype='uint8').reshape((10, 1)) 18 | split_dict = {'train': {'features': (0, 8), 'targets': (0, 8)}, 19 | 'test': {'features': (8, 10), 'targets': (8, 10)}} 20 | f.attrs['split'] = H5PYDataset.create_split_array(split_dict) 21 | f.close() 22 | dataset = SVHN(which_format=2, which_sets=('train',)) 23 | assert_equal(dataset.filename, 'svhn_format_2.hdf5') 24 | finally: 25 | config.data_path = data_path 26 | os.remove('svhn_format_2.hdf5') 27 | -------------------------------------------------------------------------------- /tests/test_toy.py: -------------------------------------------------------------------------------- 1 | 2 | from fuel.datasets.toy import Spiral, SwissRoll 3 | 4 | 5 | def test_spiral(): 6 | ds = Spiral(num_examples=1000, classes=2) 7 | 8 | features, position, label = ds.get_data(None, slice(0, 1000)) 9 | 10 | assert features.ndim == 2 11 | assert features.shape[0] == 1000 12 | assert features.shape[1] == 2 13 | 14 | assert position.ndim == 1 15 | assert position.shape[0] == 1000 16 | 17 | assert label.ndim == 1 18 | assert label.shape[0] == 1000 19 | 20 | assert features.max() <= 1. 21 | assert position.max() <= 1. 22 | assert label.max() == 1 23 | 24 | 25 | def test_swiossroll(): 26 | ds = SwissRoll(num_examples=1000) 27 | 28 | features, position = ds.get_data(None, slice(0, 1000)) 29 | 30 | assert features.ndim == 2 31 | assert features.shape[0] == 1000 32 | assert features.shape[1] == 3 33 | 34 | assert position.ndim == 2 35 | assert position.shape[0] == 1000 36 | assert position.shape[1] == 2 37 | 38 | assert features.max() <= 1. 39 | assert position.max() <= 1. 40 | -------------------------------------------------------------------------------- /tests/transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mila-iqia/fuel/1d6292dc25e3a115544237e392e61bff6631d23c/tests/transformers/__init__.py --------------------------------------------------------------------------------