├── LICENSE.txt ├── README.rst ├── docs ├── Makefile └── source │ ├── conf.py │ ├── howto.rst │ ├── index.rst │ ├── quick.rst │ └── reference.rst ├── setup.cfg ├── setup.py ├── tftables.py └── tftables_test.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 G. H. Collin (ghcollin) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | `tftables `_ allows convenient access to HDF5 files with Tensorflow. 2 | A class for reading batches of data out of arrays or tables is provided. 3 | A secondary class wraps both the primary reader and a Tensorflow FIFOQueue for straight-forward streaming 4 | of data from HDF5 files into Tensorflow operations. 5 | 6 | The library is backed by `multitables `_ for high-speed reading of HDF5 7 | datasets. ``multitables`` is based on PyTables (``tables``), so this library can make use of any compression algorithms 8 | that PyTables supports. 9 | 10 | Licence 11 | ======= 12 | 13 | This software is distributed under the MIT licence. 14 | See the `LICENSE.txt `_ file for details. 15 | 16 | Installation 17 | ============ 18 | 19 | :: 20 | 21 | pip install tftables 22 | 23 | Alternatively, to install from HEAD, run 24 | 25 | :: 26 | 27 | pip install git+https://github.com/ghcollin/tftables.git 28 | 29 | You can also `download `_ 30 | or `clone the repository `_ and run 31 | 32 | :: 33 | 34 | python setup.py install 35 | 36 | ``tftables`` depends on ``multitables``, ``numpy`` and ``tensorflow``. The package is compatible with the latest versions of python 37 | 2 and 3. 38 | 39 | Quick start 40 | =========== 41 | 42 | An example of accessing a table in a HDF5 file. 43 | 44 | .. code:: python 45 | 46 | import tftables 47 | import tensorflow as tf 48 | 49 | with tf.device('/cpu:0'): 50 | # This function preprocesses the batches before they 51 | # are loaded into the internal queue. 52 | # You can cast data, or do one-hot transforms. 53 | # If the dataset is a table, this function is required. 54 | def input_transform(tbl_batch): 55 | labels = tbl_batch['label'] 56 | data = tbl_batch['data'] 57 | 58 | truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0)) 59 | data_float = tf.to_float(data) 60 | 61 | return truth, data_float 62 | 63 | # Open the HDF5 file and create a loader for a dataset. 64 | # The batch_size defines the length (in the outer dimension) 65 | # of the elements (batches) returned by the reader. 66 | # Takes a function as input that pre-processes the data. 67 | loader = tftables.load_dataset(filename='path/to/h5_file.h5', 68 | dataset_path='/internal/h5/path', 69 | input_transform=input_transform, 70 | batch_size=20) 71 | 72 | # To get the data, we dequeue it from the loader. 73 | # Tensorflow tensors are returned in the same order as input_transformation 74 | truth_batch, data_batch = loader.dequeue() 75 | 76 | # The placeholder can then be used in your network 77 | result = my_network(truth_batch, data_batch) 78 | 79 | with tf.Session() as sess: 80 | 81 | # This context manager starts and stops the internal threads and 82 | # processes used to read the data from disk and store it in the queue. 83 | with loader.begin(sess): 84 | for _ in range(num_iterations): 85 | sess.run(result) 86 | 87 | 88 | If the dataset is an array instead of a table. Then ``input_transform`` can be omitted 89 | if no pre-processing is required. If only a single pass through the dataset is desired, 90 | then you should pass ``cyclic=False`` to ``load_dataset``. 91 | 92 | 93 | Examples 94 | ======== 95 | 96 | See the `unit tests `_ for complete examples. 97 | 98 | Examples 99 | ======== 100 | 101 | See the `how-to `_ for more in-depth documentation, and the 102 | `unit tests `_ for complete examples. 103 | 104 | Documentation 105 | ============= 106 | 107 | `Online documentation `_ is available. 108 | A `how to `_ gives a basic overview of the library. 109 | 110 | Offline documentation can be built from the ``docs`` folder using ``sphinx``. -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = build 9 | 10 | # Internal variables. 11 | PAPEROPT_a4 = -D latex_paper_size=a4 12 | PAPEROPT_letter = -D latex_paper_size=letter 13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 14 | # the i18n builder cannot share the environment and doctrees with the others 15 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 16 | 17 | .PHONY: help 18 | help: 19 | @echo "Please use \`make ' where is one of" 20 | @echo " html to make standalone HTML files" 21 | @echo " dirhtml to make HTML files named index.html in directories" 22 | @echo " singlehtml to make a single large HTML file" 23 | @echo " pickle to make pickle files" 24 | @echo " json to make JSON files" 25 | @echo " htmlhelp to make HTML files and a HTML help project" 26 | @echo " qthelp to make HTML files and a qthelp project" 27 | @echo " applehelp to make an Apple Help Book" 28 | @echo " devhelp to make HTML files and a Devhelp project" 29 | @echo " epub to make an epub" 30 | @echo " epub3 to make an epub3" 31 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 32 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 33 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 34 | @echo " text to make text files" 35 | @echo " man to make manual pages" 36 | @echo " texinfo to make Texinfo files" 37 | @echo " info to make Texinfo files and run them through makeinfo" 38 | @echo " gettext to make PO message catalogs" 39 | @echo " changes to make an overview of all changed/added/deprecated items" 40 | @echo " xml to make Docutils-native XML files" 41 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 42 | @echo " linkcheck to check all external links for integrity" 43 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 44 | @echo " coverage to run coverage check of the documentation (if enabled)" 45 | @echo " dummy to check syntax errors of document sources" 46 | 47 | .PHONY: clean 48 | clean: 49 | rm -rf $(BUILDDIR)/* 50 | 51 | .PHONY: html 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | .PHONY: dirhtml 58 | dirhtml: 59 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 60 | @echo 61 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 62 | 63 | .PHONY: singlehtml 64 | singlehtml: 65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 66 | @echo 67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 68 | 69 | .PHONY: pickle 70 | pickle: 71 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 72 | @echo 73 | @echo "Build finished; now you can process the pickle files." 74 | 75 | .PHONY: json 76 | json: 77 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 78 | @echo 79 | @echo "Build finished; now you can process the JSON files." 80 | 81 | .PHONY: htmlhelp 82 | htmlhelp: 83 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 84 | @echo 85 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 86 | ".hhp project file in $(BUILDDIR)/htmlhelp." 87 | 88 | .PHONY: qthelp 89 | qthelp: 90 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 91 | @echo 92 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 93 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 94 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/tftables.qhcp" 95 | @echo "To view the help file:" 96 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/tftables.qhc" 97 | 98 | .PHONY: applehelp 99 | applehelp: 100 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 101 | @echo 102 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 103 | @echo "N.B. You won't be able to view it unless you put it in" \ 104 | "~/Library/Documentation/Help or install it in your application" \ 105 | "bundle." 106 | 107 | .PHONY: devhelp 108 | devhelp: 109 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 110 | @echo 111 | @echo "Build finished." 112 | @echo "To view the help file:" 113 | @echo "# mkdir -p $$HOME/.local/share/devhelp/tftables" 114 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/tftables" 115 | @echo "# devhelp" 116 | 117 | .PHONY: epub 118 | epub: 119 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 120 | @echo 121 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 122 | 123 | .PHONY: epub3 124 | epub3: 125 | $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 126 | @echo 127 | @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." 128 | 129 | .PHONY: latex 130 | latex: 131 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 132 | @echo 133 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 134 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 135 | "(use \`make latexpdf' here to do that automatically)." 136 | 137 | .PHONY: latexpdf 138 | latexpdf: 139 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 140 | @echo "Running LaTeX files through pdflatex..." 141 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 142 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 143 | 144 | .PHONY: latexpdfja 145 | latexpdfja: 146 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 147 | @echo "Running LaTeX files through platex and dvipdfmx..." 148 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 149 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 150 | 151 | .PHONY: text 152 | text: 153 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 154 | @echo 155 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 156 | 157 | .PHONY: man 158 | man: 159 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 160 | @echo 161 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 162 | 163 | .PHONY: texinfo 164 | texinfo: 165 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 166 | @echo 167 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 168 | @echo "Run \`make' in that directory to run these through makeinfo" \ 169 | "(use \`make info' here to do that automatically)." 170 | 171 | .PHONY: info 172 | info: 173 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 174 | @echo "Running Texinfo files through makeinfo..." 175 | make -C $(BUILDDIR)/texinfo info 176 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 177 | 178 | .PHONY: gettext 179 | gettext: 180 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 181 | @echo 182 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 183 | 184 | .PHONY: changes 185 | changes: 186 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 187 | @echo 188 | @echo "The overview file is in $(BUILDDIR)/changes." 189 | 190 | .PHONY: linkcheck 191 | linkcheck: 192 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 193 | @echo 194 | @echo "Link check complete; look for any errors in the above output " \ 195 | "or in $(BUILDDIR)/linkcheck/output.txt." 196 | 197 | .PHONY: doctest 198 | doctest: 199 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 200 | @echo "Testing of doctests in the sources finished, look at the " \ 201 | "results in $(BUILDDIR)/doctest/output.txt." 202 | 203 | .PHONY: coverage 204 | coverage: 205 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 206 | @echo "Testing of coverage in the sources finished, look at the " \ 207 | "results in $(BUILDDIR)/coverage/python.txt." 208 | 209 | .PHONY: xml 210 | xml: 211 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 212 | @echo 213 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 214 | 215 | .PHONY: pseudoxml 216 | pseudoxml: 217 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 218 | @echo 219 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 220 | 221 | .PHONY: dummy 222 | dummy: 223 | $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy 224 | @echo 225 | @echo "Build finished. Dummy builder generates no files." 226 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # tftables documentation build configuration file, created by 4 | # sphinx-quickstart on Tue Mar 7 21:24:01 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another directory, 16 | # add these directories to sys.path here. If the directory is relative to the 17 | # documentation root, use os.path.abspath to make it absolute, like shown here. 18 | # 19 | import os 20 | import sys 21 | # sys.path.insert(0, os.path.abspath('.')) 22 | sys.path.insert(0, os.path.abspath('../..')) 23 | 24 | from mock import Mock as MagicMock 25 | 26 | class Mock(MagicMock): 27 | @classmethod 28 | def __getattr__(cls, name): 29 | return Mock() 30 | 31 | MOCK_MODULES = ['multitables', 'numpy', 'tensorflow'] 32 | sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) 33 | 34 | # -- General configuration ------------------------------------------------ 35 | 36 | # If your documentation needs a minimal Sphinx version, state it here. 37 | # 38 | # needs_sphinx = '1.0' 39 | 40 | # Add any Sphinx extension module names here, as strings. They can be 41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 42 | # ones. 43 | extensions = [ 44 | 'sphinx.ext.autodoc', 45 | ] 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # The suffix(es) of source filenames. 51 | # You can specify multiple suffix as a list of string: 52 | # 53 | # source_suffix = ['.rst', '.md'] 54 | source_suffix = '.rst' 55 | 56 | # The encoding of source files. 57 | # 58 | # source_encoding = 'utf-8-sig' 59 | 60 | # The master toctree document. 61 | master_doc = 'index' 62 | 63 | # General information about the project. 64 | project = u'tftables' 65 | copyright = u'2017, ghcollin' 66 | author = u'ghcollin' 67 | 68 | # The version info for the project you're documenting, acts as replacement for 69 | # |version| and |release|, also used in various other places throughout the 70 | # built documents. 71 | # 72 | # The short X.Y version. 73 | version = u'1.0.0' 74 | # The full version, including alpha/beta/rc tags. 75 | release = u'1.0.0' 76 | 77 | # The language for content autogenerated by Sphinx. Refer to documentation 78 | # for a list of supported languages. 79 | # 80 | # This is also used if you do content translation via gettext catalogs. 81 | # Usually you set "language" from the command line for these cases. 82 | language = None 83 | 84 | # There are two options for replacing |today|: either, you set today to some 85 | # non-false value, then it is used: 86 | # 87 | # today = '' 88 | # 89 | # Else, today_fmt is used as the format for a strftime call. 90 | # 91 | # today_fmt = '%B %d, %Y' 92 | 93 | # List of patterns, relative to source directory, that match files and 94 | # directories to ignore when looking for source files. 95 | # This patterns also effect to html_static_path and html_extra_path 96 | exclude_patterns = [] 97 | 98 | # The reST default role (used for this markup: `text`) to use for all 99 | # documents. 100 | # 101 | # default_role = None 102 | 103 | # If true, '()' will be appended to :func: etc. cross-reference text. 104 | # 105 | # add_function_parentheses = True 106 | 107 | # If true, the current module name will be prepended to all description 108 | # unit titles (such as .. function::). 109 | # 110 | # add_module_names = True 111 | 112 | # If true, sectionauthor and moduleauthor directives will be shown in the 113 | # output. They are ignored by default. 114 | # 115 | # show_authors = False 116 | 117 | # The name of the Pygments (syntax highlighting) style to use. 118 | pygments_style = 'sphinx' 119 | 120 | # A list of ignored prefixes for module index sorting. 121 | # modindex_common_prefix = [] 122 | 123 | # If true, keep warnings as "system message" paragraphs in the built documents. 124 | # keep_warnings = False 125 | 126 | # If true, `todo` and `todoList` produce output, else they produce nothing. 127 | todo_include_todos = False 128 | 129 | 130 | # -- Options for HTML output ---------------------------------------------- 131 | 132 | # The theme to use for HTML and HTML Help pages. See the documentation for 133 | # a list of builtin themes. 134 | # 135 | html_theme = 'alabaster' 136 | 137 | # Theme options are theme-specific and customize the look and feel of a theme 138 | # further. For a list of options available for each theme, see the 139 | # documentation. 140 | # 141 | # html_theme_options = {} 142 | 143 | # Add any paths that contain custom themes here, relative to this directory. 144 | # html_theme_path = [] 145 | 146 | # The name for this set of Sphinx documents. 147 | # " v documentation" by default. 148 | # 149 | # html_title = u'tftables v1.0.0' 150 | 151 | # A shorter title for the navigation bar. Default is the same as html_title. 152 | # 153 | # html_short_title = None 154 | 155 | # The name of an image file (relative to this directory) to place at the top 156 | # of the sidebar. 157 | # 158 | # html_logo = None 159 | 160 | # The name of an image file (relative to this directory) to use as a favicon of 161 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 162 | # pixels large. 163 | # 164 | # html_favicon = None 165 | 166 | # Add any paths that contain custom static files (such as style sheets) here, 167 | # relative to this directory. They are copied after the builtin static files, 168 | # so a file named "default.css" will overwrite the builtin "default.css". 169 | html_static_path = ['_static'] 170 | 171 | # Add any extra paths that contain custom files (such as robots.txt or 172 | # .htaccess) here, relative to this directory. These files are copied 173 | # directly to the root of the documentation. 174 | # 175 | # html_extra_path = [] 176 | 177 | # If not None, a 'Last updated on:' timestamp is inserted at every page 178 | # bottom, using the given strftime format. 179 | # The empty string is equivalent to '%b %d, %Y'. 180 | # 181 | # html_last_updated_fmt = None 182 | 183 | # If true, SmartyPants will be used to convert quotes and dashes to 184 | # typographically correct entities. 185 | # 186 | # html_use_smartypants = True 187 | 188 | # Custom sidebar templates, maps document names to template names. 189 | # 190 | # html_sidebars = {} 191 | 192 | # Additional templates that should be rendered to pages, maps page names to 193 | # template names. 194 | # 195 | # html_additional_pages = {} 196 | 197 | # If false, no module index is generated. 198 | # 199 | # html_domain_indices = True 200 | 201 | # If false, no index is generated. 202 | # 203 | # html_use_index = True 204 | 205 | # If true, the index is split into individual pages for each letter. 206 | # 207 | # html_split_index = False 208 | 209 | # If true, links to the reST sources are added to the pages. 210 | # 211 | # html_show_sourcelink = True 212 | 213 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 214 | # 215 | # html_show_sphinx = True 216 | 217 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 218 | # 219 | # html_show_copyright = True 220 | 221 | # If true, an OpenSearch description file will be output, and all pages will 222 | # contain a tag referring to it. The value of this option must be the 223 | # base URL from which the finished HTML is served. 224 | # 225 | # html_use_opensearch = '' 226 | 227 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 228 | # html_file_suffix = None 229 | 230 | # Language to be used for generating the HTML full-text search index. 231 | # Sphinx supports the following languages: 232 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' 233 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr', 'zh' 234 | # 235 | # html_search_language = 'en' 236 | 237 | # A dictionary with options for the search language support, empty by default. 238 | # 'ja' uses this config value. 239 | # 'zh' user can custom change `jieba` dictionary path. 240 | # 241 | # html_search_options = {'type': 'default'} 242 | 243 | # The name of a javascript file (relative to the configuration directory) that 244 | # implements a search results scorer. If empty, the default will be used. 245 | # 246 | # html_search_scorer = 'scorer.js' 247 | 248 | # Output file base name for HTML help builder. 249 | htmlhelp_basename = 'tftablesdoc' 250 | 251 | # -- Options for LaTeX output --------------------------------------------- 252 | 253 | latex_elements = { 254 | # The paper size ('letterpaper' or 'a4paper'). 255 | # 256 | # 'papersize': 'letterpaper', 257 | 258 | # The font size ('10pt', '11pt' or '12pt'). 259 | # 260 | # 'pointsize': '10pt', 261 | 262 | # Additional stuff for the LaTeX preamble. 263 | # 264 | # 'preamble': '', 265 | 266 | # Latex figure (float) alignment 267 | # 268 | # 'figure_align': 'htbp', 269 | } 270 | 271 | # Grouping the document tree into LaTeX files. List of tuples 272 | # (source start file, target name, title, 273 | # author, documentclass [howto, manual, or own class]). 274 | latex_documents = [ 275 | (master_doc, 'tftables.tex', u'tftables Documentation', 276 | u'ghcollin', 'manual'), 277 | ] 278 | 279 | # The name of an image file (relative to this directory) to place at the top of 280 | # the title page. 281 | # 282 | # latex_logo = None 283 | 284 | # For "manual" documents, if this is true, then toplevel headings are parts, 285 | # not chapters. 286 | # 287 | # latex_use_parts = False 288 | 289 | # If true, show page references after internal links. 290 | # 291 | # latex_show_pagerefs = False 292 | 293 | # If true, show URL addresses after external links. 294 | # 295 | # latex_show_urls = False 296 | 297 | # Documents to append as an appendix to all manuals. 298 | # 299 | # latex_appendices = [] 300 | 301 | # It false, will not define \strong, \code, itleref, \crossref ... but only 302 | # \sphinxstrong, ..., \sphinxtitleref, ... To help avoid clash with user added 303 | # packages. 304 | # 305 | # latex_keep_old_macro_names = True 306 | 307 | # If false, no module index is generated. 308 | # 309 | # latex_domain_indices = True 310 | 311 | 312 | # -- Options for manual page output --------------------------------------- 313 | 314 | # One entry per manual page. List of tuples 315 | # (source start file, name, description, authors, manual section). 316 | man_pages = [ 317 | (master_doc, 'tftables', u'tftables Documentation', 318 | [author], 1) 319 | ] 320 | 321 | # If true, show URL addresses after external links. 322 | # 323 | # man_show_urls = False 324 | 325 | 326 | # -- Options for Texinfo output ------------------------------------------- 327 | 328 | # Grouping the document tree into Texinfo files. List of tuples 329 | # (source start file, target name, title, author, 330 | # dir menu entry, description, category) 331 | texinfo_documents = [ 332 | (master_doc, 'tftables', u'tftables Documentation', 333 | author, 'tftables', 'One line description of project.', 334 | 'Miscellaneous'), 335 | ] 336 | 337 | # Documents to append as an appendix to all manuals. 338 | # 339 | # texinfo_appendices = [] 340 | 341 | # If false, no module index is generated. 342 | # 343 | # texinfo_domain_indices = True 344 | 345 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 346 | # 347 | # texinfo_show_urls = 'footnote' 348 | 349 | # If true, do not generate a @detailmenu in the "Top" node's menu. 350 | # 351 | # texinfo_no_detailmenu = False 352 | -------------------------------------------------------------------------------- /docs/source/howto.rst: -------------------------------------------------------------------------------- 1 | How To 2 | ****** 3 | 4 | Use of the library starts with creating a ``TableReader`` object. 5 | 6 | .. code:: python 7 | 8 | import tftables 9 | reader = tftables.open_file(filename="/path/to/h5/file", batch_size=10) 10 | 11 | Here the batch size is specified as an argument to the ``open_file`` function. The batch_size defines the length 12 | (in the outer dimension) of the elements (batches) returned by ``reader``. 13 | 14 | Accessing a single array 15 | ======================== 16 | 17 | Suppose you only want to read a single array from your HDF5 file. Doing this is quite straight-forward. 18 | Start by getting a tensorflow placeholder for your batch from ``reader``. 19 | 20 | .. code:: python 21 | 22 | array_batch_placeholder = reader.get_batch( 23 | path = '/h5/path', # This is the path to your array inside the HDF5 file. 24 | cyclic = True, # In cyclic access, when the reader gets to the end of the 25 | # array, it will wrap back to the beginning and continue. 26 | ordered = False # The reader will not require the rows of the array to be 27 | # returned in the same order as on disk. 28 | ) 29 | 30 | # You can transform the batch however you like now. 31 | # For example, casting it to floats. 32 | array_batch_float = tf.to_float(array_batch_placeholder) 33 | 34 | # The data can now be fed into your network 35 | result = my_network(array_batch_float) 36 | 37 | with tf.Session() as sess: 38 | # The feed method provides a generator that returns 39 | # feed_dict's containing batches from your HDF5 file. 40 | for i, feed_dict in enumerate(reader.feed()): 41 | sess.run(result, feed_dict=feed_dict) 42 | if i >= N: 43 | break 44 | 45 | # Finally, the reader should be closed. 46 | reader.close() 47 | 48 | Note that be default, the ``ordered`` argument to ``get_batch`` is set to ``True``. If you require the rows of the 49 | array to be returned in the same order as they are on disk, then you should leave it as ``ordered = True``. 50 | However, this may result in a performance penalty. In machine learning, rows of a dataset often represent 51 | independent examples, or data points. Thus their ordering is not important. 52 | 53 | Accessing a single table 54 | ======================== 55 | 56 | When reading from a table, the ``get_batch`` method returns a dictionary. The columns of the table form the keys 57 | of this dictionary, and the values are tensorflow placeholders for batches of each column. If one of the columns has 58 | a compound datatype, then its corresponding value in the dictionary will itself be a dictionary. In this way, 59 | recursive compound datatypes will give recursive dictionaries. 60 | 61 | For example, if your table just had two columns, named ``label`` and ``data``, then you could use: 62 | 63 | .. code:: python 64 | 65 | table_batch = reader.get_batch( 66 | path = '/path/to/table', 67 | cyclic = True, 68 | ordered = False 69 | ) 70 | 71 | label_batch = table_batch['label'] 72 | data_batch = table_batch['data'] 73 | 74 | If your table was a bit more complicated, with columns named ``label`` and ``value``. And the ``value`` column has 75 | a compound type with fields named ``image`` and ``lidar``, then you could use: 76 | 77 | .. code:: python 78 | 79 | table_batch = reader.get_batch( 80 | path = '/path/to/complex_table', 81 | cyclic = True, 82 | ordered = False 83 | ) 84 | 85 | label_batch = table_batch['label'] 86 | value_batch = table_batch['value'] 87 | 88 | image_batch = value_batch['image'] 89 | lidar_batch = value_batch['lidar'] 90 | 91 | Using a FIFO queue 92 | ================== 93 | 94 | Copying data to the GPU through a ``feed_dict`` is notoriously slow in Tensorflow. It is much faster to buffer 95 | data in a queue. You are free to manage your own queues, but a helper class is included to make this task easier. 96 | 97 | .. code:: python 98 | 99 | # As before 100 | array_batch_placeholder = reader.get_batch( 101 | path = '/h5/path', 102 | cyclic = True, 103 | ordered = False) 104 | array_batch_float = tf.to_float(array_batch_placeholder) 105 | 106 | # Now we create a FIFO Loader 107 | loader = reader.get_fifoloader( 108 | queue_size = 10, # The maximum number of elements that the 109 | # internal Tensorflow queue should hold. 110 | inputs = [array_batch_float], # A list of tensors that will be stored 111 | # in the queue. 112 | threads = 1 # The number of threads used to stuff the 113 | # queue. If ordered access to a dataset 114 | # was requested, then only 1 thread 115 | # should be used. 116 | ) 117 | 118 | # Batches can now be dequeued from the loader for use in your network. 119 | array_batch_cpu = loader.dequeue() 120 | result = my_network(array_batch_cpu) 121 | 122 | with tf.Session() as sess: 123 | 124 | # The loader needs to be started with your Tensorflow session. 125 | loader.start(sess) 126 | 127 | for i in range(N): 128 | # You can now cleanly evaluate your network without a feed_dict. 129 | sess.run(result) 130 | 131 | # It also needs to be stopped for clean shutdown. 132 | loader.stop(sess) 133 | 134 | # Finally, the reader should be closed. 135 | reader.close() 136 | 137 | Non-cyclic access 138 | ----------------- 139 | 140 | If you are classifying a dataset, rather than training a model, then you probably only want to run through the 141 | dataset once. This can be done by passing ``cyclic = False`` to ``get_batch``. Once finished, the internal Tensorflow 142 | queue will throw an instance of the ``tensorflow.errors.OutOfRangeError`` exception to signal termination of the loop. 143 | 144 | This can be caught manually with a try-catch block: 145 | 146 | .. code:: python 147 | 148 | with tf.Session() as sess: 149 | loader.start(sess) 150 | 151 | try: 152 | # Keep iterating until the exception breaks the loop 153 | while True: 154 | sess.run(result) 155 | # Now silently catch the exception. 156 | except tf.errors.OutOfRangeError: 157 | pass 158 | 159 | loader.stop(sess) 160 | 161 | A slightly more elegant solution is to use a context manager supplied by the loader class: 162 | 163 | .. code:: python 164 | 165 | with tf.Session() as sess: 166 | loader.start(sess) 167 | 168 | # This context manager suppresses the exception. 169 | with loader.catch_termination(): 170 | # Keep iterating until the exception breaks the loop 171 | while True: 172 | sess.run(result) 173 | 174 | loader.stop(sess) 175 | 176 | Start stop context manager 177 | -------------------------- 178 | 179 | In either cyclic or non-cyclic access, we can use a context manager to start and stop the loader class. 180 | 181 | .. code:: python 182 | 183 | with tf.Session() as sess: 184 | with loader.begin(sess): 185 | # Loop 186 | 187 | Quick access to a single dataset 188 | ================================ 189 | 190 | It is highly recommended that you use a single dataset, this allows you to use unordered access which is a fastest 191 | way of reading data. If you have multiple sources of data, such as labels and images, then you should organise them 192 | into a table. This also has performance benefits due to the locality of the data. 193 | 194 | When you only have one dataset, the function ``load_dataset`` is provided to set up the reader and loader for you. 195 | Any preprocessing that need to be done CPU side before loading into the queue can be written as a function that 196 | generates a Tensorflow graph. This input transformation function is fed into ``load_dataset`` as an argument. 197 | 198 | The input transform function should return a list of tensors that will be stored in the queue. The input transform 199 | is required when the dataset is a table, as the dictionary needs to be turned into a list. 200 | 201 | .. code:: python 202 | 203 | # This function preprocesses the batches before they 204 | # are loaded into the internal queue. 205 | # You can cast data, or do one-hot transforms. 206 | # If the dataset is a table, this function is required. 207 | def input_transform(tbl_batch): 208 | labels = tbl_batch['label'] 209 | data = tbl_batch['data'] 210 | 211 | truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0)) 212 | data_float = tf.to_float(data) 213 | 214 | return truth, data_float 215 | 216 | # Open the HDF5 file and create a loader for a dataset. 217 | # The batch_size defines the length (in the outer dimension) 218 | # of the elements (batches) returned by the reader. 219 | # Takes a function as input that pre-processes the data. 220 | loader = tftables.load_dataset(filename='path/to/h5_file.h5', 221 | dataset_path='/internal/h5/path', 222 | input_transform=input_transform, 223 | batch_size=20) 224 | 225 | # To get the data, we dequeue it from the loader. 226 | # Tensorflow tensors are returned in the same order as input_transformation 227 | truth_batch, data_batch = loader.dequeue() 228 | 229 | # The placeholder can then be used in your network 230 | result = my_network(truth_batch, data_batch) 231 | 232 | with tf.Session() as sess: 233 | 234 | # This context manager starts and stops the internal threads and 235 | # processes used to read the data from disk and store it in the queue. 236 | with loader.begin(sess): 237 | for _ in range(num_iterations): 238 | sess.run(result) 239 | 240 | When using ``load_dataset`` the reader is automatically closed when the loader is stopped. 241 | 242 | Accessing multiple datasets 243 | =========================== 244 | 245 | If your HDF5 file has multiple datasets (multiple arrays, tables or both) then you should write a script to transform 246 | it into a file with only a single table. If this isn't possible, then you can access the datasets directly through 247 | ``tftables``, but must do so using ordered access (otherwise the datasets can get out of sync). 248 | 249 | .. code:: python 250 | 251 | # Use get_batch to access the table. 252 | # Both datasets must be accessed in ordered mode. 253 | table_batch_dict = reader.get_batch( 254 | path = '/internal/h5_path/to/table', 255 | ordered = True) 256 | col_A_pl, col_B_pl = table_batch_dict['col_A'], table_batch_dict['col_B'] 257 | 258 | # Now use get_batch again to access an array. 259 | # Both datasets must be accessed in ordered mode. 260 | labels_batch = reader.get_batch('/my_label_array', ordered = True) 261 | truth_batch = tf.one_hot(labels_batch, 2, 1, 0) 262 | 263 | # The loader takes a list of tensors to be stored in the queue. 264 | # When accessing in ordered mode, threads should be set to 1. 265 | loader = reader.get_fifoloader( 266 | queue_size = 10, 267 | inputs = [truth_batch, col_A_pl, col_B_pl], 268 | threads = 1) 269 | 270 | # Batches are taken out of the queue using a dequeue operation. 271 | # Tensors are returned in the order they were given when creating the loader. 272 | truth_cpu, col_A_cpu, col_B_cpu = loader.dequeue() 273 | 274 | # The dequeued data can then be used in your network. 275 | result = my_network(truth_cpu, col_A_cpu, col_B_cpu) 276 | 277 | with tf.Session() as sess: 278 | with loader.begin(sess): 279 | for _ in range(N): 280 | sess.run(result) 281 | 282 | reader.close() 283 | 284 | Ordered access is enabled be default when using ``get_batch`` as a safety measure. It is disabled when using 285 | ``load_dataset`` as that function restricts access to a single dataset. -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. tftables documentation master file, created by 2 | sphinx-quickstart on Tue Mar 7 21:24:01 2017. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | tftables documentation 7 | ********************** 8 | 9 | `tftables `_ allows convenient access to HDF5 files with Tensorflow. 10 | A class for reading batches of data out of arrays or tables is provided. 11 | A secondary class wraps both the primary reader and a Tensorflow FIFOQueue for straight-forward streaming 12 | of data from HDF5 files into Tensorflow operations. 13 | 14 | The library is backed by `multitables `_ for high-speed reading of HDF5 15 | datasets. ``multitables`` is based on PyTables (``tables``), so this library can make use of any compression algorithms 16 | that PyTables supports. 17 | 18 | Contents 19 | ======== 20 | 21 | .. toctree:: 22 | :maxdepth: 2 23 | 24 | quick 25 | howto 26 | reference 27 | 28 | Licence 29 | ======= 30 | 31 | This software is distributed under the MIT licence. 32 | See the `LICENSE.txt `_ file for details. 33 | 34 | Indices and tables 35 | ================== 36 | 37 | * :ref:`genindex` 38 | * :ref:`modindex` 39 | * :ref:`search` 40 | 41 | -------------------------------------------------------------------------------- /docs/source/quick.rst: -------------------------------------------------------------------------------- 1 | Quick Start 2 | *********** 3 | 4 | Installation 5 | ============ 6 | 7 | :: 8 | 9 | pip install tftables 10 | 11 | Alternatively, to install from HEAD, run 12 | 13 | :: 14 | 15 | pip install git+https://github.com/ghcollin/tftables.git 16 | 17 | You can also `download `_ 18 | or `clone the repository `_ and run 19 | 20 | :: 21 | 22 | python setup.py install 23 | 24 | ``tftables`` depends on ``multitables``, ``numpy`` and ``tensorflow``. The package is compatible with the latest versions of python 25 | 2 and 3. 26 | 27 | Quick start 28 | =========== 29 | 30 | An example of accessing a table in a HDF5 file. 31 | 32 | .. code:: python 33 | 34 | import tftables 35 | import tensorflow as tf 36 | 37 | with tf.device('/cpu:0'): 38 | # This function preprocesses the batches before they 39 | # are loaded into the internal queue. 40 | # You can cast data, or do one-hot transforms. 41 | # If the dataset is a table, this function is required. 42 | def input_transform(tbl_batch): 43 | labels = tbl_batch['label'] 44 | data = tbl_batch['data'] 45 | 46 | truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0)) 47 | data_float = tf.to_float(data) 48 | 49 | return truth, data_float 50 | 51 | # Open the HDF5 file and create a loader for a dataset. 52 | # The batch_size defines the length (in the outer dimension) 53 | # of the elements (batches) returned by the reader. 54 | # Takes a function as input that pre-processes the data. 55 | loader = tftables.load_dataset(filename=self.test_filename, 56 | dataset_path=self.test_mock_data_path, 57 | input_transform=input_transform, 58 | batch_size=20) 59 | 60 | # To get the data, we dequeue it from the loader. 61 | # Tensorflow tensors are returned in the same order as input_transformation 62 | truth_batch, data_batch = loader.dequeue() 63 | 64 | # The placeholder can then be used in your network 65 | result = my_network(truth_batch, data_batch) 66 | 67 | with tf.Session() as sess: 68 | 69 | # This context manager starts and stops the internal threads and 70 | # processes used to read the data from disk and store it in the queue. 71 | with loader.begin(sess): 72 | for _ in range(num_iterations): 73 | sess.run(result) 74 | 75 | 76 | If the dataset is an array instead of a table. Then ``input_transform`` can be omitted 77 | if no pre-processing is required. If only a single pass through the dataset is desired, 78 | then you should pass ``cyclic=False`` to ``load_dataset``. 79 | 80 | 81 | Examples 82 | ======== 83 | 84 | See the :doc:`How-To ` for more in-depth documentation, and the 85 | `unit tests `_ for complete examples. -------------------------------------------------------------------------------- /docs/source/reference.rst: -------------------------------------------------------------------------------- 1 | Reference 2 | ********* 3 | 4 | .. automodule:: tftables 5 | :members: open_file, load_dataset, FileReader, FIFOQueueLoader -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | # This flag says that the code is written to work on both Python 2 and Python 3 | # 3. If at all possible, it is good practice to do this. If you cannot, you 4 | # will need to generate wheels for each Python version that you support. 5 | universal=1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='tftables', 5 | version='1.1.2', 6 | url='https://github.com/ghcollin/tftables', 7 | description='Interface for reading HDF5 files into Tensorflow.', 8 | long_description=open("README.rst").read(), 9 | keywords='tensorflow HDF5', 10 | license='MIT', 11 | author='ghcollin', 12 | author_email='', 13 | classifiers=[ 14 | 'Development Status :: 4 - Beta', 15 | 'Intended Audience :: Developers', 16 | 'Intended Audience :: Science/Research', 17 | 'License :: OSI Approved :: MIT License', 18 | 'Programming Language :: Python :: 2.7', 19 | 'Programming Language :: Python :: 3', 20 | 'Topic :: Scientific/Engineering :: Mathematics', 21 | 'Topic :: Software Development :: Libraries :: Python Modules' 22 | ], 23 | py_modules=['tftables'], 24 | install_requires=['multitables', 'numpy!=1.10.1', 'tensorflow'] 25 | ) -------------------------------------------------------------------------------- /tftables.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2016 G. H. Collin (ghcollin) 2 | # 3 | # This software may be modified and distributed under the terms 4 | # of the MIT license. See the LICENSE.txt file for details. 5 | 6 | import tensorflow as tf 7 | import multitables as mtb 8 | import numpy as np 9 | import threading 10 | import contextlib 11 | 12 | __author__ = "G. H. Collin" 13 | __version__ = "1.1.2" 14 | 15 | def open_file(filename, batch_size, **kw_args): 16 | """ 17 | Open a HDF5 file for streaming with multitables. 18 | Batches will be retrieved with size ``batch_size``. 19 | Additional keyword arguments will be passed to the ``multitables.Streamer`` object. 20 | 21 | :param filename: Filename for the HDF5 file to be read. 22 | :param batch_size: The size of the batches to be fetched by this reader. 23 | :param kw_args: Optional arguments to pass to multitables. 24 | :return: A FileReader instance. 25 | """ 26 | return FileReader(filename, batch_size, **kw_args) 27 | 28 | 29 | def load_dataset(filename, dataset_path, batch_size, queue_size=8, 30 | input_transform=None, 31 | ordered=False, 32 | cyclic=True, 33 | processes=None, 34 | threads=None): 35 | """ 36 | Convenience function to quickly and easily load a dataset using best guess defaults. 37 | If a table is loaded, then the ``input_transformation`` argument is required. 38 | Returns an instance of ``FIFOQueueLoader`` that loads this dataset into a fifo queue. 39 | 40 | This function takes a single argument, which is either a tensorflow placeholder for the 41 | requested array or a dictionary of tensorflow placeholders for the columns in the 42 | requested table. The output of this function should be either a single tensorflow tensor, 43 | a tuple of tensorflow tensors, or a list of tensorflow tensors. A subsequent call to 44 | ``loader.dequeue()`` will return tensors in the same order as ``input_transform``. 45 | 46 | For example, if an array is stored in uint8 format, but we want to cast 47 | it to float32 format to do work on the GPU, the ``input_transform`` would be: 48 | 49 | :: 50 | 51 | def input_transform(ary_batch): 52 | return tf.cast(ary_batch, tf.float32) 53 | 54 | If, instead we were loading a table with column names ``label`` and ``data`` we 55 | need to transform this into a list. We might use something like the following 56 | to also do the one hot transform. 57 | 58 | :: 59 | 60 | def input_transform(tbl_batch): 61 | labels = tbl_batch['labels'] 62 | data = tbl_batch['data'] 63 | 64 | truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0)) 65 | data_float = tf.to_float(data) 66 | 67 | return truth, data_float 68 | 69 | Then the subsequent call to ``loader.dequeue()`` returns these int the same order: 70 | 71 | :: 72 | 73 | truth_batch, data_batch = loader.dequeue() 74 | 75 | By default, this function does not preserve on-disk ordering, and gives cyclic access. 76 | The disk ordering can be preserved using the ``ordered`` argument; however, this may result 77 | in slower read performance. 78 | 79 | :param filename: The filename to the HDF5 file. 80 | :param dataset_path: The internal HDF5 path to the dataset. 81 | :param batch_size: The size of the batches to be loaded into tensorflow. 82 | :param queue_size: The size of the tensorflow FIFO queue. 83 | :param input_transform: A function that transforms the batch before being loaded into the queue. 84 | :param ordered: Preserve the on-disk ordering of the requested dataset. 85 | :param cyclic: Data will be loaded in an endless loop that wraps around the end of the dataset. 86 | :param processes: Number of concurrent processes that multitables should use to read data from disk. 87 | :param threads: Number of threads to use to preprocess data and load the FIFO queue. 88 | :return: a loader for the dataset 89 | """ 90 | if processes is None: 91 | processes = (queue_size + 1) // 2 92 | if threads is None: 93 | threads = 1 if ordered else processes 94 | 95 | reader = FileReader(filename, batch_size) 96 | 97 | batch = reader.get_batch(dataset_path, ordered=ordered, cyclic=cyclic, n_procs=processes) 98 | 99 | if input_transform is not None: 100 | # Transform the input based on user specified function. 101 | processed_batch = input_transform(batch) 102 | elif isinstance(batch, dict): 103 | # If the user tries to load a table, but no function is given, then we cannot go further. 104 | # Table's return dictionaries and there is no good default on how to handle this. 105 | raise ValueError("Table datasets must have an input transformation.") 106 | else: 107 | # User loaded an array, no processing requested or required. 108 | processed_batch = batch 109 | 110 | if isinstance(processed_batch, list): 111 | # If the user gave a list, we're good 112 | pass 113 | elif isinstance(processed_batch, tuple): 114 | # If the user gave a tuple, turn it into a list 115 | processed_batch = list(processed_batch) 116 | else: 117 | # If the user returned a single value, also turn it into a list 118 | processed_batch = [processed_batch] 119 | 120 | loader = FIFOQueueLoader(reader, queue_size, processed_batch, threads=threads) 121 | # The user never gets a reference to the reader, so we request the loader to close the 122 | # reader for us when it is stopped. 123 | loader.close_reader = True 124 | 125 | return loader 126 | 127 | 128 | class FileReader: 129 | """This class reads batches from datasets in a HDF5 file.""" 130 | 131 | def __init__(self, filename, batch_size, **kw_args): 132 | """ 133 | Create a HDF5 file reader that reads batches of size ``batch_size``. 134 | The batch size is the number of elements of the outer-most dimension of the datasets that 135 | will be read. This can thought of as the number of rows that will be read at once and returned 136 | to the user. 137 | 138 | :param filename: The HDF5 file to read. 139 | :param batch_size: The size of the batches to be read. 140 | :param kw_args: Optional arguments to pass to multitables. 141 | """ 142 | self.streamer = mtb.Streamer(filename, **kw_args) 143 | self.vars = [] 144 | self.batch_size = batch_size 145 | self.queues = [] 146 | self.order_lock = None 147 | 148 | @staticmethod 149 | def __match_slices(slice1, len1, slice2): 150 | """ 151 | Assures that the two given slices are compatible with each other and slice1 does no extend past the end 152 | of an array with length len1. 153 | If slice1 would extend greater than len1, then slice1 is spliced to wrap around len1. 154 | slice2 would then be spliced to match the two new slices for slice1. 155 | 156 | :param slice1: Slice that will be checked against slice1. 157 | :param len1: The length of an array that slice1 should wrap around. 158 | :param slice2: The slice that should be spliced to match slice1. 159 | :return: Two tuples. 160 | Tuple 1 contains two slices that correspond to the non-wraped part of slice1 and slice2. 161 | Tuple 2 contains two slices that correspond to the wrapped part of slice1 and slice2. 162 | """ 163 | delta_A, delta_B = len1 - slice1.start, slice1.stop - len1 164 | 165 | slice1_A = slice(slice1.start, slice1.start + delta_A) 166 | slice2_A = slice(slice2.start, slice2.start + delta_A) 167 | 168 | slice1_B = slice(0, 0 + delta_B) 169 | slice2_B = slice(slice2_A.stop, slice2_A.stop + delta_B) 170 | return (slice1_A, slice2_A), (slice1_B, slice2_B) 171 | 172 | @staticmethod 173 | def __to_tf_dtype(np_dtype): 174 | """ 175 | Converts a numpy dtype to a tensorflow dtype. 176 | This may return a larger dtype if no exact fit to np_dtype can be made. 177 | 178 | :param np_dtype: The numpy dtype to convert 179 | :return: A tensorflow dtype that matches np_dtype as closely as possible. 180 | """ 181 | # We try converting first so that the code gracefully falls back if tensorflow one day supports uint32/64. 182 | try: 183 | return tf.as_dtype(np_dtype) 184 | except TypeError as e: 185 | # there is no tensorflow dtype for uint32 at the moment, but we can stuff these into int64s safely 186 | if np_dtype == np.uint32: 187 | return tf.int64 188 | elif np_dtype == np.uint64: 189 | raise ValueError("Arrays with 64-bit unsigned integer type are not supported, as Tensorflow " 190 | + "has no corresponding data type.") 191 | raise e 192 | 193 | @staticmethod 194 | def __create_placeholders(type, batch_shape): 195 | """ 196 | Recursive function for creating placeholders. If the type is simple (not-compound) then a single tensorflow 197 | placeholder is returned will the appropriate batch_shape. 198 | If the type is compound, then a dictionary is returned. Each key of the dictionary corresponds to a 199 | column (or element) of the compound type. Each value of the dictionary contains the corresponding placeholder. 200 | 201 | The placeholders for this dictionary are created by recursively calling this function. Thus, a tree of 202 | dictionaries is created if the compound type contains other compound types. 203 | 204 | :param type: The corresponding numpy data type for this placeholder. 205 | :param batch_shape: The shape of the batch for this placeholder. 206 | :return: Either a placeholder, or a dictionary of placeholders. 207 | """ 208 | # If .fields is None, then the array is just a simple (not-compound) array. So a single placeholder is returned. 209 | if type.fields is None: 210 | placeholder = tf.placeholder(shape=batch_shape, dtype=FileReader.__to_tf_dtype(type)) 211 | result = placeholder 212 | 213 | # Otherwise, a dictionary of placeholders is needed: 214 | # As tensorflow doesn't support tensors with compound (or 'structured') types, a tensor (and thus placeholder) 215 | # if needed for each column in this array. 216 | else: 217 | placeholders = {} 218 | for name in type.fields.keys(): 219 | field_dtype = type.fields[name][0] # np dtype for the column 220 | subdtype = field_dtype.subdtype 221 | 222 | # The subdtype will be None, if this is a scalar. 223 | if subdtype is None: 224 | placeholder = FileReader.__create_placeholders(field_dtype, batch_shape) 225 | placeholders[name] = placeholder 226 | # If the column contains a sub-array, then subdtype is not None. 227 | else: 228 | subfield_type, subfield_shape = subdtype # subfield_shape is a shape of the sub-array 229 | # Append the sub-array shape to the batch_shape, as we are creating a single tensor for each column. 230 | subfield_batch_shape = batch_shape + list(subfield_shape) 231 | placeholder = FileReader.__create_placeholders(subfield_type, subfield_batch_shape) 232 | placeholders[name] = placeholder 233 | result = placeholders 234 | 235 | return result 236 | 237 | def get_batch(self, path, **kw_args): 238 | """ 239 | Get a Tensorflow placeholder for a batch that will be read from the dataset located at path. 240 | Additional key word arguments will be forwarded to the get_queue method in multitables. 241 | This defaults the multitables arguments `cyclic` and `ordered` to true. 242 | 243 | When ordering of batches is unimportant, the `ordered` argument can be set to False for potentially 244 | better performance. When reading from multiple datasets (eg; when examples and labels are in two different 245 | arrays), it is recommended to set `ordered` to True to preserve synchronisation. 246 | 247 | If the dataset is a table (or other compound-type array) then a dictionary of placeholders will be returned 248 | instead. The keys of this dictionary correspond to the column names of the table (or compound sub-types). 249 | 250 | :param path: The internal HDF5 path to the dataset to be read. 251 | :param kw_args: Optional arguments to be forwarded to multitables. 252 | :return: Either a placeholder or a dictionary depending on the type of dataset. 253 | If the dataset is a plain array, a placeholder representing once batch is returned. 254 | If the dataset is a table or compound type, a dictionary of placeholders is returned. 255 | """ 256 | if 'cyclic' not in kw_args: 257 | kw_args['cyclic'] = True 258 | if 'ordered' not in kw_args: 259 | kw_args['ordered'] = True 260 | if kw_args['ordered']: 261 | if self.order_lock is None: 262 | self.order_lock = threading.Lock() 263 | queue = self.streamer.get_queue(path=path, **kw_args) 264 | block_size = queue.block_size 265 | # get an example for finding data types and row sizes. 266 | example = self.streamer.get_remainder(path, block_size) 267 | batch_type = example.dtype 268 | inner_shape = example.shape[1:] 269 | batch_shape = [self.batch_size] + list(inner_shape) 270 | 271 | # Generator for reading batches. 272 | def read_batch(): 273 | # A 'scratch' space of one batch is needed to take care of remainder elements. 274 | # Here, remainder elements are defined as those left over when the batch size does not divide 275 | # the block size evenly. 276 | scratch_offset = 0 277 | scratch = np.zeros(batch_shape, dtype=batch_type) 278 | 279 | while True: 280 | guard = queue.get() 281 | if guard is mtb.QueueClosed: 282 | if kw_args['ordered']: 283 | remainder = self.streamer.get_remainder(path, block_size) 284 | remaining_scratch_space = self.batch_size - scratch_offset 285 | if len(remainder) >= remaining_scratch_space: 286 | rows_to_write = min(remaining_scratch_space, len(remainder)) 287 | scratch[scratch_offset:scratch_offset+rows_to_write] = remainder[:rows_to_write] 288 | yield scratch 289 | indexes = range(rows_to_write, len(remainder) + 1, self.batch_size) 290 | for start, end in zip(indexes[:-1], indexes[1:]): 291 | yield remainder[start:end] 292 | break 293 | with guard as block: 294 | block_offset = 0 295 | if kw_args['ordered'] and scratch_offset != 0: 296 | remaining_scratch_space = self.batch_size - scratch_offset 297 | rows_to_write = min(remaining_scratch_space, block_size) 298 | scratch[scratch_offset:scratch_offset+rows_to_write] = block[:rows_to_write] 299 | scratch_offset = scratch_offset + rows_to_write 300 | if scratch_offset == self.batch_size: 301 | yield scratch 302 | scratch_offset = 0 303 | block_offset = rows_to_write 304 | if block_offset == block_size: 305 | continue 306 | 307 | # First, if the batch size is smaller than the block size, then 308 | # batches are extracted from the block as yielded. 309 | indexes = range(block_offset, block_size+1, self.batch_size) 310 | for start, end in zip(indexes[:-1], indexes[1:]): 311 | yield block[start:end] 312 | 313 | # However, if the batch size is larger than the block size, or the 314 | # batch size does not divide the block size evenly, then there will be remainder elements. 315 | remainder = slice(indexes[-1], block_size) 316 | # These remainder elements will be written into the scratch batch, starting at the current offset. 317 | write_slice = slice(scratch_offset, scratch_offset + (remainder.stop - remainder.start)) 318 | 319 | if write_slice.stop < self.batch_size: 320 | scratch[write_slice] = block[remainder] 321 | # It is possible though, that the remainder elements will write off the end of the scratch block. 322 | else: 323 | # In this case, the remainder elements need to be split into 2 groups: Those 324 | # before the end (slices_A) and those after (slices_B). slices_B will then wrap 325 | # around to the start of the scratch batch. 326 | slices_A, slices_B = FileReader.__match_slices(write_slice, self.batch_size, remainder) 327 | # Write the before group. 328 | scratch[slices_A[0]] = block[slices_A[1]] 329 | # The scratch batch is now full, so yield it. 330 | yield scratch 331 | # Now that the batch was yieled, it is safe to write to the front of it. 332 | scratch[slices_B[0]] = block[slices_B[1]] 333 | # Reset the write_slice so that batch_offset will be updated correctly. 334 | write_slice = slices_B[0] 335 | 336 | # Update the batch_offset, now the remainder elements are written. 337 | scratch_offset = write_slice.stop 338 | 339 | result = FileReader.__create_placeholders(batch_type, batch_shape) 340 | 341 | self.vars.append((read_batch, result)) 342 | self.queues.append(queue) 343 | 344 | return result 345 | 346 | @contextlib.contextmanager 347 | def __feed_lock(self): 348 | """ 349 | If ordered access was requested for any variables, then the feed method should 350 | be locked to prevent accidental data races. 351 | :return: 352 | """ 353 | if self.order_lock is not None: 354 | with self.order_lock: 355 | yield 356 | else: 357 | yield 358 | 359 | @staticmethod 360 | def __feed_batch(feed_dict, batch, placeholders): 361 | """ 362 | Recursive function for filling in the feed_dict. This recursively walks the dictionary tree given 363 | by placeholders and adds an element to feed_dict for each leaf. 364 | 365 | :param feed_dict: The feed_dict to fill. 366 | :param batch: The batch containing the data to be fed. 367 | :param placeholders: Either a single placeholder, or a dictionary of placeholders. 368 | :return: None 369 | """ 370 | if isinstance(placeholders, dict): 371 | for name in placeholders.keys(): 372 | FileReader.__feed_batch(feed_dict, batch[name], placeholders[name]) 373 | else: 374 | feed_dict[placeholders] = batch 375 | 376 | def feed(self): 377 | """ 378 | Generator for feeding a tensorflow operation. Each iteration returns a feed_dict that contains 379 | the data for one batch. This method reads data for *all* placeholders created. 380 | 381 | :return: A generator which yields tensorflow feed_dicts 382 | """ 383 | with self.__feed_lock(): 384 | # The reader generator is initialised here to allow safe multi-threaded access to the reader. 385 | generators = [(reader(), placeholders) for reader, placeholders in self.vars] 386 | while True: 387 | feed_dict = {} 388 | for gen, placeholders in generators: 389 | # Get the next batch 390 | try: 391 | # Unfortunately Tensorflow seems to keep references to these arrays around somewhere, 392 | # so a copy is required to prevent data corruption. 393 | batch = next(gen).copy() 394 | except StopIteration: 395 | return 396 | # Populate the feed_dict with the elements of this batch. 397 | FileReader.__feed_batch(feed_dict, batch, placeholders) 398 | yield feed_dict 399 | 400 | def close(self): 401 | """ 402 | Closes the internal queue, signaling the background processes to stop. 403 | This calls the multitables.Streamer.Queue.close method. 404 | 405 | :return: None 406 | """ 407 | for q in self.queues: 408 | q.close() 409 | 410 | def get_fifoloader(self, queue_size, inputs, threads=None): 411 | """ 412 | Convenience method for creating a FIFOQueueLoader object. 413 | See the FIFOQueueLoader constructor for documentation on parameters. 414 | 415 | :param queue_size: 416 | :param inputs: 417 | :param threads: Defaults to 1 if ordered access to this reader was 418 | requested, otherwise defaults to 2. 419 | :return: 420 | """ 421 | threads = 2 if self.order_lock is None else 1 422 | return FIFOQueueLoader(self, queue_size, inputs, threads) 423 | 424 | 425 | @contextlib.contextmanager 426 | def _contextsuppress(exception): 427 | """ 428 | Exception suppression context manager. 429 | Similar functionality provided in ``contextlib.suppress``, but not in python2.7. 430 | :param exception: The exception to suppress. 431 | :return: A context manager that suppresses the exception. 432 | """ 433 | try: 434 | yield 435 | except exception: 436 | pass 437 | 438 | 439 | class FIFOQueueLoader: 440 | """A class to handle the creation and population of a Tensorflow FIFOQueue.""" 441 | 442 | def __init__(self, reader, size, inputs, threads=1): 443 | """ 444 | Creates a loader that populates a Tensorflow FIFOQueue. 445 | Experimentation suggests this tends to perform best when threads=1. 446 | The graph defined by the inputs should be derived only from placeholders created 447 | by the supplied reader object. 448 | 449 | :param reader: An instance of the associated FileReader class. 450 | :param queue_size: The max size of the internal queue. 451 | :param inputs: A list of tensors that will be stored in the queue. 452 | :param threads: Number of background threads to populate the queue with. 453 | """ 454 | self.reader = reader 455 | self.coord = tf.train.Coordinator() 456 | self.q = tf.FIFOQueue(size, [i.dtype for i in inputs], [i.get_shape() for i in inputs]) 457 | self.enq_op = self.q.enqueue(inputs) 458 | self.q_close_now_op = self.q.close(cancel_pending_enqueues=True) 459 | self.n_threads = threads 460 | self.threads = [] 461 | self.monitor_thread = None 462 | self.close_reader = False 463 | 464 | def __read_thread(self, sess): 465 | """ 466 | Function that defines the background threads. Feeds data from the reader into the FIFOQueue. 467 | 468 | :param sess: Tensorflow session. 469 | :return: 470 | """ 471 | with self.coord.stop_on_exception(): 472 | with _contextsuppress(tf.errors.CancelledError): 473 | for feed_dict in self.reader.feed(): 474 | sess.run(self.enq_op, feed_dict=feed_dict) 475 | 476 | if self.coord.should_stop(): 477 | break 478 | 479 | def __monitor(self, sess): 480 | self.coord.join(self.threads) 481 | sess.run(self.q_close_now_op) 482 | 483 | def dequeue(self): 484 | """ 485 | Returns a dequeue operation. Elements defined by the input tensors and supplied by the reader 486 | are returned from this operation. This calls the dequeue method on the internal Tensorflow FIFOQueue. 487 | 488 | :return: A dequeue operation. 489 | """ 490 | return self.q.dequeue() 491 | 492 | def start(self, sess): 493 | """ 494 | Starts the background threads. The enqueue operations are run in the given Tensorflow session. 495 | 496 | :param sess: Tensorflow session. 497 | :return: None 498 | """ 499 | if self.monitor_thread is not None: 500 | raise Exception("This loader has already been started.") 501 | 502 | for _ in range(self.n_threads): 503 | t = threading.Thread(target=FIFOQueueLoader.__read_thread, args=(self, sess)) 504 | t.daemon = True 505 | t.start() 506 | self.threads.append(t) 507 | 508 | self.monitor_thread = threading.Thread(target=FIFOQueueLoader.__monitor, args=(self, sess)) 509 | self.monitor_thread.daemon = True 510 | self.monitor_thread.start() 511 | 512 | def stop(self, sess): 513 | """ 514 | Stops the background threads, and joins them. This should be called after all operations are complete. 515 | 516 | :param sess: The Tensorflow operation that this queue loader was started with. 517 | :return: 518 | """ 519 | self.coord.request_stop() 520 | sess.run(self.q_close_now_op) 521 | self.coord.join([self.monitor_thread]) 522 | if self.close_reader: 523 | self.reader.close() 524 | 525 | @staticmethod 526 | def catch_termination(): 527 | """ 528 | In non-cyclic access, once the end of the dataset is reached, an exception 529 | is called to halt all access to the queue. 530 | This context manager catches this exception for silent handling 531 | of the termination condition. 532 | :return: 533 | """ 534 | return _contextsuppress(tf.errors.OutOfRangeError) 535 | 536 | @contextlib.contextmanager 537 | def begin(self, tf_session, catch_termination=True): 538 | """ 539 | Convenience context manager for starting and stopping the loader. 540 | :param tf_session: The current Tensorflow session. 541 | :param catch_termination: Catch the termination of the loop for non-cyclic access. 542 | :return: 543 | """ 544 | self.start(tf_session) 545 | try: 546 | if catch_termination: 547 | with self.catch_termination(): 548 | yield 549 | else: 550 | yield 551 | finally: 552 | self.stop(tf_session) 553 | -------------------------------------------------------------------------------- /tftables_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2016 G. H. Collin (ghcollin) 2 | # 3 | # This software may be modified and distributed under the terms 4 | # of the MIT license. See the LICENSE.txt file for details. 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import tables 9 | import tempfile 10 | import os 11 | import shutil 12 | import tqdm 13 | 14 | import tftables 15 | 16 | test_table_col_A_shape = (100,200) 17 | test_table_col_B_shape = (7,49) 18 | 19 | 20 | class TestTableRow(tables.IsDescription): 21 | col_A = tables.UInt32Col(shape=test_table_col_A_shape) 22 | col_B = tables.Float64Col(shape=test_table_col_B_shape) 23 | 24 | test_mock_data_shape = (100, 100) 25 | 26 | 27 | class TestMockDataRow(tables.IsDescription): 28 | label = tables.UInt32Col() 29 | data = tables.Float64Col(shape=test_mock_data_shape) 30 | 31 | 32 | def lcm(a,b): 33 | import fractions 34 | return abs(a * b) // fractions.gcd(a, b) if a and b else 0 35 | 36 | 37 | def get_batches(array, size, trim_remainder=False): 38 | result = [ array[i:i+size] for i in range(0, len(array), size)] 39 | if trim_remainder and len(result[-1]) != len(result[0]): 40 | result = result[:-1] 41 | return result 42 | 43 | 44 | def assert_array_equal(self, a, b): 45 | self.assertTrue(np.array_equal(a, b), 46 | msg="LHS: \n" + str(a) + "\n RHS: \n" + str(b)) 47 | 48 | 49 | def assert_items_equal(self, a, b, key, epsilon=0): 50 | a = [item for sublist in a for item in sublist] 51 | b = [item for sublist in b for item in sublist] 52 | self.assertEqual(len(a), len(b)) 53 | #a_sorted, b_sorted = (a, b) if key is None else (sorted(a, key=key), sorted(b, key=key)) 54 | 55 | unique_a, counts_a = np.unique(a, return_counts=True) 56 | unique_b, counts_b = np.unique(b, return_counts=True) 57 | 58 | self.assertAllEqual(unique_a, unique_b) 59 | 60 | epsilon *= np.prod(a[0].shape) 61 | delta = counts_a - counts_b 62 | self.assertLessEqual(np.max(np.abs(delta)), 1, msg="More than one extra copy of an element.\n" + str(delta) 63 | + "\n" + str(np.unique(delta, return_counts=True))) 64 | non_zero = np.abs(delta) > 0 65 | n_non_zero = np.sum(non_zero) 66 | self.assertLessEqual(n_non_zero, epsilon, msg="Num. zero deltas=" + str(n_non_zero) + " epsilon=" + str(epsilon) 67 | + "\n" + str(np.unique(delta, return_counts=True)) 68 | + "\n" + str(delta)) 69 | 70 | 71 | class TFTablesTest(tf.test.TestCase): 72 | 73 | def setUp(self): 74 | self.test_dir = tempfile.mkdtemp() 75 | self.test_filename = os.path.join(self.test_dir, 'test.h5') 76 | test_file = tables.open_file(self.test_filename, 'w') 77 | 78 | self.test_array = np.arange(100*1000).reshape((1000, 10, 10)) 79 | self.test_array_path = '/test_array' 80 | array = test_file.create_array(test_file.root, self.test_array_path[1:], self.test_array) 81 | 82 | self.test_table_ary = np.array([ ( 83 | np.random.randint(256, size=np.prod(test_table_col_A_shape)).reshape(test_table_col_A_shape), 84 | np.random.rand(*test_table_col_B_shape)) for _ in range(100) ], 85 | dtype=tables.dtype_from_descr(TestTableRow)) 86 | self.test_table_path = '/test_table' 87 | table = test_file.create_table(test_file.root, self.test_table_path[1:], TestTableRow) 88 | table.append(self.test_table_ary) 89 | 90 | self.test_uint64_array = np.arange(10).astype(np.uint64) 91 | self.test_uint64_array_path = '/test_uint64' 92 | uint64_array = test_file.create_array(test_file.root, self.test_uint64_array_path[1:], self.test_uint64_array) 93 | 94 | self.test_mock_data_ary = np.array([ ( 95 | np.random.rand(*test_mock_data_shape), 96 | np.random.randint(10, size=1)[0] ) for _ in range(1000) ], 97 | dtype=tables.dtype_from_descr(TestMockDataRow)) 98 | self.test_mock_data_path = '/mock_data' 99 | mock = test_file.create_table(test_file.root, self.test_mock_data_path[1:], TestMockDataRow) 100 | mock.append(self.test_mock_data_ary) 101 | 102 | test_file.close() 103 | 104 | def tearDown(self): 105 | import time 106 | time.sleep(5) 107 | shutil.rmtree(self.test_dir) 108 | 109 | def test_cyclic_unordered(self): 110 | N = 4 111 | N_threads = 4 112 | 113 | def set_up(path, array, batchsize, get_tensors): 114 | blocksize = batchsize*2 + 1 115 | reader = tftables.open_file(self.test_filename, batchsize) 116 | cycles = lcm(len(array), blocksize)//len(array) 117 | batch = reader.get_batch(path, block_size=blocksize, ordered=False) 118 | batches = get_batches(array, batchsize)*cycles*N_threads 119 | loader = reader.get_fifoloader(N, get_tensors(batch), threads=N_threads) 120 | return reader, loader, batches, batch 121 | 122 | array_batchsize = 10 123 | array_reader, array_loader, array_batches, array_batch_pl = set_up(self.test_array_path, self.test_array, 124 | array_batchsize, lambda x: [x]) 125 | array_data = array_loader.dequeue() 126 | array_result = [] 127 | 128 | table_batchsize = 5 129 | table_reader, table_loader, table_batches, table_batch_pl = set_up(self.test_table_path, self.test_table_ary, 130 | table_batchsize, lambda x: [x['col_A'], x['col_B']]) 131 | table_A_data, table_B_data = table_loader.dequeue() 132 | table_result = [] 133 | 134 | with self.test_session() as sess: 135 | sess.run(tf.global_variables_initializer()) 136 | 137 | array_loader.start(sess) 138 | table_loader.start(sess) 139 | 140 | for i in tqdm.tqdm(range(len(array_batches))): 141 | array_result.append(sess.run(array_data).copy()) 142 | self.assertEqual(len(array_result[-1]), array_batchsize) 143 | 144 | assert_items_equal(self, array_batches, array_result, 145 | key=lambda x: x[0, 0], epsilon=2*N_threads*array_batchsize) 146 | 147 | for i in tqdm.tqdm(range(len(table_batches))): 148 | result = np.zeros_like(table_batches[0]) 149 | result['col_A'], result['col_B'] = sess.run([table_A_data, table_B_data]) 150 | table_result.append(result) 151 | self.assertEqual(len(table_result[-1]), table_batchsize) 152 | 153 | assert_items_equal(self, table_batches, table_result, 154 | key=lambda x: x[1][0, 0], epsilon=2*N_threads*table_batchsize) 155 | 156 | try: 157 | array_loader.stop(sess) 158 | table_loader.stop(sess) 159 | except tf.errors.CancelledError: 160 | pass 161 | 162 | array_reader.close() 163 | table_reader.close() 164 | 165 | def test_shared_reader(self): 166 | batch_size = 8 167 | reader = tftables.open_file(self.test_filename, batch_size) 168 | 169 | array_batch = reader.get_batch(self.test_array_path, cyclic=False) 170 | table_batch = reader.get_batch(self.test_table_path, cyclic=False) 171 | 172 | array_batches = get_batches(self.test_array, batch_size, trim_remainder=True) 173 | table_batches = get_batches(self.test_table_ary, batch_size, trim_remainder=True) 174 | total_batches = min(len(array_batches), len(table_batches)) 175 | 176 | loader = reader.get_fifoloader(10, [array_batch, table_batch['col_A'], table_batch['col_B']], threads=4) 177 | 178 | deq = loader.dequeue() 179 | array_result = [] 180 | table_result = [] 181 | 182 | with self.test_session() as sess: 183 | sess.run(tf.global_variables_initializer()) 184 | 185 | loader.start(sess) 186 | 187 | with loader.catch_termination(): 188 | while True: 189 | tbl = np.zeros_like(self.test_table_ary[:batch_size]) 190 | ary, tbl['col_A'], tbl['col_B'] = sess.run(deq) 191 | array_result.append(ary) 192 | table_result.append(tbl) 193 | 194 | 195 | assert_items_equal(self, array_result, array_batches[:total_batches], 196 | key=None, epsilon=0) 197 | 198 | assert_items_equal(self, table_result, table_batches[:total_batches], 199 | key=None, epsilon=0) 200 | 201 | loader.stop(sess) 202 | 203 | reader.close() 204 | 205 | def test_uint64(self): 206 | reader = tftables.open_file(self.test_filename, 10) 207 | with self.assertRaises(ValueError): 208 | batch = reader.get_batch("/test_uint64") 209 | reader.close() 210 | 211 | 212 | def test_quick_start_A(self): 213 | my_network = lambda x, y: x 214 | num_iterations = 100 215 | num_labels = 10 216 | 217 | with tf.device('/cpu:0'): 218 | # This function preprocesses the batches before they 219 | # are loaded into the internal queue. 220 | # You can cast data, or do one-hot transforms. 221 | # If the dataset is a table, this function is required. 222 | def input_transform(tbl_batch): 223 | labels = tbl_batch['label'] 224 | data = tbl_batch['data'] 225 | 226 | truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0)) 227 | data_float = tf.to_float(data) 228 | 229 | return truth, data_float 230 | 231 | # Open the HDF5 file and create a loader for a dataset. 232 | # The batch_size defines the length (in the outer dimension) 233 | # of the elements (batches) returned by the reader. 234 | # Takes a function as input that pre-processes the data. 235 | loader = tftables.load_dataset(filename=self.test_filename, 236 | dataset_path=self.test_mock_data_path, 237 | input_transform=input_transform, 238 | batch_size=20) 239 | 240 | # To get the data, we dequeue it from the loader. 241 | # Tensorflow tensors are returned in the same order as input_transformation 242 | truth_batch, data_batch = loader.dequeue() 243 | 244 | # The placeholder can then be used in your network 245 | result = my_network(truth_batch, data_batch) 246 | 247 | with tf.Session() as sess: 248 | 249 | # This context manager starts and stops the internal threads and 250 | # processes used to read the data from disk and store it in the queue. 251 | with loader.begin(sess): 252 | for _ in range(num_iterations): 253 | sess.run(result) 254 | 255 | 256 | def test_howto(self): 257 | def my_network(*args): 258 | return args[0] 259 | N = 100 260 | 261 | reader = tftables.open_file(filename=self.test_filename, batch_size=10) 262 | 263 | # Accessing a single array 264 | # ======================== 265 | 266 | array_batch_placeholder = reader.get_batch( 267 | path=self.test_array_path, # This is the path to your array inside the HDF5 file. 268 | cyclic=True, # In cyclic access, when the reader gets to the end of the 269 | # array, it will wrap back to the beginning and continue. 270 | ordered=False # The reader will not require the rows of the array to be 271 | # returned in the same order as on disk. 272 | ) 273 | 274 | # You can transform the batch however you like now. 275 | # For example, casting it to floats. 276 | array_batch_float = tf.to_float(array_batch_placeholder) 277 | 278 | # The data can now be fed into your network 279 | result = my_network(array_batch_float) 280 | 281 | with tf.Session() as sess: 282 | # The feed method provides a generator that returns 283 | # feed_dict's containing batches from your HDF5 file. 284 | for i, feed_dict in enumerate(reader.feed()): 285 | sess.run(result, feed_dict=feed_dict) 286 | if i >= N: 287 | break 288 | 289 | # Finally, the reader should be closed. 290 | #reader.close() 291 | 292 | # Accessing a single table 293 | # ======================== 294 | 295 | table_batch = reader.get_batch( 296 | path=self.test_mock_data_path, 297 | cyclic=True, 298 | ordered=False 299 | ) 300 | 301 | label_batch = table_batch['label'] 302 | data_batch = table_batch['data'] 303 | 304 | # Using a FIFO queue 305 | # ================== 306 | 307 | # As before 308 | array_batch_placeholder = reader.get_batch( 309 | path=self.test_array_path, 310 | cyclic=True, 311 | ordered=False) 312 | array_batch_float = tf.to_float(array_batch_placeholder) 313 | 314 | # Now we create a FIFO Loader 315 | loader = reader.get_fifoloader( 316 | queue_size=10, # The maximum number of elements that the 317 | # internal Tensorflow queue should hold. 318 | inputs=[array_batch_float], # A list of tensors that will be stored 319 | # in the queue. 320 | threads=1 # The number of threads used to stuff the 321 | # queue. If ordered access to a dataset 322 | # was requested, then only 1 thread 323 | # should be used. 324 | ) 325 | 326 | # Batches can now be dequeued from the loader for use in your network. 327 | array_batch_cpu = loader.dequeue() 328 | result = my_network(array_batch_cpu) 329 | 330 | with tf.Session() as sess: 331 | 332 | # The loader needs to be started with your Tensorflow session. 333 | loader.start(sess) 334 | 335 | for i in range(N): 336 | # You can now cleanly evaluate your network without a feed_dict. 337 | sess.run(result) 338 | 339 | # It also needs to be stopped for clean shutdown. 340 | loader.stop(sess) 341 | 342 | # Finally, the reader should be closed. 343 | #reader.close() 344 | 345 | # Accessing multiple datasets 346 | # =========================== 347 | 348 | # Use get_batch to access the table. 349 | # Both datasets must be accessed in ordered mode. 350 | table_batch_dict = reader.get_batch( 351 | path=self.test_table_path, 352 | ordered=True) 353 | col_A_pl, col_B_pl = table_batch_dict['col_A'], table_batch_dict['col_B'] 354 | 355 | # Now use get_batch again to access an array. 356 | # Both datasets must be accessed in ordered mode. 357 | labels_batch = reader.get_batch(self.test_array_path, ordered=True) 358 | truth_batch = tf.one_hot(labels_batch, 2, 1, 0) 359 | 360 | # The loader takes a list of tensors to be stored in the queue. 361 | # When accessing in ordered mode, threads should be set to 1. 362 | loader = reader.get_fifoloader( 363 | queue_size=10, 364 | inputs=[truth_batch, col_A_pl, col_B_pl], 365 | threads=1) 366 | 367 | # Batches are taken out of the queue using a dequeue operation. 368 | # Tensors are returned in the order they were given when creating the loader. 369 | truth_cpu, col_A_cpu, col_B_cpu = loader.dequeue() 370 | 371 | # The dequeued data can then be used in your network. 372 | result = my_network(truth_cpu, col_A_cpu, col_B_cpu) 373 | 374 | with tf.Session() as sess: 375 | with loader.begin(sess): 376 | for _ in range(N): 377 | sess.run(result) 378 | 379 | reader.close() 380 | 381 | def test_howto_quick(self): 382 | my_network = lambda x, y: x 383 | num_iterations = 100 384 | num_labels = 256 385 | 386 | # This function preprocesses the batches before they 387 | # are loaded into the internal queue. 388 | # You can cast data, or do one-hot transforms. 389 | # If the dataset is a table, this function is required. 390 | def input_transform(tbl_batch): 391 | labels = tbl_batch['label'] 392 | data = tbl_batch['data'] 393 | 394 | truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0)) 395 | data_float = tf.to_float(data) 396 | 397 | return truth, data_float 398 | 399 | # Open the HDF5 file and create a loader for a dataset. 400 | # The batch_size defines the length (in the outer dimension) 401 | # of the elements (batches) returned by the reader. 402 | # Takes a function as input that pre-processes the data. 403 | loader = tftables.load_dataset(filename=self.test_filename, 404 | dataset_path=self.test_mock_data_path, 405 | input_transform=input_transform, 406 | batch_size=20) 407 | 408 | # To get the data, we dequeue it from the loader. 409 | # Tensorflow tensors are returned in the same order as input_transformation 410 | truth_batch, data_batch = loader.dequeue() 411 | 412 | # The placeholder can then be used in your network 413 | result = my_network(truth_batch, data_batch) 414 | 415 | with tf.Session() as sess: 416 | # This context manager starts and stops the internal threads and 417 | # processes used to read the data from disk and store it in the queue. 418 | with loader.begin(sess): 419 | for _ in range(num_iterations): 420 | sess.run(result) 421 | 422 | def test_howto_cyclic1(self): 423 | 424 | def my_network(*args): 425 | return args[0] 426 | 427 | reader = tftables.open_file(filename=self.test_filename, batch_size=10) 428 | 429 | # Non-cyclic access 430 | # ----------------- 431 | 432 | array_batch_placeholder = reader.get_batch( 433 | path=self.test_array_path, 434 | cyclic=False, 435 | ordered=False) 436 | array_batch_float = tf.to_float(array_batch_placeholder) 437 | 438 | loader = reader.get_fifoloader( 439 | queue_size=10, 440 | inputs=[array_batch_float], 441 | threads=1 442 | ) 443 | 444 | array_batch_cpu = loader.dequeue() 445 | result = my_network(array_batch_cpu) 446 | 447 | with tf.Session() as sess: 448 | loader.start(sess) 449 | 450 | try: 451 | # Keep iterating until the exception breaks the loop 452 | while True: 453 | sess.run(result) 454 | # Now silently catch the exception. 455 | except tf.errors.OutOfRangeError: 456 | pass 457 | 458 | loader.stop(sess) 459 | 460 | def test_howto_cyclic2(self): 461 | 462 | def my_network(*args): 463 | return args[0] 464 | 465 | reader = tftables.open_file(filename=self.test_filename, batch_size=10) 466 | 467 | # Non-cyclic access 468 | # ----------------- 469 | 470 | array_batch_placeholder = reader.get_batch( 471 | path=self.test_array_path, 472 | cyclic=False, 473 | ordered=False) 474 | array_batch_float = tf.to_float(array_batch_placeholder) 475 | 476 | loader = reader.get_fifoloader( 477 | queue_size=10, 478 | inputs=[array_batch_float], 479 | threads=1 480 | ) 481 | 482 | array_batch_cpu = loader.dequeue() 483 | result = my_network(array_batch_cpu) 484 | 485 | with tf.Session() as sess: 486 | loader.start(sess) 487 | 488 | # This context manager suppresses the exception. 489 | with loader.catch_termination(): 490 | # Keep iterating until the exception breaks the loop 491 | while True: 492 | sess.run(result) 493 | 494 | loader.stop(sess) 495 | 496 | if __name__ == '__main__': 497 | tf.test.main() 498 | --------------------------------------------------------------------------------