├── .coveragerc ├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── LICENSE.txt ├── README.rst ├── docs ├── .gitignore ├── Makefile ├── api │ ├── .gitignore │ ├── index.rst │ ├── tfsnippet.dataflows.rst │ ├── tfsnippet.datasets.rst │ ├── tfsnippet.layers.rst │ ├── tfsnippet.ops.rst │ ├── tfsnippet.preprocessing.rst │ ├── tfsnippet.rst │ └── tfsnippet.utils.rst ├── conf.py ├── index.rst └── make.bat ├── requirements-dev.txt ├── requirements-docs.txt ├── requirements.txt ├── scripts ├── pkg_export.py ├── travis-docker-entry.sh └── travis-run-tests.sh ├── setup.py ├── tests ├── __init__.py ├── dataflows │ ├── __init__.py │ ├── test_array_flow.py │ ├── test_base.py │ ├── test_data_mappers.py │ ├── test_gather_flow.py │ ├── test_iterator_flow.py │ ├── test_mapper_flow.py │ ├── test_seq_flow.py │ └── test_threading_flow.py ├── datasets │ ├── __init__.py │ ├── helper.py │ ├── test_cifar.py │ ├── test_fashion_mnist.py │ └── test_mnist.py ├── distributions │ ├── __init__.py │ ├── test_base.py │ ├── test_batch_to_value.py │ ├── test_discretized.py │ ├── test_flow.py │ ├── test_mixture.py │ ├── test_multivariate.py │ ├── test_univariate.py │ ├── test_utils.py │ └── test_wrapper.py ├── evaluation │ ├── __init__.py │ └── test_collect_outputs.py ├── examples │ ├── __init__.py │ ├── helper.py │ ├── test_examples.py │ └── utils │ │ ├── __init__.py │ │ ├── test_mlconfig.py │ │ └── test_mlresult.py ├── helper.py ├── layers │ ├── __init__.py │ ├── activation │ │ ├── __init__.py │ │ ├── test_base.py │ │ └── test_leaky_relu.py │ ├── convolutional │ │ ├── __init__.py │ │ ├── helper.py │ │ ├── test_conv2d.py │ │ ├── test_pixelcnn.py │ │ ├── test_pooling.py │ │ ├── test_resnet.py │ │ ├── test_shifted.py │ │ └── test_utils.py │ ├── core │ │ ├── __init__.py │ │ ├── test_dense.py │ │ ├── test_dropout.py │ │ └── test_gated.py │ ├── flows │ │ ├── __init__.py │ │ ├── helper.py │ │ ├── test_base.py │ │ ├── test_coupling.py │ │ ├── test_invert.py │ │ ├── test_linear.py │ │ ├── test_planar_nf.py │ │ ├── test_rearrangement.py │ │ ├── test_reshape.py │ │ ├── test_sequential.py │ │ ├── test_split.py │ │ └── test_utils.py │ ├── helper.py │ ├── normalization │ │ ├── __init__.py │ │ ├── test_act_norm.py │ │ └── test_weight_norm.py │ ├── test_base.py │ ├── test_initialization.py │ ├── test_regularization.py │ └── test_utils.py ├── ops │ ├── __init__.py │ ├── test_assertions.py │ ├── test_classification.py │ ├── test_control_flows.py │ ├── test_convolution.py │ ├── test_evaluation.py │ ├── test_loop.py │ ├── test_misc.py │ ├── test_shape_utils.py │ ├── test_shift.py │ └── test_type_utils.py ├── preprocessing │ ├── __init__.py │ └── test_samplers.py ├── scaffold │ ├── __init__.py │ ├── test_checkpoint.py │ ├── test_logs.py │ ├── test_scheduled_var.py │ └── test_train_loop.py ├── test_bayes.py ├── test_stochastic.py ├── trainer │ ├── __init__.py │ ├── test_base_trainer.py │ ├── test_dynamic_values.py │ ├── test_evaluator.py │ ├── test_feed_dict.py │ └── test_trainer.py ├── utils │ ├── __init__.py │ ├── _div_op.py │ ├── _true_div_op.py │ ├── assets │ │ ├── payload.rar │ │ ├── payload.tar │ │ ├── payload.tar.bz2 │ │ ├── payload.tar.gz │ │ ├── payload.tar.xz │ │ └── payload.zip │ ├── test_caching.py │ ├── test_concepts.py │ ├── test_config_utils.py │ ├── test_console_table.py │ ├── test_data_utils.py │ ├── test_debugging.py │ ├── test_deprecation.py │ ├── test_doc_utils.py │ ├── test_events.py │ ├── test_extractor.py │ ├── test_invertible_matrix.py │ ├── test_misc.py │ ├── test_model_vars.py │ ├── test_random.py │ ├── test_registry.py │ ├── test_reuse.py │ ├── test_scope.py │ ├── test_session.py │ ├── test_settings.py │ ├── test_shape_utils.py │ ├── test_statistics.py │ ├── test_summary_collector.py │ ├── test_tensor_spec.py │ ├── test_tensor_wrapper.py │ ├── test_tfver.py │ └── test_typeutils.py └── variational │ ├── __init__.py │ ├── test_chain.py │ ├── test_estimators.py │ ├── test_evaluation.py │ ├── test_inference.py │ └── test_objectives.py └── tfsnippet ├── __init__.py ├── bayes.py ├── dataflows ├── __init__.py ├── array_flow.py ├── base.py ├── data_mappers.py ├── gather_flow.py ├── iterator_flow.py ├── mapper_flow.py ├── seq_flow.py └── threading_flow.py ├── datasets ├── __init__.py ├── cifar.py ├── fashion_mnist.py └── mnist.py ├── distributions ├── __init__.py ├── base.py ├── batch_to_value.py ├── discretized.py ├── flow.py ├── mixture.py ├── multivariate.py ├── univariate.py ├── utils.py └── wrapper.py ├── evaluation ├── __init__.py └── collect_outputs_.py ├── examples ├── README.rst ├── __init__.py ├── auto_encoders │ ├── __init__.py │ ├── bernoulli_latent_vae.py │ ├── conv_vae.py │ ├── dense_real_nvp.py │ ├── gm_vae.py │ ├── mixture_vae.py │ ├── planar_nf.py │ └── vae.py ├── classification │ ├── __init__.py │ ├── cifar10.py │ ├── cifar10_conv.py │ ├── mnist.py │ └── mnist_conv.py └── utils │ ├── __init__.py │ ├── dataflows_factory.py │ ├── evaluation.py │ ├── graph.py │ ├── jsonutils.py │ ├── misc.py │ ├── mlconfig.py │ ├── mlresults.py │ └── multi_gpu.py ├── layers ├── __init__.py ├── activations │ ├── __init__.py │ ├── base.py │ └── leaky_relu.py ├── base.py ├── convolutional │ ├── __init__.py │ ├── conv2d_.py │ ├── pixelcnn.py │ ├── pooling.py │ ├── resnet.py │ ├── shifted.py │ └── utils.py ├── core │ ├── __init__.py │ ├── dense_.py │ ├── dropout_.py │ └── gated.py ├── flows │ ├── __init__.py │ ├── base.py │ ├── branch.py │ ├── coupling.py │ ├── invert.py │ ├── linear.py │ ├── planar_nf.py │ ├── rearrangement.py │ ├── reshape.py │ ├── sequential.py │ └── utils.py ├── initialization.py ├── normalization │ ├── __init__.py │ ├── act_norm_.py │ └── weight_norm_.py ├── regularization.py └── utils.py ├── ops ├── __init__.py ├── assertions.py ├── classification.py ├── control_flows.py ├── convolution.py ├── evaluation.py ├── loop.py ├── misc.py ├── shape_utils.py ├── shifting.py └── type_utils.py ├── preprocessing ├── __init__.py └── samplers.py ├── scaffold ├── __init__.py ├── checkpoint.py ├── event_keys.py ├── logging_.py ├── scheduled_var.py └── train_loop_.py ├── shortcuts.py ├── stochastic.py ├── trainer ├── __init__.py ├── base_trainer.py ├── dynamic_values.py ├── evaluator.py ├── feed_dict.py ├── loss_trainer.py ├── trainer.py └── validator.py ├── utils ├── __init__.py ├── archive_file.py ├── caching.py ├── concepts.py ├── config_utils.py ├── console_table.py ├── data_utils.py ├── debugging.py ├── deprecation.py ├── doc_utils.py ├── events.py ├── graph_keys.py ├── imported.py ├── invertible_matrix.py ├── misc.py ├── model_vars.py ├── random.py ├── registry.py ├── reuse.py ├── scope.py ├── session.py ├── settings_.py ├── shape_utils.py ├── statistics.py ├── summary_collector.py ├── tensor_spec.py ├── tensor_wrapper.py ├── tfver.py └── type_utils.py └── variational ├── __init__.py ├── chain.py ├── estimators.py ├── evaluation.py ├── inference.py ├── objectives.py └── utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = tfsnippet 4 | 5 | [report] 6 | exclude_lines = 7 | pragma: no cover 8 | if self\.debug 9 | raise AssertionError 10 | raise NotImplementedError 11 | if 0: 12 | if __name__ == .__main__.: 13 | ignore_errors = True 14 | omit = 15 | tests/* 16 | scripts/* 17 | tfsnippet/utils/imported.py 18 | tfsnippet/examples/* 19 | setup.py 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .cache 3 | .pytest_cache 4 | .vscode 5 | *.iml 6 | *.lock 7 | /config.py 8 | /debug.py 9 | /.coverage 10 | *.pyc 11 | .DS_Store 12 | *.*~ 13 | results/ 14 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: required 3 | services: 4 | - docker 5 | env: 6 | matrix: 7 | - PYTHON_VERSION=2 TENSORFLOW_VERSION=1.5 8 | - PYTHON_VERSION=3 TENSORFLOW_VERSION=1.5 9 | - PYTHON_VERSION=2 TENSORFLOW_VERSION=1.12 10 | - PYTHON_VERSION=3 TENSORFLOW_VERSION=1.12 11 | cache: 12 | directories: 13 | - /home/travis/.tfsnippet 14 | - /home/travis/.keras 15 | before_install: 16 | - docker pull haowenxu/travis-tensorflow-docker:py${PYTHON_VERSION}tf${TENSORFLOW_VERSION} 17 | script: 18 | - bash scripts/travis-run-tests.sh 19 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## v0.2.0-alpha.4 8 | This version introduces breaking changes. Existing code might better stick to [v0.1.2](https://github.com/haowen-xu/tfsnippet/tree/v0.1.2) 9 | 10 | ### Added 11 | - Utilities have been exported to the root package, and now it's recommended to use TFSnippet by ``import tfsnippet as spt``. 12 | - Added `layers` package, including dense layer, convolutional layers, normalization layers, and flow layers. 13 | - Added a class `Config` to define user configs. 14 | - Added the global config object `settings`. 15 | - Added `model_variable` and `get_model_variables`; now all layer variables are created via `model_variable` function, instead of `tf.get_variable`. 16 | - Added `CheckpointSaver`. 17 | - Added `utils.EventSource`. 18 | 19 | ### Changed 20 | - Pin `ZhuSuan` dependency to the last commit (48c0f4e) of 3.x. 21 | - `global_reuse`, `instance_reuse`, `reopen_variable_scope`, `root_variable_scope` and `VarScopeObject` have been rewritten, and their behaviors have been slightly changed. This might cause existing code to be malfunction, if these code relies heavily on the precise variable scope or name scope of certain variables or tensors. 22 | - `Trainer` now accepts `summaries` argument on construction. 23 | - `flows` package now moved to `layers.flows`, and all its contents 24 | can be directly found under `layers` namespace. The interface of flows has been re-designed. 25 | - Some utilities in `utils` have been migrated to `ops`. 26 | - Added `ScheduledVariable` and `AnnealingScheduledVariable`, to replace `SimpleDynamicValue` and `AnnealingDynamicValue`. `DynamicValue` is still reserved. 27 | - `BayesianNet.add` now removes the `flow` argument. 28 | - The hook facility of `BaseTrainer` and `Evaluator` have been rewritten with `utils.EventSource`. 29 | - `TrainLoop` now supports to make checkpoints, and recover from the checkpoints. 30 | - Several utilities of `utils.shape_utils` and `utils.type_utils` have been moved from `utils` package to `ops` package. 31 | 32 | ### Removed 33 | - The `modules` package has been purged out of this project totally, including the `VAE` class. 34 | - `mathops` package has been removed. Some of its members have been migrated to `ops`. 35 | - `auto_reuse_variables` has been removed. 36 | - `VariableSaver` has been removed. 37 | - `EarlyStopping` has been removed. 38 | - `VariationalTrainingObjectives.rws_wake` has been removed. 39 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 TFSnippet Contributors 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | /_build -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = TFSnippet 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/api/.gitignore: -------------------------------------------------------------------------------- 1 | *.rst 2 | 3 | !index.rst 4 | 5 | !tfsnippet.rst 6 | !tfsnippet.dataflows.rst 7 | !tfsnippet.datasets.rst 8 | !tfsnippet.layers.rst 9 | !tfsnippet.ops.rst 10 | !tfsnippet.preprocessing.rst 11 | !tfsnippet.utils.rst 12 | -------------------------------------------------------------------------------- /docs/api/index.rst: -------------------------------------------------------------------------------- 1 | API Docs 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | tfsnippet 8 | tfsnippet.dataflows 9 | tfsnippet.datasets 10 | tfsnippet.layers 11 | tfsnippet.ops 12 | tfsnippet.preprocessing 13 | tfsnippet.utils 14 | -------------------------------------------------------------------------------- /docs/api/tfsnippet.dataflows.rst: -------------------------------------------------------------------------------- 1 | tfsnippet.dataflows 2 | =================== 3 | 4 | .. automodapi:: tfsnippet.dataflows 5 | -------------------------------------------------------------------------------- /docs/api/tfsnippet.datasets.rst: -------------------------------------------------------------------------------- 1 | tfsnippet\.datasets 2 | =================== 3 | 4 | .. automodapi:: tfsnippet.datasets 5 | -------------------------------------------------------------------------------- /docs/api/tfsnippet.layers.rst: -------------------------------------------------------------------------------- 1 | tfsnippet\.layers 2 | ================= 3 | 4 | .. automodapi:: tfsnippet.layers 5 | -------------------------------------------------------------------------------- /docs/api/tfsnippet.ops.rst: -------------------------------------------------------------------------------- 1 | tfsnippet\.ops 2 | ============== 3 | 4 | .. automodapi:: tfsnippet.ops 5 | -------------------------------------------------------------------------------- /docs/api/tfsnippet.preprocessing.rst: -------------------------------------------------------------------------------- 1 | tfsnippet\.preprocessing 2 | ======================== 3 | 4 | .. automodapi:: tfsnippet.preprocessing 5 | -------------------------------------------------------------------------------- /docs/api/tfsnippet.rst: -------------------------------------------------------------------------------- 1 | tfsnippet 2 | ========= 3 | 4 | .. automodapi:: tfsnippet 5 | -------------------------------------------------------------------------------- /docs/api/tfsnippet.utils.rst: -------------------------------------------------------------------------------- 1 | tfsnippet.utils 2 | =============== 3 | 4 | .. automodapi:: tfsnippet.utils 5 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TFSnippet documentation master file, created by 2 | sphinx-quickstart on Wed Nov 29 11:03:04 2017. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to TFSnippet 7 | ==================== 8 | 9 | TFSnippet is a set of utilities for writing and testing TensorFlow models. 10 | 11 | The design philosophy of TFSnippet is `non-interfering`. It aims to provide a 12 | set of useful utilities, possible to be used along with any other TensorFlow 13 | libraries and frameworks. 14 | 15 | Installation 16 | ------------ 17 | 18 | .. code-block:: bash 19 | 20 | pip install git+https://github.com/thu-ml/zhusuan.git 21 | pip install git+https://github.com/haowen-xu/tfsnippet.git 22 | 23 | 24 | Documentation 25 | ------------- 26 | 27 | .. toctree:: 28 | :maxdepth: 2 29 | 30 | api/index 31 | 32 | Indices and tables 33 | ------------------ 34 | 35 | * :ref:`genindex` 36 | * :ref:`modindex` 37 | * :ref:`search` 38 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=TFSnippet 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | 3 | coverage >= 4.3.4 4 | click >= 7.0 5 | fs >= 2.1.2 6 | imageio >= 2.3.0 7 | mock >= 2.0.0 8 | pytest >= 3.0 9 | rarfile >= 3.0 10 | 11 | # dependencies for documentation 12 | sphinx >= 1.6.3 13 | sphinx_rtd_theme 14 | sphinx_automodapi 15 | 16 | # dependencies required only for running the examples 17 | scikit-learn >= 0.20 18 | matplotlib >= 2.1.0 19 | git+https://github.com/haowen-xu/mlsnippet.git 20 | 21 | # !!! DO NOT INCLUDE TENSORFLOW IN THIS FILE !!! 22 | # Installing tensorflow (no `-gpu`) will cause the gpu version not to work. 23 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | -r requirements-dev.txt 2 | 3 | tensorflow == 1.5.0 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backports.tempfile >= 1.0 ; python_version < '3.2' 2 | filelock >= 3.0.10 3 | frozendict >= 1.2.0 4 | idx2numpy >= 1.2.2 5 | lazy-object-proxy >= 1.3.1 6 | natsort >= 5.3.3 7 | numpy >= 1.12.1 8 | pathlib2 >= 2.3.0 ; python_version < '3.5' 9 | PyYAML >= 3.13 10 | requests >= 2.18.4 11 | scipy >= 1.2.0 12 | semver >= 2.7.9 13 | six >= 1.11.0 14 | tqdm >= 4.23.0 15 | git+https://github.com/thu-ml/zhusuan.git@48c0f4e 16 | -------------------------------------------------------------------------------- /scripts/pkg_export.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import codecs 3 | import os 4 | 5 | 6 | def parse_all_list(python_file): 7 | with codecs.open(python_file, 'rb', 'utf-8') as f: 8 | tree = ast.parse(f.read(), python_file) 9 | for node in ast.iter_child_nodes(tree): 10 | if isinstance(node, ast.Assign) and len(node.targets) == 1 and \ 11 | node.targets[0].id == '__all__': 12 | return list(ast.literal_eval(node.value)) 13 | 14 | 15 | def process_dir(module_dir): 16 | module_all_list = [] 17 | module_import_list = [] 18 | 19 | for name in os.listdir(module_dir): 20 | path = os.path.join(module_dir, name) 21 | if name.endswith('.py') and name != '__init__.py': 22 | all_list = parse_all_list(path) 23 | if all_list is None: 24 | print('Warning: cannot parse __all__ list of: {}'.format(path)) 25 | else: 26 | module_all_list.extend(all_list) 27 | module_import_list.append(name[:-3]) 28 | 29 | elif name not in ('__pycache__',) and os.path.isdir(path): 30 | all_list = process_dir(path) 31 | if all_list: 32 | module_all_list.extend(all_list) 33 | module_import_list.append(name) 34 | 35 | module_import_list.sort() 36 | module_all_list.sort() 37 | 38 | module_all_list_lines = [' '] 39 | for n in module_all_list: 40 | new_s = module_all_list_lines[-1] + repr(n) + ', ' 41 | if len(new_s) >= 81: 42 | module_all_list_lines.append(' {!r}, '.format(n)) 43 | else: 44 | module_all_list_lines[-1] = new_s 45 | module_all_list_lines = [s.rstrip() for s in module_all_list_lines 46 | if s.strip()] 47 | 48 | init_content = '\n'.join( 49 | ['from .{} import *'.format(n) for n in module_import_list] + 50 | [''] + 51 | ['__all__ = ['] + 52 | module_all_list_lines + 53 | [']'] + 54 | [''] 55 | ) 56 | 57 | module_init_file = os.path.join(module_dir, '__init__.py') 58 | with codecs.open(module_init_file, 'wb', 'utf-8') as f: 59 | f.write(init_content) 60 | print(module_init_file) 61 | 62 | return module_all_list 63 | 64 | 65 | if __name__ == '__main__': 66 | tfsnippet_root = os.path.join( 67 | os.path.split(os.path.abspath(__file__))[0], '../tfsnippet') 68 | for name in os.listdir(tfsnippet_root): 69 | path = os.path.join(tfsnippet_root, name) 70 | if name not in ('examples', '__pycache__') and os.path.isdir(path): 71 | process_dir(path) 72 | -------------------------------------------------------------------------------- /scripts/travis-docker-entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | apt-get -y update && apt-get -y install unrar 6 | pip install . 7 | pip install -r requirements-dev.txt 8 | mkdir -p /root/.config/matplotlib 9 | echo 'backend : Agg' > /root/.config/matplotlib/matplotlibrc 10 | export PYTHONPATH="$(pwd):${PYTHONPATH}" 11 | 12 | if [ "${TENSORFLOW_VERSION}" = "*" ]; then 13 | python -m pytest \ 14 | tests/utils/test_reuse.py \ 15 | tests/utils/test_scope.py \ 16 | tests/utils/test_session.py \ 17 | tests/utils/test_shape_utils.py \ 18 | tests/utils/test_tensor_wrapper.py \ 19 | tests/utils/test_tfver.py \ 20 | tests/utils/test_typeutils.py 21 | else 22 | coverage run -m pytest && coveralls; 23 | fi 24 | -------------------------------------------------------------------------------- /scripts/travis-run-tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function runTest() { 4 | PY_VER="$1" 5 | TF_VER="$2" 6 | RUN_EXAMPLES_TEST_CASE="$3" 7 | RUN_DATASETS_TEST_CASE="${RUN_EXAMPLES_TEST_CASE}" 8 | 9 | echo "TFSnippet Tests 10 | 11 | PYTHON_VERSION=${PY_VER} 12 | TENSORFLOW_VERSION=${TF_VER} 13 | RUN_EXAMPLES_TEST_CASE=${RUN_EXAMPLES_TEST_CASE} 14 | RUN_DATASETS_TEST_CASE=${RUN_DATASETS_TEST_CASE} 15 | " 16 | 17 | IMAGE_NAME="haowenxu/travis-tensorflow-docker:py${PY_VER}tf${TF_VER}" 18 | docker run \ 19 | -v "$(pwd)":"$(pwd)" \ 20 | -v "/home/travis/.tfsnippet":"/root/.tfsnippet" \ 21 | -v "/home/travis/.keras":"/root/.keras" \ 22 | -w "$(pwd)" \ 23 | -e TRAVIS="${TRAVIS}" \ 24 | -e TRAVIS_JOB_ID="${TRAVIS_JOB_ID}" \ 25 | -e TRAVIS_BRANCH="${TRAVIS_BRANCH}" \ 26 | -e PYTHON_VERSION="${PYTHON_VERSION}" \ 27 | -e TENSORFLOW_VERSION="${TENSORFLOW_VERSION}" \ 28 | -e RUN_EXAMPLES_TEST_CASE="${RUN_EXAMPLES_TEST_CASE}" \ 29 | -e RUN_DATASETS_TEST_CASE="${RUN_DATASETS_TEST_CASE}" \ 30 | "${IMAGE_NAME}" \ 31 | bash "scripts/travis-docker-entry.sh" 32 | } 33 | 34 | if [[ "${TRAVIS_BRANCH}" = "master" || "${TRAVIS_BRANCH}" = "develop" ]]; then 35 | if [[ "${TENSORFLOW_VERSION}" = "*" ]]; then 36 | for TF_VER in 1.5 1.6 1.7 1.8 1.9 1.10 1.11 1.12; do 37 | runTest "${PYTHON_VERSION}" "${TF_VER}" "0"; 38 | done 39 | else 40 | runTest "${PYTHON_VERSION}" "${TENSORFLOW_VERSION}" "1"; 41 | fi 42 | else 43 | if [[ "${TENSORFLOW_VERSION}" != "*" ]]; then 44 | runTest "${PYTHON_VERSION}" "${TENSORFLOW_VERSION}" "0"; 45 | fi 46 | fi 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | TF Snippet 3 | ---------- 4 | 5 | TF Snippet is a set of utilities for writing and testing TensorFlow models. 6 | These utilities are in an early development stage, and might be migrated to 7 | a new dedicated project once they are mature enough. 8 | """ 9 | import ast 10 | import codecs 11 | import os 12 | import re 13 | import sys 14 | from setuptools import setup, find_packages 15 | 16 | 17 | _version_re = re.compile(r'__version__\s+=\s+(.*)') 18 | _source_dir = os.path.split(os.path.abspath(__file__))[0] 19 | 20 | if sys.version_info[0] == 2: 21 | def read_file(path): 22 | with open(path, 'rb') as f: 23 | return f.read() 24 | else: 25 | def read_file(path): 26 | with codecs.open(path, 'rb', 'utf-8') as f: 27 | return f.read() 28 | 29 | version = str(ast.literal_eval(_version_re.search( 30 | read_file(os.path.join(_source_dir, 'tfsnippet/__init__.py'))).group(1))) 31 | 32 | requirements_list = list(filter( 33 | lambda v: v and not v.startswith('#'), 34 | (s.strip() for s in read_file( 35 | os.path.join(_source_dir, 'requirements.txt')).split('\n')) 36 | )) 37 | dependency_links = [s for s in requirements_list if s.startswith('git+')] 38 | install_requires = [s for s in requirements_list if not s.startswith('git+')] 39 | 40 | 41 | setup( 42 | name='TFSnippet', 43 | version=version, 44 | url='https://github.com/haowen-xu/tfsnippet/', 45 | license='MIT', 46 | author='Haowen Xu', 47 | author_email='haowen.xu@outlook.com', 48 | description='A set of utilities for writing and testing TensorFlow models.', 49 | long_description=__doc__, 50 | packages=find_packages('.', include=['tfsnippet', 'tfsnippet.*']), 51 | zip_safe=False, 52 | platforms='any', 53 | setup_requires=['setuptools'], 54 | install_requires=install_requires, 55 | dependency_links=dependency_links, 56 | classifiers=[ 57 | 'Development Status :: 2 - Alpha', 58 | 'Intended Audience :: Developers', 59 | 'License :: OSI Approved :: MIT License', 60 | 'Operating System :: OS Independent', 61 | 'Programming Language :: Python', 62 | 'Programming Language :: Python :: 2', 63 | 'Programming Language :: Python :: 2.7', 64 | 'Programming Language :: Python :: 3', 65 | 'Programming Language :: Python :: 3.5', 66 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 67 | 'Topic :: Software Development :: Libraries :: Python Modules' 68 | ] 69 | ) 70 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/__init__.py -------------------------------------------------------------------------------- /tests/dataflows/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/dataflows/__init__.py -------------------------------------------------------------------------------- /tests/dataflows/test_base.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | from mock import MagicMock 6 | 7 | from tfsnippet.dataflows import DataFlow 8 | from tfsnippet.dataflows.array_flow import ArrayFlow 9 | 10 | 11 | class _DataFlow(DataFlow): 12 | 13 | def __init__(self): 14 | self._minibatch_iterator = MagicMock(return_value=[123]) 15 | 16 | 17 | class DataFlowTestCase(unittest.TestCase): 18 | 19 | def test_iter(self): 20 | df = _DataFlow() 21 | 22 | self.assertFalse(df._is_iter_entered) 23 | self.assertEqual(0, df._minibatch_iterator.call_count) 24 | 25 | for x in df: 26 | self.assertEqual(123, x) 27 | self.assertTrue(df._is_iter_entered) 28 | self.assertEqual(1, df._minibatch_iterator.call_count) 29 | 30 | with pytest.raises( 31 | RuntimeError, match='_DataFlow.__iter__ is not reentrant'): 32 | for _ in df: 33 | pass 34 | 35 | self.assertFalse(df._is_iter_entered) 36 | self.assertEqual(1, df._minibatch_iterator.call_count) 37 | 38 | def test_get_arrays(self): 39 | with pytest.raises(ValueError, match='empty, cannot convert to arrays'): 40 | _ = DataFlow.arrays([np.arange(0)], batch_size=5).get_arrays() 41 | 42 | # test one batch 43 | df = DataFlow.arrays([np.arange(5), np.arange(5, 10)], batch_size=6) 44 | arrays = df.get_arrays() 45 | np.testing.assert_equal(np.arange(5), arrays[0]) 46 | np.testing.assert_equal(np.arange(5, 10), arrays[1]) 47 | 48 | # test two batches 49 | df = DataFlow.arrays([np.arange(10), np.arange(10, 20)], batch_size=6) 50 | arrays = df.get_arrays() 51 | np.testing.assert_equal(np.arange(10), arrays[0]) 52 | np.testing.assert_equal(np.arange(10, 20), arrays[1]) 53 | 54 | # test to_arrays_flow 55 | df2 = df.to_arrays_flow(batch_size=6) 56 | self.assertIsInstance(df2, ArrayFlow) 57 | 58 | def test_implicit_iterator(self): 59 | df = DataFlow.arrays([np.arange(3)], batch_size=2) 60 | self.assertIsNone(df.current_batch) 61 | 62 | np.testing.assert_equal([[0, 1]], df.next_batch()) 63 | np.testing.assert_equal([[0, 1]], df.current_batch) 64 | np.testing.assert_equal([[2]], df.next_batch()) 65 | np.testing.assert_equal([[2]], df.current_batch) 66 | with pytest.raises(StopIteration): 67 | _ = df.next_batch() 68 | self.assertIsNone(df.current_batch) 69 | 70 | np.testing.assert_equal([[0, 1]], df.next_batch()) 71 | np.testing.assert_equal([[0, 1]], df.current_batch) 72 | -------------------------------------------------------------------------------- /tests/dataflows/test_data_mappers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | from mock import Mock 6 | 7 | from tfsnippet.dataflows import DataMapper, SlidingWindow 8 | 9 | 10 | class DataMapperTestCase(unittest.TestCase): 11 | 12 | def test_error(self): 13 | dm = DataMapper() 14 | dm._transform = Mock(return_value=np.array([1, 2, 3])) 15 | with pytest.raises(TypeError, match='The output of .* is neither ' 16 | 'a tuple, nor a list'): 17 | dm(np.array([1, 2, 3])) 18 | 19 | 20 | class SlidingWindowTestCase(unittest.TestCase): 21 | 22 | def test_props(self): 23 | arr = np.arange(13) 24 | sw = SlidingWindow(arr, window_size=3) 25 | self.assertIs(arr, sw.data_array) 26 | self.assertEqual(3, sw.window_size) 27 | 28 | def test_transform(self): 29 | arr = np.arange(13) 30 | sw = SlidingWindow(arr, window_size=3) 31 | np.testing.assert_equal( 32 | [[0, 1, 2], [5, 6, 7], [3, 4, 5]], 33 | sw(np.asarray([0, 5, 3]))[0] 34 | ) 35 | 36 | def test_as_flow(self): 37 | arr = np.arange(13) 38 | sw = SlidingWindow(arr, window_size=3) 39 | batches = list(sw.as_flow(batch_size=4)) 40 | self.assertEqual(3, len(batches)) 41 | np.testing.assert_equal( 42 | [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5]], 43 | batches[0][0] 44 | ) 45 | np.testing.assert_equal( 46 | [[4, 5, 6], [5, 6, 7], [6, 7, 8], [7, 8, 9]], 47 | batches[1][0] 48 | ) 49 | np.testing.assert_equal( 50 | [[8, 9, 10], [9, 10, 11], [10, 11, 12]], 51 | batches[2][0] 52 | ) 53 | -------------------------------------------------------------------------------- /tests/dataflows/test_gather_flow.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from tfsnippet.dataflows import DataFlow 7 | from tfsnippet.dataflows.gather_flow import GatherFlow 8 | 9 | 10 | class GatherFlowTestCase(unittest.TestCase): 11 | 12 | def test_flow(self): 13 | x_flow = DataFlow.arrays([np.arange(10)], batch_size=4) 14 | y_flow = DataFlow.arrays([np.arange(10, 17)], batch_size=4) 15 | flow = DataFlow.gather([x_flow, y_flow]) 16 | self.assertIsInstance(flow, GatherFlow) 17 | self.assertEqual((x_flow, y_flow), flow.flows) 18 | batches = list(flow) 19 | self.assertEqual(2, len(batches)) 20 | np.testing.assert_equal(np.arange(4), batches[0][0]) 21 | np.testing.assert_equal(np.arange(10, 14), batches[0][1]) 22 | np.testing.assert_equal(np.arange(4, 8), batches[1][0]) 23 | np.testing.assert_equal(np.arange(14, 17), batches[1][1]) 24 | 25 | def test_errors(self): 26 | with pytest.raises( 27 | ValueError, match='At least one flow must be specified'): 28 | _ = DataFlow.gather([]) 29 | with pytest.raises(TypeError, match='Not a DataFlow'): 30 | _ = DataFlow.gather([1]) 31 | -------------------------------------------------------------------------------- /tests/dataflows/test_iterator_flow.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from tfsnippet.dataflows import DataFlow 6 | 7 | 8 | class IteratorFactoryFlowTestCase(unittest.TestCase): 9 | 10 | def test_iterator_factory(self): 11 | x_flow = DataFlow.arrays([np.arange(5)], batch_size=3) 12 | y_flow = DataFlow.arrays([np.arange(5, 10)], batch_size=3) 13 | flow = DataFlow.iterator_factory(lambda: ( 14 | (x, y) for (x,), (y,) in zip(x_flow, y_flow) 15 | )) 16 | 17 | b = list(flow) 18 | self.assertEqual(2, len(b)) 19 | self.assertEqual(2, len(b[0])) 20 | np.testing.assert_array_equal([0, 1, 2], b[0][0]) 21 | np.testing.assert_array_equal([5, 6, 7], b[0][1]) 22 | np.testing.assert_array_equal([3, 4], b[1][0]) 23 | np.testing.assert_array_equal([8, 9], b[1][1]) 24 | -------------------------------------------------------------------------------- /tests/dataflows/test_seq_flow.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from tfsnippet.dataflows import DataFlow 7 | from tfsnippet.dataflows.seq_flow import SeqFlow 8 | 9 | 10 | class SeqFlowTestCase(unittest.TestCase): 11 | 12 | def test_seq(self): 13 | df = DataFlow.seq( 14 | 1, 9, 2, batch_size=3, shuffle=False, skip_incomplete=False, 15 | dtype=np.int64 16 | ) 17 | self.assertIsInstance(df, SeqFlow) 18 | self.assertEqual(1, df.array_count) 19 | self.assertEqual(4, df.data_length) 20 | self.assertEqual(((),), df.data_shapes) 21 | self.assertEqual(3, df.batch_size) 22 | self.assertFalse(df.is_shuffled) 23 | self.assertFalse(df.skip_incomplete) 24 | self.assertEqual(1, df.start) 25 | self.assertEqual(9, df.stop) 26 | self.assertEqual(2, df.step) 27 | 28 | def test_property(self): 29 | df = SeqFlow( 30 | 1, 9, 2, batch_size=3, shuffle=True, skip_incomplete=True, 31 | dtype=np.int64 32 | ) 33 | self.assertEqual(1, df.array_count) 34 | self.assertEqual(4, df.data_length) 35 | self.assertEqual(((),), df.data_shapes) 36 | self.assertEqual(3, df.batch_size) 37 | self.assertTrue(df.skip_incomplete) 38 | self.assertTrue(df.is_shuffled) 39 | self.assertEqual(1, df.start) 40 | self.assertEqual(9, df.stop) 41 | self.assertEqual(2, df.step) 42 | 43 | # test default options 44 | df = SeqFlow(1, 9, batch_size=3) 45 | self.assertFalse(df.skip_incomplete) 46 | self.assertFalse(df.is_shuffled) 47 | self.assertEqual(1, df.step) 48 | 49 | def test_errors(self): 50 | with pytest.raises( 51 | ValueError, match='`batch_size` is required'): 52 | _ = SeqFlow(1, 9, 2) 53 | 54 | def test_iterator(self): 55 | # test single array, without shuffle, no ignore 56 | b = [a[0] for a in SeqFlow(1, 9, 2, batch_size=3)] 57 | self.assertEqual(2, len(b)) 58 | np.testing.assert_array_equal([1, 3, 5], b[0]) 59 | np.testing.assert_array_equal([7], b[1]) 60 | 61 | # test single array, without shuffle, ignore 62 | b = [a[0] for a in SeqFlow(1, 9, 2, batch_size=3, skip_incomplete=True)] 63 | self.assertEqual(1, len(b)) 64 | np.testing.assert_array_equal([1, 3, 5], b[0]) 65 | 66 | # test single array, with shuffle, no ignore 67 | b = [a[0] for a in SeqFlow(1, 9, 2, batch_size=3, shuffle=True)] 68 | self.assertEqual(2, len(b)) 69 | np.testing.assert_array_equal( 70 | np.arange(1, 9, 2), sorted(np.concatenate(b))) 71 | -------------------------------------------------------------------------------- /tests/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/datasets/__init__.py -------------------------------------------------------------------------------- /tests/datasets/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | 5 | def skipUnlessRunDatasetsTests(): 6 | return unittest.skipUnless( 7 | os.environ.get('RUN_DATASETS_TEST_CASE') == '1', 8 | 'RUN_DATASETS_TEST_CASE is not set to 1, thus skipped' 9 | ) 10 | -------------------------------------------------------------------------------- /tests/datasets/test_cifar.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from tests.datasets.helper import skipUnlessRunDatasetsTests 7 | from tfsnippet.datasets import * 8 | 9 | 10 | class CifarTestCase(unittest.TestCase): 11 | 12 | @skipUnlessRunDatasetsTests() 13 | def test_fetch_cifar10(self): 14 | # test channels_last = True, normalize_x = False 15 | (train_x, train_y), (test_x, test_y) = load_cifar10() 16 | self.assertTupleEqual(train_x.shape, (50000, 32, 32, 3)) 17 | self.assertTupleEqual(train_y.shape, (50000,)) 18 | self.assertTupleEqual(test_x.shape, (10000, 32, 32, 3)) 19 | self.assertTupleEqual(test_y.shape, (10000,)) 20 | 21 | self.assertGreater(np.max(train_x), 128.) 22 | self.assertEqual(np.max(train_y), 9) 23 | 24 | # test channels_last = False, normalize_x = True 25 | (train_x, train_y), (test_x, test_y) = load_cifar10(channels_last=False, 26 | normalize_x=True) 27 | self.assertTupleEqual(train_x.shape, (50000, 3, 32, 32)) 28 | self.assertTupleEqual(train_y.shape, (50000,)) 29 | self.assertTupleEqual(test_x.shape, (10000, 3, 32, 32)) 30 | self.assertTupleEqual(test_y.shape, (10000,)) 31 | 32 | self.assertLess(np.max(train_x), 1. + 1e-5) 33 | 34 | # test x_shape 35 | (train_x, train_y), (test_x, test_y) = load_cifar10(x_shape=(1024, 3)) 36 | self.assertTupleEqual(train_x.shape, (50000, 1024, 3)) 37 | self.assertTupleEqual(test_x.shape, (10000, 1024, 3)) 38 | 39 | with pytest.raises(ValueError, 40 | match='`x_shape` does not product to 3072'): 41 | _ = load_cifar10(x_shape=(1, 2, 3)) 42 | 43 | @skipUnlessRunDatasetsTests() 44 | def test_fetch_cifar100(self): 45 | # test channels_last = True, normalize_x = False 46 | (train_x, train_y), (test_x, test_y) = load_cifar100() 47 | self.assertTupleEqual(train_x.shape, (50000, 32, 32, 3)) 48 | self.assertTupleEqual(train_y.shape, (50000,)) 49 | self.assertTupleEqual(test_x.shape, (10000, 32, 32, 3)) 50 | self.assertTupleEqual(test_y.shape, (10000,)) 51 | 52 | self.assertGreater(np.max(train_x), 128.) 53 | self.assertEqual(np.max(train_y), 99) 54 | 55 | # test channels_last = False, normalize_x = True 56 | (train_x, train_y), (test_x, test_y) = load_cifar100( 57 | label_mode='coarse', channels_last=False, normalize_x=True) 58 | self.assertTupleEqual(train_x.shape, (50000, 3, 32, 32)) 59 | self.assertTupleEqual(train_y.shape, (50000,)) 60 | self.assertTupleEqual(test_x.shape, (10000, 3, 32, 32)) 61 | self.assertTupleEqual(test_y.shape, (10000,)) 62 | 63 | self.assertLess(np.max(train_x), 1. + 1e-5) 64 | self.assertEqual(np.max(train_y), 19) 65 | 66 | # test x_shape 67 | (train_x, train_y), (test_x, test_y) = load_cifar100(x_shape=(1024, 3)) 68 | self.assertTupleEqual(train_x.shape, (50000, 1024, 3)) 69 | self.assertTupleEqual(test_x.shape, (10000, 1024, 3)) 70 | 71 | with pytest.raises(ValueError, 72 | match='`x_shape` does not product to 3072'): 73 | _ = load_cifar100(x_shape=(1, 2, 3)) 74 | -------------------------------------------------------------------------------- /tests/datasets/test_fashion_mnist.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from tests.datasets.helper import skipUnlessRunDatasetsTests 7 | from tfsnippet.datasets import * 8 | 9 | 10 | class FashionMnistTestCase(unittest.TestCase): 11 | 12 | @skipUnlessRunDatasetsTests() 13 | def test_fetch_fashion_mnist(self): 14 | # test normalize_x = False 15 | (train_x, train_y), (test_x, test_y) = load_fashion_mnist() 16 | self.assertTupleEqual(train_x.shape, (60000, 28, 28)) 17 | self.assertTupleEqual(train_y.shape, (60000,)) 18 | self.assertTupleEqual(test_x.shape, (10000, 28, 28)) 19 | self.assertTupleEqual(test_y.shape, (10000,)) 20 | 21 | self.assertGreater(np.max(train_x), 128.) 22 | 23 | # test normalize_x = True 24 | (train_x, train_y), (test_x, test_y) = \ 25 | load_fashion_mnist(normalize_x=True) 26 | self.assertTupleEqual(train_x.shape, (60000, 28, 28)) 27 | self.assertTupleEqual(train_y.shape, (60000,)) 28 | self.assertTupleEqual(test_x.shape, (10000, 28, 28)) 29 | self.assertTupleEqual(test_y.shape, (10000,)) 30 | 31 | self.assertLess(np.max(train_x), 1. + 1e-5) 32 | 33 | # test x_shape 34 | (train_x, train_y), (test_x, test_y) = \ 35 | load_fashion_mnist(x_shape=(784,)) 36 | self.assertTupleEqual(train_x.shape, (60000, 784)) 37 | self.assertTupleEqual(test_x.shape, (10000, 784)) 38 | 39 | with pytest.raises(ValueError, 40 | match='`x_shape` does not product to 784'): 41 | _ = load_mnist(x_shape=(1, 2, 3)) 42 | -------------------------------------------------------------------------------- /tests/datasets/test_mnist.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from tests.datasets.helper import skipUnlessRunDatasetsTests 7 | from tfsnippet.datasets import * 8 | 9 | 10 | class MnistTestCase(unittest.TestCase): 11 | 12 | @skipUnlessRunDatasetsTests() 13 | def test_fetch_mnist(self): 14 | # test normalize_x = False 15 | (train_x, train_y), (test_x, test_y) = load_mnist() 16 | self.assertTupleEqual(train_x.shape, (60000, 28, 28)) 17 | self.assertTupleEqual(train_y.shape, (60000,)) 18 | self.assertTupleEqual(test_x.shape, (10000, 28, 28)) 19 | self.assertTupleEqual(test_y.shape, (10000,)) 20 | 21 | self.assertGreater(np.max(train_x), 128.) 22 | 23 | # test normalize_x = True 24 | (train_x, train_y), (test_x, test_y) = load_mnist(normalize_x=True) 25 | self.assertTupleEqual(train_x.shape, (60000, 28, 28)) 26 | self.assertTupleEqual(train_y.shape, (60000,)) 27 | self.assertTupleEqual(test_x.shape, (10000, 28, 28)) 28 | self.assertTupleEqual(test_y.shape, (10000,)) 29 | 30 | self.assertLess(np.max(train_x), 1. + 1e-5) 31 | 32 | # test x_shape 33 | (train_x, train_y), (test_x, test_y) = load_mnist(x_shape=(784,)) 34 | self.assertTupleEqual(train_x.shape, (60000, 784)) 35 | self.assertTupleEqual(test_x.shape, (10000, 784)) 36 | 37 | with pytest.raises(ValueError, 38 | match='`x_shape` does not product to 784'): 39 | _ = load_mnist(x_shape=(1, 2, 3)) 40 | -------------------------------------------------------------------------------- /tests/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/distributions/__init__.py -------------------------------------------------------------------------------- /tests/distributions/test_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.distributions import Distribution, reduce_group_ndims 5 | 6 | 7 | class DistributionTestCase(tf.test.TestCase): 8 | 9 | def test_basic(self): 10 | class _Distribution(Distribution): 11 | def log_prob(self, given, group_ndims=0, name=None): 12 | return reduce_group_ndims( 13 | tf.reduce_sum, 14 | tf.convert_to_tensor(given) - 1., 15 | group_ndims 16 | ) 17 | 18 | with self.test_session() as sess: 19 | distrib = _Distribution( 20 | dtype=tf.float32, 21 | is_reparameterized=True, 22 | is_continuous=True, 23 | batch_shape=tf.constant([]), 24 | batch_static_shape=tf.TensorShape([]), 25 | value_ndims=0, 26 | ) 27 | self.assertIs(distrib.base_distribution, distrib) 28 | x = np.asarray([0., 1., 2.]) 29 | np.testing.assert_allclose( 30 | sess.run(distrib.prob(x, group_ndims=0)), 31 | np.exp(x - 1.) 32 | ) 33 | np.testing.assert_allclose( 34 | sess.run(distrib.prob(x, group_ndims=1)), 35 | np.exp(np.sum(x - 1., -1)) 36 | ) 37 | -------------------------------------------------------------------------------- /tests/distributions/test_multivariate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.distributions import * 5 | 6 | 7 | class OnehotCategoricalTestCase(tf.test.TestCase): 8 | 9 | def test_props(self): 10 | logits = np.arange(24, dtype=np.float32).reshape([2, 3, 4]) 11 | with self.test_session(): 12 | one_hot_categorical = OnehotCategorical(logits=tf.constant(logits)) 13 | self.assertEqual(one_hot_categorical.value_ndims, 1) 14 | self.assertEqual(one_hot_categorical.n_categories, 4) 15 | np.testing.assert_allclose( 16 | one_hot_categorical.logits.eval(), logits) 17 | 18 | 19 | class ConcreteCategoricalTestCase(tf.test.TestCase): 20 | 21 | def test_props(self): 22 | logits = np.arange(24, dtype=np.float32).reshape([2, 3, 4]) 23 | with self.test_session(): 24 | concrete = Concrete(temperature=.5, logits=tf.constant(logits)) 25 | self.assertEqual(concrete.value_ndims, 1) 26 | self.assertEqual(concrete.temperature.eval(), .5) 27 | self.assertEqual(concrete.n_categories, 4) 28 | np.testing.assert_allclose(concrete.logits.eval(), logits) 29 | 30 | 31 | class ExpConcreteCategoricalTestCase(tf.test.TestCase): 32 | 33 | def test_props(self): 34 | logits = np.arange(24, dtype=np.float32).reshape([2, 3, 4]) 35 | with self.test_session(): 36 | exp_concrete = ExpConcrete( 37 | temperature=.5, logits=tf.constant(logits)) 38 | self.assertEqual(exp_concrete.value_ndims, 1) 39 | self.assertEqual(exp_concrete.temperature.eval(), .5) 40 | self.assertEqual(exp_concrete.n_categories, 4) 41 | np.testing.assert_allclose(exp_concrete.logits.eval(), logits) 42 | -------------------------------------------------------------------------------- /tests/distributions/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import six 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tfsnippet.distributions import reduce_group_ndims 7 | 8 | if six.PY2: 9 | LONG_MAX = long(1) << 63 - long(1) 10 | else: 11 | LONG_MAX = 1 << 63 - 1 12 | 13 | 14 | class ReduceGroupNdimsTestCase(tf.test.TestCase): 15 | 16 | def test_errors(self): 17 | for o in [object(), None, 1.2, LONG_MAX, 18 | tf.constant(1.2, dtype=tf.float32), 19 | tf.constant(LONG_MAX, dtype=tf.int64)]: 20 | with pytest.raises( 21 | TypeError, 22 | match='group_ndims cannot be converted to int32'): 23 | _ = reduce_group_ndims(tf.reduce_sum, tf.constant(0.), o) 24 | 25 | with pytest.raises( 26 | ValueError, match='group_ndims must be non-negative'): 27 | _ = reduce_group_ndims(tf.reduce_sum, tf.constant(0.), -1) 28 | 29 | with self.test_session(): 30 | with pytest.raises( 31 | Exception, match='group_ndims must be non-negative'): 32 | _ = reduce_group_ndims(tf.reduce_sum, tf.constant(0.), 33 | tf.constant(-1, dtype=tf.int32)).eval() 34 | 35 | def test_output(self): 36 | tensor = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) 37 | tensor_sum_1 = tf.reduce_sum(tensor, axis=-1) 38 | tensor_sum_2 = tf.reduce_sum(tensor, axis=[-2, -1]) 39 | tensor_prod = tf.reduce_prod(tensor, axis=-1) 40 | g0 = tf.constant(0, dtype=tf.int32) 41 | g1 = tf.constant(1, dtype=tf.int32) 42 | g2 = tf.constant(2, dtype=tf.int32) 43 | 44 | with self.test_session(): 45 | # static group_ndims 46 | np.testing.assert_equal( 47 | tensor.eval(), 48 | reduce_group_ndims(tf.reduce_sum, tensor, 0).eval() 49 | ) 50 | np.testing.assert_equal( 51 | tensor_sum_1.eval(), 52 | reduce_group_ndims(tf.reduce_sum, tensor, 1).eval() 53 | ) 54 | np.testing.assert_equal( 55 | tensor_sum_2.eval(), 56 | reduce_group_ndims(tf.reduce_sum, tensor, 2).eval() 57 | ) 58 | np.testing.assert_equal( 59 | tensor_prod.eval(), 60 | reduce_group_ndims(tf.reduce_prod, tensor, 1).eval() 61 | ) 62 | 63 | # dynamic group_ndims 64 | np.testing.assert_equal( 65 | tensor.eval(), 66 | reduce_group_ndims(tf.reduce_sum, tensor, g0).eval() 67 | ) 68 | np.testing.assert_equal( 69 | tensor_sum_1.eval(), 70 | reduce_group_ndims(tf.reduce_sum, tensor, g1).eval() 71 | ) 72 | np.testing.assert_equal( 73 | tensor_sum_2.eval(), 74 | reduce_group_ndims(tf.reduce_sum, tensor, g2).eval() 75 | ) 76 | np.testing.assert_equal( 77 | tensor_prod.eval(), 78 | reduce_group_ndims(tf.reduce_prod, tensor, g1).eval() 79 | ) 80 | -------------------------------------------------------------------------------- /tests/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/evaluation/__init__.py -------------------------------------------------------------------------------- /tests/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/examples/__init__.py -------------------------------------------------------------------------------- /tests/examples/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | 5 | def skipUnlessRunExamplesTests(): 6 | return unittest.skipUnless( 7 | os.environ.get('RUN_EXAMPLES_TEST_CASE') == '1', 8 | 'RUN_EXAMPLES_TEST_CASE is not set to 1, thus skipped' 9 | ) 10 | -------------------------------------------------------------------------------- /tests/examples/test_examples.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import copy 3 | import os 4 | import re 5 | import subprocess 6 | import sys 7 | import time 8 | import unittest 9 | 10 | from tfsnippet.utils import TemporaryDirectory, humanize_duration 11 | from tests.examples.helper import skipUnlessRunExamplesTests 12 | 13 | 14 | class ExamplesTestCase(unittest.TestCase): 15 | """ 16 | Test case to ensure all examples can run for at least one step. 17 | """ 18 | 19 | @skipUnlessRunExamplesTests() 20 | def test_examples_can_run_one_step(self): 21 | timer = -time.time() 22 | 23 | # discover all example scripts 24 | def walk(pa, dst): 25 | for fn in os.listdir(pa): 26 | fp = os.path.join(pa, fn) 27 | if os.path.isdir(fp): 28 | walk(fp, dst) 29 | elif fp.endswith('.py'): 30 | with codecs.open(fp, 'rb', 'utf-8') as f: 31 | cnt = f.read() 32 | if re.search( 33 | r'''if\s+__name__\s*==\s+(['"])__main__\1:''', 34 | cnt): 35 | if 'max_step=config.max_step' not in cnt: 36 | raise RuntimeError('Example script does not have ' 37 | 'max_step configuration: {}'. 38 | format(fp)) 39 | dst.append(fp) 40 | return dst 41 | 42 | examples_dir = os.path.join( 43 | os.path.split(os.path.abspath(__file__))[0], 44 | '../../tfsnippet/examples' 45 | ) 46 | examples_scripts = walk(examples_dir, []) 47 | 48 | # run all examples scripts for just max_step 49 | env_dict = copy.copy(os.environ) 50 | 51 | for example_script in examples_scripts: 52 | print('Run {} ...'.format(example_script)) 53 | 54 | with TemporaryDirectory() as tempdir: 55 | args = [sys.executable, '-u', 56 | example_script, '--max_step=1'] 57 | subprocess.check_call(args, cwd=tempdir, env=env_dict) 58 | print('') 59 | 60 | # report finished tests 61 | print('Finished to run {} example scripts in {}.'.format( 62 | len(examples_scripts), humanize_duration(time.time() + timer))) 63 | -------------------------------------------------------------------------------- /tests/examples/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/examples/utils/__init__.py -------------------------------------------------------------------------------- /tests/examples/utils/test_mlconfig.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pytest 4 | 5 | from tfsnippet.examples.utils import MLConfig 6 | 7 | 8 | class MLConfigTestCase(unittest.TestCase): 9 | 10 | def test_assign(self): 11 | class MyConfig(MLConfig): 12 | a = 123 13 | 14 | config = MyConfig() 15 | self.assertEqual(config.a, 123) 16 | config.a = 234 17 | self.assertEqual(config.a, 234) 18 | 19 | with pytest.raises(AttributeError, match='Config key \'non_exist\' ' 20 | 'does not exist'): 21 | config.non_exist = 12345 22 | 23 | def test_defaults_and_to_dict(self): 24 | self.assertDictEqual(MLConfig.defaults(), {}) 25 | self.assertDictEqual(MLConfig().to_dict(), {}) 26 | 27 | class MyConfig(MLConfig): 28 | a = 123 29 | b = 456 30 | 31 | self.assertDictEqual(MyConfig.defaults(), {'a': 123, 'b': 456}) 32 | config = MyConfig() 33 | self.assertDictEqual(config.to_dict(), {'a': 123, 'b': 456}) 34 | config.a = 333 35 | self.assertDictEqual(config.to_dict(), {'a': 333, 'b': 456}) 36 | 37 | class MyConfig2(MyConfig): 38 | a = 234 39 | c = 1234 40 | 41 | self.assertDictEqual(MyConfig2.defaults(), 42 | {'a': 234, 'b': 456, 'c': 1234}) 43 | config = MyConfig2() 44 | self.assertDictEqual(config.to_dict(), {'a': 234, 'b': 456, 'c': 1234}) 45 | config.a = 333 46 | config.c = 444 47 | self.assertDictEqual(config.to_dict(), {'a': 333, 'b': 456, 'c': 444}) 48 | 49 | -------------------------------------------------------------------------------- /tests/examples/utils/test_mlresult.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | from tfsnippet.examples.utils import MLResults 7 | from tfsnippet.utils import TemporaryDirectory 8 | 9 | 10 | def head_of_file(path, n): 11 | with open(path, 'rb') as f: 12 | return f.read(n) 13 | 14 | 15 | class MLResultTestCase(unittest.TestCase): 16 | 17 | def test_imwrite(self): 18 | with TemporaryDirectory() as tmpdir: 19 | results = MLResults(tmpdir) 20 | im = np.zeros([32, 32], dtype=np.uint8) 21 | im[16:, ...] = 255 22 | 23 | results.save_image('test.bmp', im) 24 | file_path = os.path.join(tmpdir, 'test.bmp') 25 | self.assertTrue(os.path.isfile(file_path)) 26 | self.assertEqual(head_of_file(file_path, 2), b'\x42\x4d') 27 | 28 | results.save_image('test.png', im) 29 | file_path = os.path.join(tmpdir, 'test.png') 30 | self.assertTrue(os.path.isfile(file_path)) 31 | self.assertEqual(head_of_file(file_path, 8), 32 | b'\x89\x50\x4e\x47\x0d\x0a\x1a\x0a') 33 | 34 | results.save_image('test.jpg', im) 35 | file_path = os.path.join(tmpdir, 'test.jpg') 36 | self.assertTrue(os.path.isfile(file_path)) 37 | self.assertEqual(head_of_file(file_path, 3), b'\xff\xd8\xff') 38 | -------------------------------------------------------------------------------- /tests/helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | __all__ = ['assert_variables'] 4 | 5 | 6 | def assert_variables(names, exist=True, trainable=None, scope=None, 7 | collections=None): 8 | """ 9 | Assert variables of `name_or_names` meet certain criterion. 10 | 11 | Args: 12 | names (Iterable[str]): Name, or names. 13 | exist (bool): Assert variables exist or not. 14 | trainable: Assert variables are trainable or not. 15 | scope (None or str): The scope prefix to be prepended to the names. 16 | collections (Iterable[str]): Additional graph collections, where 17 | to ensure the variables are in. 18 | """ 19 | def normalize_name(n): 20 | return n.rsplit(':', 1)[0] 21 | 22 | names = tuple(names) 23 | 24 | if scope: 25 | scope = str(scope).rstrip('/') 26 | names = tuple('{}/{}'.format(scope, name) for name in names) 27 | 28 | global_vars = {normalize_name(v.name): v 29 | for v in tf.global_variables()} 30 | trainable_vars = {normalize_name(v.name): v 31 | for v in tf.trainable_variables()} 32 | collections = list(collections or ()) 33 | collection_vars = [ 34 | {normalize_name(v.name): v for v in tf.get_collection(c)} 35 | for c in collections 36 | ] 37 | 38 | for name in names: 39 | if exist: 40 | if name not in global_vars: 41 | raise AssertionError('Variable `{}` is expected to exist, but ' 42 | 'turn out to be non-exist.'.format(name)) 43 | 44 | # check trainable 45 | if trainable is False: 46 | if name in trainable_vars: 47 | raise AssertionError('Variable `{}` is expected not to be ' 48 | 'trainable, but turned out to be ' 49 | 'trainable'.format(name)) 50 | elif trainable is True: 51 | if name not in trainable_vars: 52 | raise AssertionError('Variable `{}` is expected to be ' 53 | 'trainable, but turned out not to be ' 54 | 'trainable'.format(name)) 55 | 56 | # check collections 57 | for coll, coll_vars in zip(collections, collection_vars): 58 | if name not in coll_vars: 59 | raise AssertionError('Variable `{}` is expected to be ' 60 | 'in the collection `{}`, but turned ' 61 | 'out not.'.format(name, coll)) 62 | 63 | else: 64 | if name in global_vars: 65 | raise AssertionError('Variable `{}` is expected not to exist, ' 66 | 'but turn out to be exist.'.format(name)) 67 | -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/layers/__init__.py -------------------------------------------------------------------------------- /tests/layers/activation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/layers/activation/__init__.py -------------------------------------------------------------------------------- /tests/layers/activation/test_leaky_relu.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | import pytest 5 | import tensorflow as tf 6 | 7 | from tfsnippet.layers import * 8 | 9 | 10 | class LeakyReLUTestCase(tf.test.TestCase): 11 | 12 | def test_leaky_relu(self): 13 | assert_allclose = functools.partial( 14 | np.testing.assert_allclose, rtol=1e-5, atol=1e-6) 15 | 16 | np.random.seed(12345) 17 | x = np.random.normal(size=[11, 31, 51]).astype(np.float32) 18 | y = np.maximum(x * .2, x) 19 | log_det = np.where( 20 | x < 0, 21 | np.ones_like(x, dtype=np.float32) * np.log(.2).astype(np.float32), 22 | np.zeros_like(x, dtype=np.float32) 23 | ) 24 | leaky_relu = LeakyReLU(alpha=.2) 25 | 26 | with self.test_session() as sess: 27 | # test value_ndims = 0 28 | y_out, log_det_out = sess.run( 29 | leaky_relu.transform(x)) 30 | self.assertTupleEqual(y_out.shape, (11, 31, 51)) 31 | self.assertTupleEqual(log_det_out.shape, (11, 31, 51)) 32 | x2_out, log_det2_out = sess.run( 33 | leaky_relu.inverse_transform(y)) 34 | assert_allclose(x2_out, x) 35 | assert_allclose(y_out, y) 36 | assert_allclose(log_det_out, log_det) 37 | assert_allclose(log_det2_out, -log_det) 38 | 39 | # test value_ndims = 1 40 | y_out, log_det_out = sess.run( 41 | leaky_relu.transform(x, value_ndims=1)) 42 | self.assertTupleEqual(y_out.shape, (11, 31, 51)) 43 | self.assertTupleEqual(log_det_out.shape, (11, 31)) 44 | x2_out, log_det2_out = sess.run( 45 | leaky_relu.inverse_transform(y, value_ndims=1)) 46 | assert_allclose(x2_out, x) 47 | assert_allclose(y_out, y) 48 | assert_allclose(log_det_out, np.sum(log_det, axis=-1)) 49 | assert_allclose(log_det2_out, -np.sum(log_det, axis=-1)) 50 | 51 | # test call 52 | assert_allclose(sess.run(leaky_relu(x)), y) 53 | 54 | with pytest.raises(ValueError, 55 | match='`alpha` must be a float number, ' 56 | 'and 0 < alpha < 1: got 0'): 57 | _ = LeakyReLU(alpha=0) 58 | with pytest.raises(ValueError, match='`alpha` must be a float number, ' 59 | 'and 0 < alpha < 1: got 1'): 60 | _ = LeakyReLU(alpha=1) 61 | with pytest.raises(ValueError, match='`alpha` must be a float number, ' 62 | 'and 0 < alpha < 1: got -1'): 63 | _ = LeakyReLU(alpha=-1) 64 | with pytest.raises(ValueError, match='`alpha` must be a float number, ' 65 | 'and 0 < alpha < 1: got 2'): 66 | _ = LeakyReLU(alpha=2) 67 | -------------------------------------------------------------------------------- /tests/layers/convolutional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/layers/convolutional/__init__.py -------------------------------------------------------------------------------- /tests/layers/convolutional/helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.utils import is_tensor_object 4 | 5 | __all__ = [ 6 | 'strides_tuple_to_channels_last', 7 | 'input_maybe_to_channels_last', 8 | 'output_maybe_to_channels_first', 9 | ] 10 | 11 | 12 | def strides_tuple_to_channels_last(stride_tuples, channels_last=None, 13 | data_format=None): 14 | def to_tensor(x): 15 | x = list(x) 16 | if any(is_tensor_object(t) for t in x): 17 | return tf.stack(list(x)) 18 | else: 19 | return tuple(x) 20 | 21 | if channels_last is None and data_format is None: 22 | raise ValueError('At least one of `channels_last` and `data_format` ' 23 | 'should be specified.') 24 | 25 | if channels_last is False or data_format == 'NCHW': 26 | stride_tuples = tuple( 27 | to_tensor(strides[i] for i in (0, 2, 3, 1)) 28 | for strides in stride_tuples 29 | ) 30 | 31 | return stride_tuples 32 | 33 | 34 | def input_maybe_to_channels_last(input, channels_last=None, data_format=None): 35 | if channels_last is None and data_format is None: 36 | raise ValueError('At least one of `channels_last` and `data_format` ' 37 | 'should be specified.') 38 | 39 | if channels_last is False or data_format == 'NCHW': 40 | return tf.transpose(input, (0, 2, 3, 1)) 41 | return input 42 | 43 | 44 | def output_maybe_to_channels_first(output, channels_last=None, 45 | data_format=None): 46 | if channels_last is None and data_format is None: 47 | raise ValueError('At least one of `channels_last` and `data_format` ' 48 | 'should be specified.') 49 | 50 | if channels_last is False or data_format == 'NCHW': 51 | return tf.transpose(output, (0, 3, 1, 2)) 52 | return output 53 | -------------------------------------------------------------------------------- /tests/layers/convolutional/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.layers.convolutional.utils import get_deconv_output_length 5 | from tfsnippet.utils import get_static_shape 6 | 7 | 8 | class GetDeconv2dOutputLengthTestCase(tf.test.TestCase): 9 | 10 | def test_output_length(self): 11 | def check(input_size, kernel_size, strides, padding): 12 | output_size = get_deconv_output_length( 13 | input_size, kernel_size, strides, padding) 14 | self.assertGreater(output_size, 0) 15 | 16 | # assert input <- output 17 | x = tf.nn.conv2d( 18 | np.zeros([1, output_size, output_size, 1], dtype=np.float32), 19 | filter=np.zeros([kernel_size, kernel_size, 1, 1]), 20 | strides=[1, strides, strides, 1], 21 | padding=padding.upper(), 22 | data_format='NHWC' 23 | ) 24 | h, w = get_static_shape(x)[1:3] 25 | self.assertEqual(input_size, h) 26 | 27 | check(input_size=7, kernel_size=1, strides=1, padding='same') 28 | check(input_size=7, kernel_size=1, strides=1, padding='valid') 29 | 30 | check(input_size=7, kernel_size=2, strides=1, padding='same') 31 | check(input_size=7, kernel_size=2, strides=1, padding='valid') 32 | check(input_size=7, kernel_size=1, strides=2, padding='same') 33 | check(input_size=7, kernel_size=1, strides=2, padding='valid') 34 | 35 | check(input_size=7, kernel_size=3, strides=1, padding='same') 36 | check(input_size=7, kernel_size=3, strides=1, padding='valid') 37 | check(input_size=7, kernel_size=1, strides=3, padding='same') 38 | check(input_size=7, kernel_size=1, strides=3, padding='valid') 39 | 40 | check(input_size=7, kernel_size=2, strides=3, padding='same') 41 | check(input_size=7, kernel_size=2, strides=3, padding='valid') 42 | check(input_size=7, kernel_size=3, strides=2, padding='same') 43 | check(input_size=7, kernel_size=3, strides=2, padding='valid') 44 | -------------------------------------------------------------------------------- /tests/layers/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/layers/core/__init__.py -------------------------------------------------------------------------------- /tests/layers/core/test_dropout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from mock import mock, Mock 4 | 5 | from tfsnippet.layers import dropout 6 | 7 | 8 | class DropoutTestCase(tf.test.TestCase): 9 | 10 | def test_dropout(self): 11 | with self.test_session() as sess: 12 | x = np.random.normal(size=[2, 3, 4, 5]).astype(np.float64) 13 | 14 | # test noise_shape = None 15 | noise = np.random.uniform(size=x.shape, low=0., high=1.) 16 | mask = (noise <= .6).astype(np.float64) 17 | noise_tensor = tf.convert_to_tensor(noise) 18 | 19 | with mock.patch('tensorflow.random_uniform', 20 | Mock(return_value=noise_tensor)) as m: 21 | # training = True 22 | y = dropout(x, rate=0.4, training=True) 23 | self.assertDictEqual(dict(m.call_args[1]), { 24 | 'shape': (2, 3, 4, 5), 25 | 'dtype': tf.float64, 26 | 'minval': 0., 27 | 'maxval': 1., 28 | }) 29 | np.testing.assert_allclose(sess.run(y), x * mask / .6) 30 | m.reset_mock() 31 | 32 | # training = False 33 | y = dropout(x, rate=0.4, training=False) 34 | self.assertFalse(m.called) 35 | np.testing.assert_allclose(sess.run(y), x) 36 | 37 | # test specify noise shape, and dynamic training 38 | noise = np.random.uniform(size=[3, 1, 5], low=0., high=1.) 39 | mask = (noise <= .4).astype(np.float64) 40 | noise_tensor = tf.convert_to_tensor(noise) 41 | training = tf.placeholder(dtype=tf.bool, shape=()) 42 | 43 | with mock.patch('tensorflow.random_uniform', 44 | Mock(return_value=noise_tensor)) as m: 45 | y = dropout(x, rate=tf.constant(0.6, dtype=tf.float32), 46 | training=training, noise_shape=(3, 1, 5)) 47 | self.assertDictEqual(dict(m.call_args[1]), { 48 | 'shape': (3, 1, 5), 49 | 'dtype': tf.float64, 50 | 'minval': 0., 51 | 'maxval': 1., 52 | }) 53 | 54 | np.testing.assert_allclose( 55 | sess.run(y, feed_dict={training: True}), 56 | x * mask / .4 57 | ) 58 | np.testing.assert_allclose( 59 | sess.run(y, feed_dict={training: False}), x) 60 | -------------------------------------------------------------------------------- /tests/layers/flows/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/layers/flows/__init__.py -------------------------------------------------------------------------------- /tests/layers/flows/helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.layers import BaseFlow 5 | 6 | 7 | class Ops(object): 8 | pass 9 | 10 | 11 | tfops = Ops() 12 | for attr in ('log', 'abs', 'exp', 'sign'): 13 | setattr(tfops, attr, getattr(tf, attr)) 14 | 15 | npyops = Ops() 16 | for attr in ('log', 'abs', 'exp', 'sign'): 17 | setattr(npyops, attr, getattr(np, attr)) 18 | 19 | 20 | def safe_pow(ops, x, e): 21 | return ops.sign(x) * ops.exp(e * ops.log(ops.abs(x))) 22 | 23 | 24 | def quadratic_transform(ops, x, a, b): 25 | return a * x ** 3 + b, ops.log(3. * a * (x ** 2)) 26 | 27 | 28 | def quadratic_inverse_transform(ops, y, a, b): 29 | return ( 30 | safe_pow(ops, (y - b) / a, 1./3), 31 | ops.log(ops.abs(safe_pow(ops, (y - b) / a, -2. / 3) / (3. * a))) 32 | ) 33 | 34 | 35 | class QuadraticFlow(BaseFlow): 36 | 37 | def __init__(self, a, b, value_ndims=0): 38 | super(QuadraticFlow, self).__init__(x_value_ndims=value_ndims, 39 | y_value_ndims=value_ndims) 40 | self.a = a 41 | self.b = b 42 | 43 | def _build(self, input=None): 44 | pass 45 | 46 | @property 47 | def explicitly_invertible(self): 48 | return True 49 | 50 | def _transform(self, x, compute_y, compute_log_det): 51 | y, log_det = quadratic_transform(tfops, x, self.a, self.b) 52 | if self.x_value_ndims > 0: 53 | log_det = tf.reduce_sum( 54 | log_det, axis=tf.range(-self.x_value_ndims, 0, dtype=tf.int32)) 55 | if not compute_y: 56 | y = None 57 | if not compute_log_det: 58 | log_det = None 59 | return y, log_det 60 | 61 | def _inverse_transform(self, y, compute_x, compute_log_det): 62 | x, log_det = quadratic_inverse_transform(tfops, y, self.a, self.b) 63 | if self.y_value_ndims > 0: 64 | log_det = tf.reduce_sum( 65 | log_det, axis=tf.range(-self.y_value_ndims, 0, dtype=tf.int32)) 66 | if not compute_x: 67 | x = None 68 | if not compute_log_det: 69 | log_det = None 70 | return x, log_det 71 | 72 | 73 | def invertible_flow_standard_check(self, flow, session, x, feed_dict=None, 74 | atol=0., rtol=1e-5): 75 | x = tf.convert_to_tensor(x) 76 | self.assertTrue(flow.explicitly_invertible) 77 | 78 | # test mapping from x -> y -> x 79 | y, log_det_y = flow.transform(x) 80 | x2, log_det_x = flow.inverse_transform(y) 81 | 82 | x_out, y_out, log_det_y_out, x2_out, log_det_x_out = \ 83 | session.run([x, y, log_det_y, x2, log_det_x], feed_dict=feed_dict) 84 | np.testing.assert_allclose(x2_out, x_out, atol=atol, rtol=rtol) 85 | 86 | np.testing.assert_allclose( 87 | -log_det_x_out, log_det_y_out, atol=atol, rtol=rtol) 88 | self.assertEqual(np.size(x_out), np.size(y_out)) 89 | 90 | x_batch_shape = x_out.shape 91 | y_batch_shape = y_out.shape 92 | if flow.x_value_ndims > 0: 93 | x_batch_shape = x_batch_shape[:-flow.x_value_ndims] 94 | if flow.y_value_ndims > 0: 95 | y_batch_shape = y_batch_shape[:-flow.y_value_ndims] 96 | self.assertTupleEqual(log_det_y_out.shape, x_batch_shape) 97 | self.assertTupleEqual(log_det_x_out.shape, y_batch_shape) 98 | self.assertTupleEqual(log_det_y_out.shape, log_det_x_out.shape) 99 | -------------------------------------------------------------------------------- /tests/layers/flows/test_invert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tests.layers.flows.helper import (QuadraticFlow, quadratic_transform, 6 | npyops, invertible_flow_standard_check) 7 | from tfsnippet import FlowDistribution, Normal 8 | from tfsnippet.layers import InvertFlow, PlanarNormalizingFlow, BaseFlow 9 | 10 | 11 | class InvertFlowTestCase(tf.test.TestCase): 12 | 13 | def test_invert_flow(self): 14 | with self.test_session() as sess: 15 | # test invert a normal flow 16 | flow = QuadraticFlow(2., 5.) 17 | inv_flow = flow.invert() 18 | 19 | self.assertIsInstance(inv_flow, InvertFlow) 20 | self.assertEqual(inv_flow.x_value_ndims, 0) 21 | self.assertEqual(inv_flow.y_value_ndims, 0) 22 | self.assertFalse(inv_flow.require_batch_dims) 23 | 24 | test_x = np.arange(12, dtype=np.float32) + 1. 25 | test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.) 26 | 27 | self.assertFalse(flow._has_built) 28 | y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x)) 29 | self.assertTrue(flow._has_built) 30 | 31 | np.testing.assert_allclose(sess.run(y), test_y) 32 | np.testing.assert_allclose(sess.run(log_det_y), test_log_det) 33 | invertible_flow_standard_check(self, inv_flow, sess, test_y) 34 | 35 | # test invert an InvertFlow 36 | inv_inv_flow = inv_flow.invert() 37 | self.assertIs(inv_inv_flow, flow) 38 | 39 | # test use with FlowDistribution 40 | normal = Normal(mean=1., std=2.) 41 | inv_flow = QuadraticFlow(2., 5.).invert() 42 | distrib = FlowDistribution(normal, inv_flow) 43 | distrib_log_det = distrib.log_prob(test_x) 44 | np.testing.assert_allclose( 45 | *sess.run([distrib_log_det, 46 | normal.log_prob(test_y) + test_log_det]) 47 | ) 48 | 49 | def test_property(self): 50 | class _Flow(BaseFlow): 51 | @property 52 | def explicitly_invertible(self): 53 | return True 54 | 55 | inv_flow = _Flow(x_value_ndims=2, y_value_ndims=3, 56 | require_batch_dims=True).invert() 57 | self.assertTrue(inv_flow.require_batch_dims) 58 | self.assertEqual(inv_flow.x_value_ndims, 3) 59 | self.assertEqual(inv_flow.y_value_ndims, 2) 60 | 61 | def test_errors(self): 62 | with pytest.raises(ValueError, match='`flow` must be an explicitly ' 63 | 'invertible flow'): 64 | _ = InvertFlow(object()) 65 | 66 | with pytest.raises(ValueError, match='`flow` must be an explicitly ' 67 | 'invertible flow'): 68 | _ = PlanarNormalizingFlow().invert() 69 | -------------------------------------------------------------------------------- /tests/layers/flows/test_rearrangement.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tests.helper import assert_variables 5 | from tests.layers.flows.helper import invertible_flow_standard_check 6 | from tfsnippet.layers import FeatureShufflingFlow 7 | 8 | 9 | class FeatureShufflingFlowTestCase(tf.test.TestCase): 10 | 11 | def test_feature_shuffling_flow(self): 12 | np.random.seed(1234) 13 | 14 | with self.test_session() as sess: 15 | # axis = -1, value_ndims = 1 16 | x = np.random.normal(size=[3, 4, 5, 6]).astype(np.float32) 17 | x_ph = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 6]) 18 | permutation = np.arange(6, dtype=np.int32) 19 | np.random.shuffle(permutation) 20 | y = x[..., permutation] 21 | log_det = np.zeros([3, 4, 5]).astype(np.float32) 22 | 23 | layer = FeatureShufflingFlow(axis=-1, value_ndims=1) 24 | y_out, log_det_out = layer.transform(x_ph) 25 | sess.run(tf.assign(layer._permutation, permutation)) 26 | y_out, log_det_out = sess.run( 27 | [y_out, log_det_out], feed_dict={x_ph: x}) 28 | 29 | np.testing.assert_equal(y_out, y) 30 | np.testing.assert_equal(log_det_out, log_det) 31 | 32 | invertible_flow_standard_check( 33 | self, layer, sess, x_ph, feed_dict={x_ph: x}) 34 | 35 | assert_variables(['permutation'], trainable=False, 36 | scope='feature_shuffling_flow', 37 | collections=[tf.GraphKeys.MODEL_VARIABLES]) 38 | 39 | # axis = -2, value_ndims = 3 40 | x = np.random.normal(size=[3, 4, 5, 6]).astype(np.float32) 41 | x_ph = tf.placeholder(dtype=tf.float32, shape=[None, None, 5, None]) 42 | permutation = np.arange(5, dtype=np.int32) 43 | np.random.shuffle(permutation) 44 | y = x[..., permutation, :] 45 | log_det = np.zeros([3]).astype(np.float32) 46 | 47 | layer = FeatureShufflingFlow(axis=-2, value_ndims=3) 48 | y_out, log_det_out = layer.transform(x_ph) 49 | sess.run(tf.assign(layer._permutation, permutation)) 50 | y_out, log_det_out = sess.run( 51 | [y_out, log_det_out], feed_dict={x_ph: x}) 52 | 53 | np.testing.assert_equal(y_out, y) 54 | np.testing.assert_equal(log_det_out, log_det) 55 | 56 | invertible_flow_standard_check( 57 | self, layer, sess, x_ph, feed_dict={x_ph: x}) 58 | -------------------------------------------------------------------------------- /tests/layers/flows/test_sequential.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tfsnippet.layers import SequentialFlow, BaseFlow 6 | from tests.layers.flows.test_base import MultiLayerQuadraticFlow 7 | from tests.layers.flows.helper import (QuadraticFlow, 8 | invertible_flow_standard_check) 9 | 10 | 11 | class SequentialFlowTestCase(tf.test.TestCase): 12 | 13 | def test_errors(self): 14 | class _Flow(BaseFlow): 15 | pass 16 | 17 | with pytest.raises(TypeError, match='`flows` must not be empty'): 18 | _ = SequentialFlow([]) 19 | 20 | with pytest.raises( 21 | TypeError, match='The 0-th flow in `flows` is not an instance ' 22 | 'of `BaseFlow`: 123'): 23 | _ = SequentialFlow([123]) 24 | 25 | with pytest.raises( 26 | TypeError, match='`x_value_ndims` of the 1-th flow != ' 27 | '`y_value_ndims` of the 0-th flow: 2 vs 3'): 28 | _ = SequentialFlow([ 29 | _Flow(x_value_ndims=1, y_value_ndims=3), 30 | _Flow(x_value_ndims=2), 31 | ]) 32 | 33 | def test_sequential_with_quadratic_flows(self): 34 | n_layers = 3 35 | flow1 = MultiLayerQuadraticFlow(n_layers) 36 | flow2 = SequentialFlow([ 37 | QuadraticFlow(i + 1., i * 2. + 1.) 38 | for i in range(n_layers) 39 | ]) 40 | self.assertTrue(flow2.explicitly_invertible) 41 | self.assertEqual(len(flow2.flows), n_layers) 42 | for i in range(n_layers): 43 | self.assertEqual(flow2.flows[i].a, i + 1.) 44 | self.assertEqual(flow2.flows[i].b, i * 2. + 1.) 45 | 46 | x = tf.range(12, dtype=tf.float32) + 1. 47 | 48 | with self.test_session() as sess: 49 | invertible_flow_standard_check(self, flow2, sess, x) 50 | 51 | # transform 52 | y1, log_det_y1 = flow1.transform(x) 53 | y2, log_det_y2 = flow2.transform(x) 54 | np.testing.assert_allclose(*sess.run([y1, y2])) 55 | np.testing.assert_allclose(*sess.run([log_det_y1, log_det_y2])) 56 | 57 | # inverse transform 58 | x1, log_det_x1 = flow1.inverse_transform(y1) 59 | x2, log_det_x2 = flow1.inverse_transform(y2) 60 | np.testing.assert_allclose(*sess.run([x1, x2])) 61 | np.testing.assert_allclose(*sess.run([log_det_x1, log_det_x2])) 62 | 63 | def test_property(self): 64 | class _Flow(BaseFlow): 65 | @property 66 | def explicitly_invertible(self): 67 | return False 68 | 69 | flow = SequentialFlow([ 70 | _Flow(x_value_ndims=1, y_value_ndims=2), 71 | _Flow(x_value_ndims=2, y_value_ndims=3), 72 | ]) 73 | self.assertFalse(flow.explicitly_invertible) 74 | self.assertEqual(flow.x_value_ndims, 1) 75 | self.assertEqual(flow.y_value_ndims, 3) 76 | 77 | flow = SequentialFlow([ 78 | QuadraticFlow(2., 3.), 79 | _Flow(x_value_ndims=0), 80 | ]) 81 | self.assertFalse(flow.explicitly_invertible) 82 | -------------------------------------------------------------------------------- /tests/layers/helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def l2_normalize(x, axis, epsilon=1e-12): 5 | out = x / np.sqrt( 6 | np.maximum( 7 | np.sum(np.square(x), axis=axis, keepdims=True), 8 | epsilon 9 | ) 10 | ) 11 | return out 12 | -------------------------------------------------------------------------------- /tests/layers/normalization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/layers/normalization/__init__.py -------------------------------------------------------------------------------- /tests/layers/test_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | from mock import Mock 4 | 5 | from tfsnippet.layers import BaseLayer 6 | 7 | 8 | class MyLayer(BaseLayer): 9 | 10 | def __init__(self, output, **kwargs): 11 | super(MyLayer, self).__init__(**kwargs) 12 | self._build = Mock(wraps=self._build) 13 | self._apply = Mock(return_value=output) 14 | 15 | def _build(self, input=None): 16 | assert(tf.get_variable_scope().name == self.variable_scope.name) 17 | 18 | 19 | class BaseLayerTestCase(tf.test.TestCase): 20 | 21 | def test_skeleton(self): 22 | input = tf.constant(123.) 23 | inputs = [tf.constant(12.), tf.constant(3.)] 24 | output = tf.constant(456.) 25 | 26 | # test call build manually 27 | layer = MyLayer(output) 28 | self.assertFalse(layer._has_built) 29 | layer.build() 30 | self.assertTrue(layer._has_built) 31 | self.assertEqual(layer._build.call_args, ((None,), {})) 32 | self.assertFalse(layer._apply.called) 33 | self.assertIs(layer.apply(input), output) 34 | self.assertEqual(layer._build.call_count, 1) 35 | self.assertEqual(layer._apply.call_args, ((input,), {})) 36 | 37 | with pytest.raises(RuntimeError, match='Layer has already been built'): 38 | _ = layer.build() 39 | 40 | layer = MyLayer(output) 41 | layer._build_require_input = True 42 | with pytest.raises(ValueError, 43 | match='`MyLayer` requires `input` to build'): 44 | _ = layer.build() 45 | 46 | # test call build automatically 47 | layer = MyLayer(output) 48 | self.assertFalse(layer._has_built) 49 | self.assertIs(layer(inputs), output) 50 | self.assertEqual(layer._build.call_args, ((inputs,), {})) 51 | self.assertEqual(layer._apply.call_args, ((inputs,), {})) 52 | -------------------------------------------------------------------------------- /tests/layers/test_initialization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.layers import * 4 | 5 | 6 | class DefaultKernelInitializerTestCase(tf.test.TestCase): 7 | 8 | def test_default_kernel_initializer(self): 9 | i = default_kernel_initializer(weight_norm=True) 10 | self.assertEqual(i.stddev, .05) 11 | 12 | i = default_kernel_initializer(weight_norm=(lambda t: t)) 13 | self.assertEqual(i.stddev, .05) 14 | 15 | i = default_kernel_initializer(weight_norm=False) 16 | self.assertFalse(hasattr(i, 'stddev')) 17 | 18 | i = default_kernel_initializer(weight_norm=None) 19 | self.assertFalse(hasattr(i, 'stddev')) 20 | -------------------------------------------------------------------------------- /tests/layers/test_regularization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.layers import l2_regularizer 5 | 6 | 7 | class L2RegualrizerTestCase(tf.test.TestCase): 8 | 9 | def test_l2_regularizer(self): 10 | with self.test_session() as sess: 11 | w = np.random.random(size=[10, 11, 12]) 12 | lambda_ = .75 13 | loss = lambda_ * .5 * np.sum(w ** 2) 14 | np.testing.assert_allclose( 15 | sess.run(l2_regularizer(lambda_)(w)), 16 | loss 17 | ) 18 | 19 | self.assertIsNone(l2_regularizer(None)) 20 | -------------------------------------------------------------------------------- /tests/layers/test_utils.py: -------------------------------------------------------------------------------- 1 | import mock 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tfsnippet.layers.utils import validate_weight_norm_arg 6 | 7 | 8 | class ValidateWeightNormArgTestCase(tf.test.TestCase): 9 | 10 | def test_validate_weight_norm_arg(self): 11 | # noinspection PyUnresolvedReferences 12 | import tfsnippet.layers.utils 13 | 14 | # test callable 15 | f = lambda t: t 16 | self.assertIs(validate_weight_norm_arg(f, -1, True), f) 17 | 18 | # test True: should generate a function that wraps true weight_norm 19 | with mock.patch('tfsnippet.layers.normalization.weight_norm') as m: 20 | f = validate_weight_norm_arg(True, -2, False) 21 | t = tf.reshape(tf.range(6, dtype=tf.float32), [1, 2, 3]) 22 | _ = f(t) 23 | self.assertEqual( 24 | m.call_args, ((t,), {'axis': -2, 'use_scale': False})) 25 | 26 | # test False: should return None 27 | self.assertIsNone(validate_weight_norm_arg(False, -1, True)) 28 | self.assertIsNone(validate_weight_norm_arg(None, -1, True)) 29 | 30 | # test others: should raise error 31 | with pytest.raises(TypeError, 32 | match='Invalid value for argument `weight_norm`: ' 33 | 'expected a bool or a callable function'): 34 | _ = validate_weight_norm_arg(123, -1, True) 35 | -------------------------------------------------------------------------------- /tests/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/ops/__init__.py -------------------------------------------------------------------------------- /tests/ops/test_classification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.ops import * 5 | 6 | 7 | class ClassificationAccuracyTestCase(tf.test.TestCase): 8 | 9 | def test_classification_accuracy(self): 10 | y_pred = np.asarray([0, 1, 3, 3, 2, 5, 4]) 11 | y_true = np.asarray([0, 2, 3, 1, 3, 5, 5]) 12 | acc = np.mean(y_pred == y_true) 13 | 14 | with self.test_session() as sess: 15 | self.assertAllClose( 16 | acc, 17 | sess.run(classification_accuracy(y_pred, y_true)) 18 | ) 19 | 20 | def test_softmax_classification_output(self): 21 | with self.test_session() as sess: 22 | np.random.seed(1234) 23 | 24 | # test 2d input, 1 class 25 | logits = np.random.random(size=[100, 1]) 26 | ans = np.argmax(logits, axis=-1) 27 | np.testing.assert_equal( 28 | sess.run(softmax_classification_output(logits)), 29 | ans 30 | ) 31 | 32 | # test 2d input, 5 classes 33 | logits = np.random.random(size=[100, 5]) 34 | ans = np.argmax(logits, axis=-1) 35 | np.testing.assert_equal( 36 | sess.run(softmax_classification_output(logits)), 37 | ans 38 | ) 39 | 40 | # test 3d input, 7 classes 41 | logits = np.random.random(size=[10, 100, 7]) 42 | ans = np.argmax(logits, axis=-1) 43 | np.testing.assert_equal( 44 | sess.run(softmax_classification_output(logits)), 45 | ans 46 | ) 47 | 48 | -------------------------------------------------------------------------------- /tests/ops/test_control_flows.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.ops import * 4 | 5 | __all__ = ['SmartCondTestCase'] 6 | 7 | 8 | class SmartCondTestCase(tf.test.TestCase): 9 | 10 | def test_smart_cond(self): 11 | with self.test_session() as sess: 12 | # test static condition 13 | self.assertEqual(1, smart_cond(True, (lambda: 1), (lambda: 2))) 14 | self.assertEqual(2, smart_cond(False, (lambda: 1), (lambda: 2))) 15 | 16 | # test dynamic condition 17 | cond_in = tf.placeholder(dtype=tf.bool, shape=()) 18 | value = smart_cond( 19 | cond_in, lambda: tf.constant(1), lambda: tf.constant(2)) 20 | self.assertIsInstance(value, tf.Tensor) 21 | self.assertEqual(sess.run(value, feed_dict={cond_in: True}), 1) 22 | self.assertEqual(sess.run(value, feed_dict={cond_in: False}), 2) 23 | -------------------------------------------------------------------------------- /tests/ops/test_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.ops import bits_per_dimension 5 | 6 | 7 | class BitsPerDimensionTestCase(tf.test.TestCase): 8 | 9 | def test_bits_per_dimension(self): 10 | with self.test_session() as sess: 11 | log_p = np.random.normal(size=[2, 3, 4, 5]) 12 | 13 | np.testing.assert_allclose( 14 | sess.run(bits_per_dimension(log_p, 1., scale=None)), 15 | -log_p / np.log(2) 16 | ) 17 | np.testing.assert_allclose( 18 | sess.run(bits_per_dimension(log_p, 1024 * 3, scale=None)), 19 | -log_p / (np.log(2) * 1024 * 3) 20 | ) 21 | np.testing.assert_allclose( 22 | sess.run(bits_per_dimension(log_p, 1., scale=256.)), 23 | -(log_p - np.log(256)) / np.log(2) 24 | ) 25 | np.testing.assert_allclose( 26 | sess.run(bits_per_dimension(log_p, 1024 * 3, scale=256)), 27 | -(log_p - np.log(256) * 1024 * 3) / (np.log(2) * 1024 * 3) 28 | ) 29 | -------------------------------------------------------------------------------- /tests/ops/test_shift.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tfsnippet.ops import shift 6 | 7 | 8 | class ShiftTestCase(tf.test.TestCase): 9 | 10 | def test_shift(self): 11 | x = np.random.normal(size=[3, 4, 5, 6]) 12 | t = tf.convert_to_tensor(x) 13 | 14 | with self.test_session() as sess: 15 | # test shift a scalar will do nothing 16 | t0 = tf.constant(0.) 17 | self.assertIs(shift(t0, []), t0) 18 | 19 | # test shift by zeros should result in `t` 20 | self.assertIs(shift(t, [0, 0, 0, 0]), t) 21 | 22 | # test shift all contents outside the size 23 | y = np.zeros_like(x) 24 | np.testing.assert_allclose( 25 | sess.run(shift(x, [3, 4, 5, 6])), 26 | y 27 | ) 28 | np.testing.assert_allclose( 29 | sess.run(shift(x, [-3, -4, -5, -6])), 30 | y 31 | ) 32 | 33 | # test shift by various distances 34 | y = np.zeros_like(x) 35 | y[1:, :-2, 3:, :] = x[:-1, 2:, :-3, :] 36 | np.testing.assert_allclose( 37 | sess.run(shift(x, [1, -2, 3, 0])), 38 | y 39 | ) 40 | for i in range(4): 41 | s = [0] * 4 42 | with pytest.raises(ValueError, match='Cannot shift `input`: ' 43 | 'input .* vs shift .*'): 44 | s[i] = 4 + i 45 | _ = shift(x, s) 46 | with pytest.raises(ValueError, match='Cannot shift `input`: ' 47 | 'input .* vs shift .*'): 48 | s[i] = -(4 + i) 49 | _ = shift(x, s) 50 | 51 | # test shift dynamic shape 52 | ph = tf.placeholder(dtype=tf.float64, shape=[None] * 4) 53 | np.testing.assert_allclose( 54 | sess.run(shift(ph, [1, -2, 3, 0]), feed_dict={ph: x}), 55 | y 56 | ) 57 | for i in range(4): 58 | s = [0] * 4 59 | s[i] = 4 + i 60 | 61 | output = shift(ph, s) 62 | with pytest.raises(Exception, match='Cannot shift `input`: ' 63 | 'input .* vs shift .*'): 64 | _ = sess.run(output, feed_dict={ph: x}) 65 | 66 | s[i] = -(4 + i) 67 | output = shift(ph, s) 68 | with pytest.raises(Exception, match='Cannot shift `input`: ' 69 | 'input .* vs shift .*'): 70 | _ = sess.run(output, feed_dict={ph: x}) 71 | 72 | with pytest.raises(ValueError, match='The rank of `shape` is required ' 73 | 'to be deterministic:'): 74 | _ = shift(tf.placeholder(dtype=tf.float64, shape=None), [0]) 75 | 76 | with pytest.raises(ValueError, 77 | match='The length of `shift` is required to equal ' 78 | 'the rank of `input`: shift .* vs input .*'): 79 | _ = shift(x, [0, 1, 2]) 80 | 81 | with pytest.raises(ValueError, 82 | match='The length of `shift` is required to equal ' 83 | 'the rank of `input`: shift .* vs input .*'): 84 | _ = shift(tf.constant(0.), [0]) 85 | -------------------------------------------------------------------------------- /tests/ops/test_type_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.ops import convert_to_tensor_and_cast 5 | 6 | 7 | class ConvertToTensorAndCastTestCase(tf.test.TestCase): 8 | 9 | def test_convert_to_tensor_and_cast(self): 10 | def check(x, dtype=None): 11 | z = tf.convert_to_tensor(x) 12 | y = convert_to_tensor_and_cast(x, dtype) 13 | self.assertIsInstance(y, tf.Tensor) 14 | if dtype is not None: 15 | self.assertEqual(y.dtype, dtype) 16 | else: 17 | self.assertEqual(y.dtype, z.dtype) 18 | 19 | check(np.arange(10, dtype=np.float32)) 20 | check(np.arange(10, dtype=np.float32), np.float64) 21 | check(np.arange(10, dtype=np.float32), tf.float64) 22 | check(tf.range(10, dtype=tf.float32)) 23 | check(tf.range(10, dtype=tf.float32), np.float64) 24 | check(tf.range(10, dtype=tf.float32), tf.float64) 25 | -------------------------------------------------------------------------------- /tests/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/preprocessing/__init__.py -------------------------------------------------------------------------------- /tests/preprocessing/test_samplers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from tfsnippet.preprocessing import * 6 | 7 | 8 | class BaseSamplerTestCase(unittest.TestCase): 9 | 10 | def test_sample(self): 11 | class _MySampler(BaseSampler): 12 | def sample(self, x): 13 | return x 14 | 15 | sampler = _MySampler() 16 | x = np.arange(12).reshape([3, 4]) 17 | self.assertIs(sampler.sample(x), x) 18 | self.assertEqual(sampler(x), (x,)) 19 | 20 | 21 | class BernoulliSamplerTestCase(unittest.TestCase): 22 | 23 | def test_property(self): 24 | self.assertEqual(BernoulliSampler().dtype, np.int32) 25 | self.assertEqual(BernoulliSampler(np.float64).dtype, np.float64) 26 | 27 | def test_sample(self): 28 | np.random.seed(1234) 29 | x = np.linspace(0, 1, 1001, dtype=np.float32) 30 | 31 | # test output is int32 arrays 32 | sampler = BernoulliSampler() 33 | y = sampler.sample(x) 34 | self.assertEqual(y.shape, x.shape) 35 | self.assertEqual(y.dtype, np.int32) 36 | self.assertLessEqual(np.max(y), 1) 37 | self.assertGreaterEqual(np.min(y), 0) 38 | 39 | # test output is float32 arrays 40 | sampler = BernoulliSampler(dtype=np.float32) 41 | y = sampler.sample(x) 42 | self.assertEqual(y.shape, x.shape) 43 | self.assertEqual(y.dtype, np.float32) 44 | self.assertLessEqual(np.max(y), 1 + 1e-5) 45 | self.assertGreaterEqual(np.min(y), 0 - 1e-5) 46 | 47 | 48 | class UniformNoiseSamplerTestCase(unittest.TestCase): 49 | 50 | def test_property(self): 51 | sampler = UniformNoiseSampler() 52 | self.assertIsNone(sampler.dtype) 53 | 54 | sampler = UniformNoiseSampler(minval=-2., maxval=2., dtype=np.float64) 55 | self.assertEqual(sampler.minval, -2.) 56 | self.assertEqual(sampler.maxval, 2.) 57 | self.assertEqual(sampler.dtype, np.float64) 58 | 59 | def test_sample(self): 60 | np.random.seed(1234) 61 | x = np.arange(0, 1000, dtype=np.float64) 62 | 63 | # test output dtype equals to input 64 | sampler = UniformNoiseSampler() 65 | y = sampler.sample(x) 66 | self.assertEqual(y.shape, x.shape) 67 | self.assertEqual(y.dtype, np.float64) 68 | self.assertLess(np.max(y - x), 1.) 69 | self.assertGreaterEqual(np.min(y - x), 0.) 70 | 71 | # test output is float32 arrays, and min&max val is not 0.&1. 72 | x = x * 4 73 | sampler = UniformNoiseSampler(minval=-2., maxval=2., dtype=np.float32) 74 | y = sampler.sample(x) 75 | self.assertEqual(y.shape, x.shape) 76 | self.assertEqual(y.dtype, np.float32) 77 | self.assertLess(np.max(y - x), 2.) 78 | self.assertGreaterEqual(np.min(y - x), -2.) 79 | -------------------------------------------------------------------------------- /tests/scaffold/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/scaffold/__init__.py -------------------------------------------------------------------------------- /tests/scaffold/test_scheduled_var.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | from tests.helper import assert_variables 5 | from tfsnippet.scaffold import ScheduledVariable, AnnealingVariable 6 | from tfsnippet.utils import ensure_variables_initialized, TemporaryDirectory 7 | 8 | 9 | class ScheduledVariableTestCase(tf.test.TestCase): 10 | 11 | def test_ScheduledVariable(self): 12 | v = ScheduledVariable('v', 123., dtype=tf.int32, model_var=True, 13 | collections=['my_variables']) 14 | assert_variables(['v'], trainable=False, collections=['my_variables']) 15 | 16 | with TemporaryDirectory() as tmpdir: 17 | saver = tf.train.Saver(var_list=[v.variable]) 18 | save_path = os.path.join(tmpdir, 'saved_var') 19 | 20 | with self.test_session() as sess: 21 | ensure_variables_initialized() 22 | 23 | self.assertEqual(v.get(), 123) 24 | self.assertEqual(sess.run(v), 123) 25 | self.assertEqual(v.set(456), 456) 26 | self.assertEqual(v.get(), 456) 27 | 28 | saver.save(sess, save_path) 29 | 30 | with self.test_session() as sess: 31 | saver.restore(sess, save_path) 32 | self.assertEqual(v.get(), 456) 33 | 34 | sess.run(v.assign_op, feed_dict={v.assign_ph: 789}) 35 | self.assertEqual(v.get(), 789) 36 | 37 | def test_AnnealingDynamicValue(self): 38 | with self.test_session() as sess: 39 | # test without min_value 40 | v = AnnealingVariable('v', 1, 2) 41 | ensure_variables_initialized() 42 | self.assertEqual(v.get(), 1) 43 | 44 | self.assertEqual(v.anneal(), 2) 45 | self.assertEqual(v.get(), 2) 46 | self.assertEqual(v.anneal(), 4) 47 | self.assertEqual(v.get(), 4) 48 | 49 | self.assertEqual(v.set(2), 2) 50 | self.assertEqual(v.get(), 2) 51 | self.assertEqual(v.anneal(), 4) 52 | self.assertEqual(v.get(), 4) 53 | 54 | # test with min_value 55 | v = AnnealingVariable('v2', 1, .5, 2) 56 | ensure_variables_initialized() 57 | self.assertEqual(v.get(), 2) 58 | 59 | v = AnnealingVariable('v3', 1, .5, .5) 60 | ensure_variables_initialized() 61 | self.assertEqual(v.get(), 1) 62 | self.assertEqual(v.anneal(), .5) 63 | self.assertEqual(v.get(), .5) 64 | self.assertEqual(v.anneal(), .5) 65 | self.assertEqual(v.get(), .5) 66 | -------------------------------------------------------------------------------- /tests/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/trainer/__init__.py -------------------------------------------------------------------------------- /tests/trainer/test_feed_dict.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import six 4 | import tensorflow as tf 5 | 6 | from tfsnippet.trainer import * 7 | from tfsnippet.scaffold import ScheduledVariable 8 | from tfsnippet.utils import ensure_variables_initialized 9 | 10 | 11 | class MyDynamicValue(DynamicValue): 12 | 13 | def __init__(self, value): 14 | self.value = value 15 | 16 | def get(self): 17 | return self.value 18 | 19 | 20 | class ResolveFeedDictTestCase(tf.test.TestCase): 21 | 22 | def test_copy(self): 23 | with self.test_session(): 24 | d = { 25 | 'a': 12, 26 | 'b': ScheduledVariable('b', 34), 27 | 'c': MyDynamicValue(56), 28 | 'd': lambda: 78, 29 | } 30 | ensure_variables_initialized() 31 | d2 = resolve_feed_dict(d) 32 | self.assertIsNot(d2, d) 33 | self.assertDictEqual({'a': 12, 'b': 34, 'c': 56, 'd': 78}, d2) 34 | self.assertIsInstance(d['b'], ScheduledVariable) 35 | self.assertIsInstance(d['c'], MyDynamicValue) 36 | 37 | def test_inplace(self): 38 | with self.test_session(): 39 | d = { 40 | 'a': 12, 41 | 'b': ScheduledVariable('b', 34), 42 | 'c': MyDynamicValue(56), 43 | 'd': lambda: 78, 44 | } 45 | ensure_variables_initialized() 46 | self.assertIs(d, resolve_feed_dict(d, inplace=True)) 47 | self.assertDictEqual({'a': 12, 'b': 34, 'c': 56, 'd': 78}, d) 48 | 49 | 50 | class MergeFeedDictTestCase(unittest.TestCase): 51 | 52 | def test_merge(self): 53 | self.assertDictEqual( 54 | {'a': 10, 'b': 200, 'c': 300, 'd': 4}, 55 | merge_feed_dict( 56 | None, 57 | {'a': 1, 'b': 2, 'c': 3, 'd': 4}, 58 | iter([('a', 10), ('b', 20), ('c', 30)]), 59 | None, 60 | six.iteritems({'b': 200, 'c': 300}) 61 | ) 62 | ) 63 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/_div_op.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def regular_div(x, y): 5 | return x / y 6 | 7 | 8 | def floor_div(x, y): 9 | return x // y 10 | -------------------------------------------------------------------------------- /tests/utils/_true_div_op.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | 6 | def true_div(x, y): 7 | return x / y 8 | -------------------------------------------------------------------------------- /tests/utils/assets/payload.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/utils/assets/payload.rar -------------------------------------------------------------------------------- /tests/utils/assets/payload.tar.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/utils/assets/payload.tar.bz2 -------------------------------------------------------------------------------- /tests/utils/assets/payload.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/utils/assets/payload.tar.gz -------------------------------------------------------------------------------- /tests/utils/assets/payload.tar.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/utils/assets/payload.tar.xz -------------------------------------------------------------------------------- /tests/utils/assets/payload.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/utils/assets/payload.zip -------------------------------------------------------------------------------- /tests/utils/test_debugging.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tfsnippet.utils import * 6 | 7 | 8 | class AssertionTestCase(tf.test.TestCase): 9 | 10 | def test_assert_deps(self): 11 | ph = tf.placeholder(dtype=tf.bool, shape=()) 12 | op = tf.assert_equal(ph, True, message='abcdefg') 13 | 14 | # test ops are empty 15 | with assert_deps([None]) as asserted: 16 | self.assertFalse(asserted) 17 | 18 | # test assertion enabled, and ops are not empty 19 | with self.test_session() as sess, \ 20 | scoped_set_config(settings, enable_assertions=True): 21 | with assert_deps([op, None]) as asserted: 22 | self.assertTrue(asserted) 23 | out = tf.constant(1.) 24 | with pytest.raises(Exception, match='abcdefg'): 25 | self.assertEqual(sess.run(out, feed_dict={ph: False}), 1.) 26 | 27 | # test assertion disabled 28 | with self.test_session() as sess, \ 29 | scoped_set_config(settings, enable_assertions=False): 30 | with assert_deps([op, None]) as asserted: 31 | self.assertFalse(asserted) 32 | out = tf.constant(1.) 33 | self.assertEqual(sess.run(out, feed_dict={ph: False}), 1.) 34 | 35 | 36 | class CheckNumericsTestCase(tf.test.TestCase): 37 | 38 | def test_check_numerics(self): 39 | # test enabled 40 | ph = tf.placeholder(dtype=tf.float32, shape=()) 41 | with scoped_set_config(settings, check_numerics=True): 42 | x = maybe_check_numerics(ph, message='numerical issues') 43 | with pytest.raises(Exception, match='numerical issues'): 44 | with self.test_session() as sess: 45 | _ = sess.run(x, feed_dict={ph: np.nan}) 46 | 47 | # test disabled 48 | with scoped_set_config(settings, check_numerics=False): 49 | x = maybe_check_numerics( 50 | tf.constant(np.nan), message='numerical issues') 51 | with self.test_session() as sess: 52 | self.assertTrue(np.isnan(sess.run(x))) 53 | 54 | 55 | class AddHistogramTestCase(tf.test.TestCase): 56 | 57 | def test_add_histogram(self): 58 | with tf.name_scope('parent'): 59 | x = tf.constant(0., name='x') 60 | y = tf.constant(1., name='y') 61 | z = tf.constant(2., name='z') 62 | w = tf.constant(3., name='w') 63 | 64 | # test enabled 65 | with scoped_set_config(settings, auto_histogram=True): 66 | maybe_add_histogram(x, strip_scope=True) 67 | maybe_add_histogram(y, summary_name='the_tensor') 68 | maybe_add_histogram(z, collections=[tf.GraphKeys.SUMMARIES]) 69 | 70 | # test disabled 71 | with scoped_set_config(settings, auto_histogram=False): 72 | maybe_add_histogram(w) 73 | 74 | self.assertListEqual( 75 | [op.name for op in tf.get_collection(GraphKeys.AUTO_HISTOGRAM)], 76 | ['maybe_add_histogram/x:0', 'maybe_add_histogram_1/the_tensor:0'] 77 | ) 78 | self.assertListEqual( 79 | [op.name for op in tf.get_collection(tf.GraphKeys.SUMMARIES)], 80 | ['maybe_add_histogram_2/parent/z:0'] 81 | ) 82 | -------------------------------------------------------------------------------- /tests/utils/test_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | import unittest 5 | 6 | import pytest 7 | 8 | from tfsnippet.utils import * 9 | 10 | 11 | class ExtractorTestCase(unittest.TestCase): 12 | 13 | def check_archive_file(self, extractor_class, archive_file, alias=None): 14 | if alias is not None: 15 | with TemporaryDirectory() as tmpdir: 16 | new_archive_file = os.path.join(tmpdir, alias) 17 | shutil.copy(archive_file, new_archive_file) 18 | self.check_archive_file(extractor_class, new_archive_file) 19 | else: 20 | with Extractor.open(archive_file) as e: 21 | self.assertIsInstance(e, extractor_class) 22 | files = [(n, f.read()) for n, f in e.iter_extract()] 23 | self.assertListEqual( 24 | [ 25 | ('a/1.txt', b'a/1.txt'), 26 | ('b/2.txt', b'b/2.txt'), 27 | ('c.txt', b'c.txt'), 28 | ], 29 | files 30 | ) 31 | 32 | def get_asset(self, name): 33 | return os.path.join( 34 | os.path.split(os.path.abspath(__file__))[0], 35 | 'assets', 36 | name 37 | ) 38 | 39 | def test_zip(self): 40 | self.check_archive_file(ZipExtractor, self.get_asset('payload.zip')) 41 | 42 | def test_rar(self): 43 | self.check_archive_file(RarExtractor, self.get_asset('payload.rar')) 44 | 45 | def test_tar(self): 46 | self.check_archive_file(TarExtractor, self.get_asset('payload.tar')) 47 | # xz 48 | if sys.version_info[:2] >= (3, 3): 49 | self.check_archive_file( 50 | TarExtractor, self.get_asset('payload.tar.xz')) 51 | self.check_archive_file( 52 | TarExtractor, self.get_asset('payload.tar.xz'), 'payload.txz') 53 | # gz 54 | self.check_archive_file(TarExtractor, self.get_asset('payload.tar.gz')) 55 | self.check_archive_file(TarExtractor, self.get_asset('payload.tar.gz'), 56 | 'payload.tgz') 57 | # bz2 58 | self.check_archive_file(TarExtractor, self.get_asset('payload.tar.bz2')) 59 | self.check_archive_file(TarExtractor, self.get_asset('payload.tar.bz2'), 60 | 'payload.tbz') 61 | self.check_archive_file(TarExtractor, self.get_asset('payload.tar.bz2'), 62 | 'payload.tbz2') 63 | self.check_archive_file(TarExtractor, self.get_asset('payload.tar.bz2'), 64 | 'payload.tb2') 65 | 66 | def test_errors(self): 67 | with TemporaryDirectory() as tmpdir: 68 | archive_file = os.path.join(tmpdir, 'payload.txt') 69 | with open(archive_file, 'wb') as f: 70 | f.write(b'') 71 | with pytest.raises( 72 | IOError, match='File is not a supported archive file'): 73 | with Extractor.open(archive_file): 74 | pass 75 | -------------------------------------------------------------------------------- /tests/utils/test_model_vars.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tests.helper import assert_variables 4 | from tfsnippet import model_variable, get_model_variables 5 | 6 | 7 | class ModelVariableTestCase(tf.test.TestCase): 8 | 9 | def test_model_variable(self): 10 | a = model_variable('a', shape=(), dtype=tf.float32) 11 | b = model_variable('b', shape=(), dtype=tf.float32, trainable=False, 12 | collections=['my_collection']) 13 | c = tf.get_variable('c', shape=(), dtype=tf.float32) 14 | 15 | assert_variables(['a'], trainable=True, 16 | collections=[tf.GraphKeys.MODEL_VARIABLES]) 17 | assert_variables(['b'], trainable=False, 18 | collections=[tf.GraphKeys.MODEL_VARIABLES, 19 | 'my_collection']) 20 | self.assertEqual(get_model_variables(), [a, b]) 21 | -------------------------------------------------------------------------------- /tests/utils/test_random.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tfsnippet.utils import VarScopeRandomState, set_random_seed 7 | 8 | 9 | class SetRandomSeedTestCase(tf.test.TestCase): 10 | 11 | def test_set_random_seed(self): 12 | with tf.Graph().as_default(): 13 | with self.test_session() as sess: 14 | set_random_seed(0) 15 | np_x = np.random.randn() 16 | tf_x = sess.run(tf.random_normal(shape=[], seed=0)) 17 | vsrs_seed = VarScopeRandomState._global_seed 18 | 19 | set_random_seed(1) 20 | self.assertNotEqual(np.random.randn(), np_x) 21 | self.assertNotEqual( 22 | sess.run(tf.random_normal(shape=[], seed=0)), tf_x) 23 | self.assertNotEqual(VarScopeRandomState._global_seed, vsrs_seed) 24 | 25 | with tf.Graph().as_default(): 26 | with self.test_session() as sess: 27 | set_random_seed(0) 28 | self.assertEqual(np.random.randn(), np_x) 29 | self.assertEqual( 30 | sess.run(tf.random_normal(shape=[], seed=0)), tf_x) 31 | self.assertEqual(VarScopeRandomState._global_seed, vsrs_seed) 32 | 33 | 34 | class VarScopeRandomStateTestCase(tf.test.TestCase): 35 | 36 | def test_VarScopeRandomState(self): 37 | def get_seq(): 38 | state = VarScopeRandomState(tf.get_variable_scope()) 39 | return state.randint(0, 0xffffffff, size=[100]) 40 | 41 | with tf.Graph().as_default(): 42 | VarScopeRandomState.set_global_seed(0) 43 | 44 | with tf.variable_scope('a'): 45 | a = get_seq() 46 | 47 | with tf.variable_scope('a'): 48 | np.testing.assert_equal(get_seq(), a) 49 | 50 | with tf.variable_scope('b'): 51 | self.assertFalse(np.all(get_seq() == a)) 52 | 53 | with tf.Graph().as_default(): 54 | VarScopeRandomState.set_global_seed(0) 55 | 56 | with tf.variable_scope('a'): 57 | np.testing.assert_equal(get_seq(), a) 58 | 59 | VarScopeRandomState.set_global_seed(1) 60 | 61 | with tf.variable_scope('a'): 62 | self.assertFalse(np.all(get_seq() == a)) 63 | 64 | VarScopeRandomState.set_global_seed(0) 65 | 66 | with tf.variable_scope('a'): 67 | np.testing.assert_equal(get_seq(), a) 68 | -------------------------------------------------------------------------------- /tests/utils/test_registry.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pytest 4 | 5 | from tfsnippet.utils import BaseRegistry, ClassRegistry 6 | 7 | 8 | class RegistryTestCase(unittest.TestCase): 9 | 10 | def test_base_registry(self): 11 | a = object() 12 | b = object() 13 | 14 | # test not ignore case 15 | r = BaseRegistry(ignore_case=False) 16 | self.assertFalse(r.ignore_case) 17 | 18 | r.register('a', a) 19 | self.assertIs(r.get('a'), a) 20 | with pytest.raises(KeyError, match='Object not registered: \'A\''): 21 | _ = r.get('A') 22 | self.assertListEqual(list(r), ['a']) 23 | 24 | with pytest.raises(KeyError, match='Object already registered: \'a\''): 25 | _ = r.register('a', a) 26 | with pytest.raises(KeyError, match='Object not registered: \'b\''): 27 | _ = r.get('b') 28 | 29 | r.register('A', b) 30 | self.assertIs(r.get('A'), b) 31 | self.assertListEqual(list(r), ['a', 'A']) 32 | 33 | # test ignore case 34 | r = BaseRegistry(ignore_case=True) 35 | self.assertTrue(r.ignore_case) 36 | 37 | r.register('a', a) 38 | self.assertIs(r.get('a'), a) 39 | self.assertIs(r.get('A'), a) 40 | self.assertListEqual(list(r), ['a']) 41 | 42 | with pytest.raises(KeyError, match='Object already registered: \'A\''): 43 | _ = r.register('A', a) 44 | with pytest.raises(KeyError, match='Object not registered: \'b\''): 45 | _ = r.get('b') 46 | 47 | r.register('B', b) 48 | self.assertIs(r.get('b'), b) 49 | self.assertIs(r.get('B'), b) 50 | self.assertListEqual(list(r), ['a', 'B']) 51 | 52 | def test_class_registry(self): 53 | r = ClassRegistry() 54 | 55 | with pytest.raises(TypeError, match='`obj` is not a class: 123'): 56 | r.register('int', 123) 57 | 58 | class MyClass(object): 59 | def __init__(self, value, message): 60 | self.value = value 61 | self.message = message 62 | 63 | r.register('MyClass', MyClass) 64 | self.assertIs(r.get('MyClass'), MyClass) 65 | o = r.construct('MyClass', 123, message='message') 66 | self.assertIsInstance(o, MyClass) 67 | self.assertEqual(o.value, 123) 68 | self.assertEqual(o.message, 'message') 69 | -------------------------------------------------------------------------------- /tests/utils/test_settings.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import tfsnippet as spt 4 | 5 | 6 | class TFSnippetConfigTestCase(unittest.TestCase): 7 | 8 | def test_tfsnippet_settings(self): 9 | self.assertTrue(spt.settings.enable_assertions) 10 | self.assertFalse(spt.settings.check_numerics) 11 | self.assertFalse(spt.settings.auto_histogram) 12 | -------------------------------------------------------------------------------- /tests/utils/test_tfver.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import tensorflow as tf 4 | 5 | from tfsnippet.utils import is_tensorflow_version_higher_or_equal 6 | 7 | 8 | class IsTensorflowVersionHigherOrEqualTestCase(unittest.TestCase): 9 | 10 | def test_is_tensorflow_version_higher_or_equal(self): 11 | # test compatibility with current version 12 | tf_version = tf.__version__ 13 | self.assertTrue(is_tensorflow_version_higher_or_equal(tf_version), 14 | msg='{} >= {} not hold'.format(tf_version, tf_version)) 15 | 16 | # test various cases 17 | try: 18 | versions = [ 19 | '0.1.0', '0.9.0', '0.12.0', '0.12.1', 20 | '1.0.0-rc0', '1.0.0-rc1', '1.0.0', '1.0.1', 21 | ] 22 | for i, v0 in enumerate(versions): 23 | tf.__version__ = v0 24 | for v in versions[:i+1]: 25 | self.assertTrue(is_tensorflow_version_higher_or_equal(v), 26 | msg='{} >= {} not hold'.format(v0, v)) 27 | for v in versions[i+1:]: 28 | self.assertFalse(is_tensorflow_version_higher_or_equal(v), 29 | msg='{} < {} not hold'.format(v0, v)) 30 | finally: 31 | tf.__version__ = tf_version 32 | -------------------------------------------------------------------------------- /tests/variational/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tests/variational/__init__.py -------------------------------------------------------------------------------- /tests/variational/test_chain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from mock import Mock 4 | 5 | from tfsnippet.variational import VariationalChain, VariationalInference 6 | 7 | 8 | class VariationalChainTestCase(tf.test.TestCase): 9 | 10 | def prepare_model(self): 11 | variational_local_log_probs = Mock( 12 | return_value=[tf.constant(1.), tf.constant(2.)]) 13 | variational = Mock( 14 | local_log_probs=Mock( 15 | wraps=lambda names: variational_local_log_probs(tuple(names))), 16 | __iter__=Mock(return_value=iter(['a', 'b'])), 17 | ) 18 | model_local_log_probs = Mock( 19 | return_value=[tf.constant(3.), tf.constant(4.)]) 20 | model = Mock( 21 | local_log_probs=Mock( 22 | wraps=lambda names: model_local_log_probs(tuple(names))), 23 | __iter__=Mock(return_value=iter(['c', 'd'])), 24 | ) 25 | return (variational_local_log_probs, variational, 26 | model_local_log_probs, model) 27 | 28 | def test_default_args(self): 29 | (variational_local_log_probs, variational, 30 | model_local_log_probs, model) = self.prepare_model() 31 | 32 | chain = VariationalChain(variational, model) 33 | self.assertEqual(variational_local_log_probs.call_args, 34 | ((('a', 'b'),),)) 35 | self.assertEqual(model_local_log_probs.call_args, 36 | ((('c', 'd'),),)) 37 | 38 | self.assertIs(chain.variational, variational) 39 | self.assertIs(chain.model, model) 40 | self.assertEqual(chain.latent_names, ('a', 'b')) 41 | self.assertIsNone(chain.latent_axis) 42 | self.assertIsInstance(chain.vi, VariationalInference) 43 | 44 | with self.test_session() as sess: 45 | np.testing.assert_allclose(chain.log_joint.eval(), 7.) 46 | np.testing.assert_allclose(chain.vi.log_joint.eval(), 7.) 47 | np.testing.assert_allclose(sess.run(chain.vi.latent_log_probs), 48 | [1., 2.]) 49 | 50 | def test_log_joint_arg(self): 51 | (variational_local_log_probs, variational, 52 | model_local_log_probs, model) = self.prepare_model() 53 | 54 | chain = VariationalChain(variational, model, log_joint=tf.constant(-1.)) 55 | self.assertEqual(variational_local_log_probs.call_args, 56 | ((('a', 'b'),),)) 57 | self.assertFalse(model_local_log_probs.called) 58 | 59 | with self.test_session(): 60 | np.testing.assert_allclose(chain.log_joint.eval(), -1.) 61 | np.testing.assert_allclose(chain.vi.log_joint.eval(), -1.) 62 | 63 | def test_latent_names_arg(self): 64 | (variational_local_log_probs, variational, 65 | model_local_log_probs, model) = self.prepare_model() 66 | 67 | chain = VariationalChain(variational, model, latent_names=iter(['a'])) 68 | self.assertEqual(variational_local_log_probs.call_args, 69 | ((('a',),),)) 70 | self.assertEqual(model_local_log_probs.call_args, 71 | ((('c', 'd'),),)) 72 | self.assertEqual(chain.latent_names, ('a',)) 73 | 74 | def test_latent_axis_arg(self): 75 | (variational_local_log_probs, variational, 76 | model_local_log_probs, model) = self.prepare_model() 77 | 78 | chain = VariationalChain(variational, model, latent_axis=1) 79 | self.assertEqual(chain.latent_axis, 1) 80 | self.assertEqual(chain.vi.axis, 1) 81 | -------------------------------------------------------------------------------- /tests/variational/test_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tfsnippet.ops import log_mean_exp 6 | from tfsnippet.variational import * 7 | 8 | 9 | def prepare_test_payload(): 10 | np.random.seed(1234) 11 | log_p = tf.constant(np.random.normal(size=[13]), dtype=tf.float32) 12 | log_q = tf.constant(np.random.normal(size=[7, 13]), dtype=tf.float32) 13 | return log_p, log_q 14 | 15 | 16 | def assert_allclose(a, b): 17 | np.testing.assert_allclose(a, b, atol=1e-4) 18 | 19 | 20 | class ImportanceSamplingLogLikelihoodTestCase(tf.test.TestCase): 21 | 22 | def test_error(self): 23 | with pytest.raises(ValueError, 24 | match='importance sampling log-likelihood requires ' 25 | 'multi-samples of latent variables'): 26 | log_p, log_q = prepare_test_payload() 27 | _ = importance_sampling_log_likelihood(log_p, log_q, axis=None) 28 | 29 | def test_monto_carlo_objective(self): 30 | with self.test_session() as sess: 31 | log_p, log_q = prepare_test_payload() 32 | 33 | ll = importance_sampling_log_likelihood(log_p, log_q, axis=0) 34 | ll_shape = ll.get_shape().as_list() 35 | assert_allclose(*sess.run([ 36 | ll, 37 | log_mean_exp(log_p - log_q, axis=0) 38 | ])) 39 | 40 | ll_k = importance_sampling_log_likelihood( 41 | log_p, log_q, axis=0, keepdims=True) 42 | self.assertListEqual( 43 | [1] + ll_shape, ll_k.get_shape().as_list()) 44 | assert_allclose(*sess.run([ 45 | ll_k, 46 | log_mean_exp(log_p - log_q, axis=0, keepdims=True) 47 | ])) 48 | -------------------------------------------------------------------------------- /tests/variational/test_objectives.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tfsnippet.ops import log_mean_exp 6 | from tfsnippet.variational import * 7 | 8 | 9 | def prepare_test_payload(): 10 | np.random.seed(1234) 11 | log_p = tf.constant(np.random.normal(size=[13]), dtype=tf.float32) 12 | log_q = tf.constant(np.random.normal(size=[7, 13]), dtype=tf.float32) 13 | return log_p, log_q 14 | 15 | 16 | def assert_allclose(a, b): 17 | np.testing.assert_allclose(a, b, atol=1e-4) 18 | 19 | 20 | class ELBOObjectiveTestCase(tf.test.TestCase): 21 | 22 | def test_elbo(self): 23 | with self.test_session() as sess: 24 | log_p, log_q = prepare_test_payload() 25 | 26 | obj = elbo_objective(log_p, log_q) 27 | obj_shape = obj.get_shape().as_list() 28 | assert_allclose(*sess.run([ 29 | obj, 30 | log_p - log_q 31 | ])) 32 | 33 | obj_r = elbo_objective(log_p, log_q, axis=0) 34 | self.assertListEqual( 35 | obj_shape[1:], obj_r.get_shape().as_list()) 36 | assert_allclose(*sess.run([ 37 | obj_r, 38 | tf.reduce_mean(log_p - log_q, axis=0) 39 | ])) 40 | 41 | obj_rk = elbo_objective(log_p, log_q, axis=0, keepdims=True) 42 | self.assertListEqual( 43 | [1] + obj_shape[1:], obj_rk.get_shape().as_list()) 44 | assert_allclose(*sess.run([ 45 | obj_rk, 46 | tf.reduce_mean(log_p - log_q, axis=0, keepdims=True) 47 | ])) 48 | 49 | 50 | class MonteCarloObjectiveTestCase(tf.test.TestCase): 51 | 52 | def test_error(self): 53 | with pytest.raises(ValueError, 54 | match='monte carlo objective requires multi-samples ' 55 | 'of latent variables'): 56 | log_p, log_q = prepare_test_payload() 57 | _ = monte_carlo_objective(log_p, log_q, axis=None) 58 | 59 | def test_monto_carlo_objective(self): 60 | with self.test_session() as sess: 61 | log_p, log_q = prepare_test_payload() 62 | 63 | obj = monte_carlo_objective(log_p, log_q, axis=0) 64 | obj_shape = obj.get_shape().as_list() 65 | assert_allclose(*sess.run([ 66 | obj, 67 | log_mean_exp(log_p - log_q, axis=0) 68 | ])) 69 | 70 | obj_k = monte_carlo_objective(log_p, log_q, axis=0, keepdims=True) 71 | self.assertListEqual( 72 | [1] + obj_shape, obj_k.get_shape().as_list()) 73 | assert_allclose(*sess.run([ 74 | obj_k, 75 | log_mean_exp(log_p - log_q, axis=0, keepdims=True) 76 | ])) 77 | -------------------------------------------------------------------------------- /tfsnippet/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.0a4' 2 | 3 | 4 | from . import (dataflows, datasets, distributions, evaluation, layers, 5 | ops, preprocessing, scaffold, trainer, utils, variational, 6 | bayes, shortcuts, stochastic) 7 | from .distributions import * 8 | from .scaffold import * 9 | from .trainer import * 10 | from .variational import * 11 | from .bayes import * 12 | from .shortcuts import * 13 | from .stochastic import * 14 | 15 | 16 | def _exports(): 17 | exports = [ 18 | # export modules 19 | 'dataflows', 'datasets', 'distributions', 'evaluation', 'layers', 20 | 'ops', 'preprocessing', 'scaffold', 'trainer', 'utils', 'variational', 21 | 'bayes', 'stochastic', 22 | ] 23 | 24 | # recursively export classes and functions 25 | for pkg in (distributions, scaffold, trainer, variational, bayes, 26 | shortcuts, stochastic): 27 | exports += list(pkg.__all__) 28 | 29 | # remove `_exports` from root namespace 30 | import sys 31 | del sys.modules[__name__]._exports 32 | 33 | return exports 34 | 35 | 36 | __all__ = _exports() 37 | -------------------------------------------------------------------------------- /tfsnippet/dataflows/__init__.py: -------------------------------------------------------------------------------- 1 | from .array_flow import * 2 | from .base import * 3 | from .data_mappers import * 4 | from .gather_flow import * 5 | from .iterator_flow import * 6 | from .mapper_flow import * 7 | from .seq_flow import * 8 | from .threading_flow import * 9 | 10 | __all__ = [ 11 | 'ArrayFlow', 'DataFlow', 'DataMapper', 'ExtraInfoDataFlow', 'GatherFlow', 12 | 'IteratorFactoryFlow', 'MapperFlow', 'SeqFlow', 'SlidingWindow', 13 | 'ThreadingFlow', 14 | ] 15 | -------------------------------------------------------------------------------- /tfsnippet/dataflows/gather_flow.py: -------------------------------------------------------------------------------- 1 | from .base import DataFlow 2 | 3 | __all__ = ['GatherFlow'] 4 | 5 | 6 | class GatherFlow(DataFlow): 7 | """ 8 | Gathering multiple data flows into a single flow. 9 | 10 | Usage:: 11 | 12 | x_flow = DataFlow.arrays([x], batch_size=256) 13 | y_flow = DataFlow.arrays([y], batch_size=256) 14 | xy_flow = DataFlow.gather([x_flow, y_flow]) 15 | """ 16 | 17 | def __init__(self, flows): 18 | """ 19 | Construct an :class:`IteratorFlow`. 20 | 21 | Args: 22 | flows(Iterable[DataFlow]): The data flows to gather. 23 | At least one data flow should be specified, otherwise a 24 | :class:`ValueError` will be raised. 25 | 26 | Raises: 27 | ValueError: If not even one data flow is specified. 28 | TypeError: If a specified flow is not a :class:`DataFlow`. 29 | """ 30 | flows = tuple(flows) 31 | if not flows: 32 | raise ValueError('At least one flow must be specified.') 33 | for flow in flows: 34 | if not isinstance(flow, DataFlow): 35 | raise TypeError('Not a DataFlow: {!r}'.format(flow)) 36 | self._flows = flows 37 | 38 | @property 39 | def flows(self): 40 | """ 41 | Get the data flows to be gathered. 42 | 43 | Returns: 44 | tuple[DataFlow]: The data flows to be gathered. 45 | """ 46 | return self._flows 47 | 48 | def _minibatch_iterator(self): 49 | for batches in zip(*self._flows): 50 | yield sum([tuple(b) for b in batches], ()) 51 | -------------------------------------------------------------------------------- /tfsnippet/dataflows/iterator_flow.py: -------------------------------------------------------------------------------- 1 | from .base import DataFlow 2 | 3 | __all__ = ['IteratorFactoryFlow'] 4 | 5 | 6 | class IteratorFactoryFlow(DataFlow): 7 | """ 8 | Data flow constructed from an iterator factory. 9 | 10 | Usage:: 11 | 12 | x_flow = DataFlow.arrays([x], batch_size=256) 13 | y_flow = DataFlow.arrays([y], batch_size=256) 14 | xy_flow = DataFlow.iterator_factory(lambda: ( 15 | (x, y) for (x,), (y,) in zip(x_flow, y_flow) 16 | )) 17 | """ 18 | 19 | def __init__(self, factory): 20 | """ 21 | Construct an :class:`IteratorFlow`. 22 | 23 | Args: 24 | factory (() -> Iterator or Iterable): A factory method for 25 | constructing the mini-batch iterators for each epoch. 26 | """ 27 | self._factory = factory 28 | 29 | def _minibatch_iterator(self): 30 | for batch in self._factory(): 31 | yield batch 32 | -------------------------------------------------------------------------------- /tfsnippet/dataflows/mapper_flow.py: -------------------------------------------------------------------------------- 1 | from .base import DataFlow 2 | 3 | __all__ = ['MapperFlow'] 4 | 5 | 6 | class MapperFlow(DataFlow): 7 | """ 8 | Data flow which transforms the mini-batch arrays from source flow 9 | by a specified mapper function. 10 | 11 | Usage:: 12 | 13 | source_flow = Data.arrays([x, y], batch_size=256) 14 | mapper_flow = source_flow.map(lambda x, y: (x + y,)) 15 | """ 16 | 17 | def __init__(self, source, mapper, array_indices=None): 18 | """ 19 | Construct a :class:`MapperFlow`. 20 | 21 | Args: 22 | source (DataFlow): The source data flow. 23 | mapper ((\*np.ndarray) -> tuple[np.ndarray])): The mapper 24 | function, which transforms numpy arrays into a tuple 25 | of other numpy arrays. 26 | array_indices (int or Iterable[int]): The indices of the arrays 27 | to be processed within a mini-batch. 28 | 29 | If specified, will apply the mapper only on these selected 30 | arrays. This will require the mapper to produce exactly 31 | the same number of output arrays as the inputs. 32 | 33 | If not specified, apply the mapper on all arrays, and do 34 | not require the number of output arrays to match the inputs. 35 | """ 36 | if array_indices is not None: 37 | try: 38 | array_indices = (int(array_indices),) 39 | except TypeError: 40 | array_indices = tuple(map(int, array_indices)) 41 | self._source = source 42 | self._mapper = mapper 43 | self._array_indices = array_indices 44 | 45 | @property 46 | def source(self): 47 | """Get the source data flow.""" 48 | return self._source 49 | 50 | @property 51 | def array_indices(self): 52 | """Get the indices of the arrays to be processed.""" 53 | return self._array_indices 54 | 55 | def _validate_outputs(self, outputs): 56 | if isinstance(outputs, list): 57 | outputs = tuple(outputs) 58 | elif not isinstance(outputs, tuple): 59 | raise TypeError('The output of the mapper is expected to ' 60 | 'be a tuple or a list, but got a {}.'. 61 | format(outputs.__class__.__name__)) 62 | return outputs 63 | 64 | def _minibatch_iterator(self): 65 | for batch in self._source: 66 | if self._array_indices is not None: 67 | mapped_b = list(batch) 68 | inputs = [mapped_b[i] for i in self._array_indices] 69 | outputs = self._validate_outputs(self._mapper(*inputs)) 70 | if len(outputs) != len(inputs): 71 | raise ValueError('The number of output arrays of the ' 72 | 'mapper is required to match the inputs, ' 73 | 'since `array_indices` is specified: ' 74 | 'outputs {} != inputs {}.'. 75 | format(len(outputs), len(inputs))) 76 | for i, o in zip(self._array_indices, outputs): 77 | mapped_b[i] = o 78 | mapped_b = tuple(mapped_b) 79 | else: 80 | mapped_b = self._validate_outputs(self._mapper(*batch)) 81 | yield mapped_b 82 | -------------------------------------------------------------------------------- /tfsnippet/dataflows/seq_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .array_flow import ArrayFlow 4 | 5 | __all__ = ['SeqFlow'] 6 | 7 | 8 | class SeqFlow(ArrayFlow): 9 | """ 10 | Using number sequence as data source flow. 11 | 12 | This :class:`SeqFlow` is particularly used for generating the `seed` 13 | number indices, then fetch the actual data by :class:`MapperFlow` 14 | according to the seed numbers. 15 | 16 | Usage:: 17 | 18 | seq_flow = DataFlow.seq(0, len(x), batch_size=256) 19 | mapper_flow = seq_flow.map(lambda idx: np.stack( 20 | [fetch_data_by_index(i) for i in idx] 21 | )) 22 | """ 23 | 24 | def __init__(self, start, stop, step=1, batch_size=None, shuffle=False, 25 | skip_incomplete=False, dtype=np.int32, random_state=None): 26 | """ 27 | Construct a :class:`SeqFlow`. 28 | 29 | Args: 30 | start: The starting number of the sequence. 31 | stop: The ending number of the sequence. 32 | step: The step of the sequence. (default ``1``) 33 | batch_size: Batch size of the data flow. Required. 34 | shuffle (bool): Whether or not to shuffle the numbers before 35 | iterating? (default :obj:`False`) 36 | skip_incomplete (bool): Whether or not to exclude the last 37 | mini-batch if it is incomplete? (default :obj:`False`) 38 | dtype: Data type of the numbers. (default ``np.int32``) 39 | random_state (RandomState): Optional numpy RandomState for 40 | shuffling data before each epoch. (default :obj:`None`, 41 | construct a new :class:`RandomState`). 42 | """ 43 | # check the parameters 44 | if batch_size is None: 45 | raise ValueError('`batch_size` is required.') 46 | 47 | # memorize the parameters 48 | super(SeqFlow, self).__init__( 49 | arrays=[np.arange(start, stop, step, dtype=dtype)], 50 | batch_size=batch_size, 51 | shuffle=shuffle, 52 | skip_incomplete=skip_incomplete, 53 | random_state=random_state 54 | ) 55 | self._start = start 56 | self._stop = stop 57 | self._step = step 58 | 59 | @property 60 | def start(self): 61 | """Get the starting number of the sequence.""" 62 | return self._start 63 | 64 | @property 65 | def stop(self): 66 | """Get the ending number of the sequence.""" 67 | return self._stop 68 | 69 | @property 70 | def step(self): 71 | """Get the step of the sequence.""" 72 | return self._step 73 | -------------------------------------------------------------------------------- /tfsnippet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import * 2 | from .fashion_mnist import * 3 | from .mnist import * 4 | 5 | __all__ = [ 6 | 'load_cifar10', 'load_cifar100', 'load_fashion_mnist', 'load_mnist', 7 | ] 8 | -------------------------------------------------------------------------------- /tfsnippet/datasets/fashion_mnist.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import hashlib 3 | 4 | import numpy as np 5 | import idx2numpy 6 | 7 | from tfsnippet.utils import CacheDir 8 | 9 | __all__ = ['load_fashion_mnist'] 10 | 11 | 12 | TRAIN_X_URI = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz' 13 | TRAIN_X_MD5 = '8d4fb7e6c68d591d4c3dfef9ec88bf0d' 14 | TRAIN_Y_URI = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz' 15 | TRAIN_Y_MD5 = '25c81989df183df01b3e8a0aad5dffbe' 16 | TEST_X_URI = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz' 17 | TEST_X_MD5 = 'bef4ecab320f06d8554ea6380940ec79' 18 | TEST_Y_URI = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz' 19 | TEST_Y_MD5 = 'bb300cfdad3c16e7a12a480ee83cd310' 20 | 21 | 22 | def _fetch_array(uri, md5): 23 | """Fetch an MNIST array from the `uri` with cache.""" 24 | path = CacheDir('fashion_mnist').download( 25 | uri, hasher=hashlib.md5(), expected_hash=md5) 26 | with gzip.open(path, 'rb') as f: 27 | return idx2numpy.convert_from_file(f) 28 | 29 | 30 | def _validate_x_shape(x_shape): 31 | x_shape = tuple([int(v) for v in x_shape]) 32 | if np.prod(x_shape) != 784: 33 | raise ValueError('`x_shape` does not product to 784: {!r}'. 34 | format(x_shape)) 35 | return x_shape 36 | 37 | 38 | def load_fashion_mnist(x_shape=(28, 28), x_dtype=np.float32, 39 | y_dtype=np.int32, normalize_x=False): 40 | """ 41 | Load the Fashion MNIST dataset as NumPy arrays. 42 | 43 | Homepage: https://github.com/zalandoresearch/fashion-mnist 44 | 45 | Args: 46 | x_shape: Reshape each digit into this shape. Default ``(784,)``. 47 | x_dtype: Cast each digit into this data type. Default `np.float32`. 48 | y_dtype: Cast each label into this data type. Default `np.int32`. 49 | normalize_x (bool): Whether or not to normalize x into ``[0, 1]``, 50 | by dividing each pixel value with 255.? (default :obj:`False`) 51 | 52 | Returns: 53 | (np.ndarray, np.ndarray), (np.ndarray, np.ndarray): The 54 | (train_x, train_y), (test_x, test_y) 55 | """ 56 | # check arguments 57 | x_shape = _validate_x_shape(x_shape) 58 | 59 | # load data 60 | train_x = _fetch_array(TRAIN_X_URI, TRAIN_X_MD5).astype(x_dtype) 61 | train_y = _fetch_array(TRAIN_Y_URI, TRAIN_Y_MD5).astype(y_dtype) 62 | test_x = _fetch_array(TEST_X_URI, TEST_X_MD5).astype(x_dtype) 63 | test_y = _fetch_array(TEST_Y_URI, TEST_Y_MD5).astype(y_dtype) 64 | 65 | assert(len(train_x) == len(train_y) == 60000) 66 | assert(len(test_x) == len(test_y) == 10000) 67 | 68 | # change shape 69 | train_x = train_x.reshape([len(train_x)] + list(x_shape)) 70 | test_x = test_x.reshape([len(test_x)] + list(x_shape)) 71 | 72 | # normalize x 73 | if normalize_x: 74 | train_x /= np.asarray(255., dtype=train_x.dtype) 75 | test_x /= np.asarray(255., dtype=test_x.dtype) 76 | 77 | return (train_x, train_y), (test_x, test_y) 78 | -------------------------------------------------------------------------------- /tfsnippet/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import hashlib 3 | 4 | import numpy as np 5 | import idx2numpy 6 | 7 | from tfsnippet.utils import CacheDir 8 | 9 | __all__ = ['load_mnist'] 10 | 11 | 12 | TRAIN_X_URI = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz' 13 | TRAIN_X_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' 14 | TRAIN_Y_URI = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz' 15 | TRAIN_Y_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432' 16 | TEST_X_URI = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz' 17 | TEST_X_MD5 = '9fb629c4189551a2d022fa330f9573f3' 18 | TEST_Y_URI = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz' 19 | TEST_Y_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c' 20 | 21 | 22 | def _fetch_array(uri, md5): 23 | """Fetch an MNIST array from the `uri` with cache.""" 24 | path = CacheDir('mnist').download( 25 | uri, hasher=hashlib.md5(), expected_hash=md5) 26 | with gzip.open(path, 'rb') as f: 27 | return idx2numpy.convert_from_file(f) 28 | 29 | 30 | def _validate_x_shape(x_shape): 31 | x_shape = tuple([int(v) for v in x_shape]) 32 | if np.prod(x_shape) != 784: 33 | raise ValueError('`x_shape` does not product to 784: {!r}'. 34 | format(x_shape)) 35 | return x_shape 36 | 37 | 38 | def load_mnist(x_shape=(28, 28), x_dtype=np.float32, y_dtype=np.int32, 39 | normalize_x=False): 40 | """ 41 | Load the MNIST dataset as NumPy arrays. 42 | 43 | Args: 44 | x_shape: Reshape each digit into this shape. Default ``(28, 28, 1)``. 45 | x_dtype: Cast each digit into this data type. Default `np.float32`. 46 | y_dtype: Cast each label into this data type. Default `np.int32`. 47 | normalize_x (bool): Whether or not to normalize x into ``[0, 1]``, 48 | by dividing each pixel value with 255.? (default :obj:`False`) 49 | 50 | Returns: 51 | (np.ndarray, np.ndarray), (np.ndarray, np.ndarray): The 52 | (train_x, train_y), (test_x, test_y) 53 | """ 54 | # check arguments 55 | x_shape = _validate_x_shape(x_shape) 56 | 57 | # load data 58 | train_x = _fetch_array(TRAIN_X_URI, TRAIN_X_MD5).astype(x_dtype) 59 | train_y = _fetch_array(TRAIN_Y_URI, TRAIN_Y_MD5).astype(y_dtype) 60 | test_x = _fetch_array(TEST_X_URI, TEST_X_MD5).astype(x_dtype) 61 | test_y = _fetch_array(TEST_Y_URI, TEST_Y_MD5).astype(y_dtype) 62 | 63 | assert(len(train_x) == len(train_y) == 60000) 64 | assert(len(test_x) == len(test_y) == 10000) 65 | 66 | # change shape 67 | train_x = train_x.reshape([len(train_x)] + list(x_shape)) 68 | test_x = test_x.reshape([len(test_x)] + list(x_shape)) 69 | 70 | # normalize x 71 | if normalize_x: 72 | train_x /= np.asarray(255., dtype=train_x.dtype) 73 | test_x /= np.asarray(255., dtype=test_x.dtype) 74 | 75 | return (train_x, train_y), (test_x, test_y) 76 | -------------------------------------------------------------------------------- /tfsnippet/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .batch_to_value import * 3 | from .discretized import * 4 | from .flow import * 5 | from .mixture import * 6 | from .multivariate import * 7 | from .univariate import * 8 | from .utils import * 9 | from .wrapper import * 10 | 11 | __all__ = [ 12 | 'BatchToValueDistribution', 'Bernoulli', 'Categorical', 'Concrete', 13 | 'Discrete', 'DiscretizedLogistic', 'Distribution', 'ExpConcrete', 14 | 'FlowDistribution', 'FlowDistributionDerivedTensor', 'Mixture', 'Normal', 15 | 'OnehotCategorical', 'Uniform', 'as_distribution', 'reduce_group_ndims', 16 | ] 17 | -------------------------------------------------------------------------------- /tfsnippet/distributions/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.utils import is_tensor_object, validate_group_ndims_arg 4 | 5 | __all__ = ['reduce_group_ndims'] 6 | 7 | 8 | def reduce_group_ndims(operation, tensor, group_ndims, name=None): 9 | """ 10 | Reduce the last `group_ndims` dimensions in `tensor`, using `operation`. 11 | 12 | In :class:`~tfsnippet.distributions.Distribution`, when computing the 13 | (log-)densities of certain `tensor`, the last few dimensions 14 | may represent a group of events, thus should be accounted together. 15 | This method can be used to reduce these dimensions, for example: 16 | 17 | .. code-block:: python 18 | 19 | log_prob = reduce_group_ndims(tf.reduce_sum, log_prob, group_ndims) 20 | prob = reduce_group_ndims(tf.reduce_prod, log_prob, group_ndims) 21 | 22 | Args: 23 | operation: The operation for reducing the last `group_ndims` 24 | dimensions. It must receive `tensor` as the 1st argument, and 25 | `axis` as the 2nd argument. 26 | tensor: The tensor to be reduced. 27 | group_ndims: The number of dimensions at the end of `tensor` to be 28 | reduced. If it is a constant integer and is zero, then no 29 | operation will take place. 30 | name: TensorFlow name scope of the graph nodes. (default 31 | "reduce_group_ndims") 32 | 33 | Returns: 34 | tf.Tensor: The reduced tensor. 35 | 36 | Raises: 37 | ValueError: If `group_ndims` cannot be validated by 38 | :meth:`validate_group_ndims`. 39 | """ 40 | group_ndims = validate_group_ndims_arg(group_ndims) 41 | with tf.name_scope(name, default_name='reduce_group_ndims'): 42 | if is_tensor_object(group_ndims): 43 | tensor = tf.cond( 44 | group_ndims > 0, 45 | lambda: operation(tensor, tf.range(-group_ndims, 0)), 46 | lambda: tensor 47 | ) 48 | else: 49 | if group_ndims > 0: 50 | tensor = operation(tensor, tf.range(-group_ndims, 0)) 51 | return tensor 52 | 53 | 54 | def compute_density_immediately(t): 55 | """ 56 | Compute the prob and log_prob of `t` immediately. 57 | 58 | Args: 59 | t (StochasticTensor): The stochastic tensor. 60 | """ 61 | with tf.name_scope('compute_density_immediately'): 62 | log_p = t.log_prob() 63 | t._self_prob = tf.exp(log_p) 64 | -------------------------------------------------------------------------------- /tfsnippet/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .collect_outputs_ import * 2 | 3 | __all__ = [ 4 | 'collect_outputs', 5 | ] 6 | -------------------------------------------------------------------------------- /tfsnippet/examples/README.rst: -------------------------------------------------------------------------------- 1 | Examples for TFSnippet 2 | ====================== 3 | 4 | If you want to run any example script, you must first install the dependencies 5 | from `requirements-dev.txt`. 6 | -------------------------------------------------------------------------------- /tfsnippet/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tfsnippet/examples/__init__.py -------------------------------------------------------------------------------- /tfsnippet/examples/auto_encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tfsnippet/examples/auto_encoders/__init__.py -------------------------------------------------------------------------------- /tfsnippet/examples/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haowen-xu/tfsnippet/63adaf04d2ffff8dec299623627d55d4bacac598/tfsnippet/examples/classification/__init__.py -------------------------------------------------------------------------------- /tfsnippet/examples/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataflows_factory import * 2 | from .evaluation import * 3 | from .graph import * 4 | from .jsonutils import * 5 | from .misc import * 6 | from .mlconfig import * 7 | from .mlresults import * 8 | from .multi_gpu import * 9 | -------------------------------------------------------------------------------- /tfsnippet/examples/utils/dataflows_factory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tfsnippet.dataflows import DataFlow 4 | from tfsnippet.preprocessing import BernoulliSampler 5 | 6 | __all__ = ['bernoulli_flow'] 7 | 8 | 9 | def _create_sampled_dataflow(arrays, sampler, sample_now, **kwargs): 10 | if sample_now: 11 | arrays = sampler(*arrays) 12 | df = DataFlow.arrays(arrays, **kwargs) 13 | if not sample_now: 14 | df = df.map(sampler) 15 | return df 16 | 17 | 18 | def bernoulli_flow(x, batch_size, shuffle=False, skip_incomplete=False, 19 | sample_now=False, dtype=np.int32, random_state=None): 20 | """ 21 | Construct a new :class:`DataFlow`, which samples 0/1 binary images 22 | according to the given `x` array. 23 | 24 | Args: 25 | x: The `train_x` or `test_x` of an image dataset. The pixel values 26 | must be 8-bit integers, having the range of ``[0, 255]``. 27 | batch_size (int): Size of each mini-batch. 28 | shuffle (bool): Whether or not to shuffle data before iterating? 29 | (default :obj:`False`) 30 | skip_incomplete (bool): Whether or not to exclude the last 31 | mini-batch if it is incomplete? (default :obj:`False`) 32 | sample_now (bool): Whether or not to sample immediately instead 33 | of sampling at the beginning of each epoch? (default :obj:`False`) 34 | dtype: The data type of the sampled array. Default `np.int32`. 35 | random_state (RandomState): Optional numpy RandomState for 36 | shuffling data before each epoch. (default :obj:`None`, 37 | construct a new :class:`RandomState`). 38 | 39 | Returns: 40 | DataFlow: The Bernoulli `x` flow. 41 | """ 42 | x = np.asarray(x) 43 | 44 | # prepare the sampler 45 | x = x / np.asarray(255., dtype=x.dtype) 46 | sampler = BernoulliSampler(dtype=dtype, random_state=random_state) 47 | 48 | # compose the data flow 49 | return _create_sampled_dataflow( 50 | [x], sampler, sample_now, batch_size=batch_size, shuffle=shuffle, 51 | skip_incomplete=skip_incomplete, random_state=random_state 52 | ) 53 | -------------------------------------------------------------------------------- /tfsnippet/examples/utils/graph.py: -------------------------------------------------------------------------------- 1 | import six 2 | import tensorflow as tf 3 | 4 | __all__ = [ 5 | 'add_name_scope', 6 | 'add_variable_scope', 7 | ] 8 | 9 | 10 | def add_name_scope(method): 11 | """ 12 | Automatically open a new name scope when calling the method. 13 | 14 | Usage:: 15 | 16 | @add_name_scope 17 | def dense(inputs, name=None): 18 | return tf.layers.dense(inputs) 19 | 20 | Args: 21 | method: The method to decorate. It must accept an optional named 22 | argument `name`, to receive the inbound name argument. 23 | If the `name` argument is not specified as named argument during 24 | calling, the name of the method will be used as `name`. 25 | 26 | Returns: 27 | The decorated method. 28 | """ 29 | method_name = method.__name__ 30 | 31 | @six.wraps(method) 32 | def wrapper(*args, **kwargs): 33 | if kwargs.get('name') is None: 34 | kwargs['name'] = method_name 35 | with tf.name_scope(kwargs['name']): 36 | return method(*args, **kwargs) 37 | return wrapper 38 | 39 | 40 | def add_variable_scope(method): 41 | """ 42 | Automatically open a new variable scope when calling the method. 43 | 44 | Usage:: 45 | 46 | @add_variable_scope 47 | def dense(inputs): 48 | return tf.layers.dense(inputs) 49 | 50 | Args: 51 | method: The method to decorate. 52 | If the `name` argument is not specified as named argument during 53 | calling, the name of the method will be used as `name`. 54 | 55 | Returns: 56 | The decorated method. 57 | """ 58 | method_name = method.__name__ 59 | 60 | @six.wraps(method) 61 | def wrapper(*args, **kwargs): 62 | name = kwargs.pop('name', None) 63 | with tf.variable_scope(name, default_name=method_name): 64 | return method(*args, **kwargs) 65 | return wrapper 66 | -------------------------------------------------------------------------------- /tfsnippet/examples/utils/misc.py: -------------------------------------------------------------------------------- 1 | import six 2 | 3 | from tfsnippet.utils import is_integer 4 | 5 | __all__ = [ 6 | 'validate_strides_or_kernel_size', 7 | 'cached', 8 | 'print_with_title', 9 | ] 10 | 11 | 12 | def validate_strides_or_kernel_size(arg_name, arg_value): 13 | """ 14 | Validate the `strides` or `filter` arg, to ensure it is a tuple of 15 | two integers. 16 | 17 | Args: 18 | arg_name (str): The name of the argument, for formatting error. 19 | arg_value: The value of the argument. 20 | 21 | Returns: 22 | (int, int): The validated argument. 23 | """ 24 | 25 | if not is_integer(arg_value) and (not isinstance(arg_value, tuple) or 26 | len(arg_value) != 2 or 27 | not is_integer(arg_value[0]) or 28 | not is_integer(arg_value[1])): 29 | raise TypeError('`{}` must be a int or a tuple (int, int).'. 30 | format(arg_name)) 31 | if not isinstance(arg_value, tuple): 32 | arg_value = (arg_value, arg_value) 33 | arg_value = tuple(int(v) for v in arg_value) 34 | return arg_value 35 | 36 | 37 | def cached(method): 38 | """ 39 | Decorate `method`, to cache its result. 40 | 41 | Args: 42 | method: The method whose result should be cached. 43 | 44 | Returns: 45 | The decorated method. 46 | """ 47 | results = {} 48 | 49 | @six.wraps(method) 50 | def wrapper(*args, **kwargs): 51 | cache_key = (args, tuple((k, kwargs[k]) for k, v in sorted(kwargs))) 52 | if cache_key not in results: 53 | results[cache_key] = method(*args, **kwargs) 54 | return results[cache_key] 55 | 56 | return wrapper 57 | 58 | 59 | def print_with_title(title, content, before='', after='', hl='='): 60 | """ 61 | Print a content section with title. 62 | 63 | Args: 64 | title (str): The title of the section. 65 | content (str): The multi-line content. 66 | before (str): String to print before the title. 67 | after (str): String to print after the content. 68 | hl (str): The character for horizon line. 69 | """ 70 | cont_maxlen = max(len(s) for s in content.split('\n')) 71 | hl_len = max(cont_maxlen, len(title)) 72 | print('{}{}\n{}\n{}{}'.format(before, title, hl * hl_len, content, after)) 73 | -------------------------------------------------------------------------------- /tfsnippet/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .base import * 3 | from .convolutional import * 4 | from .core import * 5 | from .flows import * 6 | from .initialization import * 7 | from .normalization import * 8 | from .regularization import * 9 | from .utils import * 10 | 11 | __all__ = [ 12 | 'ActNorm', 'BaseFlow', 'BaseLayer', 'CouplingLayer', 'FeatureMappingFlow', 13 | 'FeatureShufflingFlow', 'InvertFlow', 'InvertibleActivation', 14 | 'InvertibleActivationFlow', 'InvertibleConv2d', 'InvertibleDense', 15 | 'LeakyReLU', 'MultiLayerFlow', 'PixelCNN2DOutput', 'PlanarNormalizingFlow', 16 | 'ReshapeFlow', 'SequentialFlow', 'SpaceToDepthFlow', 'SplitFlow', 17 | 'act_norm', 'as_gated', 'avg_pool2d', 'broadcast_log_det_against_input', 18 | 'conv2d', 'deconv2d', 'default_kernel_initializer', 'dense', 'dropout', 19 | 'global_avg_pool2d', 'l2_regularizer', 'max_pool2d', 'pixelcnn_2d_input', 20 | 'pixelcnn_2d_output', 'pixelcnn_conv2d_resnet', 'planar_normalizing_flows', 21 | 'resnet_conv2d_block', 'resnet_deconv2d_block', 'resnet_general_block', 22 | 'shifted_conv2d', 'weight_norm', 23 | ] 24 | -------------------------------------------------------------------------------- /tfsnippet/layers/activations/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .leaky_relu import * 3 | 4 | __all__ = [ 5 | 'InvertibleActivation', 'InvertibleActivationFlow', 'LeakyReLU', 6 | ] 7 | -------------------------------------------------------------------------------- /tfsnippet/layers/activations/leaky_relu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .base import InvertibleActivation 5 | 6 | __all__ = ['LeakyReLU'] 7 | 8 | 9 | class LeakyReLU(InvertibleActivation): 10 | """ 11 | Leaky ReLU activation function. 12 | 13 | `y = x if x >= 0 else alpha * x` 14 | """ 15 | 16 | def __init__(self, alpha=0.2): 17 | alpha = float(alpha) 18 | if alpha <= 0 or alpha >= 1: 19 | raise ValueError('`alpha` must be a float number, and 0 < alpha < ' 20 | '1: got {}'.format(alpha)) 21 | 22 | self._alpha = alpha 23 | self._inv_alpha = 1. / alpha 24 | self._log_alpha = float(np.log(alpha)) 25 | 26 | def _transform_or_inverse_transform(self, x, compute_y, compute_log_det, 27 | reverse=False): 28 | y = None 29 | if compute_y: 30 | if reverse: 31 | y = tf.minimum(x * self._inv_alpha, x) 32 | else: 33 | y = tf.maximum(x * self._alpha, x) 34 | 35 | log_det = None 36 | if compute_log_det: 37 | log_det = tf.cast(tf.less(x, 0), dtype=tf.float32) 38 | if reverse: 39 | log_det *= -self._log_alpha 40 | else: 41 | log_det *= self._log_alpha 42 | 43 | return y, log_det 44 | 45 | def _transform(self, x, compute_y, compute_log_det): 46 | return self._transform_or_inverse_transform( 47 | x=x, compute_y=compute_y, compute_log_det=compute_log_det, 48 | reverse=False 49 | ) 50 | 51 | def _inverse_transform(self, y, compute_x, compute_log_det): 52 | return self._transform_or_inverse_transform( 53 | x=y, compute_y=compute_x, compute_log_det=compute_log_det, 54 | reverse=True 55 | ) 56 | -------------------------------------------------------------------------------- /tfsnippet/layers/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tfsnippet.shortcuts import VarScopeObject 5 | from tfsnippet.utils import (DocInherit, add_name_and_scope_arg_doc, 6 | get_default_scope_name, is_tensor_object, 7 | reopen_variable_scope) 8 | 9 | __all__ = ['BaseLayer'] 10 | 11 | 12 | # We choose to derive `BaseLayer` from `VarScopeObject`, even if it does not 13 | # need such a VarScopeObject. This is because we can enjoy having a 14 | # uniquified layer name for each Layer object, and add its name to the name 15 | # scopes generated by its method, so as to make the debugging messages of 16 | # TensorFlow much clearer. 17 | 18 | @DocInherit 19 | class BaseLayer(VarScopeObject): 20 | """ 21 | Base class for all neural network layers. 22 | """ 23 | 24 | _build_require_input = False #: whether or not `build` requires input 25 | 26 | @add_name_and_scope_arg_doc 27 | def __init__(self, 28 | name=None, 29 | scope=None): 30 | """ 31 | Construct a new :class:`BaseLayer`. 32 | """ 33 | super(BaseLayer, self).__init__(name=name, scope=scope) 34 | 35 | self._has_built = False 36 | 37 | def _build(self, input=None): 38 | raise NotImplementedError() 39 | 40 | def build(self, input=None): 41 | """ 42 | Build the layer, creating all required variables. 43 | 44 | Args: 45 | input (Tensor or list[Tensor] or None): If :meth:`build` is called 46 | within :meth:`apply`, it will be the input tensor(s). 47 | Otherwise if it is called separately, it will be :obj:`None`. 48 | """ 49 | if self._has_built: 50 | raise RuntimeError('Layer has already been built: {!r}'. 51 | format(self)) 52 | if self._build_require_input and input is None: 53 | raise ValueError('`{}` requires `input` to build.'. 54 | format(self.__class__.__name__)) 55 | with reopen_variable_scope(self.variable_scope): 56 | self._build(input) 57 | self._has_built = True 58 | 59 | def _apply(self, input): 60 | raise NotImplementedError() 61 | 62 | def apply(self, input): 63 | """ 64 | Apply the layer on `input`, to produce output. 65 | 66 | Args: 67 | input (Tensor or list[Tensor]): The input tensor, or a list of 68 | input tensors. 69 | 70 | Returns: 71 | The output tensor, or a list of output tensors. 72 | """ 73 | if is_tensor_object(input) or isinstance(input, np.ndarray): 74 | input = tf.convert_to_tensor(input) 75 | ns_values = [input] 76 | else: 77 | input = [tf.convert_to_tensor(i) for i in input] 78 | ns_values = input 79 | 80 | if not self._has_built: 81 | self.build(input) 82 | 83 | with tf.name_scope(get_default_scope_name('apply', self), 84 | values=ns_values): 85 | return self._apply(input) 86 | 87 | def __call__(self, input): 88 | return self.apply(input) 89 | -------------------------------------------------------------------------------- /tfsnippet/layers/convolutional/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv2d_ import * 2 | from .pixelcnn import * 3 | from .pooling import * 4 | from .resnet import * 5 | from .shifted import * 6 | 7 | __all__ = [ 8 | 'PixelCNN2DOutput', 'avg_pool2d', 'conv2d', 'deconv2d', 9 | 'global_avg_pool2d', 'max_pool2d', 'pixelcnn_2d_input', 10 | 'pixelcnn_2d_output', 'pixelcnn_conv2d_resnet', 'resnet_conv2d_block', 11 | 'resnet_deconv2d_block', 'resnet_general_block', 'shifted_conv2d', 12 | ] 13 | -------------------------------------------------------------------------------- /tfsnippet/layers/convolutional/utils.py: -------------------------------------------------------------------------------- 1 | from tfsnippet.utils import (validate_int_tuple_arg, InputSpec, 2 | get_static_shape, validate_enum_arg) 3 | 4 | 5 | def validate_conv2d_input(input, channels_last, arg_name='input'): 6 | """ 7 | Validate the input for 2-d convolution. 8 | 9 | Args: 10 | input: The input tensor, must be at least 4-d. 11 | channels_last (bool): Whether or not the last dimension is the 12 | channels dimension? (i.e., `data_format` is "NHWC") 13 | arg_name (str): Name of the input argument. 14 | 15 | Returns: 16 | (tf.Tensor, int, str): The validated input tensor, the number of input 17 | channels, and the data format. 18 | """ 19 | if channels_last: 20 | input_spec = InputSpec(shape=('...', '?', '?', '?', '*')) 21 | channel_axis = -1 22 | data_format = 'NHWC' 23 | else: 24 | input_spec = InputSpec(shape=('...', '?', '*', '?', '?')) 25 | channel_axis = -3 26 | data_format = 'NCHW' 27 | input = input_spec.validate(arg_name, input) 28 | input_shape = get_static_shape(input) 29 | in_channels = input_shape[channel_axis] 30 | 31 | return input, in_channels, data_format 32 | 33 | 34 | def validate_conv2d_size_tuple(arg_name, arg_value): 35 | """ 36 | Validate the `arg_value`, ensure it is one or two positive integers, 37 | such that it can be used as the kernel size. 38 | 39 | Args: 40 | arg_name: Name of the argument. 41 | arg_value: An integer, or a tuple of two integers. 42 | 43 | Returns: 44 | (int, int): The validated two integers. 45 | """ 46 | arg_value = validate_int_tuple_arg(arg_name, arg_value) 47 | if len(arg_value) not in (1, 2) or any(a < 1 for a in arg_value): 48 | raise ValueError('Invalid value for argument `{}`: expected to be ' 49 | 'one or two positive integers, but got {!r}.'. 50 | format(arg_name, arg_value)) 51 | if len(arg_value) == 1: 52 | arg_value = arg_value * 2 53 | return arg_value 54 | 55 | 56 | def validate_conv2d_strides_tuple(arg_name, arg_value, channels_last): 57 | """ 58 | Validate the `arg_value`, ensure it is one or two positive integers, 59 | such that is can be used as the strides. 60 | 61 | Args: 62 | arg_name: Name of the argument. 63 | arg_value: An integer, or a tuple of two integers. 64 | channels_last: Whether or not the last axis is the channel dimension? 65 | 66 | Returns: 67 | (int, int, int, int): The validated two integers, plus two `1` as 68 | the strides for batch and channels dimensions. 69 | """ 70 | value = validate_conv2d_size_tuple(arg_name, arg_value) 71 | if channels_last: 72 | value = (1,) + value + (1,) 73 | else: 74 | value = (1, 1) + value 75 | return value 76 | 77 | 78 | def get_deconv_output_length(input_length, kernel_size, strides, padding): 79 | """ 80 | Get the output length of deconvolution at a specific dimension. 81 | 82 | Args: 83 | input_length: Input tensor length. 84 | kernel_size: The size of the kernel. 85 | strides: The stride of convolution. 86 | padding: One of {"same", "valid"}, case in-sensitive 87 | 88 | Returns: 89 | int: The output length of deconvolution. 90 | """ 91 | padding = validate_enum_arg( 92 | 'padding', str(padding).upper(), ['SAME', 'VALID']) 93 | output_length = input_length * strides 94 | if padding == 'VALID': 95 | output_length += max(kernel_size - strides, 0) 96 | return output_length 97 | -------------------------------------------------------------------------------- /tfsnippet/layers/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .dense_ import * 2 | from .dropout_ import * 3 | from .gated import * 4 | 5 | __all__ = [ 6 | 'as_gated', 'dense', 'dropout', 7 | ] 8 | -------------------------------------------------------------------------------- /tfsnippet/layers/core/dropout_.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework import add_arg_scope 3 | 4 | from tfsnippet.ops import smart_cond, convert_to_tensor_and_cast 5 | from tfsnippet.utils import add_name_arg_doc, get_shape 6 | 7 | __all__ = ['dropout'] 8 | 9 | 10 | @add_arg_scope 11 | @add_name_arg_doc 12 | def dropout(input, rate=.5, noise_shape=None, training=False, name=None): 13 | """ 14 | Apply dropout on `input`. 15 | 16 | Args: 17 | input (Tensor): The input tensor. 18 | rate (float or tf.Tensor): The rate of dropout. 19 | noise_shape (tuple[int] or tf.Tensor): Shape of the noise. 20 | If not specified, use the shape of `input`. 21 | training (bool or tf.Tensor): Whether or not the model is under 22 | training stage? 23 | 24 | Returns: 25 | tf.Tensor: The dropout transformed tensor. 26 | """ 27 | input = tf.convert_to_tensor(input) 28 | 29 | with tf.name_scope(name, default_name='dropout', values=[input]): 30 | dtype = input.dtype.base_dtype 31 | retain_prob = convert_to_tensor_and_cast(1. - rate, dtype=dtype) 32 | inv_retain_prob = 1. / retain_prob 33 | if noise_shape is None: 34 | noise_shape = get_shape(input) 35 | 36 | def training_branch(): 37 | noise = tf.random_uniform( 38 | shape=noise_shape, minval=0., maxval=1., dtype=dtype) 39 | mask = tf.cast(noise < retain_prob, dtype=dtype) 40 | return input * mask * inv_retain_prob 41 | 42 | def testing_branch(): 43 | return input 44 | 45 | return smart_cond( 46 | training, 47 | training_branch, 48 | testing_branch, 49 | ) 50 | -------------------------------------------------------------------------------- /tfsnippet/layers/core/gated.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import tensorflow as tf 4 | from tensorflow.contrib.framework import add_arg_scope 5 | 6 | from tfsnippet.utils import maybe_add_histogram 7 | 8 | __all__ = ['as_gated'] 9 | 10 | 11 | def as_gated(layer_fn, sigmoid_bias=2., default_name=None): 12 | """ 13 | Wrap a layer function into a gated layer function. 14 | 15 | For example, the following `gated_dense` function:: 16 | 17 | @add_arg_scope 18 | def gated_dense(inputs, units, activation_fn=None, sigmoid_bias=2., 19 | name=None, scope=None, **kwargs): 20 | with tf.name_scope(scope, default_name=name): 21 | gate = tf.sigmoid(sigmoid_bias + 22 | dense(inputs, units, scope='gate', **kwargs)) 23 | return gate * dense( 24 | inputs, units, activation_fn=activation_fn, scope='main', 25 | **kwargs 26 | ) 27 | 28 | can be deduced by applying this function:: 29 | 30 | gated_dense = as_gated(dense) 31 | 32 | Args: 33 | layer_fn: The layer function to be wrapped. 34 | sigmoid_bias: The constant bias added to the `gate` before 35 | applying the sigmoid activation. 36 | default_name: Default name of variable scope. 37 | 38 | Returns: 39 | The wrapped gated layer function. 40 | 41 | Notes: 42 | If a layer supports `gated` argument (e.g., :func:`spt.layers.dense`), 43 | it is generally better to use that argument, instead of using this 44 | :func:`as_gated` wrapper on the layer. 45 | """ 46 | if not default_name: 47 | if getattr(layer_fn, '__name__', None): 48 | default_name = 'gated_' + layer_fn.__name__ 49 | if not default_name: 50 | raise ValueError('`default_name` cannot be inferred, you must specify ' 51 | 'this argument.') 52 | 53 | @add_arg_scope 54 | def gated_layer(*args, **kwargs): 55 | name = kwargs.pop('name', None) 56 | scope = kwargs.pop('scope', None) 57 | activation_fn = kwargs.pop('activation_fn', None) 58 | 59 | with tf.variable_scope(scope, default_name=name or default_name): 60 | # the gate branch 61 | gate_kwargs = copy.copy(kwargs) 62 | gate_kwargs['scope'] = 'gate' 63 | gate = tf.sigmoid(sigmoid_bias + layer_fn(*args, **gate_kwargs)) 64 | 65 | # the main branch 66 | main_kwargs = copy.copy(kwargs) 67 | main_kwargs['scope'] = 'main' 68 | main_kwargs['activation_fn'] = activation_fn 69 | main = layer_fn(*args, **main_kwargs) 70 | 71 | # compose the final output 72 | output = main * gate 73 | maybe_add_histogram(output, 'output') 74 | 75 | return output 76 | 77 | return gated_layer 78 | -------------------------------------------------------------------------------- /tfsnippet/layers/flows/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .branch import * 3 | from .coupling import * 4 | from .invert import * 5 | from .linear import * 6 | from .planar_nf import * 7 | from .rearrangement import * 8 | from .reshape import * 9 | from .sequential import * 10 | from .utils import * 11 | 12 | __all__ = [ 13 | 'BaseFlow', 'CouplingLayer', 'FeatureMappingFlow', 'FeatureShufflingFlow', 14 | 'InvertFlow', 'InvertibleConv2d', 'InvertibleDense', 'MultiLayerFlow', 15 | 'PlanarNormalizingFlow', 'ReshapeFlow', 'SequentialFlow', 16 | 'SpaceToDepthFlow', 'SplitFlow', 'broadcast_log_det_against_input', 17 | 'planar_normalizing_flows', 18 | ] 19 | -------------------------------------------------------------------------------- /tfsnippet/layers/flows/invert.py: -------------------------------------------------------------------------------- 1 | from .base import BaseFlow 2 | 3 | __all__ = ['InvertFlow'] 4 | 5 | 6 | class InvertFlow(BaseFlow): 7 | """ 8 | Turn a :class:`BaseFlow` into its inverted flow. 9 | 10 | This class is particularly useful when the flow is (theoretically) defined 11 | in the opposite direction to the direction of network initialization. 12 | For example, define `z -> x`, but initialized by feeding `x`. 13 | """ 14 | 15 | def __init__(self, flow, name=None, scope=None): 16 | """ 17 | Construct a new :class:`InvertFlow`. 18 | 19 | Args: 20 | flow (BaseFlow): The underlying flow. 21 | """ 22 | if not isinstance(flow, BaseFlow) or not flow.explicitly_invertible: 23 | raise ValueError('`flow` must be an explicitly invertible flow: ' 24 | 'got {!r}'.format(flow)) 25 | self._flow = flow 26 | 27 | super(InvertFlow, self).__init__( 28 | x_value_ndims=flow.y_value_ndims, 29 | y_value_ndims=flow.x_value_ndims, 30 | require_batch_dims=flow.require_batch_dims, 31 | name=name, 32 | scope=scope 33 | ) 34 | 35 | def invert(self): 36 | """ 37 | Get the original flow, inverted by this :class:`InvertFlow`. 38 | 39 | Returns: 40 | BaseFlow: The original flow. 41 | """ 42 | return self._flow 43 | 44 | @property 45 | def explicitly_invertible(self): 46 | return True 47 | 48 | def build(self, input=None): # pragma: no cover 49 | # since `flow` should be inverted, we should build `flow` in 50 | # `inverse_transform` rather than in `transform` or `build` 51 | pass 52 | 53 | def transform(self, x, compute_y=True, compute_log_det=True, name=None): 54 | return self._flow.inverse_transform( 55 | y=x, compute_x=compute_y, compute_log_det=compute_log_det, 56 | name=name 57 | ) 58 | 59 | def inverse_transform(self, y, compute_x=True, compute_log_det=True, 60 | name=None): 61 | return self._flow.transform( 62 | x=y, compute_y=compute_x, compute_log_det=compute_log_det, 63 | name=name 64 | ) 65 | 66 | def _build(self, input=None): 67 | raise RuntimeError('Should never be called.') # pragma: no cover 68 | 69 | def _transform(self, x, compute_y, compute_log_det): 70 | raise RuntimeError('Should never be called.') # pragma: no cover 71 | 72 | def _inverse_transform(self, y, compute_x, compute_log_det): 73 | raise RuntimeError('Should never be called.') # pragma: no cover 74 | -------------------------------------------------------------------------------- /tfsnippet/layers/flows/sequential.py: -------------------------------------------------------------------------------- 1 | from tfsnippet.utils import add_name_and_scope_arg_doc 2 | from .base import BaseFlow, MultiLayerFlow 3 | 4 | __all__ = ['SequentialFlow'] 5 | 6 | 7 | class SequentialFlow(MultiLayerFlow): 8 | """ 9 | Compose a large flow from a sequential of :class:`BaseFlow`. 10 | """ 11 | 12 | @add_name_and_scope_arg_doc 13 | def __init__(self, flows, name=None, scope=None): 14 | """ 15 | Construct a new :class:`SequentialFlow`. 16 | 17 | Args: 18 | flows (Iterable[BaseFlow]): The flow list. 19 | """ 20 | flows = tuple(flows) # type: tuple[BaseFlow] 21 | if not flows: 22 | raise TypeError('`flows` must not be empty.') 23 | 24 | for i, flow in enumerate(flows): 25 | if not isinstance(flow, BaseFlow): 26 | raise TypeError('The {}-th flow in `flows` is not an instance ' 27 | 'of `BaseFlow`: {!r}'.format(i, flow)) 28 | 29 | for i, (flow1, flow2) in enumerate(zip(flows[:-1], flows[1:])): 30 | if flow2.x_value_ndims != flow1.y_value_ndims: 31 | raise TypeError( 32 | '`x_value_ndims` of the {}-th flow != `y_value_ndims` ' 33 | 'of the {}-th flow: {} vs {}.'. 34 | format(i + 1, i, flow2.x_value_ndims, flow1.y_value_ndims) 35 | ) 36 | 37 | super(SequentialFlow, self).__init__( 38 | n_layers=len(flows), x_value_ndims=flows[0].x_value_ndims, 39 | y_value_ndims=flows[-1].y_value_ndims, name=name, scope=scope 40 | ) 41 | self._flows = flows 42 | self._explicitly_invertible = \ 43 | all(flow.explicitly_invertible for flow in flows) 44 | 45 | def _build(self, input=None): 46 | # do nothing, the building procedure of every flows are automatically 47 | # called within their `apply` methods. 48 | pass 49 | 50 | @property 51 | def flows(self): 52 | """ 53 | Get the immutable flow list. 54 | 55 | Returns: 56 | tuple[BaseFlow]: The immutable flow list. 57 | """ 58 | return self._flows 59 | 60 | @property 61 | def explicitly_invertible(self): 62 | return self._explicitly_invertible 63 | 64 | def _transform_layer(self, layer_id, x, compute_y, compute_log_det): 65 | flow = self._flows[layer_id] 66 | return flow.transform(x, compute_y, compute_log_det) 67 | 68 | def _inverse_transform_layer(self, layer_id, y, compute_x, compute_log_det): 69 | flow = self._flows[layer_id] 70 | return flow.inverse_transform(y, compute_x, compute_log_det) 71 | -------------------------------------------------------------------------------- /tfsnippet/layers/initialization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | __all__ = [ 4 | 'default_kernel_initializer', 5 | ] 6 | 7 | 8 | def default_kernel_initializer(weight_norm=False): 9 | """ 10 | Get the default initializer for layer kernels (i.e., `W` of layers). 11 | 12 | Args: 13 | weight_norm: Whether or not to apply weight normalization 14 | (Salimans & Kingma, 2016) on the kernel? If is not :obj:`False` 15 | or :obj:`None`, will use ``tf.random_normal_initializer(0, .05)``. 16 | 17 | Returns: 18 | The default initializer for kernels. 19 | """ 20 | if weight_norm not in (False, None): 21 | return tf.random_normal_initializer(0., .05) 22 | else: 23 | return tf.glorot_normal_initializer() 24 | -------------------------------------------------------------------------------- /tfsnippet/layers/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | from .act_norm_ import * 2 | from .weight_norm_ import * 3 | 4 | __all__ = [ 5 | 'ActNorm', 'act_norm', 'weight_norm', 6 | ] 7 | -------------------------------------------------------------------------------- /tfsnippet/layers/regularization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.utils import add_name_arg_doc 4 | 5 | __all__ = ['l2_regularizer'] 6 | 7 | 8 | @add_name_arg_doc 9 | def l2_regularizer(lambda_, name=None): 10 | """ 11 | Construct an L2 regularizer that computes the L2 regularization loss:: 12 | 13 | output = lambda_ * 0.5 * sum(input ** 2) 14 | 15 | Args: 16 | lambda_ (float or Tensor or None): The coefficiency of L2 regularizer. 17 | If `lambda_` is :obj:`None`, will return :obj:`None`. 18 | 19 | Returns: 20 | (tf.Tensor) -> tf.Tensor: A function that computes the L2 21 | regularization term for input tensor. Will be :obj:`None` 22 | if `lambda_` is :obj:`None`. 23 | """ 24 | if lambda_ is None: 25 | return None 26 | 27 | def regularizer(input): 28 | input = tf.convert_to_tensor(input) 29 | with tf.name_scope(name, default_name='l2_regularization', 30 | values=[input]): 31 | return lambda_ * tf.nn.l2_loss(input) 32 | 33 | return regularizer 34 | -------------------------------------------------------------------------------- /tfsnippet/layers/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | __all__ = [] 4 | 5 | 6 | def validate_weight_norm_arg(weight_norm, axis, use_scale): 7 | """ 8 | Validate the specified `weight_norm` argument. 9 | 10 | Args: 11 | weight_norm (bool or (tf.Tensor) -> tf.Tensor)): 12 | If :obj:`True`, wraps :func:`~tfsnippet.layers.weight_norm` 13 | with `axis` and `use_scale` argument. If a callable function, 14 | it will be returned directly. 15 | axis (int): The axis argument for `weight_norm`. 16 | use_scale (bool): The `use_scale` argument for `weight_norm`. 17 | 18 | Returns: 19 | None or (tf.Tensor) -> tf.Tensor: The weight normalization function, 20 | or None if weight normalization is not enabled. 21 | """ 22 | from .normalization import weight_norm as weight_norm_fn 23 | if callable(weight_norm): 24 | return weight_norm 25 | elif weight_norm is True: 26 | return functools.partial(weight_norm_fn, axis=axis, use_scale=use_scale) 27 | elif weight_norm in (False, None): 28 | return None 29 | else: 30 | raise TypeError('Invalid value for argument `weight_norm`: expected ' 31 | 'a bool or a callable function, got {!r}.'. 32 | format(weight_norm)) 33 | -------------------------------------------------------------------------------- /tfsnippet/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .assertions import * 2 | from .classification import * 3 | from .control_flows import * 4 | from .convolution import * 5 | from .evaluation import * 6 | from .loop import * 7 | from .misc import * 8 | from .shape_utils import * 9 | from .shifting import * 10 | from .type_utils import * 11 | 12 | __all__ = [ 13 | 'add_n_broadcast', 'assert_rank', 'assert_rank_at_least', 14 | 'assert_scalar_equal', 'assert_shape_equal', 'bits_per_dimension', 15 | 'broadcast_concat', 'broadcast_to_shape', 'broadcast_to_shape_strict', 16 | 'classification_accuracy', 'convert_to_tensor_and_cast', 'depth_to_space', 17 | 'flatten_to_ndims', 'log_mean_exp', 'log_sum_exp', 'maybe_clip_value', 18 | 'pixelcnn_2d_sample', 'prepend_dims', 'reshape_tail', 'shift', 19 | 'smart_cond', 'softmax_classification_output', 'space_to_depth', 20 | 'transpose_conv2d_axis', 'transpose_conv2d_channels_last_to_x', 21 | 'transpose_conv2d_channels_x_to_last', 'unflatten_from_ndims', 22 | ] 23 | -------------------------------------------------------------------------------- /tfsnippet/ops/classification.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.utils import add_name_arg_doc, InputSpec, get_static_shape 4 | 5 | __all__ = [ 6 | 'classification_accuracy', 7 | 'softmax_classification_output', 8 | ] 9 | 10 | 11 | @add_name_arg_doc 12 | def classification_accuracy(y_pred, y_true, name=None): 13 | """ 14 | Compute the classification accuracy for `y_pred` and `y_true`. 15 | 16 | Args: 17 | y_pred: The predicted labels. 18 | y_true: The ground truth labels. Its shape must match `y_pred`. 19 | 20 | Returns: 21 | tf.Tensor: The accuracy. 22 | """ 23 | y_pred = tf.convert_to_tensor(y_pred) 24 | y_true = InputSpec(shape=get_static_shape(y_pred)). \ 25 | validate('y_true', y_true) 26 | with tf.name_scope(name, default_name='classification_accuracy', 27 | values=[y_pred, y_true]): 28 | return tf.reduce_mean( 29 | tf.cast(tf.equal(y_pred, y_true), dtype=tf.float32)) 30 | 31 | 32 | @add_name_arg_doc 33 | def softmax_classification_output(logits, name=None): 34 | """ 35 | Get the most possible softmax classification output for each logit. 36 | 37 | Args: 38 | logits: The softmax logits. Its last dimension will be treated 39 | as the softmax logits dimension, and will be reduced. 40 | 41 | Returns: 42 | tf.Tensor: tf.int32 tensor, the class label for each logit. 43 | """ 44 | logits = InputSpec(shape=('...', '?', '?')).validate('logits', logits) 45 | with tf.name_scope(name, default_name='softmax_classification_output', 46 | values=[logits]): 47 | return tf.argmax(logits, axis=-1, output_type=tf.int32) 48 | -------------------------------------------------------------------------------- /tfsnippet/ops/control_flows.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.utils import add_name_arg_doc, is_tensor_object 4 | 5 | __all__ = ['smart_cond'] 6 | 7 | 8 | @add_name_arg_doc 9 | def smart_cond(cond, true_fn, false_fn, name=None): 10 | """ 11 | Execute `true_fn` or `false_fn` according to `cond`. 12 | 13 | Args: 14 | cond (bool or tf.Tensor): A bool constant or a tensor. 15 | true_fn (() -> tf.Tensor): The function of the true branch. 16 | false_fn (() -> tf.Tensor): The function of the false branch. 17 | 18 | Returns: 19 | tf.Tensor: The output tensor. 20 | """ 21 | if is_tensor_object(cond): 22 | return tf.cond(cond, true_fn, false_fn, name=name) 23 | else: 24 | if cond: 25 | return true_fn() 26 | else: 27 | return false_fn() 28 | -------------------------------------------------------------------------------- /tfsnippet/ops/convolution.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.utils import add_name_arg_doc 4 | from .shape_utils import flatten_to_ndims, unflatten_from_ndims 5 | 6 | __all__ = ['space_to_depth', 'depth_to_space'] 7 | 8 | 9 | @add_name_arg_doc 10 | def space_to_depth(input, block_size, channels_last=True, name=None): 11 | """ 12 | Wraps :func:`tf.space_to_depth`, to support tensors higher than 4-d. 13 | 14 | Args: 15 | input: The input tensor, at least 4-d. 16 | block_size (int): An int >= 2, the size of the spatial block. 17 | channels_last (bool): Whether or not the channels axis 18 | is the last axis in the input tensor? 19 | 20 | Returns: 21 | tf.Tensor: The output tensor. 22 | 23 | See Also: 24 | :func:`tf.space_to_depth` 25 | """ 26 | block_size = int(block_size) 27 | data_format = 'NHWC' if channels_last else 'NCHW' 28 | input = tf.convert_to_tensor(input) 29 | with tf.name_scope(name or 'space_to_depth', values=[input]): 30 | output, s1, s2 = flatten_to_ndims(input, ndims=4) 31 | output = tf.space_to_depth(output, block_size, data_format=data_format) 32 | output = unflatten_from_ndims(output, s1, s2) 33 | return output 34 | 35 | 36 | @add_name_arg_doc 37 | def depth_to_space(input, block_size, channels_last=True, name=None): 38 | """ 39 | Wraps :func:`tf.depth_to_space`, to support tensors higher than 4-d. 40 | 41 | Args: 42 | input: The input tensor, at least 4-d. 43 | block_size (int): An int >= 2, the size of the spatial block. 44 | channels_last (bool): Whether or not the channels axis 45 | is the last axis in the input tensor? 46 | 47 | Returns: 48 | tf.Tensor: The output tensor. 49 | 50 | See Also: 51 | :func:`tf.depth_to_space` 52 | """ 53 | block_size = int(block_size) 54 | data_format = 'NHWC' if channels_last else 'NCHW' 55 | input = tf.convert_to_tensor(input) 56 | with tf.name_scope(name or 'space_to_depth', values=[input]): 57 | output, s1, s2 = flatten_to_ndims(input, ndims=4) 58 | output = tf.depth_to_space(output, block_size, data_format=data_format) 59 | output = unflatten_from_ndims(output, s1, s2) 60 | return output 61 | -------------------------------------------------------------------------------- /tfsnippet/ops/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .type_utils import convert_to_tensor_and_cast 5 | 6 | __all__ = ['bits_per_dimension'] 7 | 8 | 9 | def bits_per_dimension(log_p, value_size, scale=256., name=None): 10 | """ 11 | Compute "bits per dimension" of `x`. 12 | 13 | `BPD(x) = - log(p(x)) / (log(2) * Dim(x))` 14 | 15 | If `u = s * x`, then: 16 | 17 | `BPD(x) = - (log(p(u)) - log(s) * Dim(x)) / (log(2) * Dim(x))` 18 | 19 | Args: 20 | log_p (Tensor): If `scale` is specified, then it should be `log(p(u))`. 21 | Otherwise it should be `log(p(x))`. 22 | value_size (int or Tensor): The size of each `x`, i.e., `Dim(x)`. 23 | scale (float or Tensor or None): The scale `s`, where `u = s * x`, 24 | and `log_p` is `log(p(u))`. 25 | 26 | Returns: 27 | tf.Tensor: The computed "bits per dimension" of `x`. 28 | """ 29 | log_p = tf.convert_to_tensor(log_p) 30 | dtype = log_p.dtype.base_dtype 31 | 32 | with tf.name_scope(name, default_name='bits_per_dimension', values=[log_p]): 33 | if scale is not None: 34 | scale = convert_to_tensor_and_cast(scale, dtype) 35 | nll = tf.log(scale) * value_size - log_p 36 | else: 37 | nll = -log_p 38 | ret = nll / (np.log(2) * value_size) 39 | 40 | return ret 41 | -------------------------------------------------------------------------------- /tfsnippet/ops/type_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | __all__ = ['convert_to_tensor_and_cast'] 4 | 5 | 6 | def convert_to_tensor_and_cast(x, dtype=None): 7 | """ 8 | Convert `x` into a :class:`tf.Tensor`, and cast its dtype if required. 9 | 10 | Args: 11 | x: The tensor to be converted into a :class:`tf.Tensor`. 12 | dtype (tf.DType): The data type. 13 | 14 | Returns: 15 | tf.Tensor: The converted and casted tensor. 16 | """ 17 | x = tf.convert_to_tensor(x) 18 | if dtype is not None: 19 | dtype = tf.as_dtype(dtype) 20 | if dtype != x.dtype: 21 | x = tf.cast(x, dtype) 22 | return x 23 | -------------------------------------------------------------------------------- /tfsnippet/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .samplers import * 2 | 3 | __all__ = [ 4 | 'BaseSampler', 'BernoulliSampler', 'UniformNoiseSampler', 5 | ] 6 | -------------------------------------------------------------------------------- /tfsnippet/scaffold/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import * 2 | from .event_keys import * 3 | from .logging_ import * 4 | from .scheduled_var import * 5 | from .train_loop_ import * 6 | 7 | __all__ = [ 8 | 'AnnealingVariable', 'CheckpointSavableObject', 'CheckpointSaver', 9 | 'DefaultMetricFormatter', 'EventKeys', 'MetricFormatter', 'MetricLogger', 10 | 'ScheduledVariable', 'TrainLoop', 'summarize_variables', 11 | ] 12 | -------------------------------------------------------------------------------- /tfsnippet/scaffold/event_keys.py: -------------------------------------------------------------------------------- 1 | __all__ = ['EventKeys'] 2 | 3 | 4 | class EventKeys(object): 5 | """Defines event keys for TFSnippet.""" 6 | # (TrainLoop) Enter the train loop. 7 | ENTER_LOOP = 'enter_loop' 8 | 9 | # (TrainLoop) Exit the train loop. 10 | EXIT_LOOP = 'exit_loop' 11 | 12 | # (TrainLoop) When metrics (except time metrics) have been collected. 13 | METRICS_COLLECTED = 'metrics_collected' 14 | 15 | # (TrainLoop) When time metrics have been collected. 16 | TIME_METRICS_COLLECTED = 'time_metrics_collected' 17 | 18 | # (TrainLoop) When metric statistics have been printed. 19 | METRIC_STATS_PRINTED = 'metric_stats_printed' 20 | 21 | # (TrainLoop) When time metric statistics have been printed. 22 | TIME_METRIC_STATS_PRINTED = 'time_metric_stats_printed' 23 | 24 | # (TrainLoop) When TensorFlow summary has been added. 25 | SUMMARY_ADDED = 'summary_added' 26 | 27 | # (TrainLoop, Trainer) Before executing an epoch. 28 | BEFORE_EPOCH = 'before_epoch' 29 | 30 | # (Trainer) Run evaluation after an epoch. 31 | EPOCH_EVALUATION = 'epoch_evaluation' 32 | 33 | # (Trainer) Anneal after an epoch. 34 | EPOCH_ANNEALING = 'epoch_annealing' 35 | 36 | # (Trainer) Log after an epoch. 37 | EPOCH_LOGGING = 'epoch_logging' 38 | 39 | # (TrainLoop, Trainer) After executing an epoch. 40 | AFTER_EPOCH = 'after_epoch' 41 | 42 | # (TrainLoop, Trainer) Before executing a step. 43 | BEFORE_STEP = 'before_step' 44 | 45 | # (Trainer) Run evaluation after a step. 46 | STEP_EVALUATION = 'step_evaluation' 47 | 48 | # (Trainer) Anneal after a step. 49 | STEP_ANNEALING = 'step_annealing' 50 | 51 | # (Trainer) Log after a step. 52 | STEP_LOGGING = 'step_logging' 53 | 54 | # (TrainLoop, Trainer) After executing a step. 55 | AFTER_STEP = 'after_step' 56 | 57 | # (Trainer, Evaluator) Before execution. 58 | BEFORE_EXECUTION = 'before_execution' 59 | 60 | # (Evaluator) After execution. 61 | AFTER_EXECUTION = 'after_execution' 62 | -------------------------------------------------------------------------------- /tfsnippet/shortcuts.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package provides shortcuts to utilities from second-level packages. 3 | """ 4 | 5 | from .dataflows import DataFlow, DataMapper, SlidingWindow 6 | from .utils.config_utils import (Config, ConfigField, get_config_defaults, 7 | register_config_arguments) 8 | from .utils.graph_keys import GraphKeys 9 | from .utils.invertible_matrix import InvertibleMatrix 10 | from .utils.model_vars import model_variable, get_model_variables 11 | from .utils.reuse import instance_reuse, global_reuse, VarScopeObject 12 | from .utils.settings_ import settings 13 | from .utils.summary_collector import (SummaryCollector, add_histogram, 14 | add_summary, default_summary_collector) 15 | 16 | __all__ = [ 17 | # from tfsnippet.dataflows 18 | 'DataFlow', 'DataMapper', 'SlidingWindow', 19 | 20 | # from tfsnippet.utils.config_utils 21 | 'Config', 'ConfigField', 22 | 'get_config_defaults', 'register_config_arguments', 23 | 24 | # from tfsnippet.utils.graph_keys 25 | 'GraphKeys', 26 | 27 | # from tfsnippet.utils.invertible_matrix 28 | 'InvertibleMatrix', 29 | 30 | # from tfsnippet.utils.model_vars 31 | 'model_variable', 'get_model_variables', 32 | 33 | # from tfsnippet.utils.reuse 34 | 'instance_reuse', 'global_reuse', 'VarScopeObject', 35 | 36 | # from tfsnippet.utils.settings_ 37 | 'settings', 38 | 39 | # from tfsnippet.utils.summary_collector 40 | 'SummaryCollector', 'add_histogram', 'add_summary', 41 | 'default_summary_collector', 42 | ] 43 | -------------------------------------------------------------------------------- /tfsnippet/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import * 2 | from .dynamic_values import * 3 | from .evaluator import * 4 | from .feed_dict import * 5 | from .loss_trainer import * 6 | from .trainer import * 7 | from .validator import * 8 | 9 | __all__ = [ 10 | 'AnnealingScalar', 'BaseTrainer', 'DynamicValue', 'Evaluator', 11 | 'LossTrainer', 'Trainer', 'Validator', 'auto_batch_weight', 12 | 'merge_feed_dict', 'resolve_feed_dict', 13 | ] 14 | -------------------------------------------------------------------------------- /tfsnippet/trainer/feed_dict.py: -------------------------------------------------------------------------------- 1 | from tfsnippet.scaffold import ScheduledVariable 2 | from .dynamic_values import DynamicValue 3 | 4 | __all__ = ['resolve_feed_dict', 'merge_feed_dict'] 5 | 6 | 7 | def resolve_feed_dict(feed_dict, inplace=False): 8 | """ 9 | Resolve all dynamic values in `feed_dict` into fixed values. 10 | 11 | The supported dynamic value types and corresponding resolving method 12 | is listed as follows: 13 | 14 | 1. :class:`ScheduledVariable`: :meth:`get()` will be called. 15 | 2. :class:`DynamicValue`: :meth:`get()` will be called. 16 | 3. callable object: Will be called to get the value. 17 | 18 | Args: 19 | feed_dict (dict[tf.Tensor, any]): The feed dict to be resolved. 20 | inplace (bool): Whether or not to fill resolved values in 21 | the input `feed_dict` directly, instead of copying a new one? 22 | (default :obj:`False`) 23 | 24 | Returns: 25 | The resolved feed dict. 26 | """ 27 | if not inplace: 28 | feed_dict = dict(feed_dict) 29 | for k in feed_dict: 30 | v = feed_dict[k] 31 | if isinstance(v, ScheduledVariable): 32 | feed_dict[k] = v.get() 33 | elif isinstance(v, DynamicValue): 34 | feed_dict[k] = v.get() 35 | elif callable(v): 36 | feed_dict[k] = v() 37 | return feed_dict 38 | 39 | 40 | def merge_feed_dict(*feed_dicts): 41 | """ 42 | Merge all feed dicts into one. 43 | 44 | Args: 45 | \**feed_dicts: List of feed dicts. The later ones will override 46 | values specified in the previous ones. If a :obj:`None` is 47 | specified, it will be simply ignored. 48 | 49 | Returns: 50 | The merged feed dict. 51 | """ 52 | ret = {} 53 | for feed_dict in feed_dicts: 54 | if feed_dict is not None: 55 | ret.update(feed_dict) 56 | return ret 57 | -------------------------------------------------------------------------------- /tfsnippet/trainer/loss_trainer.py: -------------------------------------------------------------------------------- 1 | from tfsnippet.scaffold import TrainLoop 2 | from tfsnippet.utils import deprecated, deprecated_arg 3 | from .trainer import Trainer 4 | from .feed_dict import merge_feed_dict 5 | 6 | __all__ = ['LossTrainer'] 7 | 8 | 9 | @deprecated('use :class:`Trainer` instead.', version='0.1') 10 | class LossTrainer(Trainer): 11 | """ 12 | A subclass of :class:`BaseTrainer`, which optimizes a single loss. 13 | """ 14 | 15 | def __init__(self, loop, loss, train_op, inputs, data_flow, feed_dict=None, 16 | metric_name='loss'): 17 | """ 18 | Construct a new :class:`LossTrainer`. 19 | 20 | Args: 21 | loop (TrainLoop): The training loop object. 22 | loss (tf.Tensor): The training loss. 23 | train_op (tf.Operation): The training operation. 24 | inputs (list[tf.Tensor]): The input placeholders. The number of 25 | tensors, and the order of tensors, should both match the arrays 26 | of each mini-batch data, provided by `data_flow`. 27 | data_flow (DataFlow): The training data flow. Each mini-batch must 28 | contain one array for each placeholder in `inputs`. 29 | feed_dict: The feed dict for training. It will be merged with 30 | the arrays provided by `data_flow` in each step. 31 | (default :obj:`None`) 32 | metric_name (str): The metric name for collecting training loss. 33 | """ 34 | super(LossTrainer, self).__init__( 35 | loop=loop, train_op=train_op, inputs=inputs, data_flow=data_flow, 36 | feed_dict=feed_dict, metrics={metric_name: loss} 37 | ) 38 | 39 | @property 40 | def loss(self): 41 | """Get the training loss.""" 42 | return list(self.metrics.values())[0] 43 | 44 | @property 45 | def metric_name(self): 46 | """Get the metric name for collecting training loss.""" 47 | return list(self.metrics.keys())[0] 48 | 49 | @deprecated_arg('feed_dict', version='0.1') 50 | def run(self, feed_dict=None): 51 | """ 52 | Run training loop. 53 | 54 | Args: 55 | feed_dict: DEPRECATED. The extra feed dict to be merged with 56 | the already configured dict. (default :obj:`None`) 57 | """ 58 | old_feed_dict = self._feed_dict 59 | try: 60 | if feed_dict is not None: # pragma: no cover 61 | self._feed_dict = merge_feed_dict(self._feed_dict, feed_dict) 62 | super(LossTrainer, self).run() 63 | finally: 64 | self._feed_dict = old_feed_dict 65 | -------------------------------------------------------------------------------- /tfsnippet/trainer/validator.py: -------------------------------------------------------------------------------- 1 | from tfsnippet.utils import deprecated 2 | from .evaluator import Evaluator, auto_batch_weight 3 | 4 | __all__ = ['Validator'] 5 | 6 | 7 | @deprecated('use :class:`Evaluator` instead.', version='0.1') 8 | class Validator(Evaluator): 9 | """Class to compute validation loss and other metrics.""" 10 | 11 | def __init__(self, loop, metrics, inputs, data_flow, feed_dict=None, 12 | time_metric_name='valid_time', 13 | batch_weight_func=auto_batch_weight): # pragma: no cover 14 | super(Validator, self).__init__( 15 | loop=loop, metrics=metrics, inputs=inputs, data_flow=data_flow, 16 | feed_dict=feed_dict, time_metric_name=time_metric_name, 17 | batch_weight_func=batch_weight_func 18 | ) 19 | -------------------------------------------------------------------------------- /tfsnippet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .archive_file import * 2 | from .caching import * 3 | from .concepts import * 4 | from .config_utils import * 5 | from .console_table import * 6 | from .data_utils import * 7 | from .debugging import * 8 | from .deprecation import * 9 | from .doc_utils import * 10 | from .events import * 11 | from .graph_keys import * 12 | from .imported import * 13 | from .invertible_matrix import * 14 | from .misc import * 15 | from .model_vars import * 16 | from .random import * 17 | from .registry import * 18 | from .reuse import * 19 | from .scope import * 20 | from .session import * 21 | from .settings_ import * 22 | from .shape_utils import * 23 | from .statistics import * 24 | from .summary_collector import * 25 | from .tensor_spec import * 26 | from .tensor_wrapper import * 27 | from .tfver import * 28 | from .type_utils import * 29 | 30 | __all__ = [ 31 | 'AutoInitAndCloseable', 'BaseRegistry', 'BoolConfigValidator', 'CacheDir', 32 | 'ClassRegistry', 'Config', 'ConfigField', 'ConfigValidator', 33 | 'ConsoleTable', 'ContextStack', 'Disposable', 'DisposableContext', 34 | 'DocInherit', 'ETA', 'EventSource', 'Extractor', 'FloatConfigValidator', 35 | 'GraphKeys', 'InputSpec', 'IntConfigValidator', 'InvertibleMatrix', 36 | 'NoReentrantContext', 'ParamSpec', 'PermutationMatrix', 'RarExtractor', 37 | 'StatisticsCollector', 'StrConfigValidator', 'SummaryCollector', 38 | 'TFSnippetConfig', 'TarExtractor', 'TemporaryDirectory', 39 | 'TensorArgValidator', 'TensorSpec', 'TensorWrapper', 'VarScopeObject', 40 | 'VarScopeRandomState', 'ZipExtractor', 'add_histogram', 41 | 'add_name_and_scope_arg_doc', 'add_name_arg_doc', 'add_summary', 42 | 'append_arg_to_doc', 'append_to_doc', 'assert_deps', 'camel_to_underscore', 43 | 'concat_shapes', 'create_session', 'default_summary_collector', 44 | 'deprecated', 'deprecated_arg', 'ensure_variables_initialized', 45 | 'generate_random_seed', 'get_batch_size', 'get_cache_root', 46 | 'get_config_defaults', 'get_config_validator', 'get_default_scope_name', 47 | 'get_default_session_or_error', 'get_dimension_size', 48 | 'get_dimensions_size', 'get_model_variables', 'get_rank', 49 | 'get_reuse_stack_top', 'get_shape', 'get_static_shape', 50 | 'get_uninitialized_variables', 'get_variable_ddi', 'get_variables_as_dict', 51 | 'global_reuse', 'humanize_duration', 'instance_reuse', 'is_float', 52 | 'is_integer', 'is_shape_equal', 'is_tensor_object', 53 | 'is_tensorflow_version_higher_or_equal', 'iter_files', 'makedirs', 54 | 'maybe_add_histogram', 'maybe_check_numerics', 'maybe_close', 55 | 'minibatch_slices_iterator', 'model_variable', 'print_as_table', 56 | 'register_config_arguments', 'register_config_validator', 57 | 'register_tensor_wrapper_class', 'reopen_variable_scope', 58 | 'resolve_negative_axis', 'root_variable_scope', 'scoped_set_config', 59 | 'set_cache_root', 'set_random_seed', 'settings', 'split_numpy_array', 60 | 'split_numpy_arrays', 'validate_enum_arg', 'validate_group_ndims_arg', 61 | 'validate_int_tuple_arg', 'validate_n_samples_arg', 62 | 'validate_positive_int_arg', 63 | ] 64 | -------------------------------------------------------------------------------- /tfsnippet/utils/debugging.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import tensorflow as tf 4 | 5 | from .doc_utils import add_name_arg_doc 6 | from .graph_keys import GraphKeys 7 | 8 | __all__ = [ 9 | 'maybe_check_numerics', 10 | 'assert_deps', 11 | 'maybe_add_histogram', 12 | ] 13 | 14 | 15 | @add_name_arg_doc 16 | def maybe_check_numerics(tensor, message, name=None): 17 | """ 18 | If ``tfsnippet.settings.check_numerics == True``, check the numerics of 19 | `tensor`. Otherwise do nothing. 20 | 21 | Args: 22 | tensor: The tensor to be checked. 23 | message: The message to display when numerical issues occur. 24 | 25 | Returns: 26 | tf.Tensor: The tensor, whose numerics have been checked. 27 | """ 28 | from .settings_ import settings 29 | if settings.check_numerics: 30 | tensor = tf.convert_to_tensor(tensor) 31 | return tf.check_numerics(tensor, message, name=name) 32 | else: 33 | return tensor 34 | 35 | 36 | @contextmanager 37 | def assert_deps(assert_ops): 38 | """ 39 | If ``tfsnippet.settings.enable_assertions == True``, open a context that 40 | will run `assert_ops`. Otherwise do nothing. 41 | 42 | Args: 43 | assert_ops (Iterable[tf.Operation or None]): A list of assertion 44 | operations. :obj:`None` items will be ignored. 45 | 46 | Yields: 47 | bool: A boolean indicate whether or not the assertion operations 48 | are not empty, and are executed. 49 | """ 50 | from .settings_ import settings 51 | assert_ops = [o for o in assert_ops if o is not None] 52 | if assert_ops and settings.enable_assertions: 53 | with tf.control_dependencies(assert_ops): 54 | yield True 55 | else: 56 | for op in assert_ops: 57 | # let TensorFlow not warn about not using this assertion operation 58 | if hasattr(op, 'mark_used'): 59 | op.mark_used() 60 | yield False 61 | 62 | 63 | @add_name_arg_doc 64 | def maybe_add_histogram(tensor, summary_name=None, strip_scope=False, 65 | collections=None, name=None): 66 | """ 67 | If ``tfsnippet.settings.auto_histogram == True``, add the histogram 68 | of `tensor` via :func:`tfsnippet.add_histogram`. Otherwise do nothing. 69 | 70 | Args: 71 | tensor: Take histogram of this tensor. 72 | summary_name: Specify the summary name for `tensor`. 73 | strip_scope: If :obj:`True`, strip the name scope from `tensor.name` 74 | when adding the histogram. 75 | collections: Add the histogram to these collections. Defaults to 76 | `[tfsnippet.GraphKeys.AUTO_HISTOGRAM]`. 77 | 78 | Returns: 79 | The serialized histogram tensor of `tensor`. 80 | 81 | See Also: 82 | :func:`tfsnippet.add_histogram` 83 | """ 84 | from .settings_ import settings 85 | from .summary_collector import add_histogram 86 | if settings.auto_histogram: 87 | if collections is None: 88 | collections = (GraphKeys.AUTO_HISTOGRAM,) 89 | return add_histogram( 90 | tensor, summary_name=summary_name, collections=collections, 91 | strip_scope=strip_scope, name=name or 'maybe_add_histogram' 92 | ) 93 | -------------------------------------------------------------------------------- /tfsnippet/utils/graph_keys.py: -------------------------------------------------------------------------------- 1 | __all__ = ['GraphKeys'] 2 | 3 | 4 | class GraphKeys(object): 5 | """Defines TensorFlow graph collection keys for TFSnippet.""" 6 | 7 | AUTO_HISTOGRAM = 'TFSNIPPET_AUTO_HISTOGRAM' 8 | -------------------------------------------------------------------------------- /tfsnippet/utils/imported.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | try: 5 | from tempfile import TemporaryDirectory 6 | except ImportError: 7 | from backports.tempfile import TemporaryDirectory 8 | 9 | __all__ = [ 10 | 'TemporaryDirectory', 'makedirs' 11 | ] 12 | 13 | 14 | if sys.version_info[:2] < (3, 5): 15 | import pathlib2 16 | 17 | def makedirs(name, mode=0o777, exist_ok=False): 18 | pathlib2.Path(name).mkdir(mode=mode, parents=True, exist_ok=exist_ok) 19 | else: 20 | makedirs = os.makedirs 21 | -------------------------------------------------------------------------------- /tfsnippet/utils/model_vars.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | __all__ = ['model_variable', 'get_model_variables'] 4 | 5 | 6 | def model_variable(name, 7 | shape=None, 8 | dtype=None, 9 | initializer=None, 10 | regularizer=None, 11 | constraint=None, 12 | trainable=True, 13 | collections=None, 14 | **kwargs): 15 | """ 16 | Get or create a model variable. 17 | 18 | When the variable is created, it will be added to both `GLOBAL_VARIABLES` 19 | and `MODEL_VARIABLES` collection. 20 | 21 | Args: 22 | name: Name of the variable. 23 | shape: Shape of the variable. 24 | dtype: Data type of the variable. 25 | initializer: Initializer of the variable. 26 | regularizer: Regularizer of the variable. 27 | constraint: Constraint of the variable. 28 | trainable (bool): Whether or not the variable is trainable? 29 | collections: In addition to `GLOBAL_VARIABLES` and `MODEL_VARIABLES`, 30 | also add the variable to these collections. 31 | \\**kwargs: Other named arguments passed to :func:`tf.get_variable`. 32 | 33 | Returns: 34 | tf.Variable: The variable. 35 | """ 36 | collections = list(set( 37 | list(collections or ()) + 38 | [tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES] 39 | )) 40 | return tf.get_variable( 41 | name=name, 42 | shape=shape, 43 | dtype=dtype, 44 | initializer=initializer, 45 | regularizer=regularizer, 46 | constraint=constraint, 47 | trainable=trainable, 48 | collections=collections, 49 | **kwargs 50 | ) 51 | 52 | 53 | def get_model_variables(scope=None): 54 | """ 55 | Get all model variables (i.e., variables in `MODEL_VARIABLES` collection). 56 | 57 | Args: 58 | scope: If specified, will obtain variables only within this scope. 59 | 60 | Returns: 61 | list[tf.Variable]: The model variables. 62 | """ 63 | return tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope=scope) 64 | -------------------------------------------------------------------------------- /tfsnippet/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | __all__ = ['generate_random_seed', 'set_random_seed', 'VarScopeRandomState'] 7 | 8 | 9 | def generate_random_seed(): 10 | """ 11 | Generate a new random seed from the default NumPy random state. 12 | 13 | Returns: 14 | int: The new random seed. 15 | """ 16 | return np.random.randint(0xffffffff) 17 | 18 | 19 | def set_random_seed(seed): 20 | """ 21 | Generate random seeds for NumPy, TensorFlow and TFSnippet. 22 | 23 | Args: 24 | seed (int): The seed used to generate the separated seeds for 25 | all concerning modules. 26 | """ 27 | np.random.seed(seed) 28 | seeds = [generate_random_seed() for _ in range(4)] 29 | 30 | if hasattr(random, 'seed'): 31 | random.seed(seeds[0]) 32 | np.random.seed(seeds[1]) 33 | tf.set_random_seed(seeds[2]) 34 | VarScopeRandomState.set_global_seed(seeds[3]) 35 | 36 | 37 | class VarScopeRandomState(np.random.RandomState): 38 | """ 39 | A sub-class of :class:`np.random.RandomState`, which uses a variable-scope 40 | dependent seed. It is guaranteed for a :class:`VarScopeRandomState` 41 | initialized with the same global seed and variable scopes with the same 42 | name to produce exactly the same random sequence. 43 | """ 44 | 45 | _global_seed = 0 46 | 47 | def __init__(self, variable_scope): 48 | vs_name = variable_scope.name 49 | seed = (self._global_seed & 0xfffffff) ^ (hash(vs_name) & 0xffffffff) 50 | super(VarScopeRandomState, self).__init__(seed=seed) 51 | 52 | @classmethod 53 | def set_global_seed(cls, seed): 54 | """ 55 | Set the global random seed for all new :class:`VarScopeRandomState`. 56 | 57 | If not set, the default global random seed is `0`. 58 | 59 | Args: 60 | seed (int): The global random seed. 61 | """ 62 | cls._global_seed = int(seed) 63 | -------------------------------------------------------------------------------- /tfsnippet/utils/registry.py: -------------------------------------------------------------------------------- 1 | import six 2 | 3 | from .doc_utils import DocInherit 4 | 5 | __all__ = ['BaseRegistry', 'ClassRegistry'] 6 | 7 | 8 | @DocInherit 9 | class BaseRegistry(object): 10 | """ 11 | A base class for implement a type or object registry. 12 | 13 | Usage:: 14 | 15 | registry = BaseRegistry() 16 | registry.register('MNIST', spt.datasets.MNIST()) 17 | """ 18 | 19 | def __init__(self, ignore_case=False): 20 | """ 21 | Construct a new :class:`BaseRegistry`. 22 | 23 | Args: 24 | ignore_case (bool): Whether or not to ignore case in object names? 25 | """ 26 | ignore_case = bool(ignore_case) 27 | 28 | if ignore_case: 29 | self._norm_key = lambda s: str(s).lower() 30 | else: 31 | self._norm_key = lambda s: str(s) 32 | self._ignore_case = ignore_case 33 | self._name_and_objects = [] 34 | self._key_to_object = {} 35 | 36 | @property 37 | def ignore_case(self): 38 | """Whether or not to ignore the case?""" 39 | return self._ignore_case 40 | 41 | def __iter__(self): 42 | return (n for n, o in self._name_and_objects) 43 | 44 | def register(self, name, obj): 45 | """ 46 | Register an object. 47 | 48 | Args: 49 | name (str): Name of the object. 50 | obj: The object. 51 | """ 52 | key = self._norm_key(name) 53 | if key in self._key_to_object: 54 | raise KeyError('Object already registered: {!r}'.format(name)) 55 | self._key_to_object[key] = obj 56 | self._name_and_objects.append((name, obj)) 57 | 58 | def get(self, name): 59 | """ 60 | Get an object. 61 | 62 | Args: 63 | name (str): Name of the object. 64 | 65 | Returns: 66 | The retrieved object. 67 | 68 | Raises: 69 | KeyError: If `name` is not registered. 70 | """ 71 | key = self._norm_key(name) 72 | if key not in self._key_to_object: 73 | raise KeyError('Object not registered: {!r}'.format(name)) 74 | return self._key_to_object[key] 75 | 76 | 77 | class ClassRegistry(BaseRegistry): 78 | """ 79 | A subclass of :class:`BaseRegistry`, dedicated for classes. 80 | 81 | Usage:: 82 | 83 | Class MyClass(object): 84 | 85 | def __init__(self, value, message): 86 | ... 87 | 88 | registry = ClassRegistry() 89 | registry.register('MyClass', MyClass) 90 | 91 | obj = registry.create_object('MyClass', 123, message='message') 92 | """ 93 | 94 | def register(self, name, obj): 95 | if not isinstance(obj, six.class_types): 96 | raise TypeError('`obj` is not a class: {!r}'.format(obj)) 97 | return super(ClassRegistry, self).register(name, obj) 98 | 99 | def construct(self, name, *args, **kwargs): 100 | """ 101 | Construct an object according to class `name` and arguments. 102 | 103 | Args: 104 | name (str): Name of the class. 105 | *args: Arguments passed to the class constructor. 106 | \\**kwargs: Named arguments passed to the class constructor. 107 | 108 | Returns: 109 | The constructed object. 110 | 111 | Raises: 112 | KeyError: If `name` is not registered. 113 | """ 114 | return self.get(name)(*args, **kwargs) 115 | -------------------------------------------------------------------------------- /tfsnippet/utils/scope.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import six 4 | import tensorflow as tf 5 | from tensorflow.python.ops import variable_scope as variable_scope_ops 6 | 7 | __all__ = [ 8 | 'get_default_scope_name', 9 | 'reopen_variable_scope', 10 | 'root_variable_scope', 11 | ] 12 | 13 | 14 | def get_default_scope_name(name, cls_or_instance=None): 15 | """ 16 | Generate a valid default scope name. 17 | 18 | Args: 19 | name (str): The base name. 20 | cls_or_instance: The class or the instance object, optional. 21 | If it has attribute ``variable_scope``, then ``variable_scope.name`` 22 | will be used as a hint for the name prefix. Otherwise, its class 23 | name will be used as the name prefix. 24 | 25 | Returns: 26 | str: The generated scope name. 27 | """ 28 | # compose the candidate name 29 | prefix = '' 30 | if cls_or_instance is not None: 31 | if hasattr(cls_or_instance, 'variable_scope') and \ 32 | isinstance(cls_or_instance.variable_scope, tf.VariableScope): 33 | vs_name = cls_or_instance.variable_scope.name 34 | vs_name = vs_name.rsplit('/', 1)[-1] 35 | prefix = '{}.'.format(vs_name) 36 | else: 37 | if not isinstance(cls_or_instance, six.class_types): 38 | cls_or_instance = cls_or_instance.__class__ 39 | prefix = '{}.'.format(cls_or_instance.__name__).lstrip('_') 40 | name = prefix + name 41 | 42 | # validate the name 43 | name = name.lstrip('_') 44 | return name 45 | 46 | 47 | @contextmanager 48 | def reopen_variable_scope(var_scope, **kwargs): 49 | """ 50 | Reopen the specified `var_scope` and its original name scope. 51 | 52 | Args: 53 | var_scope (tf.VariableScope): The variable scope instance. 54 | **kwargs: Named arguments for opening the variable scope. 55 | """ 56 | if not isinstance(var_scope, tf.VariableScope): 57 | raise TypeError('`var_scope` must be an instance of `tf.VariableScope`') 58 | 59 | with tf.variable_scope(var_scope, 60 | auxiliary_name_scope=False, 61 | **kwargs) as vs: 62 | with tf.name_scope(var_scope.original_name_scope): 63 | yield vs 64 | 65 | 66 | @contextmanager 67 | def root_variable_scope(**kwargs): 68 | """ 69 | Open the root variable scope and its name scope. 70 | 71 | Args: 72 | **kwargs: Named arguments for opening the root variable scope. 73 | """ 74 | # `tf.variable_scope` does not support opening the root variable scope 75 | # from empty name. It always prepend the name of current variable scope 76 | # to the front of opened variable scope. So we get the current scope, 77 | # and pretend it to be the root scope. 78 | scope = tf.get_variable_scope() 79 | old_name = scope.name 80 | try: 81 | scope._name = '' 82 | with variable_scope_ops._pure_variable_scope('', **kwargs) as vs: 83 | scope._name = old_name 84 | with tf.name_scope(None): 85 | yield vs 86 | finally: 87 | scope._name = old_name 88 | -------------------------------------------------------------------------------- /tfsnippet/utils/settings_.py: -------------------------------------------------------------------------------- 1 | from .config_utils import Config, ConfigField 2 | 3 | __all__ = ['TFSnippetConfig', 'settings'] 4 | 5 | 6 | class TFSnippetConfig(Config): 7 | """Global configurations of TFSnippet.""" 8 | 9 | enable_assertions = ConfigField( 10 | bool, default=True, 11 | description='Whether or not to enable assertions operations?' 12 | ) 13 | check_numerics = ConfigField( 14 | bool, default=False, 15 | description='Whether or not to check numeric issues?' 16 | ) 17 | auto_histogram = ConfigField( 18 | bool, default=False, 19 | description='Whether or not to automatically add histograms of layer ' 20 | 'parameters and outputs to the collection ' 21 | '`tfsnippet.GraphKeys.AUTO_HISTOGRAM`?' 22 | ) 23 | file_cache_checksum = ConfigField( 24 | bool, default=False, 25 | description='Whether or not to validate the checksum of cached files?' 26 | ) 27 | 28 | 29 | settings = TFSnippetConfig() 30 | """The TFSnippet global configuration object.""" 31 | -------------------------------------------------------------------------------- /tfsnippet/utils/tfver.py: -------------------------------------------------------------------------------- 1 | import semver 2 | 3 | import tensorflow as tf 4 | 5 | __all__ = ['is_tensorflow_version_higher_or_equal'] 6 | 7 | 8 | def is_tensorflow_version_higher_or_equal(version): 9 | """ 10 | Check whether the version of TensorFlow is higher than or equal to 11 | `version`. 12 | 13 | Args: 14 | version (str): Expected version of TensorFlow. 15 | 16 | Returns: 17 | bool: True if higher or equal to, False if not. 18 | """ 19 | try: 20 | compare_result = semver.compare_loose(version, tf.__version__) 21 | except AttributeError: 22 | compare_result = semver.compare(version, tf.__version__) 23 | return compare_result <= 0 24 | -------------------------------------------------------------------------------- /tfsnippet/variational/__init__.py: -------------------------------------------------------------------------------- 1 | from .chain import * 2 | from .estimators import * 3 | from .evaluation import * 4 | from .inference import * 5 | from .objectives import * 6 | 7 | __all__ = [ 8 | 'VariationalChain', 'VariationalEvaluation', 'VariationalInference', 9 | 'VariationalLowerBounds', 'VariationalTrainingObjectives', 10 | 'elbo_objective', 'importance_sampling_log_likelihood', 'iwae_estimator', 11 | 'monte_carlo_objective', 'nvil_estimator', 'sgvb_estimator', 12 | 'vimco_estimator', 13 | ] 14 | -------------------------------------------------------------------------------- /tfsnippet/variational/evaluation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.ops import log_mean_exp 4 | from .utils import _require_multi_samples 5 | 6 | __all__ = ['importance_sampling_log_likelihood'] 7 | 8 | 9 | def importance_sampling_log_likelihood(log_joint, latent_log_prob, axis, 10 | keepdims=False, name=None): 11 | """ 12 | Compute :math:`\\log p(\\mathbf{x})` by importance sampling. 13 | 14 | .. math:: 15 | 16 | \\log p(\\mathbf{x}) = 17 | \\log \\mathbb{E}_{q(\\mathbf{z}|\\mathbf{x})} \\Big[\\exp\\big(\\log p(\\mathbf{x},\\mathbf{z}) - \\log q(\\mathbf{z}|\\mathbf{x})\\big) \\Big] 18 | 19 | Args: 20 | log_joint: Values of :math:`\\log p(\\mathbf{z},\\mathbf{x})`, 21 | computed with :math:`\\mathbf{z} \\sim q(\\mathbf{z}|\\mathbf{x})`. 22 | latent_log_prob: :math:`q(\\mathbf{z}|\\mathbf{x})`. 23 | axis: The sampling dimensions to be averaged out. 24 | keepdims (bool): When `axis` is specified, whether or not to keep 25 | the averaged dimensions? (default :obj:`False`) 26 | name (str): TensorFlow name scope of the graph nodes. 27 | (default "importance_sampling_log_likelihood") 28 | 29 | Returns: 30 | The computed :math:`\\log p(x)`. 31 | """ 32 | _require_multi_samples(axis, 'importance sampling log-likelihood') 33 | log_joint = tf.convert_to_tensor(log_joint) 34 | latent_log_prob = tf.convert_to_tensor(latent_log_prob) 35 | with tf.name_scope(name, default_name='importance_sampling_log_likelihood', 36 | values=[log_joint, latent_log_prob]): 37 | log_p = log_mean_exp( 38 | log_joint - latent_log_prob, axis=axis, keepdims=keepdims) 39 | return log_p 40 | -------------------------------------------------------------------------------- /tfsnippet/variational/objectives.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tfsnippet.ops import log_mean_exp 4 | from .utils import _require_multi_samples 5 | 6 | __all__ = ['elbo_objective', 'monte_carlo_objective'] 7 | 8 | 9 | def elbo_objective(log_joint, latent_log_prob, axis=None, keepdims=False, 10 | name=None): 11 | """ 12 | Derive the ELBO objective. 13 | 14 | .. math:: 15 | 16 | \\mathbb{E}_{\\mathbf{z} \\sim q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\big[ 17 | \\log p_{\\theta}(\\mathbf{x},\\mathbf{z}) - \\log q_{\\phi}(\\mathbf{z}|\\mathbf{x}) 18 | \\big] 19 | 20 | Args: 21 | log_joint: Values of :math:`\\log p(\\mathbf{z},\\mathbf{x})`, 22 | computed with :math:`\\mathbf{z} \\sim q(\\mathbf{z}|\\mathbf{x})`. 23 | latent_log_prob: :math:`q(\\mathbf{z}|\\mathbf{x})`. 24 | axis: The sampling dimensions to be averaged out. 25 | If :obj:`None`, no dimensions will be averaged out. 26 | keepdims (bool): When `axis` is specified, whether or not to keep 27 | the averaged dimensions? (default :obj:`False`) 28 | name (str): TensorFlow name scope of the graph nodes. 29 | (default "elbo_objective") 30 | 31 | Returns: 32 | tf.Tensor: The ELBO objective. Not applicable for training. 33 | """ 34 | log_joint = tf.convert_to_tensor(log_joint) 35 | latent_log_prob = tf.convert_to_tensor(latent_log_prob) 36 | with tf.name_scope(name, 37 | default_name='elbo_objective', 38 | values=[log_joint, latent_log_prob]): 39 | objective = log_joint - latent_log_prob 40 | if axis is not None: 41 | objective = tf.reduce_mean(objective, axis=axis, keepdims=keepdims) 42 | return objective 43 | 44 | 45 | def monte_carlo_objective(log_joint, latent_log_prob, axis=None, 46 | keepdims=False, name=None): 47 | """ 48 | Derive the Monte-Carlo objective. 49 | 50 | .. math:: 51 | 52 | \\mathcal{L}_{K}(\\mathbf{x};\\theta,\\phi) = 53 | \\mathbb{E}_{\\mathbf{z}^{(1:K)} \\sim q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\Bigg[ 54 | \\log \\frac{1}{K} \\sum_{k=1}^K { 55 | \\frac{p_{\\theta}(\\mathbf{x},\\mathbf{z}^{(k)})} 56 | {q_{\\phi}(\\mathbf{z}^{(k)}|\\mathbf{x})} 57 | } 58 | \\Bigg] 59 | 60 | Args: 61 | log_joint: Values of :math:`\\log p(\\mathbf{z},\\mathbf{x})`, 62 | computed with :math:`\\mathbf{z} \\sim q(\\mathbf{z}|\\mathbf{x})`. 63 | latent_log_prob: :math:`q(\\mathbf{z}|\\mathbf{x})`. 64 | axis: The sampling dimensions to be averaged out. 65 | keepdims (bool): When `axis` is specified, whether or not to keep 66 | the averaged dimensions? (default :obj:`False`) 67 | name (str): TensorFlow name scope of the graph nodes. 68 | (default "monte_carlo_objective") 69 | 70 | Returns: 71 | tf.Tensor: The Monte Carlo objective. Not applicable for training. 72 | """ 73 | _require_multi_samples(axis, 'monte carlo objective') 74 | log_joint = tf.convert_to_tensor(log_joint) 75 | latent_log_prob = tf.convert_to_tensor(latent_log_prob) 76 | with tf.name_scope(name, 77 | default_name='monte_carlo_objective', 78 | values=[log_joint, latent_log_prob]): 79 | likelihood = log_joint - latent_log_prob 80 | objective = log_mean_exp(likelihood, axis=axis, keepdims=keepdims) 81 | return objective 82 | -------------------------------------------------------------------------------- /tfsnippet/variational/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def _require_multi_samples(axis, name): 3 | if axis is None: 4 | raise ValueError('{} requires multi-samples of latent variables, ' 5 | 'thus the `axis` argument must be specified'. 6 | format(name)) 7 | --------------------------------------------------------------------------------