├── .gitignore ├── LICENSE.txt ├── README.md ├── build.sh ├── docs ├── _static │ └── images │ │ ├── intro-graph.jpg │ │ ├── intro-metrics.jpg │ │ └── intro-watch.jpg ├── conf.py ├── index.rst ├── makefile ├── readme.md ├── tensorfx.data.rst ├── tensorfx.models.nn.rst ├── tensorfx.models.rst ├── tensorfx.prediction.rst └── tensorfx.training.rst ├── init.sh ├── requirements.txt ├── samples ├── iris │ ├── data.py │ ├── data │ │ ├── eval.csv │ │ ├── eval.tfrecord │ │ ├── metadata.json │ │ ├── schema.yaml │ │ ├── train.csv │ │ └── train.tfrecord │ ├── run_csv.sh │ ├── run_df.sh │ ├── run_examples.sh │ └── trainer │ │ ├── __init__.py │ │ ├── csv.py │ │ ├── df.py │ │ ├── examples.py │ │ └── features.yaml └── readme.md ├── setup.py ├── src ├── __init__.py ├── _version.py ├── data │ ├── __init__.py │ ├── _dataset.py │ ├── _ds_csv.py │ ├── _ds_df.py │ ├── _ds_examples.py │ ├── _features.py │ ├── _metadata.py │ ├── _schema.py │ └── _transforms.py ├── models │ ├── __init__.py │ ├── _classification.py │ └── nn │ │ ├── __init__.py │ │ └── _ff.py ├── prediction │ ├── __init__.py │ └── _model.py ├── tools │ ├── __init__.py │ ├── _predict.py │ ├── _scaffold.py │ ├── _train.py │ └── tfx.py └── training │ ├── __init__.py │ ├── _args.py │ ├── _config.py │ ├── _hooks.py │ ├── _job.py │ ├── _model.py │ └── _trainer.py └── tests ├── data ├── __init__.py ├── dataset_tests.py ├── features_tests.py └── schema_tests.py ├── main.py └── training ├── __init__.py └── config_tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | tensorfx 3 | 4 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction to TensorFX 2 | 3 | TensorFX is an end to end application framework to simplifies machine learning with 4 | [TensorFlow](http://tensorflow.org) - both training models and using them for prediction. It is 5 | designed from the ground up to make the mainline scenarios simple with higher level building blocks, 6 | while ensuring custom or complex scenarios remain possible by preserving the flexibility of 7 | TensorFlow APIs. 8 | 9 | There are some important principles that shape the design of the framework: 10 | 11 | 1. **Simple, consistent set of usage patterns** 12 | Local or cloud, single node or distributed execution, in-memory data or big data sharded across 13 | files, you should have to write code once, in a single way regardless of how the code executes. 14 | 15 | 2. **A Toolbox with Useful Abstractions** 16 | The right entrypoint for the task at hand, starting with off-the-shelf algorithms that let you 17 | focus on feature engineering and hyperparam tuning. If you need to solve something unqiue, you 18 | can focus on building TensorFlow graphs, rather than infrastructure code (distributed cluster 19 | setup, checkpointing, logging, exporting models etc.). 20 | 21 | 3. **Declarative** 22 | Using YAML, JSON, and simplified Python interfaces to minimize the amount of boilerplate code. 23 | 24 | OK, enough context... here is some information to get you started. 25 | 26 | 27 | ## Getting Started 28 | Once you have a Python environment (recommendation: use Miniconda), installation is straightforward: 29 | 30 | pip install tensorflow 31 | pip install tensorfx 32 | 33 | Note that TensorFX depends on TensorFlow 1.0, and supporting libraries such as numpy and pandas. 34 | 35 | 36 | ## Documentation 37 | Documentation is at https://tensorlab.github.io/tensorfx/. This includes API reference topics, as 38 | well as conceptual and how-to topics. They are a work-in-progress, but check them out! There are a 39 | few samples that demonstrate how to get started as well in the repository. Likewise, more to be 40 | added over time. 41 | 42 | 43 | ## Contributions and Development 44 | We welcome contributions in form of ideas, issues, samples as well as code. Since the project is at 45 | a super-early stage, and evolving rapidly, its best to start a discussion by filing an issue for 46 | any contribution. 47 | 48 | ### Building and Testing 49 | If you want to develop within the repository, clone it, and run the following commands: 50 | 51 | # Install requirements and setup envionment 52 | source init.sh install 53 | 54 | # Build and Test 55 | ./build.sh test 56 | 57 | ### Related Links 58 | 59 | * Development workflow [TODO: Add wiki entry] 60 | 61 | 62 | ## Hello World - Iris Classification Model 63 | This sample here is a quick 5-minute introduction to using TensorFX. Here is the code for building 64 | a feed-forward neural network classification model for the 65 | [iris dataset](https://archive.ics.uci.edu/ml/datasets/Iris). 66 | 67 | import tensorfx as tfx 68 | import tensorfx.models.nn as nn 69 | 70 | # Hyperparameters, training parameters, and data 71 | args, job = nn.FeedForwardClassificationArguments.parse(parse_job=True) 72 | dataset = tfx.data.CsvDataSet(args.data_schema, 73 | train=args.data_train, 74 | eval=args.data_eval, 75 | metadata=args.data_metadata, 76 | features=args.data_features) 77 | 78 | # Instantiating the model builder 79 | classification = nn.FeedForwardClassification(args, dataset) 80 | 81 | # Training 82 | trainer = tfx.training.ModelTrainer() 83 | model = trainer.train(classification, job) 84 | 85 | # Prediction 86 | instances = [ 87 | '6.3,3.3,6,2.5', # virginica 88 | '4.4,3,1.3,0.2', # setosa 89 | '6.1,2.8,4.7,1.2' # versicolor 90 | ] 91 | predictions = model.predict(instances) 92 | 93 | Here's an outline steps to perform for basic usage of what TensorFX offers: 94 | 95 | 1. Parse (or build) an Arguments object, usually from the command-line to define hyperparameters. 96 | This object corresponds to the kind of model you are training, so, 97 | `FeedForwardClassificationArguments` in this case. 98 | 2. Create a DataSet to reference training and evaluation data, along with supporting configuration - 99 | namely - schema, metadata, and features (more on these below). 100 | 3. Initialize the model builder - in this case `FeedForwardClassification`. 101 | 4. Initialize the model trainer, and invoke `train()` which runs the training process to return a 102 | model. 103 | 5. Load some instances you want to run through the model and call `predict()`. 104 | 105 | #### Schema - schema.yaml 106 | The schema describes the structure of your data. This can be defined programmatically, but is 107 | conveniently expressible in declarative YAML form, and placed alongside training data. 108 | 109 | fields: 110 | - name: species 111 | type: discrete 112 | - name: petal_length 113 | type: numeric 114 | - name: petal_width 115 | type: numeric 116 | - name: sepal_length 117 | type: numeric 118 | - name: sepal_width 119 | type: numeric 120 | 121 | #### Metadata - metadata.json 122 | Metadata is the result of analyzing training data, based on type information in the schema. 123 | Iris is a tiny dataset, so metadata is readily producable using simple python code looping over 124 | the data. For real-world and large datasets, you'll find Spark and BigQuery (on Google Cloud 125 | Platform) as essential data processing runtimes. Stay tuned - TensorFX will provide support for 126 | these capabilities out of the box. 127 | 128 | { 129 | "species": { "entries": ["setosa", "virginica", "versicolor"] }, 130 | "petal_length": { "min": 4.3, "max": 7.9 }, 131 | "petal_width": { "min": 2.0, "max": 4.4 }, 132 | "sepal_length": { "min": 1.1, "max": 6.9 }, 133 | "sepal_width": { "min": 0.1, "max": 2.5 } 134 | } 135 | 136 | #### Features - features.yaml 137 | Like schema, features can also be defined programmatically, or expressed in YAML. Features describe 138 | the set of inputs that your models operate over, and how they are produced by applying 139 | transformations to the fields in your data. These transformations are turned into TensorFlow graph 140 | constructs and applied consistently to both training and prediction data. 141 | 142 | In this particular example, the FeedForwardClassification model requires two features: X defining 143 | the values the model uses for producing inferences, and Y, the target label that the model is 144 | expected to predict which are defined as follows: 145 | 146 | features: 147 | - name: X 148 | type: concat 149 | features: 150 | - name: petal_width 151 | type: scale 152 | - name: petal_length 153 | type: scale 154 | - name: sepal_width 155 | type: log 156 | - name: sepal_length 157 | type: log 158 | - name: Y 159 | type: target 160 | fields: species 161 | 162 | #### Running the Model 163 | The python code in the sample can be run directly, or using a `train` tool, as shown: 164 | 165 | cd samples 166 | tfx train \ 167 | --module iris.trainer.main \ 168 | --output /tmp/tensorfx/iris/csv \ 169 | --data-train iris/data/train.csv \ 170 | --data-eval iris/data/eval.csv \ 171 | --data-schema iris/data/schema.yaml \ 172 | --data-metadata iris/data/metadata.json \ 173 | --data-features iris/features.yaml \ 174 | --log-level-tensorflow ERROR \ 175 | --log-level INFO \ 176 | --batch-size 5 \ 177 | --max-steps 2000 \ 178 | --checkpoint-interval-secs 1 \ 179 | --hidden-layers:1 20 \ 180 | --hidden-layers:2 10 181 | 182 | Once the training is complete, you can list the contents of the output directory. You should 183 | see the model (the prediction graph, and learnt variables) in the `model` subdirectory, alongside 184 | checkpoints, and summaries. 185 | 186 | ls -R /tmp/tensorfx/iris/csv 187 | checkpoints job.yaml model summaries 188 | 189 | /tmp/tensorfx/iris/csv/checkpoints: 190 | checkpoint model.ckpt-2000.index 191 | model.ckpt-1.data-00000-of-00001 model.ckpt-2000.meta 192 | model.ckpt-1.index model.ckpt-2001.data-00000-of-00001 193 | model.ckpt-1.meta model.ckpt-2001.index 194 | model.ckpt-1562.data-00000-of-00001 model.ckpt-2001.meta 195 | model.ckpt-1562.index model.ckpt-778.data-00000-of-00001 196 | model.ckpt-1562.meta model.ckpt-778.index 197 | model.ckpt-2000.data-00000-of-00001 model.ckpt-778.meta 198 | 199 | /tmp/tensorfx/iris/csv/model: 200 | saved_model.pb variables 201 | 202 | /tmp/tensorfx/iris/csv/model/variables: 203 | variables.data-00000-of-00001 variables.index 204 | 205 | /tmp/tensorfx/iris/csv/summaries: 206 | eval prediction train 207 | 208 | /tmp/tensorfx/iris/csv/summaries/eval: 209 | events.out.tfevents.1488351760 210 | events.out.tfevents.1488352853 211 | 212 | /tmp/tensorfx/iris/csv/summaries/prediction: 213 | events.out.tfevents.1488351765 214 | 215 | /tmp/tensorfx/iris/csv/summaries/train: 216 | events.out.tfevents.1488351760 217 | events.out.tfevents.1488352852 218 | 219 | Summaries are TensorFlow events logged during training. They can be observed while the training 220 | job is running (which is essential when running a long or real training job) to understand how your 221 | training is progressing, or how the model is converging (or not!). 222 | 223 | tensorboard --logdir /tmp/tensorfx/iris/csv 224 | 225 | This should bring up TensorBoard. Its useful to see the graph structure, metrics and other tensors 226 | that are automatically published. 227 | 228 | **Training Graph** 229 | 230 | ![Graphs in TensorBoard](https://tensorlab.github.io/tensorfx/_static/images/intro-graph.jpg) 231 | 232 | **Training Metrics -- Accuracy, Loss and Throughput** 233 | 234 | ![Metrics in TensorBoard](https://tensorlab.github.io/tensorfx/_static/images/intro-metrics.jpg) 235 | 236 | **Model Variables -- Weights, Gradients, etc.** 237 | 238 | ![Watchin Learnt Variables](https://tensorlab.github.io/tensorfx/_static/images/intro-watch.jpg) 239 | 240 | 241 | As you can see, the out-of-box model takes care of a number of details. The same code can be run on 242 | a single machine, or in a cluster (of course, iris is too simple of a problem to need that). 243 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | # Create the build directory with inputs (sources) and outputs (binaries). 5 | # Ensure it is empty, to build from clean state. 6 | mkdir -p build 7 | rm -rf build 8 | mkdir -p build 9 | 10 | # Copy source files 11 | cp requirements.txt build 12 | cp setup.py build 13 | cp -r src build/tensorfx 14 | 15 | # Generate the README expected by PyPI from original markdown 16 | curl --silent http://c.docverter.com/convert \ 17 | -F from=markdown \ 18 | -F to=rst \ 19 | -F input_files[]=@README.md > build/README.rst 20 | 21 | # Finally, build 22 | pushd build > /dev/null 23 | python setup.py sdist > setup.log 24 | popd > /dev/null 25 | 26 | echo 'Build completed successfully!' 27 | 28 | 29 | # Copy over tests 30 | cp -r tests build/tests 31 | 32 | echo 'Tests copied successfully!' 33 | 34 | 35 | # Optionally run tests 36 | if [ "$1" == "test" ]; then 37 | pushd build/tests > /dev/null 38 | python main.py 39 | popd > /dev/null 40 | 41 | echo 'Tests completed' 42 | fi 43 | -------------------------------------------------------------------------------- /docs/_static/images/intro-graph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlab/tensorfx/98a91b1d9e657f9444b6140712a5dedcdd906acf/docs/_static/images/intro-graph.jpg -------------------------------------------------------------------------------- /docs/_static/images/intro-metrics.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlab/tensorfx/98a91b1d9e657f9444b6140712a5dedcdd906acf/docs/_static/images/intro-metrics.jpg -------------------------------------------------------------------------------- /docs/_static/images/intro-watch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlab/tensorfx/98a91b1d9e657f9444b6140712a5dedcdd906acf/docs/_static/images/intro-watch.jpg -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # TensorFX documentation build configuration file, created by 4 | # sphinx-quickstart on Mon Feb 20 23:55:56 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another directory, 16 | # add these directories to sys.path here. If the directory is relative to the 17 | # documentation root, use os.path.abspath to make it absolute, like shown here. 18 | # 19 | 20 | import os 21 | import sys 22 | import sphinx_rtd_theme 23 | 24 | sys.path.append(os.path.abspath('../')) 25 | 26 | 27 | # -- General configuration ------------------------------------------------ 28 | 29 | # If your documentation needs a minimal Sphinx version, state it here. 30 | # 31 | # needs_sphinx = '1.0' 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | 'sphinx.ext.autodoc', 38 | 'sphinx.ext.githubpages', 39 | 'sphinxcontrib.napoleon' 40 | ] 41 | 42 | # Add any paths that contain templates here, relative to this directory. 43 | templates_path = ['_templates'] 44 | 45 | # The suffix(es) of source filenames. 46 | # You can specify multiple suffix as a list of string: 47 | # source_suffix = ['.rst', '.md'] 48 | source_suffix = '.rst' 49 | 50 | # The master toctree document. 51 | master_doc = 'index' 52 | 53 | # General information about the project. 54 | project = u'TensorFX' 55 | copyright = u'2017, TensorLab Project' 56 | author = u'TensorLab Project' 57 | 58 | # The version info for the project you're documenting, acts as replacement for 59 | # |version| and |release|, also used in various other places throughout the 60 | # built documents. 61 | # 62 | # The short X.Y version. 63 | version = u'0.1' 64 | # The full version, including alpha/beta/rc tags. 65 | release = u'0.1' 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | # 70 | # This is also used if you do content translation via gettext catalogs. 71 | # Usually you set "language" from the command line for these cases. 72 | language = 'en' 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | # This patterns also effect to html_static_path and html_extra_path 77 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 78 | 79 | # The name of the Pygments (syntax highlighting) style to use. 80 | pygments_style = 'sphinx' 81 | 82 | # If true, `todo` and `todoList` produce output, else they produce nothing. 83 | todo_include_todos = False 84 | 85 | 86 | # -- Options for HTML output ---------------------------------------------- 87 | 88 | # The theme to use for HTML and HTML Help pages. See the documentation for 89 | # a list of builtin themes. 90 | # 91 | html_theme = 'sphinx_rtd_theme' 92 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 93 | 94 | # Theme options are theme-specific and customize the look and feel of a theme 95 | # further. For a list of options available for each theme, see the 96 | # documentation. 97 | # 98 | # html_theme_options = {} 99 | 100 | # Add any paths that contain custom static files (such as style sheets) here, 101 | # relative to this directory. They are copied after the builtin static files, 102 | # so a file named "default.css" will overwrite the builtin "default.css". 103 | html_static_path = ['_static'] 104 | 105 | html_show_sourcelink = False 106 | html_show_sphinx = False 107 | html_use_opensearch = '' 108 | html_title = '' 109 | 110 | 111 | # -- Options for HTMLHelp output ------------------------------------------ 112 | 113 | # Output file base name for HTML help builder. 114 | htmlhelp_basename = 'apidoc' 115 | 116 | 117 | # -- Options for LaTeX output --------------------------------------------- 118 | 119 | latex_elements = { 120 | # The paper size ('letterpaper' or 'a4paper'). 121 | # 122 | # 'papersize': 'letterpaper', 123 | 124 | # The font size ('10pt', '11pt' or '12pt'). 125 | # 126 | # 'pointsize': '10pt', 127 | 128 | # Additional stuff for the LaTeX preamble. 129 | # 130 | # 'preamble': '', 131 | 132 | # Latex figure (float) alignment 133 | # 134 | # 'figure_align': 'htbp', 135 | } 136 | 137 | # Grouping the document tree into LaTeX files. List of tuples 138 | # (source start file, target name, title, 139 | # author, documentclass [howto, manual, or own class]). 140 | latex_documents = [ 141 | (master_doc, 'TensorFX.tex', u'TensorFX Documentation', 142 | u'TensorLab', 'manual'), 143 | ] 144 | 145 | 146 | # -- Options for manual page output --------------------------------------- 147 | 148 | # One entry per manual page. List of tuples 149 | # (source start file, name, description, authors, manual section). 150 | man_pages = [ 151 | (master_doc, 'tensorfx', u'TensorFX Documentation', 152 | [author], 1) 153 | ] 154 | 155 | 156 | # -- Options for Texinfo output ------------------------------------------- 157 | 158 | # Grouping the document tree into Texinfo files. List of tuples 159 | # (source start file, target name, title, author, 160 | # dir menu entry, description, category) 161 | texinfo_documents = [ 162 | (master_doc, 'TensorFX', u'TensorFX Documentation', 163 | author, 'TensorFX', 'One line description of project.', 164 | 'Miscellaneous'), 165 | ] 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TensorFX documentation master file, created by 2 | sphinx-quickstart on Mon Feb 20 23:55:56 2017. 3 | 4 | TensorFX Modules 5 | ================ 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | 10 | tensorfx.data 11 | tensorfx.models 12 | tensorfx.models.nn 13 | tensorfx.training 14 | tensorfx.prediction 15 | 16 | 17 | 18 | TensorFX Links 19 | ============== 20 | 21 | * :ref:`genindex` 22 | * `GitHub `_ 23 | 24 | -------------------------------------------------------------------------------- /docs/makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = TensorFX 8 | SOURCEDIR = . 9 | BUILDDIR = ../../tensorfx-docs 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/readme.md: -------------------------------------------------------------------------------- 1 | # Reference Documentation 2 | 3 | ## Setup 4 | Docs are built using Sphinx, and use the 'read-the-docs' theme. 5 | 6 | pip install sphinx sphinx_rtd_theme sphinxcontrib-napoleon 7 | 8 | ## Build 9 | Docs are built into the ../../tensorfx-docs/html directory, which is expected to be a clone of the 10 | gh-pages branch of this repository. 11 | 12 | mkdir -p ../../tensorfx-docs 13 | make html 14 | -------------------------------------------------------------------------------- /docs/tensorfx.data.rst: -------------------------------------------------------------------------------- 1 | tensorfx.data 2 | ============= 3 | 4 | .. automodule:: tensorfx.data 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | DataSet and DataSource Implementations 10 | -------------------------------------- 11 | 12 | .. autoclass:: tensorfx.data.DataSet 13 | :members: 14 | 15 | .. autoclass:: tensorfx.data.DataSource 16 | :members: 17 | 18 | .. autoclass:: tensorfx.data.CsvDataSet 19 | :members: 20 | 21 | .. autoclass:: tensorfx.data.CsvDataSource 22 | :members: 23 | 24 | .. autoclass:: tensorfx.data.DataFrameDataSet 25 | :members: 26 | 27 | .. autoclass:: tensorfx.data.DataFrameDataSource 28 | :members: 29 | 30 | 31 | Schema and Metadata 32 | ------------------- 33 | 34 | .. autoclass:: tensorfx.data.Schema 35 | :members: 36 | 37 | .. autoclass:: tensorfx.data.SchemaField 38 | :members: 39 | 40 | .. autoclass:: tensorfx.data.SchemaFieldType 41 | :members: 42 | 43 | .. autoclass:: tensorfx.data.Metadata 44 | :members: 45 | 46 | 47 | Features 48 | -------- 49 | 50 | .. autoclass:: tensorfx.data.FeatureSet 51 | :members: 52 | 53 | .. autoclass:: tensorfx.data.Feature 54 | :members: 55 | 56 | .. autoclass:: tensorfx.data.FeatureType 57 | :members: 58 | 59 | -------------------------------------------------------------------------------- /docs/tensorfx.models.nn.rst: -------------------------------------------------------------------------------- 1 | tensorfx.models.nn 2 | ================== 3 | 4 | .. automodule:: tensorfx.models.nn 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Feed-forward Neural Networks 10 | ---------------------------- 11 | 12 | .. autoclass:: tensorfx.models.nn.FeedForwardModelArguments 13 | :members: 14 | 15 | .. autoclass:: tensorfx.models.nn.FeedForwardClassificationArguments 16 | :members: 17 | 18 | .. autoclass:: tensorfx.models.nn.FeedForwardClassification 19 | :members: 20 | 21 | -------------------------------------------------------------------------------- /docs/tensorfx.models.rst: -------------------------------------------------------------------------------- 1 | tensorfx.models 2 | =============== 3 | 4 | .. automodule:: tensorfx.models 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Machine Learning Scenarios 10 | -------------------------- 11 | 12 | .. autoclass:: tensorfx.models.ClassificationScenario 13 | :members: 14 | 15 | -------------------------------------------------------------------------------- /docs/tensorfx.prediction.rst: -------------------------------------------------------------------------------- 1 | tensorfx.prediction 2 | =================== 3 | 4 | .. automodule:: tensorfx.prediction 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. autoclass:: tensorfx.prediction.Model 10 | :members: 11 | 12 | -------------------------------------------------------------------------------- /docs/tensorfx.training.rst: -------------------------------------------------------------------------------- 1 | tensorfx.training 2 | ================= 3 | 4 | .. automodule:: tensorfx.training 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Model Builder 10 | ------------- 11 | 12 | .. autoclass:: tensorfx.training.ModelArguments 13 | :members: 14 | 15 | .. autoclass:: tensorfx.training.ModelBuilder 16 | :members: 17 | 18 | 19 | Training Jobs 20 | ------------- 21 | 22 | .. autoclass:: tensorfx.training.Configuration 23 | :members: 24 | 25 | .. autoclass:: tensorfx.training.ModelTrainer 26 | :members: 27 | 28 | -------------------------------------------------------------------------------- /init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ln -s src tensorfx 4 | 5 | export REPO=$(git rev-parse --show-toplevel) 6 | export PYTHONPATH=$REPO:$REPO/samples:$PYTHONPATH 7 | export PYTHONDONTWRITEBYTECODE=1 8 | 9 | # Optionally install python packages 10 | if [ "$1" == "pip" ]; then 11 | pip install -r requirements.txt 12 | fi 13 | 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Required packages 2 | # pip install -r requirements.txt 3 | # 4 | 5 | argparse==1.1 6 | enum34>=1.1.6,<1.2 7 | pandas>=0.19,<0.20 8 | pyyaml>=3.12,<4.0 9 | ujson>=1.35,<2.0 10 | -------------------------------------------------------------------------------- /samples/iris/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # data.py 14 | # A utility to generate data in TF.Example protobufs saved into a TF.Record file. 15 | 16 | import pandas as pd 17 | import tensorflow as tf 18 | import tensorflow.core.example.example_pb2 as examples 19 | 20 | def load_data(): 21 | # Load data into DataFrame objects. 22 | columns = ['species', 'petal_length', 'petal_width', 'sepal_length', 'sepal_width'] 23 | df_train = pd.read_csv('data/train.csv', names=columns) 24 | df_eval = pd.read_csv('data/eval.csv', names=columns) 25 | 26 | return df_train, df_eval 27 | 28 | def convert_data(df, path): 29 | writer = tf.python_io.TFRecordWriter(path) 30 | for index, row in df.iterrows(): 31 | example = examples.Example() 32 | features = example.features 33 | features.feature['species'].bytes_list.value.append(row['species']) 34 | features.feature['petal_length'].float_list.value.append(row['petal_length']) 35 | features.feature['petal_width'].float_list.value.append(row['petal_width']) 36 | features.feature['sepal_length'].float_list.value.append(row['sepal_length']) 37 | features.feature['sepal_width'].float_list.value.append(row['sepal_width']) 38 | 39 | record = example.SerializeToString() 40 | writer.write(record) 41 | writer.close() 42 | 43 | 44 | def main(): 45 | df_train, df_eval = load_data() 46 | convert_data(df_train, 'data/train.tfrecord') 47 | convert_data(df_eval, 'data/eval.tfrecord') 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /samples/iris/data/eval.csv: -------------------------------------------------------------------------------- 1 | virginica,4.9,2.5,4.5,1.7 2 | versicolor,5.7,2.8,4.1,1.3 3 | versicolor,5.1,2.5,3,1.1 4 | setosa,4.8,3,1.4,0.1 5 | versicolor,5.6,2.5,3.9,1.1 6 | setosa,5.4,3.7,1.5,0.2 7 | setosa,5.5,3.5,1.3,0.2 8 | versicolor,6.2,2.2,4.5,1.5 9 | setosa,5.1,3.4,1.5,0.2 10 | virginica,6.3,3.3,6,2.5 11 | setosa,4.4,3,1.3,0.2 12 | versicolor,6.1,2.8,4.7,1.2 13 | versicolor,5.7,2.9,4.2,1.3 14 | setosa,5,3.3,1.4,0.2 15 | versicolor,5.6,2.7,4.2,1.3 16 | setosa,5,3.5,1.6,0.6 17 | virginica,7.7,2.8,6.7,2 18 | setosa,4.6,3.6,1,0.2 19 | versicolor,6.6,2.9,4.6,1.3 20 | virginica,6.5,3,5.2,2 21 | versicolor,6.5,2.8,4.6,1.5 22 | setosa,5.3,3.7,1.5,0.2 23 | versicolor,6.7,3,5,1.7 24 | versicolor,6.3,2.5,4.9,1.5 25 | virginica,7.7,3,6.1,2.3 26 | setosa,5.2,4.1,1.5,0.1 27 | virginica,6.7,3.3,5.7,2.1 28 | virginica,5.8,2.7,5.1,1.9 29 | setosa,5.4,3.4,1.7,0.2 30 | setosa,5,3.6,1.4,0.2 -------------------------------------------------------------------------------- /samples/iris/data/eval.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlab/tensorfx/98a91b1d9e657f9444b6140712a5dedcdd906acf/samples/iris/data/eval.tfrecord -------------------------------------------------------------------------------- /samples/iris/data/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "species": { 3 | "entries": ["setosa", "virginica", "versicolor"] 4 | }, 5 | "petal_length": { 6 | "min": 4.3, 7 | "max": 7.9, 8 | "mean": 5.867500, 9 | "std": 0.827385 10 | }, 11 | "petal_width": { 12 | "min": 2.0, 13 | "max": 4.4, 14 | "mean": 3.050833, 15 | "std": 0.431335 16 | }, 17 | "sepal_length": { 18 | "min": 1.1, 19 | "max": 6.9, 20 | "mean": 3.830833, 21 | "std": 1.747497 22 | }, 23 | "sepal_width": { 24 | "min": 0.1, 25 | "max": 2.5, 26 | "mean": 1.232500, 27 | "std": 0.759053 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /samples/iris/data/schema.yaml: -------------------------------------------------------------------------------- 1 | fields: 2 | - name: species 3 | type: discrete 4 | - name: petal_length 5 | type: real 6 | - name: petal_width 7 | type: real 8 | - name: sepal_length 9 | type: real 10 | - name: sepal_width 11 | type: real 12 | 13 | -------------------------------------------------------------------------------- /samples/iris/data/train.csv: -------------------------------------------------------------------------------- 1 | setosa,4.6,3.1,1.5,0.2 2 | setosa,5.1,3.8,1.5,0.3 3 | setosa,4.4,3.2,1.3,0.2 4 | versicolor,6.3,2.3,4.4,1.3 5 | versicolor,6.6,3,4.4,1.4 6 | versicolor,6,2.2,4,1 7 | setosa,5.1,3.8,1.6,0.2 8 | virginica,6.7,3,5.2,2.3 9 | versicolor,6.9,3.1,4.9,1.5 10 | versicolor,5.9,3.2,4.8,1.8 11 | virginica,6.8,3.2,5.9,2.3 12 | virginica,6.3,2.7,4.9,1.8 13 | virginica,5.6,2.8,4.9,2 14 | setosa,5.4,3.9,1.3,0.4 15 | setosa,4.6,3.4,1.4,0.3 16 | versicolor,6.7,3.1,4.7,1.5 17 | virginica,7.4,2.8,6.1,1.9 18 | setosa,4.9,3,1.4,0.2 19 | virginica,6.3,2.5,5,1.9 20 | setosa,5.2,3.4,1.4,0.2 21 | versicolor,5.5,2.6,4.4,1.2 22 | virginica,7.2,3.6,6.1,2.5 23 | virginica,6.9,3.2,5.7,2.3 24 | setosa,5.1,3.8,1.9,0.4 25 | setosa,4.9,3.1,1.5,0.1 26 | setosa,5,3.2,1.2,0.2 27 | virginica,6.4,2.7,5.3,1.9 28 | setosa,4.8,3,1.4,0.3 29 | virginica,7.9,3.8,6.4,2 30 | versicolor,6.8,2.8,4.8,1.4 31 | setosa,5.4,3.9,1.7,0.4 32 | versicolor,5.5,2.5,4,1.3 33 | virginica,6.3,3.4,5.6,2.4 34 | setosa,4.8,3.1,1.6,0.2 35 | virginica,6,2.2,5,1.5 36 | virginica,6.4,3.1,5.5,1.8 37 | setosa,5.1,3.3,1.7,0.5 38 | versicolor,5.7,3,4.2,1.2 39 | versicolor,5.8,2.7,4.1,1 40 | virginica,5.9,3,5.1,1.8 41 | setosa,5,3,1.6,0.2 42 | versicolor,6.2,2.9,4.3,1.3 43 | versicolor,5.7,2.6,3.5,1 44 | versicolor,6.1,2.8,4,1.3 45 | versicolor,6.4,2.9,4.3,1.3 46 | setosa,4.9,3.1,1.5,0.1 47 | setosa,4.9,3.1,1.5,0.1 48 | versicolor,5.6,3,4.1,1.3 49 | versicolor,6,2.7,5.1,1.6 50 | versicolor,7,3.2,4.7,1.4 51 | virginica,6.4,3.2,5.3,2.3 52 | versicolor,5.5,2.3,4,1.3 53 | virginica,7.2,3,5.8,1.6 54 | virginica,5.8,2.8,5.1,2.4 55 | setosa,5.4,3.4,1.5,0.4 56 | virginica,6.3,2.9,5.6,1.8 57 | versicolor,6.1,2.9,4.7,1.4 58 | setosa,5.1,3.5,1.4,0.3 59 | versicolor,6.7,3.1,4.4,1.4 60 | setosa,5.8,4,1.2,0.2 61 | versicolor,6.4,3.2,4.5,1.5 62 | virginica,7.1,3,5.9,2.1 63 | setosa,4.4,2.9,1.4,0.2 64 | versicolor,5.8,2.7,3.9,1.2 65 | virginica,6.1,2.6,5.6,1.4 66 | virginica,6,3,4.8,1.8 67 | versicolor,5.4,3,4.5,1.5 68 | virginica,7.6,3,6.6,2.1 69 | setosa,5,3.4,1.6,0.4 70 | virginica,6.9,3.1,5.4,2.1 71 | versicolor,5.6,3,4.5,1.5 72 | setosa,4.8,3.4,1.6,0.2 73 | versicolor,5.7,2.8,4.5,1.3 74 | virginica,6.8,3,5.5,2.1 75 | versicolor,5.9,3,4.2,1.5 76 | virginica,6.7,3.3,5.7,2.5 77 | virginica,6.5,3.2,5.1,2 78 | virginica,6.7,3.1,5.6,2.4 79 | setosa,5.5,4.2,1.4,0.2 80 | versicolor,5.5,2.4,3.8,1.1 81 | setosa,5,3.4,1.5,0.2 82 | virginica,6.4,2.8,5.6,2.1 83 | versicolor,6.3,3.3,4.7,1.6 84 | virginica,6.1,3,4.9,1.8 85 | virginica,7.7,2.6,6.9,2.3 86 | virginica,7.2,3.2,6,1.8 87 | versicolor,4.9,2.4,3.3,1 88 | virginica,6.5,3,5.5,1.8 89 | virginica,6.2,2.8,4.8,1.8 90 | setosa,5.7,4.4,1.5,0.4 91 | setosa,4.7,3.2,1.3,0.2 92 | virginica,7.3,2.9,6.3,1.8 93 | virginica,7.7,3.8,6.7,2.2 94 | setosa,4.5,2.3,1.3,0.3 95 | virginica,6.9,3.1,5.1,2.3 96 | setosa,4.3,3,1.1,0.1 97 | virginica,6.3,2.8,5.1,1.5 98 | versicolor,5,2.3,3.3,1 99 | setosa,5.7,3.8,1.7,0.3 100 | virginica,6.4,2.8,5.6,2.2 101 | virginica,5.7,2.5,5,2 102 | versicolor,6,3.4,4.5,1.6 103 | versicolor,5.8,2.6,4,1.2 104 | versicolor,6.1,3,4.6,1.4 105 | virginica,6.7,2.5,5.8,1.8 106 | virginica,5.8,2.7,5.1,1.9 107 | setosa,5,3.5,1.3,0.3 108 | versicolor,5.2,2.7,3.9,1.4 109 | virginica,6.5,3,5.8,2.2 110 | versicolor,5.6,2.9,3.6,1.3 111 | setosa,5.2,3.5,1.5,0.2 112 | versicolor,5.5,2.4,3.7,1 113 | setosa,4.8,3.4,1.9,0.2 114 | versicolor,6,2.9,4.5,1.5 115 | setosa,5.1,3.5,1.4,0.2 116 | versicolor,5,2,3.5,1 117 | virginica,6.2,3.4,5.4,2.3 118 | setosa,4.6,3.2,1.4,0.2 119 | setosa,5.1,3.7,1.5,0.4 120 | setosa,4.7,3.2,1.6,0.2 -------------------------------------------------------------------------------- /samples/iris/data/train.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlab/tensorfx/98a91b1d9e657f9444b6140712a5dedcdd906acf/samples/iris/data/train.tfrecord -------------------------------------------------------------------------------- /samples/iris/run_csv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python -m tensorfx.tools.tfx train $1 \ 4 | --module trainer.csv \ 5 | --output /tmp/tensorfx/iris/csv \ 6 | --data-train data/train.csv \ 7 | --data-eval data/eval.csv \ 8 | --data-schema data/schema.yaml \ 9 | --data-metadata data/metadata.json \ 10 | --data-features trainer/features.yaml \ 11 | --log-level-tensorflow ERROR \ 12 | --log-level INFO \ 13 | --batch-size 5 \ 14 | --max-steps 2000 \ 15 | --checkpoint-interval-secs 1 \ 16 | --hidden-layers:1 20 \ 17 | --hidden-layers:2 10 \ 18 | 19 | -------------------------------------------------------------------------------- /samples/iris/run_df.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python -m trainer.df 4 | 5 | -------------------------------------------------------------------------------- /samples/iris/run_examples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python -m tensorfx.tools.tfx train $1 \ 4 | --module trainer.examples \ 5 | --output /tmp/tensorfx/iris/examples \ 6 | --data-train data/train.tfrecord \ 7 | --data-eval data/eval.tfrecord \ 8 | --data-schema data/schema.yaml \ 9 | --data-metadata data/metadata.json \ 10 | --data-features trainer/features.yaml \ 11 | --log-level-tensorflow ERROR \ 12 | --log-level INFO \ 13 | --batch-size 5 \ 14 | --max-steps 2000 \ 15 | --checkpoint-interval-secs 1 \ 16 | --hidden-layers:1 20 \ 17 | --hidden-layers:2 10 \ 18 | 19 | -------------------------------------------------------------------------------- /samples/iris/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # Defines the iris trainer module. 15 | -------------------------------------------------------------------------------- /samples/iris/trainer/csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # csv.py 14 | # Implements the iris classification training job using csv data. 15 | 16 | import tensorfx as tfx 17 | import tensorfx.models.nn as nn 18 | 19 | args = nn.FeedForwardClassificationArguments.parse(parse_job=True) 20 | dataset = tfx.data.CsvDataSet(args.data_schema, 21 | train=args.data_train, 22 | eval=args.data_eval, 23 | metadata=args.data_metadata, 24 | features=args.data_features) 25 | 26 | classification = nn.FeedForwardClassification(args) 27 | 28 | trainer = tfx.training.ModelTrainer() 29 | trainer.train(classification, dataset, args.output) 30 | -------------------------------------------------------------------------------- /samples/iris/trainer/df.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # run.py 14 | # Demonstrates a standalone client that uses in-memory data (using a pandas DataFrame), in-code 15 | # definition of the dataset schema, features and the model, training, and finally predictions. 16 | 17 | import json 18 | import pandas as pd 19 | import tensorfx as tfx 20 | import tensorfx.models.nn as nn 21 | 22 | 23 | def create_dataset(): 24 | """Programmatically build the DataSet 25 | """ 26 | # Load data into DataFrame objects. 27 | columns = ['species', 'petal_length', 'petal_width', 'sepal_length', 'sepal_width'] 28 | df_train = pd.read_csv('data/train.csv', names=columns) 29 | df_eval = pd.read_csv('data/eval.csv', names=columns) 30 | 31 | df_train['species'] = df_train['species'].astype('category') 32 | df_eval['species'] = df_eval['species'].astype('category') 33 | 34 | # NOTE: Ordinarily, this would be specified in YAML configuration, but defined in code to 35 | # demonstrate the programmatic interface to FeatureSet and Feature objects. This is equivalent 36 | # to features.yaml. 37 | features = [ 38 | tfx.data.Feature.concatenate('X', 39 | tfx.data.Feature.scale('pl', 'petal_length'), 40 | tfx.data.Feature.scale('pw', 'petal_width'), 41 | tfx.data.Feature.scale('sl', 'sepal_length'), 42 | tfx.data.Feature.scale('sl', 'sepal_width')), 43 | tfx.data.Feature.target('Y', 'species') 44 | ] 45 | 46 | return tfx.data.DataFrameDataSet(features=tfx.data.FeatureSet.create(features), 47 | train=df_train, eval=df_eval) 48 | 49 | 50 | def create_args(): 51 | """Programmatically create the arguments. 52 | """ 53 | # Build the arguments (programmatically starting with defaults, instead of parsing the 54 | # program's command-line flags using parse(). 55 | args = nn.FeedForwardClassificationArguments.default() 56 | args.batch_size = 5 57 | args.max_steps = 2000 58 | args.checkpoint_interval_secs = 1 59 | args.hidden_layers = [('l1', 20, 'relu'), ('l2', 10, 'relu')] 60 | 61 | return args 62 | 63 | 64 | def main(): 65 | args = create_args() 66 | dataset = create_dataset() 67 | 68 | # Define the model and the trainer to train the model 69 | classification = nn.FeedForwardClassification(args) 70 | trainer = tfx.training.ModelTrainer() 71 | 72 | # Train; since this is training in-process (i.e. by default single node training), the training 73 | # process is run as the 'master' node, which happens to load and return the exported model that 74 | # can conveniently be used to produce predictions. 75 | print 'Training...' 76 | model = trainer.train(classification, dataset, output='/tmp/tensorfx/iris/df') 77 | 78 | # Predict; predictions are returned as a set of dictionaries, in the same order as the input 79 | # instances. 80 | print 'Predicting...' 81 | instances = [ 82 | '6.3,3.3,6,2.5', # virginica 83 | '4.4,3,1.3,0.2', # setosa 84 | '6.1,2.8,4.7,1.2' # versicolor 85 | ] 86 | predictions = model.predict(instances) 87 | 88 | # Print out instances and corresponding predictions 89 | print '' 90 | for instance, prediction in zip(instances, predictions): 91 | print '%s -> %s\n' % (instance, json.dumps(prediction, indent=2)) 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /samples/iris/trainer/examples.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # examples.py 14 | # Implements the iris classification training job using examples data. 15 | 16 | import tensorfx as tfx 17 | import tensorfx.models.nn as nn 18 | 19 | args = nn.FeedForwardClassificationArguments.parse(parse_job=True) 20 | dataset = tfx.data.ExamplesDataSet(args.data_schema, 21 | train=args.data_train, 22 | eval=args.data_eval, 23 | metadata=args.data_metadata, 24 | features=args.data_features) 25 | 26 | classification = nn.FeedForwardClassification(args) 27 | 28 | trainer = tfx.training.ModelTrainer() 29 | trainer.train(classification, dataset, args.output) 30 | -------------------------------------------------------------------------------- /samples/iris/trainer/features.yaml: -------------------------------------------------------------------------------- 1 | features: 2 | - name: X 3 | type: concat 4 | features: 5 | - name: petal_width 6 | type: scale 7 | - name: petal_length 8 | type: scale 9 | - name: sepal_width 10 | type: log 11 | - name: sepal_length 12 | type: log 13 | - name: Y 14 | type: target 15 | fields: species 16 | -------------------------------------------------------------------------------- /samples/readme.md: -------------------------------------------------------------------------------- 1 | # Samples 2 | 3 | ## Iris 4 | 5 | This is the 'Hello World' of machine learning. This demonstrates two ways of 6 | using TensorFX. 7 | 8 | * Using in-memory (DataFrame) training and evaluation data. To run: 9 | 10 | python iris/run.py 11 | 12 | * Using file-based training and evaluation data, as well as schema, metadata, 13 | and features declared in external YAML or JSON files. To run: 14 | 15 | iris/run.sh 16 | 17 | Which invokes the train tool from TensorFX to launch a python module as a 18 | trainer process, as follows: 19 | 20 | tfx train \ 21 | --module iris.trainer.main --output [output path] \ 22 | ... trainer specific argumenrts 23 | 24 | 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 TensorLab. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 5 | # in compliance with the License. You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software distributed under the License 10 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | # or implied. See the License for the specific language governing permissions and limitations under 12 | # the License. 13 | 14 | # To publish to PyPi, use: 15 | # python setup.py bdist_wheel upload -r pypi 16 | 17 | import setuptools 18 | 19 | with open('tensorfx/_version.py') as vf: 20 | exec(vf.read()) 21 | 22 | with open('requirements.txt') as rf: 23 | dependencies = rf.readlines() 24 | dependencies = map(lambda d: d.strip(), dependencies) 25 | dependencies = filter(lambda d: d and not d.startswith('#'), dependencies) 26 | 27 | setuptools.setup( 28 | name='tensorfx', 29 | version=__version__, 30 | packages=[ 31 | 'tensorfx', 32 | 'tensorfx.data', 33 | 'tensorfx.training', 34 | 'tensorfx.prediction', 35 | 'tensorfx.tools', 36 | 'tensorfx.models', 37 | 'tensorfx.models.nn' 38 | ], 39 | entry_points={ 40 | 'console_scripts': [ 41 | 'tfx = tensorfx.tools.tfx:main' 42 | ], 43 | }, 44 | data_files=[('.', ['requirements.txt'])], 45 | install_requires=dependencies, 46 | author='Nikhil Kothari', 47 | author_email='nikhilk@twitter', 48 | url='https://github.com/TensorLab/tensorfx', 49 | license="Apache Software License", 50 | description='TensorFX Framework for training and serving machine learning models with TensorFlow', 51 | keywords=[ 52 | 'TensorLab', 53 | 'TensorFlow', 54 | 'Machine Learning', 55 | 'Deep Learning', 56 | 'Google' 57 | ], 58 | classifiers=[ 59 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers 60 | 'Development Status :: 3 - Alpha', 61 | 'Environment :: Other Environment', 62 | 'Intended Audience :: Developers', 63 | 'License :: OSI Approved :: Apache Software License' 64 | 'Programming Language :: Python', 65 | 'Programming Language :: Python :: 2.7', 66 | 'Operating System :: OS Independent', 67 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 68 | 'Topic :: Software Development :: Libraries', 69 | 'Topic :: Software Development :: Libraries :: Python Modules' 70 | ] 71 | ) 72 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # tensorfx module declaration. 15 | 16 | import tensorfx.data as data 17 | import tensorfx.training as training 18 | import tensorfx.prediction as prediction 19 | 20 | from _version import __version__ 21 | -------------------------------------------------------------------------------- /src/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _version.py 14 | # Declares package version. 15 | 16 | __version__ = '0.1.4' 17 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # tensorfx.data module declaration. 15 | 16 | from _schema import SchemaFieldType, SchemaField, Schema 17 | from _metadata import Metadata 18 | from _features import FeatureType, Feature, FeatureSet 19 | from _transforms import Transformer 20 | 21 | from _dataset import DataSet, DataSource 22 | from _ds_csv import CsvDataSet, CsvDataSource 23 | from _ds_df import DataFrameDataSet, DataFrameDataSource 24 | from _ds_examples import ExamplesDataSet, ExamplesDataSource 25 | -------------------------------------------------------------------------------- /src/data/_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _dataset.py 14 | # Implementation of DataSet and DataSource classes. 15 | 16 | import tensorflow as tf 17 | from tensorflow.python.lib.io import file_io as tfio 18 | from ._schema import Schema 19 | from ._metadata import Metadata 20 | from ._features import FeatureSet 21 | 22 | 23 | class DataSet(object): 24 | """A class representing data to be used within a job. 25 | 26 | A DataSet contains one or more DataSource instances, each associated with a name. 27 | """ 28 | def __init__(self, datasources, schema, metadata, features): 29 | """Initializes a DataSet with the specified DataSource instances. 30 | 31 | Arguments: 32 | datasources: the set of contained DataSource instances key'ed by name. 33 | schema: the description of the source data. 34 | metadata: additional per-field information associated with the data. 35 | features: the optional description of the transformed data. 36 | """ 37 | self._datasources = datasources 38 | 39 | if type(schema) is str: 40 | # Interpret this as a file path if the value is a string 41 | schema = tfio.read_file_to_string(schema) 42 | schema = Schema.parse(schema) 43 | self._schema = schema 44 | 45 | if metadata: 46 | if type(metadata) is str: 47 | # Interpret this as a file path if the value is a string 48 | metadata = tfio.read_file_to_string(metadata) 49 | metadata = Metadata.parse(metadata) 50 | self._metadata = metadata 51 | 52 | if features: 53 | if type(features) is str: 54 | # Interpret this as a file path if the value is a string 55 | features = tfio.read_file_to_string(features) 56 | features = FeatureSet.parse(features) 57 | self._features = features 58 | 59 | @property 60 | def schema(self): 61 | """Retrives the schema associated with the DataSet. 62 | """ 63 | return self._schema 64 | 65 | @property 66 | def metadata(self): 67 | """Retrives the metadata associated with the DataSet. 68 | """ 69 | return self._metadata 70 | 71 | @property 72 | def features(self): 73 | """Retrives the features defined with the DataSet. 74 | """ 75 | return self._features 76 | 77 | @property 78 | def sources(self): 79 | """Retrieves the names of the contained DataSource instances. 80 | """ 81 | return self._datasources.keys() 82 | 83 | def __getitem__(self, index): 84 | """Retrieves a named DataSource within the DataSet. 85 | 86 | Arguments: 87 | index: the name of the DataSource to retrieve. 88 | Returns: 89 | The DataSource if there is one with the specified name; None otherwise. 90 | """ 91 | return self._datasources.get(index, None) 92 | 93 | def __len__(self): 94 | """Retrieves the number of contained DataSource instances. 95 | """ 96 | return len(self._datasources) 97 | 98 | def parse_instances(self, instances, prediction=False): 99 | """Parses input instances according to the associated schema, metadata and features. 100 | 101 | Arguments: 102 | instances: The tensor containing input strings. 103 | prediction: Whether the instances are being parsed for producing predictions or not. 104 | Returns: 105 | A dictionary of tensors key'ed by feature names. 106 | """ 107 | raise NotImplementedError() 108 | 109 | 110 | class DataSource(object): 111 | """A base class representing data that can be read for use in a job. 112 | """ 113 | def __init__(self): 114 | """Initializes an instance of a DataSource. 115 | """ 116 | pass 117 | 118 | def read(self, batch=128, shuffle=False, shuffle_buffer=1000, epochs=0, threads=1): 119 | """Reads the data represented by this DataSource using a TensorFlow reader. 120 | 121 | Arguments: 122 | batch: The number of records to read at a time. 123 | shuffle: Whether to shuffle the list of files. 124 | shuffle_buffer: When shuffling, the number of extra items to keep in the queue for randomness. 125 | epochs: The number of epochs or passes over the data to perform. 126 | threads: the number of threads to use to read from the queue. 127 | Returns: 128 | A tensor containing a list of instances read. 129 | """ 130 | instances = self.read_instances(batch, shuffle, epochs) 131 | 132 | queue_capacity = (threads + 3) * batch 133 | if shuffle: 134 | queue_capacity = queue_capacity + shuffle_buffer 135 | return tf.train.shuffle_batch([instances], 136 | batch_size=batch, allow_smaller_final_batch=True, 137 | enqueue_many=True, 138 | capacity=queue_capacity, 139 | min_after_dequeue=shuffle_buffer, 140 | num_threads=threads, 141 | name='shuffle_batch') 142 | else: 143 | return tf.train.batch([instances], batch_size=batch, allow_smaller_final_batch=True, 144 | enqueue_many=True, capacity=queue_capacity, 145 | num_threads=threads, 146 | name='batch') 147 | 148 | def read_instances(self, count, shuffle, epochs): 149 | """Reads the data represented by this DataSource using a TensorFlow reader. 150 | 151 | Arguments: 152 | count: The number of instances to read in at most. 153 | shuffle: Whether to shuffle the input queue of files. 154 | epochs: The number of epochs or passes over the data to perform. 155 | Returns: 156 | A tensor containing instances that are read. 157 | """ 158 | raise NotImplementedError('read_instances must be implemented in a derived class.') 159 | -------------------------------------------------------------------------------- /src/data/_ds_csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _ds_csv.py 14 | # Implementation of CsvDataSource. 15 | 16 | import tensorflow as tf 17 | from ._dataset import DataSet, DataSource 18 | from ._schema import SchemaFieldType 19 | 20 | 21 | class CsvDataSet(DataSet): 22 | """A DataSet representing data in csv format. 23 | """ 24 | def __init__(self, schema, metadata=None, features=None, **kwargs): 25 | """Initializes a CsvDataSet with the specified DataSource instances. 26 | 27 | Arguments: 28 | schema: the description of the source data. 29 | metadata: additional per-field information associated with the data. 30 | features: the optional description of the transformed data. 31 | kwargs: the set of CsvDataSource instances or csv paths to populate this DataSet with. 32 | """ 33 | datasources = {} 34 | for name, value in kwargs.iteritems(): 35 | if isinstance(value, str): 36 | value = CsvDataSource(value) 37 | 38 | if isinstance(value, CsvDataSource): 39 | datasources[name] = value 40 | else: 41 | raise ValueError('The specified DataSource is not a CsvDataSource') 42 | 43 | if not len(datasources): 44 | raise ValueError('At least one DataSource must be specified.') 45 | 46 | super(CsvDataSet, self).__init__(datasources, schema, metadata, features) 47 | 48 | def parse_instances(self, instances, prediction=False): 49 | """Parses input instances according to the associated schema. 50 | 51 | Arguments: 52 | instances: The tensor containing input strings. 53 | prediction: Whether the instances are being parsed for producing predictions or not. 54 | Returns: 55 | A dictionary of tensors key'ed by field names. 56 | """ 57 | return parse_csv(self.schema, instances, prediction) 58 | 59 | 60 | class CsvDataSource(DataSource): 61 | """A DataSource representing one or more csv files. 62 | """ 63 | def __init__(self, path, delimiter=','): 64 | """Initializes an instance of a CsvDataSource with the specified csv file(s). 65 | 66 | Arguments: 67 | path: the csv file containing the data. This can be a pattern to represent a set of files. 68 | delimiter: the delimiter character used. 69 | """ 70 | super(CsvDataSource, self).__init__() 71 | self._path = path 72 | self._delimiter = delimiter 73 | 74 | @property 75 | def path(self): 76 | """Retrives the path represented by the DataSource. 77 | """ 78 | return self._path 79 | 80 | def read_instances(self, count, shuffle, epochs): 81 | """Reads the data represented by this DataSource using a TensorFlow reader. 82 | 83 | Arguments: 84 | epochs: The number of epochs or passes over the data to perform. 85 | Returns: 86 | A tensor containing instances that are read. 87 | """ 88 | # None implies unlimited; switch the value to None when epochs is 0. 89 | epochs = epochs or None 90 | 91 | files = tf.train.match_filenames_once(self._path, name='files') 92 | queue = tf.train.string_input_producer(files, num_epochs=epochs, shuffle=shuffle, 93 | name='queue') 94 | reader = tf.TextLineReader(name='reader') 95 | _, instances = reader.read_up_to(queue, count, name='read') 96 | 97 | return instances 98 | 99 | 100 | def parse_csv(schema, instances, prediction): 101 | """A wrapper around decode_csv that parses csv instances based on provided Schema information. 102 | """ 103 | if prediction: 104 | # For training and evaluation data, the expectation is the target column is always present. 105 | # For prediction however, the target may or may not be present. 106 | # - In true prediction use-cases, the target is unknown and never present. 107 | # - In prediction for model evaluation use-cases, the target is present. 108 | # To use a single prediction graph, the missing target needs to be detected by comparing 109 | # number of columns in instances with number of columns defined in the schema. If there are 110 | # fewer columns, then prepend a ',' (with assumption that target is always the first column). 111 | # 112 | # To get the number of columns in instances, split on the ',' on the first instance, and use 113 | # the first dimension of the shape of the resulting substring values. 114 | columns = tf.shape(tf.string_split([instances[0]], delimiter=',').values)[0] 115 | instances = tf.cond(tf.less(columns, len(schema)), 116 | lambda: tf.string_join([tf.constant(','), instances]), 117 | lambda: instances) 118 | 119 | # Convert the schema into a set of tensor defaults, to be used for parsing csv data. 120 | defaults = [] 121 | for field in schema: 122 | if field.length != 1: 123 | # TODO: Support variable length, and list columns in csv. 124 | raise ValueError('Unsupported schema field "%s". Length must be 1.' % field.name) 125 | 126 | if field.type == SchemaFieldType.integer: 127 | field_default = tf.constant(0, dtype=tf.int64) 128 | elif field.type == SchemaFieldType.real: 129 | field_default = tf.constant(0.0, dtype=tf.float32) 130 | else: 131 | # discrete, text, binary 132 | field_default = tf.constant('', dtype=tf.string) 133 | defaults.append([field_default]) 134 | 135 | values = tf.decode_csv(instances, defaults, name='csv') 136 | 137 | parsed_instances = {} 138 | for field, value in zip(schema, values): 139 | # The parsed values are scalars, so each tensor is of shape (None,); turn them into tensors 140 | # of shape (None, 1). 141 | parsed_instances[field.name] = tf.expand_dims(value, axis=1, name=field.name) 142 | 143 | return parsed_instances 144 | -------------------------------------------------------------------------------- /src/data/_ds_df.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _ds_df.py 14 | # Implementation of DataFrameDataSet and DataFrameDataSource. 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from ._dataset import DataSet, DataSource 19 | from ._schema import Schema, SchemaField, SchemaFieldType 20 | from ._ds_csv import parse_csv 21 | 22 | 23 | class DataFrameDataSet(DataSet): 24 | """A DataSet representing data loaded as Pandas DataFrame instances. 25 | """ 26 | def __init__(self, features=None, **kwargs): 27 | """Initializes a DataFrameDataSet with the specified DataSource instances. 28 | 29 | Arguments: 30 | features: the optional description of the transformed data. 31 | kwargs: the set of CsvDataSource instances or csv paths to populate this DataSet with. 32 | """ 33 | # Import pandas here, rather than always, to restrict loading the library at startup, as well as 34 | # having only a soft-dependency on the library. 35 | # Since the user is passing in DataFrame instances, the assumption is the library has been 36 | # loaded, and can be assumed to be installed. 37 | import pandas as pd 38 | 39 | def create_schema(df): 40 | fields = [] 41 | for name, dtype in zip(df.columns, df.dtypes): 42 | if type(dtype) == pd.types.dtypes.CategoricalDtype: 43 | fields.append(SchemaField.discrete(name)) 44 | elif dtype in (np.int32, np.int64): 45 | fields.append(SchemaField.integer(name)) 46 | elif dtype in (np.float32, np.float64): 47 | fields.append(SchemaField.real(name)) 48 | else: 49 | raise ValueError('Unsupported data type "%s" in column "%s"' % (str(dtype), name)) 50 | return Schema(fields) 51 | 52 | def create_metadata(df): 53 | metadata = {} 54 | for name, dtype in zip(df.columns, df.dtypes): 55 | md = {} 56 | if type(dtype) == pd.types.dtypes.CategoricalDtype: 57 | entries = list(df[name].unique()) 58 | if np.nan in entries: 59 | entries.remove(np.nan) 60 | md['vocab'] = {'entries': sorted(entries)} 61 | elif dtype in (np.int32, np.int64, np.float32, np.float64): 62 | for stat, stat_value in df[name].describe().iteritems(): 63 | if stat == 'min': 64 | md['min'] = stat_value 65 | if stat == 'max': 66 | md['max'] = stat_value 67 | metadata[name] = md 68 | return metadata 69 | 70 | schema = None 71 | metadata = None 72 | datasources = {} 73 | for name, value in kwargs.iteritems(): 74 | if isinstance(value, pd.DataFrame): 75 | value = DataFrameDataSource(value) 76 | 77 | if isinstance(value, DataFrameDataSource): 78 | datasources[name] = value 79 | else: 80 | raise ValueError('The specified DataSource is not a DataFrameDataSource') 81 | 82 | if not schema: 83 | schema = create_schema(value.dataframe) 84 | if not metadata: 85 | metadata = create_metadata(value.dataframe) 86 | 87 | if not len(datasources): 88 | raise ValueError('At least one DataSource must be specified.') 89 | 90 | super(DataFrameDataSet, self).__init__(datasources, schema, metadata, features) 91 | 92 | def parse_instances(self, instances, prediction=False): 93 | """Parses input instances according to the associated schema. 94 | 95 | Arguments: 96 | instances: The tensor containing input strings. 97 | prediction: Whether the instances are being parsed for producing predictions or not. 98 | Returns: 99 | A dictionary of tensors key'ed by feature names. 100 | """ 101 | return parse_csv(self.schema, instances, prediction) 102 | 103 | 104 | class DataFrameDataSource(DataSource): 105 | """A DataSource representing a Pandas DataFrame. 106 | 107 | This class is useful for working with local/in-memory data. 108 | """ 109 | def __init__(self, df): 110 | """Initializes an instance of a DataFrameDataSource with the specified Pandas DataFrame. 111 | 112 | Arguments: 113 | df: the DataFrame instance to use. 114 | """ 115 | super(DataFrameDataSource, self).__init__() 116 | self._df = df 117 | 118 | @property 119 | def dataframe(self): 120 | """Retrieves the DataFrame represented by this DataSource. 121 | """ 122 | return self._df 123 | 124 | def read_instances(self, count, shuffle, epochs): 125 | """Reads the data represented by this DataSource using a TensorFlow reader. 126 | 127 | Arguments: 128 | epochs: The number of epochs or passes over the data to perform. 129 | Returns: 130 | A tensor containing instances that are read. 131 | """ 132 | # None implies unlimited; switch the value to None when epochs is 0. 133 | epochs = epochs or None 134 | 135 | with tf.device(''): 136 | # Ensure the device is local and the queue, dequeuing and lookup all happen on the default 137 | # device, which is required for the py_func operation. 138 | 139 | # A UDF that given a batch of indices, returns a batch of string (csv formatted) instances 140 | # from the DataFrame. 141 | df = self._df 142 | def reader(indices): 143 | rows = df.iloc[indices] 144 | return [map(lambda r: ','.join(r), rows.values.astype('string'))] 145 | 146 | queue = tf.train.range_input_producer(self._df.shape[0], num_epochs=epochs, shuffle=shuffle, 147 | name='queue') 148 | indices = queue.dequeue_up_to(count) 149 | instances = tf.py_func(reader, [indices], tf.string, name='read') 150 | instances.set_shape((None,)) 151 | 152 | return instances 153 | -------------------------------------------------------------------------------- /src/data/_ds_examples.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _ds_examples.py 14 | # Implementation of ExamplesDataSource. 15 | 16 | import tensorflow as tf 17 | from ._dataset import DataSet, DataSource 18 | from ._schema import SchemaFieldType 19 | 20 | 21 | class ExamplesDataSet(DataSet): 22 | """A DataSet representing data in tf.Example protobuf within a TFRecord format. 23 | """ 24 | def __init__(self, schema, metadata=None, features=None, **kwargs): 25 | """Initializes a ExamplesDataSet with the specified DataSource instances. 26 | 27 | Arguments: 28 | schema: the description of the source data. 29 | metadata: additional per-field information associated with the data. 30 | features: the optional description of the transformed data. 31 | kwargs: the set of ExamplesDataSource instances or TFRecord paths to populate this DataSet. 32 | """ 33 | datasources = {} 34 | for name, value in kwargs.iteritems(): 35 | if isinstance(value, str): 36 | value = ExamplesDataSource(value) 37 | 38 | if isinstance(value, ExamplesDataSource): 39 | datasources[name] = value 40 | else: 41 | raise ValueError('The specified DataSource is not a ExamplesDataSource') 42 | 43 | if not len(datasources): 44 | raise ValueError('At least one DataSource must be specified.') 45 | 46 | super(ExamplesDataSet, self).__init__(datasources, schema, metadata, features) 47 | 48 | def parse_instances(self, instances, prediction=False): 49 | """Parses input instances according to the associated schema. 50 | 51 | Arguments: 52 | instances: The tensor containing input strings. 53 | prediction: Whether the instances are being parsed for producing predictions or not. 54 | Returns: 55 | A dictionary of tensors key'ed by field names. 56 | """ 57 | # Convert the schema into an equivalent Example schema (expressed as features in Example 58 | # terminology). 59 | features = {} 60 | for field in self.schema: 61 | if field.type == SchemaFieldType.integer: 62 | dtype = tf.int64 63 | default_value = [0] 64 | elif field.type == SchemaFieldType.real: 65 | dtype = tf.float32 66 | default_value = [0.0] 67 | else: 68 | # discrete 69 | dtype = tf.string 70 | default_value = [''] 71 | 72 | if field.length == 0: 73 | feature = tf.VarLenFeature(dtype=dtype) 74 | else: 75 | if field.length != 1: 76 | default_value = default_value * field.length 77 | feature = tf.FixedLenFeature(shape=[field.length], dtype=dtype, default_value=default_value) 78 | 79 | features[field.name] = feature 80 | 81 | return tf.parse_example(instances, features, name='examples') 82 | 83 | 84 | class ExamplesDataSource(DataSource): 85 | """A DataSource representing one or more TFRecord files containing tf.Example data. 86 | """ 87 | def __init__(self, path, compressed=False): 88 | """Initializes an instance of a ExamplesDataSource with the specified TFRecord file(s). 89 | 90 | Arguments: 91 | path: TFRecord file containing the data. This can be a pattern to represent a set of files. 92 | compressed: Whether the TFRecord files are compressed. 93 | """ 94 | super(ExamplesDataSource, self).__init__() 95 | self._path = path 96 | self._compressed = compressed 97 | 98 | @property 99 | def path(self): 100 | """Retrives the path represented by the DataSource. 101 | """ 102 | return self._path 103 | 104 | def read_instances(self, count, shuffle, epochs): 105 | """Reads the data represented by this DataSource using a TensorFlow reader. 106 | 107 | Arguments: 108 | epochs: The number of epochs or passes over the data to perform. 109 | Returns: 110 | A tensor containing instances that are read. 111 | """ 112 | # None implies unlimited; switch the value to None when epochs is 0. 113 | epochs = epochs or None 114 | 115 | options = None 116 | if self._compressed: 117 | options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) 118 | 119 | files = tf.train.match_filenames_once(self._path, name='files') 120 | queue = tf.train.string_input_producer(files, num_epochs=epochs, shuffle=shuffle, 121 | name='queue') 122 | reader = tf.TFRecordReader(options=options, name='reader') 123 | _, instances = reader.read_up_to(queue, count, name='read') 124 | 125 | return instances 126 | -------------------------------------------------------------------------------- /src/data/_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _features.py 14 | # Implementation of FeatureSet and related class. 15 | 16 | import enum 17 | import tensorflow as tf 18 | import yaml 19 | 20 | 21 | class FeatureType(enum.Enum): 22 | """Defines the type of Feature instances. 23 | """ 24 | identity = 'identity' 25 | target = 'target' 26 | concat = 'concat' 27 | log = 'log' 28 | scale = 'scale' 29 | bucketize = 'bucketize' 30 | one_hot = 'one-hot' 31 | 32 | 33 | def _lookup_feature_type(s): 34 | for t in FeatureType: 35 | if t.value == s: 36 | return t 37 | raise ValueError('Invalid FeatureType "%s".' % s) 38 | 39 | 40 | class Feature(object): 41 | """Defines a named feature within a FeatureSet. 42 | """ 43 | def __init__(self, name, type, fields=None, features=None, transform=None): 44 | """Initializes a Feature with its name and source fields. 45 | 46 | Arguments: 47 | name: the name of the feature. 48 | type: the type of the feature. 49 | fields: the names of the fields making up this feature. 50 | features: the names of the features making up this feature in case of composite features. 51 | transform: transform configuration to produce the feature. 52 | """ 53 | self._name = name 54 | self._type = type 55 | self._fields = fields 56 | self._features = features 57 | self._transform = transform 58 | 59 | @classmethod 60 | def identity(cls, name, field=None): 61 | """Creates a feature representing an un-transformed schema field. 62 | 63 | Arguments: 64 | name: the name of the feature. 65 | field: the name of the field. If absenst, this uses the name as the field name as well. 66 | Returns: 67 | An instance of a Feature. 68 | """ 69 | if not field: 70 | # Optimize for an identity feature named the same as the field it represents. 71 | field = name 72 | return cls(name, FeatureType.identity, fields=[field]) 73 | 74 | @classmethod 75 | def target(cls, name, field): 76 | """Creates a feature representing the target value. 77 | 78 | Arguments: 79 | name: the name of the feature. 80 | field: the name of the field. 81 | Returns: 82 | An instance of a Feature. 83 | """ 84 | return cls(name, FeatureType.target, fields=[field]) 85 | 86 | @classmethod 87 | def concatenate(cls, name, *args): 88 | """Creates a composite feature that is a concatenation of multiple features. 89 | 90 | Arguments: 91 | name: the name of the feature. 92 | args: the sequence of features to concatenate. 93 | Returns: 94 | An instance of a Feature. 95 | """ 96 | if not len(args): 97 | raise ValueError('One or more features must be specified.') 98 | 99 | if type(args[0]) == list: 100 | features = args[0] 101 | else: 102 | features = list(args) 103 | 104 | return cls(name, FeatureType.concat, features=features) 105 | 106 | @classmethod 107 | def log(cls, name, field): 108 | """Creates a feature representing a log value of a numeric field. 109 | 110 | Arguments: 111 | name: The name of the feature. 112 | field: The name of the field to create the feature from. 113 | Returns: 114 | An instance of a Feature. 115 | """ 116 | return cls(name, FeatureType.log, fields=[field]) 117 | 118 | @classmethod 119 | def scale(cls, name, field, range=(0, 1)): 120 | """Creates a feature representing a scaled version of a numeric field. 121 | 122 | In order to perform scaling, the metadata will be looked up for the field, to retrieve min, max 123 | and mean values. 124 | 125 | Arguments: 126 | name: The name of the feature. 127 | field: The name of the field to create the feature from. 128 | range: The target range of the feature. 129 | Returns: 130 | An instance of a Feature. 131 | """ 132 | # TODO: What about the other scaling approaches, besides this (min-max scaling)? 133 | transform = {'min': range[0], 'max': range[1]} 134 | return cls(name, FeatureType.scale, fields=[field], transform=transform) 135 | 136 | @classmethod 137 | def bucketize(cls, name, field, boundaries): 138 | """Creates a feature representing a bucketized version of a numeric field. 139 | 140 | The value is returned is the index of the bucket that the value falls into in one-hot 141 | representation. 142 | 143 | Arguments: 144 | name: The name of the feature. 145 | field: The name of the field to create the feature from. 146 | boundaries: The list of bucket boundaries. 147 | Returns: 148 | An instance of a Feature. 149 | """ 150 | transform = {'boundaries': ','.join(map(str, boundaries))} 151 | return cls(name, FeatureType.bucketize, fields=[field], transform=transform) 152 | 153 | @classmethod 154 | def one_hot(cls, name, field): 155 | """Creates a feature representing a one-hot representation of a discrete field. 156 | 157 | Arguments: 158 | name: The name of the feature. 159 | field: The name of the field to create the feature from. 160 | Returns: 161 | An instance of a Feature. 162 | """ 163 | return cls(name, FeatureType.one_hot, fields=[field]) 164 | 165 | @property 166 | def name(self): 167 | """Retrieves the name of the feature. 168 | """ 169 | return self._name 170 | 171 | @property 172 | def features(self): 173 | """Retrieves the features making up a composite feature. 174 | """ 175 | return self._features 176 | 177 | @property 178 | def field(self): 179 | """Retrieves the field making up the feature if the feature is based on a single field. 180 | """ 181 | if len(self._fields) == 1: 182 | return self._fields[0] 183 | return None 184 | 185 | @property 186 | def fields(self): 187 | """Retrieves the fields making up the feature. 188 | """ 189 | return self._fields 190 | 191 | @property 192 | def type(self): 193 | """Retrieves the type of the feature. 194 | """ 195 | return self._type 196 | 197 | @property 198 | def transform(self): 199 | """Retrieves the transform configuration to produce the feature. 200 | """ 201 | return self._transform 202 | 203 | def format(self): 204 | """Retrieves the raw serializable representation of the features. 205 | """ 206 | data = {'name': self._name, 'type': self._type.value} 207 | if self._fields: 208 | data['fields'] = ','.join(self._fields) 209 | if self._transform: 210 | data['transform'] = self._transform 211 | if self._features: 212 | data['features'] = map(lambda f: f.format(), self._features) 213 | return data 214 | 215 | @staticmethod 216 | def parse(data): 217 | """Parses a feature from its serialized data representation. 218 | 219 | Arguments: 220 | data: A dictionary holding the serialized representation. 221 | Returns: 222 | The parsed Feature instance. 223 | """ 224 | name = data['name'] 225 | feature_type = _lookup_feature_type(data.get('type', 'identity')) 226 | transform = data.get('transform', None) 227 | 228 | fields = None 229 | features = None 230 | if feature_type == FeatureType.concat: 231 | features = [] 232 | for f in data['features']: 233 | feature = Feature.parse(f) 234 | features.append(feature) 235 | else: 236 | fields = data.get('fields', name) 237 | if type(fields) is str: 238 | fields = map(lambda n: n.strip(), fields.split(',')) 239 | 240 | return Feature(name, feature_type, fields=fields, features=features, transform=transform) 241 | 242 | 243 | class FeatureSet(object): 244 | """Represents the set of features consumed by a model during training and prediction. 245 | 246 | A FeatureSet contains a set of named features. Features are derived from input fields specified 247 | in a schema and constructed using a transformation. 248 | """ 249 | def __init__(self, features): 250 | """Initializes a FeatureSet from its specified set of features. 251 | 252 | Arguments: 253 | features: the list of features within a FeatureSet. 254 | """ 255 | self._features = features 256 | self._features_map = dict(map(lambda f: (f.name, f), features)) 257 | 258 | @staticmethod 259 | def create(*args): 260 | """Creates a FeatureSet from a set of features. 261 | 262 | Arguments: 263 | args: a list or sequence of features defining the FeatureSet. 264 | Returns: 265 | A FeatureSet instance. 266 | """ 267 | if not len(args): 268 | raise ValueError('One or more features must be specified.') 269 | 270 | if type(args[0]) == list: 271 | return FeatureSet(args[0]) 272 | else: 273 | return FeatureSet(list(args)) 274 | 275 | @staticmethod 276 | def parse(spec): 277 | """Parses a FeatureSet from a YAML specification. 278 | 279 | Arguments: 280 | spec: The feature specification to parse. 281 | Returns: 282 | A FeatureSet instance. 283 | """ 284 | if isinstance(spec, FeatureSet): 285 | return spec 286 | 287 | spec = yaml.safe_load(spec) 288 | 289 | features = [] 290 | for f in spec['features']: 291 | feature = Feature.parse(f) 292 | features.append(feature) 293 | 294 | return FeatureSet(features) 295 | 296 | def __getitem__(self, index): 297 | """Retrives the specified Feature by name. 298 | 299 | Arguments: 300 | index: the name of the feature. 301 | Returns: 302 | The SchemaField if it exists; None otherwise. 303 | """ 304 | return self._features_map.get(index, None) 305 | 306 | def __len__(self): 307 | """Retrieves the number of Features defined. 308 | """ 309 | return len(self._features) 310 | 311 | def __iter__(self): 312 | """Creates an iterator over the features in the FeatureSet. 313 | """ 314 | for feature in self._features: 315 | yield feature 316 | -------------------------------------------------------------------------------- /src/data/_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _metadata.py 14 | # Implementation of Metadata. 15 | 16 | import ujson 17 | 18 | 19 | class Metadata(object): 20 | """This class encapsulates metadata for individual fields within a dataset. 21 | 22 | Metadata is key'ed by individual field names, and is represented as key/value pairs, specific 23 | to the type of the field, and the analysis performed to generate the metadata. 24 | """ 25 | def __init__(self, md): 26 | """Initializes an instance of a Metadata object. 27 | 28 | Arguments: 29 | md: the metadata map key'ed by field names. 30 | """ 31 | self._md = md 32 | 33 | @staticmethod 34 | def parse(metadata): 35 | """Parses a Metadata instance from a JSON specification. 36 | 37 | Arguments: 38 | metadata: The metadata to parse. 39 | Returns: 40 | A Metadata instance. 41 | """ 42 | md = ujson.loads(metadata) 43 | return Metadata(md) 44 | 45 | def __getitem__(self, index): 46 | """Retrieves the metadata of the specified field by name. 47 | 48 | Arguments: 49 | index: the name of the field whose metadata is to be retrieved. 50 | Returns: 51 | The metadata dictionary for the specified field, or an empty dictionary. 52 | """ 53 | return self._md.get(index, {}) 54 | 55 | def __len__(self): 56 | """Retrieves the number of Features defined. 57 | """ 58 | return len(self._md) 59 | -------------------------------------------------------------------------------- /src/data/_schema.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _schema.py 14 | # Implementation of Schema and related classes. 15 | 16 | import enum 17 | import yaml 18 | 19 | 20 | class SchemaFieldType(enum.Enum): 21 | """Defines the types of SchemaField instances. 22 | """ 23 | integer = 'integer' 24 | real = 'real' 25 | discrete = 'discrete' 26 | 27 | 28 | class SchemaField(object): 29 | """Defines a named and typed field within a Schema. 30 | """ 31 | def __init__(self, name, type, length): 32 | """Initializes a SchemaField with its name and type. 33 | 34 | Arguments: 35 | name: the name of the field. 36 | type: the type of the field. 37 | length: the valence of the field (0 implies variable length) 38 | """ 39 | self._name = name 40 | self._type = type 41 | self._length = length 42 | 43 | # TODO: Add support for default values 44 | 45 | @classmethod 46 | def discrete(cls, name, length=1): 47 | """Creates a field representing a discrete value. 48 | 49 | Arguments: 50 | name: the name of the field. 51 | length: the valence of the field (0 implies variable length) 52 | """ 53 | return cls(name, SchemaFieldType.discrete, length) 54 | 55 | @classmethod 56 | def integer(cls, name, length=1): 57 | """Creates a field representing an integer. 58 | 59 | Arguments: 60 | name: the name of the field. 61 | length: the valence of the field (0 implies variable length) 62 | """ 63 | return cls(name, SchemaFieldType.integer, length) 64 | 65 | @classmethod 66 | def real(cls, name, length=1): 67 | """Creates a field representing a real number. 68 | 69 | Arguments: 70 | name: the name of the field. 71 | length: the valence of the field (0 implies variable length) 72 | """ 73 | return cls(name, SchemaFieldType.real, length) 74 | 75 | @property 76 | def name(self): 77 | """Retrieves the name of the field. 78 | """ 79 | return self._name 80 | 81 | @property 82 | def type(self): 83 | """Retrieves the type of the field. 84 | """ 85 | return self._type 86 | 87 | @property 88 | def length(self): 89 | """Retrieves the length of the field. 90 | """ 91 | return self._length 92 | 93 | @property 94 | def numeric(self): 95 | """Returns whether the field is a numeric type, i.e. integer or real. 96 | """ 97 | return self._type in [SchemaFieldType.integer, SchemaFieldType.real] 98 | 99 | 100 | class Schema(object): 101 | """Defines the schema of a DataSet. 102 | 103 | The schema represents the structure of the source data before it is transformed into features. 104 | """ 105 | def __init__(self, fields): 106 | """Initializes a Schema with the specified set of fields. 107 | 108 | Arguments: 109 | fields: a list of fields representing an ordered set of columns. 110 | """ 111 | if not len(fields): 112 | raise ValueError('One or more fields must be specified') 113 | 114 | self._fields = fields 115 | self._field_map = dict(map(lambda f: (f.name, f), fields)) 116 | 117 | @staticmethod 118 | def create(*args): 119 | """Creates a Schema from a set of fields. 120 | 121 | Arguments: 122 | args: a list or sequence of ordered fields defining the schema. 123 | Returns: 124 | A Schema instance. 125 | """ 126 | if not len(args): 127 | raise ValueError('One or more fields must be specified.') 128 | 129 | if type(args[0]) == list: 130 | return Schema(args[0]) 131 | else: 132 | return Schema(list(args)) 133 | 134 | def format(self): 135 | """Formats a Schema instance into its YAML specification. 136 | 137 | Returns: 138 | A string containing the YAML specification. 139 | """ 140 | fields = map(lambda f: {'name': f.name, 'type': f.type.name, 'length': f.length}, 141 | self._fields) 142 | spec = {'fields': fields} 143 | 144 | return yaml.safe_dump(spec, default_flow_style=False) 145 | 146 | @staticmethod 147 | def parse(spec): 148 | """Parses a Schema from a YAML specification. 149 | 150 | Arguments: 151 | spec: The schema specification to parse. 152 | Returns: 153 | A Schema instance. 154 | """ 155 | if isinstance(spec, Schema): 156 | return spec 157 | 158 | spec = yaml.safe_load(spec) 159 | fields = map(lambda f: SchemaField(f['name'], SchemaFieldType[f['type']], f.get('length', 1)), 160 | spec['fields']) 161 | return Schema(fields) 162 | 163 | @property 164 | def fields(self): 165 | """Retrieve the names of the fields in the schema. 166 | """ 167 | return map(lambda f: f.name, self._fields) 168 | 169 | def __getitem__(self, index): 170 | """Retrives the specified SchemaField by name or position. 171 | 172 | Arguments: 173 | index: the name or index of the field. 174 | Returns: 175 | The SchemaField if it exists; None otherwise. 176 | """ 177 | if type(index) is int: 178 | return self._fields[index] if len(self._fields) > index else None 179 | else: 180 | return self._field_map.get(index, None) 181 | 182 | def __iter__(self): 183 | """Creates an iterator to iterate over the fields. 184 | """ 185 | for field in self._fields: 186 | yield field 187 | 188 | def __len__(self): 189 | """Retrieves the number of SchemaFields defined. 190 | """ 191 | return len(self._fields) 192 | -------------------------------------------------------------------------------- /src/data/_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _transforms.py 14 | # Implementation of various transforms to build features. 15 | 16 | import tensorflow as tf 17 | from ._features import FeatureType 18 | from ._schema import SchemaFieldType 19 | 20 | 21 | class Transformer(object): 22 | """Implements transformation logic. 23 | """ 24 | def __init__(self, dataset): 25 | """Initializes a Transformer. 26 | 27 | Arguments: 28 | dataset: The dataset containing the data to be transformed into features. 29 | """ 30 | self._dataset = dataset 31 | 32 | def transform(self, instances): 33 | """Transforms the supplied instances into features. 34 | 35 | Arguments: 36 | instances: a dictionary of tensors key'ed by field names corresponding to the schema. 37 | Returns: 38 | A dictionary of tensors key'ed by feature names corresponding to the feature set. 39 | """ 40 | features = self._dataset.features 41 | 42 | # The top-level set of features is to be represented as a map of tensors, so transform the 43 | # features, and use the map result. 44 | _, tensor_map = _transform_features(instances, features, 45 | self._dataset.schema, 46 | self._dataset.metadata) 47 | return tensor_map 48 | 49 | 50 | def _identity(instances, feature, schema, metadata): 51 | """Applies the identity transform, which causes the unmodified field value to be used. 52 | """ 53 | return tf.identity(instances[feature.field], name='identity') 54 | 55 | 56 | def _target(instances, feature, schema, metadata): 57 | """Applies the target transform, which causes the unmodified field value to be used. 58 | """ 59 | # The result of parsing csv is a tensor of shape (None, 1), and we want to return a list of 60 | # scalars, or specifically, tensor of shape (None, ). 61 | return tf.squeeze(instances[feature.field], name='target') 62 | 63 | 64 | def _concat(instances, feature, schema, metadata): 65 | """Applies the composite transform, to compose a single tensor from a set of features. 66 | """ 67 | tensors, _ = _transform_features(instances, feature.features, schema, metadata) 68 | return tf.concat(tensors, axis=1, name='concat') 69 | 70 | 71 | def _log(instances, feature, schema, metadata): 72 | """Applies the log transform to a numeric field. 73 | """ 74 | field = schema[feature.field] 75 | if not field.numeric: 76 | raise ValueError('A log transform cannot be applied to non-numerical field "%s".' % 77 | feature.field) 78 | 79 | # Add 1 to avoid log of 0 (still assuming the field does not have negative values) 80 | return tf.log(instances[feature.field] + 1, name='log') 81 | 82 | 83 | def _scale(instances, feature, schema, metadata): 84 | """Applies the scale transform to a numeric field. 85 | """ 86 | field = schema[feature.field] 87 | if not field.numeric: 88 | raise ValueError('A scale transform cannot be applied to non-numerical field "%s".' % 89 | feature.field) 90 | 91 | transform = feature.transform 92 | md = metadata[feature.field] 93 | 94 | value = instances[feature.field] 95 | 96 | range_min = float(md['min']) 97 | range_max = float(md['max']) 98 | value = (value - range_min) / (range_max - range_min) 99 | 100 | if transform: 101 | target_min = float(transform['min']) 102 | target_max = float(transform['max']) 103 | if (target_min != 0.0) or (target_max != 1.0): 104 | value = value * (target_max - target_min) + target_min 105 | 106 | return tf.identity(value, name='scale') 107 | 108 | 109 | def _bucketize(instances, feature, schema, metadata): 110 | """Applies the bucketize transform to a numeric field. 111 | """ 112 | field = schema[feature.field] 113 | if not field.numeric: 114 | raise ValueError('A scale transform cannot be applied to non-numerical field "%s".' % 115 | feature.field) 116 | 117 | transform = feature.transform 118 | boundaries = map(float, transform['boundaries'].split(',')) 119 | 120 | # TODO: Figure out how to use tf.case instead of this contrib op 121 | from tensorflow.contrib.layers.python.ops.bucketization_op import bucketize 122 | 123 | # Create a one-hot encoded tensor. The dimension of this tensor is the set of buckets defined 124 | # by N boundaries == N + 1. 125 | # A squeeze is needed to remove the extra dimension added to the shape. 126 | value = instances[feature.field] 127 | 128 | value = tf.squeeze(tf.one_hot(bucketize(value, boundaries, name='bucket'), 129 | depth=len(boundaries) + 1, on_value=1.0, off_value=0.0, 130 | name='one_hot'), 131 | axis=1, name='bucketize') 132 | value.set_shape((None, len(boundaries) + 1)) 133 | return value 134 | 135 | 136 | def _one_hot(instances, feature, schema, metadata): 137 | """Applies the one-hot transform to a discrete field. 138 | """ 139 | field = schema[feature.field] 140 | if field.type != SchemaFieldType.discrete: 141 | raise ValueError('A one-hot transform cannot be applied to non-discrete field "%s".' % 142 | feature.field) 143 | 144 | md = metadata[feature.field] 145 | if not md: 146 | raise ValueError('A one-hot transform requires metadata listing the unique values.') 147 | 148 | entries = md['entries'] 149 | table = tf.contrib.lookup.HashTable( 150 | tf.contrib.lookup.KeyValueTensorInitializer(entries, 151 | tf.range(0, len(entries), dtype=tf.int64), 152 | tf.string, tf.int64), 153 | default_value=len(entries), name='entries') 154 | 155 | # Create a one-hot encoded tensor with one added to the number of values to account for the 156 | # default value returned by the table for unknown/failed lookups. 157 | # A squeeze is needed to remove the extra dimension added to the shape. 158 | value = instances[feature.field] 159 | 160 | value = tf.squeeze(tf.one_hot(table.lookup(value), len(entries) + 1, on_value=1.0, off_value=0.0), 161 | axis=1, 162 | name='one_hot') 163 | value.set_shape((None, len(entries) + 1)) 164 | return value 165 | 166 | 167 | _transformers = { 168 | FeatureType.identity.name: _identity, 169 | FeatureType.target.name: _target, 170 | FeatureType.concat.name: _concat, 171 | FeatureType.log.name: _log, 172 | FeatureType.scale.name: _scale, 173 | FeatureType.bucketize.name: _bucketize, 174 | FeatureType.one_hot.name: _one_hot 175 | } 176 | 177 | def _transform_features(instances, features, schema, metadata): 178 | """Transforms a list of features, to produce a list and map of tensor values. 179 | """ 180 | tensors = [] 181 | tensor_map = {} 182 | 183 | for f in features: 184 | transformer = _transformers[f.type.name] 185 | with tf.name_scope(f.name): 186 | value = transformer(instances, f, schema, metadata) 187 | 188 | tensors.append(value) 189 | tensor_map[f.name] = value 190 | 191 | return tensors, tensor_map 192 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # tensorfx.models module declaration. 15 | 16 | from ._classification import ClassificationModelArguments, ClassificationModelBuilder 17 | from ._classification import StringLabelClassification 18 | -------------------------------------------------------------------------------- /src/models/_classification.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _classification.py 14 | # Implements ClassificationModelBuilder and ClassificationModelArguments. 15 | 16 | import tensorflow as tf 17 | import tensorfx as tfx 18 | 19 | class ClassificationModelArguments(tfx.training.ModelArguments): 20 | """Arguments for classification models. 21 | """ 22 | @classmethod 23 | def init_parser(cls, parser): 24 | """Initializes the argument parser. 25 | 26 | Args: 27 | parser: An argument parser instance to be initialized with arguments. 28 | """ 29 | super(ClassificationModelArguments, cls).init_parser(parser) 30 | 31 | def process(self): 32 | """Processes the parsed arguments to produce any additional objects. 33 | """ 34 | pass 35 | 36 | 37 | class ClassificationModelBuilder(tfx.training.ModelBuilder): 38 | """A ModelBuilder for building classification models. 39 | 40 | A classification model treats the target value as a label. The label might be a discrete 41 | value (which is converted to integer indices), or may be pre-indexed. 42 | """ 43 | def __init__(self, args): 44 | super(ClassificationModelBuilder, self).__init__(args) 45 | self._classification = None 46 | 47 | @property 48 | def classification(self): 49 | """Returns the classification helper object. 50 | """ 51 | return self._classification 52 | 53 | def build_graph_interfaces(self, dataset, config): 54 | """Builds graph interfaces for training and evaluating a model, and for predicting using it. 55 | 56 | A graph interface is an object containing a TensorFlow graph member, as well as members 57 | corresponding to various tensors and ops within the graph. 58 | 59 | ClassificationModelBuilder also builds a classification helper object for use during graph 60 | building. 61 | 62 | Arguments: 63 | dataset: The dataset to use during training. 64 | config: The training Configuration object. 65 | Returns: 66 | A tuple consisting of the training, evaluation and prediction interfaces. 67 | """ 68 | target_feature = filter(lambda f: f.type == tfx.data.FeatureType.target, dataset.features)[0] 69 | target_field = dataset.schema[target_feature.field] 70 | target_metadata = dataset.metadata[target_feature.field] 71 | 72 | if target_field.type == tfx.data.SchemaFieldType.discrete: 73 | self._classification = StringLabelClassification(target_metadata['vocab']['entries']) 74 | else: 75 | self._classification = None 76 | 77 | return super(ClassificationModelBuilder, self).build_graph_interfaces(dataset, config) 78 | 79 | 80 | class StringLabelClassification(object): 81 | """A classification scenario involving string label names. 82 | 83 | Labels will be converted to indices when using the input, and indices back to labels to produce 84 | output. 85 | """ 86 | def __init__(self, labels): 87 | """Initializes an instance of StringLabelClassification with specified label names. 88 | """ 89 | self._labels = labels 90 | self._num_labels = len(labels) 91 | 92 | @property 93 | def num_labels(self): 94 | """Returns the number of labels in the model. 95 | """ 96 | return self._num_labels 97 | 98 | def keys(self, inputs): 99 | """Retrieves the keys, if present from the inputs. 100 | 101 | Arguments: 102 | inputs: the dictionary of tensors corresponding to the input. 103 | Returns: 104 | A tensor containing the keys if a keys feature exists, None otherwise. 105 | """ 106 | return inputs.get('key', None) 107 | 108 | def features(self, inputs): 109 | """Retrieves the features to use to build a model. 110 | 111 | For classification models, the default behavior is to use a feature named 'X' to represent the 112 | input features for the model. 113 | 114 | Arguments: 115 | inputs: the dictionary of tensors corresponding to the input. 116 | Returns: 117 | A tensor containing model input features. 118 | """ 119 | return inputs['X'] 120 | 121 | def target_labels(self, inputs): 122 | """Retrieves the target labels to use to build a model. 123 | 124 | For classification models, the default behavior is to use a feature named 'Y' to represent the 125 | target features for the model. 126 | 127 | Arguments: 128 | inputs: the dictionary of tensors corresponding to the input. 129 | Returns: 130 | A tensor containing the target labels. 131 | """ 132 | return inputs['Y'] 133 | 134 | def target_label_indices(self, inputs, one_hot=True): 135 | """Retrieves the target labels to use to build a model, as a set of indices. 136 | 137 | For classification models, the default behavior is to use a feature named 'Y' to represent the 138 | target features for the model. The labels are used to perform a lookup to produce indices. 139 | 140 | Arguments: 141 | inputs: the dictionary of tensors corresponding to the input. 142 | one_hot: whether to convert the indices into their one-hot representation. 143 | Returns: 144 | A tensor containing the target labels as indices.. 145 | """ 146 | labels = inputs['Y'] 147 | 148 | with tf.name_scope('label_table'): 149 | string_int_mapping = tf.contrib.lookup.KeyValueTensorInitializer( 150 | self._labels, tf.range(0, self._num_labels, dtype=tf.int64), tf.string, tf.int64) 151 | table = tf.contrib.lookup.HashTable(string_int_mapping, default_value=-1) 152 | 153 | if one_hot: 154 | indices = tf.squeeze(tf.one_hot(table.lookup(labels), self._num_labels), name='indices') 155 | else: 156 | indices = table.lookup(labels, name='indices') 157 | 158 | return indices 159 | 160 | def output_labels(self, indices): 161 | """Produces the output labels to represent a model's output. 162 | 163 | The indices are used to lookup corresponding label names. 164 | 165 | Arguments: 166 | indices: The predicted label indices. 167 | Returns: 168 | A tensor containing output predicted label names. 169 | """ 170 | with tf.name_scope('label_table'): 171 | int_string_mapping = tf.contrib.lookup.KeyValueTensorInitializer( 172 | tf.range(0, self._num_labels, dtype=tf.int64), self._labels, tf.int64, tf.string) 173 | table = tf.contrib.lookup.HashTable(int_string_mapping, default_value='') 174 | 175 | return table.lookup(indices, name='label') 176 | -------------------------------------------------------------------------------- /src/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # tensorfx.models.nn module declaration. 15 | 16 | from ._ff import FeedForwardClassificationArguments, FeedForwardClassification 17 | -------------------------------------------------------------------------------- /src/models/nn/_ff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _ff.py 14 | # Implements FeedForwardClassification. 15 | 16 | import math 17 | import tensorflow as tf 18 | import tensorfx as tfx 19 | import tensorfx.models as models 20 | 21 | 22 | def _init_parser(parser): 23 | """Initializes the parser for feed-forward models. 24 | """ 25 | optimization = parser.add_argument_group(title='Optimization', 26 | description='Arguments determining the optimizer behavior.') 27 | optimization.add_argument('--learning-rate', metavar='rate', type=float, default=0.01, 28 | help='The magnitude of learning to perform at each step.') 29 | 30 | nn = parser.add_argument_group(title='Neural Network', 31 | description='Arguments controlling the structure of the neural network.') 32 | nn.add_argument('--hidden-layers', metavar='units', type=int, required=False, 33 | action=parser.var_args_action, 34 | help='The size of each hidden layer to add.') 35 | 36 | def _process_args(args): 37 | """Processes arguments for feed-forward models. 38 | """ 39 | if args.hidden_layers: 40 | args.hidden_layers = map(lambda (i, s): ('layer_%d' % i, s, 'relu'), 41 | enumerate(args.hidden_layers)) 42 | else: 43 | args.hidden_layers = [] 44 | 45 | args.optimizer = tf.train.GradientDescentOptimizer(args.learning_rate) 46 | 47 | 48 | class FeedForwardClassificationArguments(models.ClassificationModelArguments): 49 | """Arguments for feed-forward classification neural networks. 50 | """ 51 | @classmethod 52 | def init_parser(cls, parser): 53 | """Initializes the argument parser. 54 | 55 | Args: 56 | parser: An argument parser instance to be initialized with arguments. 57 | """ 58 | super(FeedForwardClassificationArguments, cls).init_parser(parser) 59 | _init_parser(parser) 60 | 61 | def process(self): 62 | """Processes the parsed arguments to produce any additional objects. 63 | """ 64 | super(FeedForwardClassificationArguments, self).process() 65 | _process_args(self) 66 | 67 | 68 | class FeedForwardClassification(models.ClassificationModelBuilder): 69 | """A ModelBuilder for building feed-forward fully connected neural network models. 70 | 71 | These models are also known as multi-layer perceptrons. 72 | """ 73 | def __init__(self, args): 74 | super(FeedForwardClassification, self).__init__(args) 75 | 76 | def build_inference(self, inputs, training): 77 | histograms = {} 78 | scalars = {} 79 | 80 | # Build a set of hidden layers. The input to the first hidden layer is 81 | # the features tensor, whose shape is (batch, size). 82 | x = self.classification.features(inputs) 83 | x_size = x.get_shape()[1].value 84 | 85 | for name, size, activation in self.args.hidden_layers: 86 | with tf.name_scope(name): 87 | weights = tf.Variable(tf.truncated_normal([x_size, size], 88 | stddev=1.0 / math.sqrt(float(x_size))), 89 | name='weights') 90 | biases = tf.Variable(tf.zeros([size]), name='biases') 91 | outputs = tf.nn.xw_plus_b(x, weights, biases, name='outputs') 92 | 93 | histograms[outputs.op.name + '.activations'] = outputs 94 | scalars[outputs.op.name + '.sparsity'] = tf.nn.zero_fraction(outputs) 95 | 96 | if activation: 97 | activation_fn = getattr(tf.nn, activation) 98 | outputs = activation_fn(outputs, name=activation) 99 | x = outputs 100 | x_size = size 101 | 102 | with tf.name_scope('logits'): 103 | weights = tf.Variable(tf.truncated_normal([x_size, self._classification.num_labels], 104 | stddev=1.0 / math.sqrt(float(x_size))), 105 | name='weights') 106 | biases = tf.Variable(tf.zeros([self._classification.num_labels]), name='biases') 107 | logits = tf.nn.xw_plus_b(x, weights, biases, name='outputs') 108 | 109 | histograms[logits.op.name + '.activations'] = logits 110 | scalars[logits.op.name + '.sparsity'] = tf.nn.zero_fraction(logits) 111 | 112 | if training: 113 | with tf.name_scope(''): 114 | for name, t in scalars.iteritems(): 115 | tf.summary.scalar(name, t) 116 | 117 | for name, t in histograms.iteritems(): 118 | tf.summary.histogram(name, t) 119 | 120 | for t in tf.trainable_variables(): 121 | tf.summary.histogram(t.op.name, t) 122 | 123 | return logits 124 | 125 | def build_training(self, global_steps, inputs, inferences): 126 | with tf.name_scope('target'): 127 | label_indices = self.classification.target_label_indices(inputs) 128 | 129 | with tf.name_scope('error'): 130 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=inferences, 131 | labels=label_indices, 132 | name='softmax_cross_entropy') 133 | loss = tf.reduce_mean(cross_entropy, name='loss') 134 | 135 | averager = tf.train.ExponentialMovingAverage(0.99, name='loss_averager') 136 | averaging = averager.apply([loss]) 137 | 138 | with tf.name_scope(''): 139 | tf.summary.scalar('metrics/loss', loss) 140 | tf.summary.scalar('metrics/loss.average', averager.average(loss)) 141 | 142 | with tf.control_dependencies([averaging]): 143 | with tf.name_scope(self.args.optimizer.get_name()): 144 | gradients = self.args.optimizer.compute_gradients(loss, var_list=tf.trainable_variables()) 145 | train = self.args.optimizer.apply_gradients(gradients, global_steps, name='optimize') 146 | 147 | with tf.name_scope(''): 148 | for gradient, t in gradients: 149 | if gradient is not None: 150 | tf.summary.histogram(t.op.name + '.gradients', gradient) 151 | 152 | return loss, train 153 | 154 | def build_output(self, inputs, inferences): 155 | scores = tf.nn.softmax(inferences, name='scores') 156 | tf.add_to_collection('outputs', scores) 157 | 158 | with tf.name_scope('labels'): 159 | label_indices = tf.arg_max(inferences, 1, name='arg_max') 160 | labels = self.classification.output_labels(label_indices) 161 | tf.add_to_collection('outputs', labels) 162 | 163 | keys = self.classification.keys(inputs) 164 | if keys: 165 | # Key feature, if it exists, is a passthrough to the output. 166 | # The use of identity is to name the tensor and correspondingly the output field. 167 | keys = tf.identity(keys, name='key') 168 | tf.add_to_collection('outputs', keys) 169 | 170 | return { 171 | 'label': labels, 172 | 'score': scores 173 | } 174 | 175 | def build_evaluation(self, inputs, outputs): 176 | target_labels = self.classification.target_labels(inputs) 177 | 178 | with tf.name_scope('accuracy'): 179 | accuracy, eval = tf.contrib.metrics.streaming_accuracy(outputs['label'], target_labels) 180 | 181 | with tf.name_scope(''): 182 | tf.summary.scalar('metrics/accuracy', accuracy) 183 | 184 | return accuracy, eval 185 | -------------------------------------------------------------------------------- /src/prediction/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # tensorfx.prediction module declaration. 15 | 16 | from _model import Model 17 | -------------------------------------------------------------------------------- /src/prediction/_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _model.py 14 | # Implements the Model class. 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | 20 | class Model(object): 21 | """A model provides performs inferences using TensorFlow to produce predictions. 22 | 23 | A model is loaded from a checkpoint that was produced during training. 24 | """ 25 | def __init__(self, session, inputs, outputs): 26 | """Initializes a Model using a TensorFlow session containing an initialized prediction graph. 27 | 28 | Arguments: 29 | session: The TensorFlow session to use for evaluating inferences. 30 | inputs: A map of input names to corresponding graph tensors. 31 | outputs: A map of output names to corresponding graph tensors. 32 | """ 33 | self._session = session 34 | self._inputs = inputs 35 | self._outputs = outputs 36 | 37 | # Optimize for the one input key for the currently supported single input graphs 38 | self._input_key = inputs[inputs.keys()[0]] 39 | 40 | @classmethod 41 | def load(cls, path): 42 | """Imports a previously exported saved model. 43 | 44 | Arguments: 45 | - path: The location on disk where the saved model exists. 46 | Returns: 47 | An initialized Model object that can be used for performing prediction. 48 | """ 49 | with tf.Graph().as_default() as graph: 50 | session = tf.Session() 51 | 52 | metagraph = tf.saved_model.loader.load(session, ['serve'], path) 53 | signature = _parse_signature(metagraph) 54 | 55 | inputs = {} 56 | for alias in signature.inputs: 57 | inputs[alias] = signature.inputs[alias].name 58 | outputs = {} 59 | for alias in signature.outputs: 60 | outputs[alias] = signature.outputs[alias].name 61 | 62 | return cls(session, inputs, outputs) 63 | 64 | 65 | @staticmethod 66 | def save(session, path, inputs, outputs): 67 | """Exports the current session, the loaded graph, and variables into a saved model. 68 | 69 | Arguments: 70 | - session: the TensorFlow session with variables to save. 71 | - path: the location where the output model directory should be created. 72 | - inputs: the list of tensors constituting the input to the prediction graph. 73 | - outputs: the list of tensors constituting the outputs of the prediction graph. 74 | """ 75 | signature_map = {'serving_default': _build_signature(inputs, outputs)} 76 | model_builder = tf.saved_model.builder.SavedModelBuilder(path) 77 | model_builder.add_meta_graph_and_variables(session, 78 | tags=['serve'], 79 | signature_def_map=signature_map, 80 | clear_devices=True) 81 | model_builder.save() 82 | 83 | def predict(self, instances): 84 | """Performs inference to return predictions for the specified instances of data. 85 | 86 | Arguments: 87 | - instances: either an object, or list of objects each containing feature values. 88 | """ 89 | if not instances: 90 | return [] 91 | 92 | # TODO: Support for DataFrames and a flag of whether to append prediction outputs to input 93 | # DataFrame. 94 | 95 | # Run the instances through the session to retrieve the prediction outputs 96 | results = self._session.run(self._outputs, feed_dict={self._input_key: instances}) 97 | 98 | # Convert outputs, which are in dictionary of lists representation (alias -> batch of values) to 99 | # list of predictions representation (list of dictionaries, where each dict is alias -> value). 100 | predictions = [{} for _ in range(len(instances))] 101 | 102 | for alias in self._outputs.iterkeys(): 103 | values = results[alias] 104 | for index, value in enumerate(values): 105 | if isinstance(value, np.ndarray): 106 | value = value.tolist() 107 | predictions[index][alias] = value 108 | 109 | return predictions 110 | 111 | 112 | def _build_signature(inputs, outputs): 113 | def tensor_alias(tensor): 114 | local_name = tensor.name.split('/')[-1] 115 | return local_name.split(':')[0] 116 | 117 | input_map = {} 118 | output_map = {} 119 | for tensor in inputs: 120 | input_map[tensor_alias(tensor)] = tf.saved_model.utils.build_tensor_info(tensor) 121 | for tensor in outputs: 122 | output_map[tensor_alias(tensor)] = tf.saved_model.utils.build_tensor_info(tensor) 123 | 124 | return tf.saved_model.signature_def_utils.build_signature_def( 125 | inputs=input_map, 126 | outputs=output_map, 127 | method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) 128 | 129 | 130 | def _parse_signature(metagraph): 131 | if not metagraph.signature_def: 132 | raise ValueError('Invalid model. The saved model does not define a signature.') 133 | if len(metagraph.signature_def) > 1: 134 | raise ValueError('Invalid model. Only models with a single signature are supported.') 135 | 136 | signature = metagraph.signature_def.get('serving_default', None) 137 | if not signature: 138 | raise ValueError('Invalid model. Unexpected signature type.') 139 | 140 | if len(signature.inputs) != 1: 141 | raise ValueError('Invalid model. Only models with a single input are supported.') 142 | for alias in signature.inputs: 143 | if signature.inputs[alias].dtype != tf.string.as_datatype_enum: 144 | raise ValueError('Invalid model. Only models with a string input are supported.') 145 | if len(signature.outputs) == 0: 146 | raise ValueError('Invalid model. Only models with at least one output are supported.') 147 | 148 | return signature 149 | -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # tensorfx.tools module declaration. 15 | -------------------------------------------------------------------------------- /src/tools/_predict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _predict.py 14 | # Implements PredictCommand. 15 | 16 | import json 17 | import os 18 | import sys 19 | import tensorflow as tf 20 | import tensorfx as tfx 21 | 22 | 23 | class PredictCommand(object): 24 | """Implements the tfx predict command to use a model to produce predictions. 25 | """ 26 | name = 'predict' 27 | help = 'Produces predictions using a model.' 28 | extra = False 29 | 30 | @staticmethod 31 | def build_parser(parser): 32 | parser.add_argument('--model', metavar='path', type=str, required=True, 33 | help='The path to a previously trained model.') 34 | parser.add_argument('--input', metavar='path', type=str, 35 | help='The path to a file with input instances. Uses stdin by default.') 36 | parser.add_argument('--output', metavar='path', type=str, 37 | help='The path to a file to write outputs to. Uses stdout by default.') 38 | parser.add_argument('--batch-size', metavar='instances', type=int, default=10, 39 | help='The number of instances to predict per batch.') 40 | 41 | @staticmethod 42 | def run(args): 43 | # TODO: Figure out where to do JSON and TF initialization in more common way. 44 | json.encoder.FLOAT_REPR = lambda f: ('%.5f' % f) 45 | 46 | tf.logging.set_verbosity(tf.logging.ERROR) 47 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(tf.logging.ERROR) 48 | 49 | model = tfx.prediction.Model.load(args.model) 50 | 51 | with TextSource(args.input, args.batch_size) as source, TextSink(args.output) as sink: 52 | for instances in source: 53 | predictions = model.predict(instances) 54 | lines = map(lambda p: json.dumps(p, sort_keys=True), predictions) 55 | sink.write(lines) 56 | 57 | 58 | class TextSource(object): 59 | 60 | def __init__(self, file=None, batch_size=1): 61 | self._file = file 62 | self._batch_size = batch_size 63 | 64 | def __enter__(self): 65 | self._stream = open(self._file, 'r') if self._file else sys.stdin 66 | return self 67 | 68 | def __exit__(self, type, value, traceback): 69 | if self._stream and self._file: 70 | self._stream.close() 71 | 72 | def __iter__(self): 73 | instances = [] 74 | 75 | while True: 76 | instance = self._stream.readline().strip() 77 | if not instance: 78 | # EOF 79 | break 80 | 81 | instances.append(instance) 82 | if len(instances) == self._batch_size: 83 | # A desired batch of instances is available 84 | yield instances 85 | instances = [] 86 | 87 | if instances: 88 | yield instances 89 | 90 | 91 | class TextSink(object): 92 | 93 | def __init__(self, file=None): 94 | self._file = file 95 | 96 | def __enter__(self): 97 | self._stream = open(self._file, 'w') if self._file else sys.stdout 98 | return self 99 | 100 | def __exit__(self, type, value, traceback): 101 | if self._stream and self._file: 102 | self._stream.close() 103 | 104 | def write(self, lines): 105 | for l in lines: 106 | self._stream.write(l + '\n') 107 | -------------------------------------------------------------------------------- /src/tools/_scaffold.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _scaffold.py 14 | # Implements ScaffoldCommand 15 | 16 | import os 17 | import tensorfx as tfx 18 | 19 | class ScaffoldCommand(object): 20 | """Implements the tfx scaffold command to create a new TensorFX project from a template. 21 | """ 22 | name = 'scaffold' 23 | help = 'Createa a new project from a template.' 24 | extra = False 25 | 26 | @staticmethod 27 | def build_parser(parser): 28 | parser.add_argument('--name', metavar='name', type=str, required=True, 29 | help='The name of the model to use when instantiating the template') 30 | parser.add_argument('--dir', metavar='path', type=str, required=False, default=os.getcwd(), 31 | help='The directory in which to instantiate the template') 32 | parser.add_argument('--model', metavar='type', type=str, required=False, default='custom', 33 | help='The type of model to create; eg. "nn.FeedForwardClassification"') 34 | 35 | @staticmethod 36 | def run(args): 37 | variables = { 38 | 'name': args.name, 39 | 'tensorfx_version': tfx.__version__ 40 | } 41 | 42 | contents = { 43 | 'setup.py': _scaffold_setup_py.format(**variables), 44 | 'trainer/__init__.py': _scaffold_trainer_init_py.format(**variables), 45 | } 46 | 47 | if args.model == 'custom': 48 | variables['model_class'] = args.name[0].upper() + args.name[1:] 49 | contents['trainer/main.py'] = _scaffold_trainer_main_py_custom.format(**variables) 50 | contents['trainer/model.py'] = _scaffold_trainer_model_py.format(**variables) 51 | else: 52 | variables['model'] = args.model 53 | variables['model_set'] = args.model.split('.')[0] 54 | contents['trainer/main.py'] = _scaffold_trainer_main_py.format(**variables) 55 | 56 | scaffold_path = os.path.join(args.dir, args.name) 57 | for path, content in contents.iteritems(): 58 | content_path = os.path.join(scaffold_path, path) 59 | 60 | content_dir = os.path.dirname(content_path) 61 | if not os.path.isdir(content_dir): 62 | os.makedirs(content_dir) 63 | 64 | with open(content_path, 'w') as content_file: 65 | content_file.write(content) 66 | 67 | 68 | # TODO: Externalize these into a template directory 69 | 70 | _scaffold_setup_py = """# setup.py 71 | 72 | import setuptools 73 | 74 | # The name and version of the package. 75 | name = '{name}' 76 | version = '1.0' 77 | 78 | # The main modules in the package. 79 | trainer_main = '{name}.trainer.main' 80 | 81 | 82 | def main(): 83 | \"""Invokes setup to build or install a distribution of the package. 84 | \""" 85 | setuptools.setup(name=name, version=version, 86 | packages=setuptools.find_packages(), 87 | install_requires=[ 88 | 'tensorfx={tensorfx_version}' 89 | ]) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | """ 95 | 96 | _scaffold_trainer_init_py = """# __init__.py 97 | # Declaration of {name}.trainer module. 98 | """ 99 | 100 | _scaffold_trainer_main_py = """# main.py 101 | # Implementation of training module. 102 | 103 | import tenosrflow as tf 104 | import tensorfx as tfx 105 | import tensorfx.models.{model_set} as {model_set} 106 | 107 | args = {model}Arguments.parse(parse_job=True) 108 | dataset = tfx.data.CsvDataSet(args.data_schema, 109 | train=args.data_train, 110 | eval=args.data_eval, 111 | metadata=args.data_metadata, 112 | features=args.data_features) 113 | 114 | builder = {model}(args) 115 | 116 | trainer = tfx.training.ModelTrainer() 117 | model = trainer.train(builder, dataset, args.output) 118 | """ 119 | 120 | _scaffold_trainer_main_py_custom = """# main.py 121 | # Implementation of training module. 122 | 123 | import tensorflow as tf 124 | import tensorfx as tfx 125 | import _model as model 126 | 127 | args = model.{model_class}Arguments.parse(parse_job=True) 128 | dataset = tfx.data.CsvDataSet(args.data_schema, 129 | train=args.data_train, 130 | eval=args.data_eval, 131 | metadata=args.data_metadata, 132 | features=args.data_features) 133 | 134 | builder = model.{model_class}(args) 135 | 136 | trainer = tfx.training.ModelTrainer() 137 | model = trainer.train(builder, dataset, args.outupt) 138 | """ 139 | 140 | _scaffold_trainer_model_py = """# model.py 141 | # Implementation of model module. 142 | 143 | import tensorflow as tf 144 | import tensorfx as tfx 145 | 146 | class {model_class}Arguments(tfx.training.ModelArguments): 147 | \"""Declares arguments supported by the model. 148 | \""" 149 | @classmethod 150 | def init_parser(cls, parser): 151 | super({model_class}Arguments, cls).init_parser(parser) 152 | 153 | # TODO: Add additional model-specific arguments. 154 | 155 | 156 | class {model_class}(tfx.training.ModelBuilder): 157 | \"""Builds the graphs for training, evaluating and predicting with the model. 158 | \""" 159 | def __init__(self, args, dataset): 160 | super({model_class}, self).__init__(args, dataset) 161 | 162 | # TODO: Implement one or more of the graph building methods. These include one or more of 163 | # build_input(), build_inference(), build_training(), build_output(), and build_evaluation() or 164 | # build_training_graph(), build_evaluation_graph(), and build_prediction_graph(). 165 | # See the documentation for more details. 166 | """ 167 | -------------------------------------------------------------------------------- /src/tools/_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _train.py 14 | # Implements TrainCommand. 15 | 16 | import json 17 | import os 18 | import subprocess 19 | import sys 20 | 21 | _PORT = 14000 22 | 23 | class TrainCommand(object): 24 | """Implements the tfx train command to launch single node and distributed training. 25 | """ 26 | name = 'train' 27 | help = 'Launches local training jobs for development.' 28 | extra = True 29 | 30 | @staticmethod 31 | def build_parser(parser): 32 | parser.add_argument('--module', metavar='name', type=str, required=True, 33 | help='The name of the training module to launch') 34 | parser.add_argument('--output', metavar='path', type=str, default='output', 35 | help='The path to write outputs') 36 | parser.add_argument('--distributed', action='store_true', 37 | help='Runs a multi-node (master, worker, parameter server) cluster') 38 | 39 | @staticmethod 40 | def run(args): 41 | args.extra.extend([ 42 | '--job-dir', os.path.abspath(args.output) 43 | ]) 44 | 45 | cmd = ['python', '-m', args.module] + args.extra 46 | 47 | if args.distributed: 48 | print 'Launching training tasks (master, worker, parameter server)...' 49 | print ' '.join(cmd) 50 | print '----\n' 51 | 52 | ps_task = _start_task(cmd, _create_distributed_config('ps')) 53 | master_task = _start_task(cmd, _create_distributed_config('master')) 54 | worker_task = _start_task(cmd, _create_distributed_config('worker')) 55 | else: 56 | print 'Launching training task...' 57 | print ' '.join(cmd) 58 | print '----\n' 59 | 60 | master_task = _start_task(cmd, _create_simple_config()) 61 | ps_task = None 62 | worker_task = None 63 | 64 | try: 65 | master_task.wait() 66 | finally: 67 | if worker_task: 68 | _kill_task(worker_task) 69 | if ps_task: 70 | _kill_task(ps_task) 71 | _kill_task(master_task) 72 | 73 | 74 | def _create_simple_config(): 75 | return { 76 | 'task': {'type': 'master', 'index': 0}, 77 | 'job': {'local': True} 78 | } 79 | 80 | def _create_distributed_config(task): 81 | return { 82 | 'cluster': { 83 | 'ps': ['localhost:%d' % _PORT], 84 | 'master': ['localhost:%d' % (_PORT + 1)], 85 | 'worker': ['localhost:%d' % (_PORT + 2)] 86 | }, 87 | 'task': {'type': task, 'index': 0}, 88 | 'job': {'local': True} 89 | } 90 | 91 | def _start_task(cmd, config): 92 | env = os.environ.copy() 93 | env['TF_CONFIG'] = json.dumps(config) 94 | return subprocess.Popen(cmd, env=env) 95 | 96 | def _kill_task(process): 97 | try: 98 | process.terminate() 99 | except: 100 | pass 101 | -------------------------------------------------------------------------------- /src/tools/tfx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # train.py 14 | # tensorfx.tools.tfx module to implement the tfx command-line tool. 15 | 16 | import argparse 17 | import sys 18 | from _scaffold import ScaffoldCommand 19 | from _train import TrainCommand 20 | from _predict import PredictCommand 21 | 22 | 23 | def _build_cli(): 24 | """Builds the command-line interface. 25 | """ 26 | commands = [ 27 | ScaffoldCommand, 28 | TrainCommand, 29 | PredictCommand 30 | ] 31 | 32 | cli = argparse.ArgumentParser(prog='tfx') 33 | subparsers = cli.add_subparsers(title='Available commands') 34 | 35 | for command in commands: 36 | command_parser = subparsers.add_parser(command.name, help=command.help, 37 | usage='%(prog)s [--help] [options]') 38 | command_parser.set_defaults(command=command) 39 | command.build_parser(command_parser) 40 | 41 | return cli 42 | 43 | 44 | def main(args=None): 45 | if not args: 46 | args = sys.argv[1:] 47 | 48 | cli = _build_cli() 49 | args, extra_args = cli.parse_known_args(args) 50 | 51 | command = args.command 52 | del args.command 53 | 54 | if extra_args: 55 | if command.extra: 56 | args.extra = extra_args 57 | else: 58 | cli.error('unrecognized arguments %s' % ' '.join(extra_args)) 59 | 60 | command.run(args) 61 | 62 | 63 | if __name__ == '__main__': 64 | main(sys.argv[1:]) 65 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # tensorfx.training module declaration. 15 | 16 | from _config import Configuration 17 | from _args import ModelArguments 18 | from _model import ModelBuilder 19 | from _trainer import ModelTrainer 20 | -------------------------------------------------------------------------------- /src/training/_args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _args.py 14 | # Defines ModelArguments and related classes. 15 | 16 | import argparse 17 | import logging 18 | import sys 19 | import tensorfx as tfx 20 | 21 | 22 | class ModelArguments(argparse.Namespace): 23 | 24 | def process(self): 25 | """Processes the parsed arguments to produce any additional objects. 26 | """ 27 | # Convert strings to logging values 28 | self.log_level = getattr(logging, self.log_level) 29 | self.log_level_tensorflow = getattr(logging, self.log_level_tensorflow) 30 | 31 | @classmethod 32 | def default(cls): 33 | """Creates an instance of the arguments with default values. 34 | 35 | Returns: 36 | The model arguments with default values. 37 | """ 38 | return cls.parse(args=[]) 39 | 40 | @classmethod 41 | def parse(cls, args=None, parse_job=False): 42 | """Parses training arguments. 43 | 44 | Arguments: 45 | args: the arguments to parse. If unspecified, the process arguments are used. 46 | parse_job: whether to parse the job related standard (input and output) arguments. 47 | Returns: 48 | The parsed arguments. 49 | """ 50 | if args is None: 51 | args = sys.argv[1:] 52 | 53 | argparser = ModelArgumentsParser(add_job_arguments=parse_job) 54 | cls.init_parser(argparser) 55 | 56 | args_object = argparser.parse_args(args, namespace=cls()) 57 | args_object._args = args 58 | args_object.process() 59 | 60 | return args_object 61 | 62 | @classmethod 63 | def init_parser(cls, parser): 64 | """Initializes the argument parser. 65 | 66 | Args: 67 | parser: An argument parser instance to be initialized with arguments. 68 | """ 69 | session = parser.add_argument_group(title='Session', 70 | description='Arguments controlling the session loop.') 71 | session.add_argument('--max-steps', type=int, default=1000, 72 | help='The number of steps to execute during the training job.') 73 | session.add_argument('--batch-size', type=int, default=128, 74 | help='The number of instances to read and process in each training step.') 75 | session.add_argument('--epochs', type=int, default=0, 76 | help='The number of passes over the training data to make.') 77 | session.add_argument('--checkpoint-interval-secs', type=int, default=60 * 5, 78 | help='The frequency of checkpoints to create during the training job.') 79 | 80 | log_levels = ['FATAL', 'ERROR', 'WARN', 'INFO', 'DEBUG'] 81 | 82 | log = parser.add_argument_group(title='Logging and Diagnostics', 83 | description='Arguments controlling logging during training.') 84 | log.add_argument('--log-level-tensorflow', metavar='level', type=str, default='ERROR', 85 | choices=log_levels, 86 | help='The logging level for TensorFlow generated log messages.') 87 | log.add_argument('--log-device-placement', default=False, action='store_true', 88 | help='Whether to log placement of ops and tensors on devices.') 89 | log.add_argument('--log-level', metavar='level', type=str, default='INFO', choices=log_levels, 90 | help='The logging level for training.') 91 | log.add_argument('--log-interval-steps', metavar='steps', type=int, default=100, 92 | help='The frequency of training logs and summary events to generate.') 93 | 94 | 95 | class ModelArgumentsParser(argparse.ArgumentParser): 96 | 97 | def __init__(self, add_job_arguments): 98 | # TODO: Add description, epilogue, etc. 99 | super(ModelArgumentsParser, self).__init__(prog='trainer', usage='%(prog)s [--help] [options]') 100 | self.var_args_action = AddVarArgAction 101 | 102 | job = self.add_argument_group(title='Job', 103 | description='Arguments defining job inputs and outputs.') 104 | job.add_argument('--data-schema', metavar='path', type=str, required=False, 105 | help='The schema (columns, types) of the data being referenced (YAML).') 106 | job.add_argument('--data-metadata', metavar='path', type=str, required=False, 107 | help='The statistics and vocabularies of the data being referenced (JSON).') 108 | job.add_argument('--data-features', metavar='path', type=str, required=False, 109 | help='The set of features to transform the raw data into (YAML).') 110 | job.add_argument('--data-train', metavar='path', type=str, required=False, 111 | help='The data to use for training. This can include wildcards.') 112 | job.add_argument('--data-eval', metavar='path', type=str, required=False, 113 | help='The data to use for evaluation. This can include wildcards.') 114 | 115 | # The framework uses output, but Cloud ML Engine uses job-dir. Only one should be provided. 116 | job.add_argument('--output', type=str, dest='output', required=False, 117 | help='The output path to use for training outputs,') 118 | job.add_argument('--job-dir', type=str, dest='output', required=False, 119 | help='For Cloud ML Engine compatibility only. Use --output instead.') 120 | 121 | def _parse_optional(self, arg_string): 122 | suffix_index = arg_string.find(':') 123 | if suffix_index < 0: 124 | return super(ModelArgumentsParser, self)._parse_optional(arg_string) 125 | 126 | original_arg_string = arg_string 127 | suffix = arg_string[suffix_index + 1:] 128 | arg_string = arg_string[0:suffix_index] 129 | 130 | option_tuple = super(ModelArgumentsParser, self)._parse_optional(arg_string) 131 | if not option_tuple: 132 | return option_tuple 133 | 134 | action, option_string, explicit_arg = option_tuple 135 | if isinstance(action, AddVarArgAction): 136 | return action, suffix, explicit_arg 137 | else: 138 | self.exit(-1, message='Unknown argument %s' % original_arg_string) 139 | 140 | 141 | class AddVarArgAction(argparse.Action): 142 | def __init__(self, 143 | option_strings, 144 | dest, 145 | nargs=None, 146 | const=None, 147 | default=None, 148 | type=None, 149 | choices=None, 150 | required=False, 151 | help=None, 152 | metavar=None): 153 | super(AddVarArgAction, self).__init__( 154 | option_strings=option_strings, 155 | dest=dest, 156 | nargs=nargs, 157 | const=const, 158 | default=default, 159 | type=type, 160 | choices=choices, 161 | required=required, 162 | help=help, 163 | metavar=metavar) 164 | 165 | def __call__(self, parser, namespace, values, option_string=None): 166 | index = 0 167 | try: 168 | index = int(option_string) - 1 169 | except ValueError: 170 | pass 171 | 172 | list = getattr(namespace, self.dest) 173 | if list is None: 174 | list = [] 175 | setattr(namespace, self.dest, list) 176 | 177 | if index >= len(list): 178 | list.extend([self.default] * (index + 1 - len(list))) 179 | list[index] = values 180 | -------------------------------------------------------------------------------- /src/training/_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _config.py 14 | # Implements TrainingConfig. 15 | 16 | import json 17 | import os 18 | import tensorflow as tf 19 | 20 | _TASK_PARAM_SERVER = 'ps' 21 | _TASK_WORKER = 'worker' 22 | _TASK_MASTER = 'master' 23 | 24 | 25 | class Configuration(object): 26 | """Contains configuration information for the training process. 27 | """ 28 | def __init__(self, task, cluster, job, env): 29 | """Initializes a TrainingConfig instance from the individual configuration objects. 30 | 31 | Task configuration represents the current training task (for both single node and distributed 32 | training), while cluster configuration represents the cluster and should be None in single 33 | node training. 34 | Job configuration represents any environment-specific representation of the training job, 35 | 36 | Arguments: 37 | task: current TensorFlow task configuration. 38 | cluster: containing TensorFlow cluster configuration for distributed training. 39 | job: environment-specific job configuration. 40 | env: the environment-provided configuration information. 41 | """ 42 | self._task = type('TaskSpec', (object,), task) 43 | self._cluster = tf.train.ClusterSpec(cluster) if cluster else None 44 | self._job = type('JobSpec', (object,), job) 45 | self._env = env 46 | 47 | @classmethod 48 | def environment(cls): 49 | """Creates a Configuration object for single node and distributed training. 50 | 51 | This relies on looking up configuration from an environment variable, 'TF_CONFIG' which allows 52 | a hosting environment to configure the training process. 53 | The specific environment variable is expected to be a JSON formatted dictionary containing 54 | configuration about the current task, cluster and job. 55 | 56 | Returns: 57 | A Configuration instance matching the current environment. 58 | """ 59 | env = json.loads(os.environ.get('TF_CONFIG', '{}')) 60 | 61 | # Note that the lookup for 'task' must handle the case where it is missing, as well as when it 62 | # is specified, but is empty, to support both single node and distributed training. 63 | 64 | return cls(env.get('task', None) or {'type': 'master', 'index': 0}, 65 | env.get('cluster', None), 66 | env.get('job', {'local': True}), 67 | env) 68 | 69 | @classmethod 70 | def local(cls): 71 | """Creates a Configuration object representing single node training in a process. 72 | 73 | Returns: 74 | A default Configuration instance with simple configuration. 75 | """ 76 | return cls(task={'type': 'master', 'index': 0}, cluster=None, job={'local': True}, env={}) 77 | 78 | @property 79 | def distributed(self): 80 | """Determines if training being performed is distributed or is single node training. 81 | 82 | Returns: 83 | True if the configuration represents distributed training; False otherwise. 84 | """ 85 | return self._cluster is not None 86 | 87 | @property 88 | def cluster(self): 89 | """Retrieves the cluster definition containing the current node. 90 | 91 | This is None if the current node is part of a single node training job. 92 | """ 93 | return self._cluster 94 | 95 | @property 96 | def job(self): 97 | """Retrieves the job definition of the current training job. 98 | """ 99 | return self._job 100 | 101 | @property 102 | def task(self): 103 | """Retrieves the task definition associated with the current node. 104 | 105 | If no job information is provided, this is None. 106 | """ 107 | return self._task 108 | 109 | @property 110 | def device(self): 111 | """Retrieve the device associated with the current node. 112 | """ 113 | return '/job:%s/task:%d' % (self._task.type, self._task.index) 114 | 115 | @property 116 | def master(self): 117 | """Retrieves whether the current task is a master task. 118 | """ 119 | return self._task.type == _TASK_MASTER 120 | 121 | @property 122 | def param_server(self): 123 | """Retrieves whether the current task is a parameter server task. 124 | """ 125 | return self._task.type == _TASK_PARAM_SERVER 126 | 127 | @property 128 | def worker(self): 129 | """Retrieves whether the current task is a worker task. 130 | """ 131 | return self._task.type == _TASK_WORKER 132 | 133 | def create_device_setter(self, args): 134 | """Creates the device setter, which assigns variables and ops to devices in distributed mode. 135 | 136 | Arguments: 137 | args: the arguments associated with the current job. 138 | """ 139 | # TODO: Provide a way to provide a custom stragery or setter 140 | return tf.train.replica_device_setter(cluster=self._cluster, 141 | ps_device='/job:ps', 142 | worker_device=self.device) 143 | 144 | def create_server(self): 145 | """Creates the TensorFlow server, which is required for distributed training. 146 | """ 147 | if not self.distributed: 148 | return None 149 | return tf.train.Server(self._cluster, self._task.type, self._task.index, protocol='grpc') 150 | -------------------------------------------------------------------------------- /src/training/_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _hooks.py 14 | # Implements various session hooks needed for training. 15 | 16 | import logging 17 | import os 18 | import tensorflow as tf 19 | import tensorfx as tfx 20 | import time 21 | from tensorflow.core.framework import summary_pb2 as tfsummaries 22 | 23 | 24 | class StopTrainingHook(tf.train.SessionRunHook): 25 | """Stops training after a specified number of steps. 26 | """ 27 | def __init__(self, job): 28 | """Initializes an instance of StopTrainingHook. 29 | 30 | Arguments: 31 | job: The current training job. 32 | """ 33 | self._global_steps = job.training.global_steps 34 | self._max_steps = job.args.max_steps 35 | 36 | def before_run(self, context): 37 | return tf.train.SessionRunArgs(self._global_steps) 38 | 39 | def after_run(self, context, values): 40 | global_steps_completed = values.results 41 | if global_steps_completed >= self._max_steps: 42 | context.request_stop() 43 | 44 | 45 | class LogSessionHook(tf.train.SessionRunHook): 46 | """Logs the session loop by outputting steps, and throughput into logs. 47 | """ 48 | _MESSAGE_FORMAT = 'Run: %.2f sec; Steps: %d; Duration: %d sec; Throughput: %.1f instances/sec' 49 | def __init__(self, job): 50 | """Initializes an instance of LogSessionHook. 51 | 52 | Arguments: 53 | job: The current training job. 54 | """ 55 | self._log_interval_steps = job.args.log_interval_steps 56 | self._batch_size = job.args.batch_size 57 | 58 | self._start_time = time.time() 59 | self._steps_completed = 0 60 | self._step_start_time = 0 61 | 62 | def before_run(self, context): 63 | self._step_start_time = time.time() 64 | 65 | def after_run(self, context, values): 66 | self._steps_completed += 1 67 | 68 | if self._steps_completed == 1 or \ 69 | self._steps_completed % self._log_interval_steps == 0: 70 | end_time = time.time() 71 | run_time = end_time - self._step_start_time 72 | duration = end_time - self._start_time 73 | throughput = self._steps_completed * float(self._batch_size) / float(duration) 74 | 75 | logging.info(LogSessionHook._MESSAGE_FORMAT, 76 | run_time, self._steps_completed, duration, throughput) 77 | 78 | 79 | class LogTrainingHook(tf.train.SessionRunHook): 80 | """Logs the training job by logging progress as well as producing summary events. 81 | """ 82 | _MESSAGE_FORMAT = 'Global steps: %d; Duration: %d sec; Throughput: %.1f instances/sec; Loss: %.3f' 83 | def __init__(self, job): 84 | """Initializes an instance of LogTrainingHook. 85 | 86 | Arguments: 87 | job: The current training job. 88 | """ 89 | self._global_steps = job.training.global_steps 90 | self._loss = job.training.loss 91 | self._summary_op = job.training.summary_op 92 | 93 | self._log_interval_steps = job.args.log_interval_steps 94 | self._max_steps = job.args.max_steps 95 | self._batch_size = job.args.batch_size 96 | 97 | self._summary_writer = tf.summary.FileWriter(job.summaries_path('train')) 98 | self._summary_writer.add_graph(job.training.graph) 99 | 100 | self._start_time = time.time() 101 | self._global_steps_completed = 0 102 | 103 | def before_run(self, context): 104 | current_step = self._global_steps_completed + 1 105 | if (current_step % self._log_interval_steps == 0) or \ 106 | (current_step + 1 >= self._max_steps): 107 | return tf.train.SessionRunArgs([self._global_steps, self._loss, self._summary_op]) 108 | else: 109 | return tf.train.SessionRunArgs([self._global_steps]) 110 | 111 | def after_run(self, context, values): 112 | if len(values.results) == 1: 113 | self._global_steps_completed, = values.results 114 | else: 115 | self._global_steps_completed, loss_value, summary = values.results 116 | 117 | end_time = time.time() 118 | duration = end_time - self._start_time 119 | throughput = self._global_steps_completed * float(self._batch_size) / float(duration) 120 | 121 | logging.info(LogTrainingHook._MESSAGE_FORMAT, 122 | self._global_steps_completed, duration, throughput, loss_value) 123 | 124 | self._summary_writer.add_summary(summary, self._global_steps_completed) 125 | _log_summary_value(self._summary_writer, 'metrics/throughput', throughput, 126 | self._global_steps_completed) 127 | self._summary_writer.flush() 128 | 129 | 130 | class SaveCheckpointHook(tf.train.SessionRunHook): 131 | """Saves checkpoints during training, evaluates them, and exports the final checkpoint as a model. 132 | 133 | This should only be used in master tasks. 134 | """ 135 | _MESSAGE_FORMAT = 'Global steps: %d; Evaluation metric: %.3f' 136 | def __init__(self, job): 137 | """Initializes an instance of SaveCheckpointHook. 138 | 139 | Arguments: 140 | job: The current training job. 141 | """ 142 | self._job = job 143 | 144 | self._global_steps = job.training.global_steps 145 | self._saver = job.training.saver 146 | 147 | self._checkpoint_interval_secs = job.args.checkpoint_interval_secs 148 | 149 | self._checkpoint_name = os.path.join(job.checkpoints_path, 'model.ckpt') 150 | 151 | self._last_save_time = time.time() 152 | self._last_save_steps = 0 153 | 154 | self._summary_writer = tf.summary.FileWriter(job.summaries_path('eval')) 155 | self._summary_writer.add_graph(job.evaluation.graph) 156 | 157 | def before_run(self, context): 158 | # Save a checkpoint after the first step (this produces early evaluation results), as well as, 159 | # every checkpoint interval. 160 | if self._last_save_steps == 0 or \ 161 | time.time() - self._last_save_time >= self._checkpoint_interval_secs: 162 | return tf.train.SessionRunArgs([self._global_steps]) 163 | 164 | def after_run(self, context, values): 165 | if values.results: 166 | global_steps_completed, = values.results 167 | checkpoint = self._saver.save(context.session, self._checkpoint_name, global_steps_completed) 168 | self._evaluate(checkpoint, global_steps_completed) 169 | 170 | self._last_save_steps = global_steps_completed 171 | self._last_save_time = time.time() 172 | 173 | def end(self, session): 174 | global_steps_completed = session.run(self._global_steps) 175 | if global_steps_completed != self._last_save_steps: 176 | checkpoint = self._saver.save(session, self._checkpoint_name, global_steps_completed) 177 | self._evaluate(checkpoint, global_steps_completed) 178 | self._export(checkpoint) 179 | 180 | def _evaluate(self, checkpoint, global_steps_completed): 181 | with self._job.evaluation.graph.as_default(): 182 | with tf.Session() as session: 183 | self._job.evaluation.init_op.run() 184 | self._job.evaluation.saver.restore(session, checkpoint) 185 | self._job.evaluation.local_init_op.run() 186 | 187 | coord = tf.train.Coordinator() 188 | threads = tf.train.start_queue_runners(coord=coord) 189 | 190 | try: 191 | while not coord.should_stop(): 192 | session.run(self._job.evaluation.eval_op) 193 | except tf.errors.OutOfRangeError: 194 | # Ignore the error raised at the end of an epoch of eval data. 195 | pass 196 | finally: 197 | coord.request_stop() 198 | coord.join(threads) 199 | 200 | metric_value = session.run(self._job.evaluation.metric) 201 | 202 | summary = session.run(self._job.evaluation.summary_op) 203 | self._summary_writer.add_summary(summary, global_steps_completed) 204 | self._summary_writer.flush() 205 | 206 | logging.info(SaveCheckpointHook._MESSAGE_FORMAT, global_steps_completed, metric_value) 207 | 208 | def _export(self, checkpoint): 209 | summary_writer = tf.summary.FileWriter(self._job.summaries_path('prediction')) 210 | summary_writer.add_graph(self._job.prediction.graph) 211 | summary_writer.close() 212 | 213 | with self._job.prediction.graph.as_default(): 214 | with tf.Session() as session: 215 | self._job.prediction.init_op.run() 216 | self._job.prediction.saver.restore(session, checkpoint) 217 | self._job.prediction.local_init_op.run() 218 | 219 | tfx.prediction.Model.save(session, self._job.model_path, 220 | self._job.prediction.inputs, self._job.prediction.outputs) 221 | 222 | 223 | class CheckNaNLossHook(tf.train.SessionRunHook): 224 | """Checks for NaN loss values to stop or abort training. 225 | """ 226 | # TODO: Implement this 227 | pass 228 | 229 | 230 | def _log_summary_value(summary_writer, tag, value, global_steps): 231 | summary_value = tfsummaries.Summary.Value(tag=tag, simple_value=value) 232 | summary = tfsummaries.Summary(value=[summary_value]) 233 | 234 | summary_writer.add_summary(summary, global_steps) 235 | -------------------------------------------------------------------------------- /src/training/_job.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _job.py 14 | # Implements Job. 15 | 16 | import os 17 | import logging 18 | import yaml 19 | import sys 20 | import tensorflow as tf 21 | from tensorflow.python.lib.io import file_io as tfio 22 | 23 | class Job(object): 24 | """Represents a training job. 25 | """ 26 | def __init__(self, model_builder, inputs, output, config): 27 | """Initializes a Job instance. 28 | 29 | Arguments: 30 | model_builder: the ModelBuilder associated with the job. 31 | inputs: the input dataset for the job. 32 | output: the output path of the job. 33 | config: the Training configuration. 34 | """ 35 | self._model_builder = model_builder 36 | self._inputs = inputs 37 | self._output = output 38 | self._config = config 39 | 40 | @property 41 | def model_builder(self): 42 | """Retrieves the ModelBuilder being used to build model graphs. 43 | """ 44 | return self._model_builder 45 | 46 | @property 47 | def args(self): 48 | """Retrieves the arguments associated with the job. 49 | """ 50 | return self._model_builder.args 51 | 52 | @property 53 | def inputs(self): 54 | """Retrieves the input dataset of the job. 55 | """ 56 | return self._inputs 57 | 58 | @property 59 | def output_path(self): 60 | """Retrieves the output path of the job. 61 | """ 62 | return self._output 63 | 64 | @property 65 | def checkpoints_path(self): 66 | """Retrieves the checkpoints path within the output path. 67 | """ 68 | return os.path.join(self._output, 'checkpoints') 69 | 70 | @property 71 | def model_path(self): 72 | """Retrieves the model path within the output path. 73 | """ 74 | return os.path.join(self._output, 'model') 75 | 76 | def summaries_path(self, summary): 77 | """Retrieves the summaries path within the output path. 78 | 79 | Arguments: 80 | summary: the type of summary. 81 | """ 82 | return os.path.join(self._output, 'summaries', summary) 83 | 84 | @property 85 | def training(self): 86 | """Retrieves the training graph interface for the job. 87 | """ 88 | return self._training 89 | 90 | @property 91 | def evaluation(self): 92 | """Retrieves the evaluation graph interface for the job. 93 | """ 94 | return self._evaluation 95 | 96 | @property 97 | def prediction(self): 98 | """Retrieves the prediction graph interface for the job. 99 | """ 100 | return self._prediction 101 | 102 | def configure_logging(self): 103 | """Initializes the loggers for the job. 104 | """ 105 | args = self._model_builder.args 106 | 107 | tf.logging.set_verbosity(args.log_level_tensorflow) 108 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.log_level_tensorflow) 109 | 110 | logger = logging.getLogger() 111 | if hasattr(self._config.job, 'local') and not logger.handlers: 112 | # Additional setup to output logs to console for local runs. On cloud, this is handled by the 113 | # environment. The additional check for existing logging handler ensures that existing logging 114 | # setup is used; for example, when training is invoked in context of another application. 115 | if self._config.distributed: 116 | format = '%%(levelname)s %s:%d: %%(message)s' 117 | format = format % (self._config.task.type, self._config.task.index) 118 | else: 119 | format = '%(levelname)s: %(message)s' 120 | 121 | handler = logging.StreamHandler(stream=sys.stderr) 122 | handler.setFormatter(logging.Formatter(fmt=format)) 123 | 124 | logger.addHandler(handler) 125 | logger.setLevel(args.log_level) 126 | 127 | def start(self): 128 | """Performs startup logic, including building graphs. 129 | """ 130 | if self._config.master: 131 | # Save out job information for later reference alongside all other outputs. 132 | job_args = ' '.join(self._model_builder.args._args).replace(' --', '\n--').split('\n') 133 | job_info = { 134 | 'config': self._config._env, 135 | 'args': job_args 136 | } 137 | job_spec = yaml.safe_dump(job_info, default_flow_style=False) 138 | job_file = os.path.join(self._output, 'job.yaml') 139 | 140 | tfio.recursive_create_dir(self._output) 141 | tfio.write_string_to_file(job_file, job_spec) 142 | 143 | # Create a checkpoints directory. This is needed to ensure checkpoint restoration logic 144 | # can lookup an existing directory. 145 | tfio.recursive_create_dir(self.checkpoints_path) 146 | 147 | # Build the graphs that will be used during the course of the job. 148 | self._training, self._evaluation, self._prediction = \ 149 | self._model_builder.build_graph_interfaces(self._inputs, self._config) 150 | -------------------------------------------------------------------------------- /src/training/_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _model.py 14 | # Implements the ModelBuilder base class. 15 | 16 | import tensorflow as tf 17 | import tensorfx as tfx 18 | from ._args import ModelArguments 19 | 20 | 21 | def _create_interface(phase, graph, references): 22 | """Creates an interface instance using a dynamic type with graph and references as attributes. 23 | """ 24 | interface = {'graph': graph} 25 | interface.update(references) 26 | 27 | return type(phase + 'Interface', (object,), interface) 28 | 29 | 30 | class ModelBuilder(object): 31 | """Builds model graphs for different phases: training, evaluation and prediction. 32 | 33 | A model graph is an interface that encapsulates a TensorFlow graph, and references to tensors and 34 | ops within that graph. 35 | 36 | A ModelBuilder serves as a base class for various models. Each specific model adds its specific 37 | logic to build the required TensorFlow graph. 38 | """ 39 | def __init__(self, args): 40 | """Initializes an instance of a ModelBuilder. 41 | 42 | Arguments: 43 | args: the arguments specified for training. 44 | """ 45 | if args is None or not isinstance(args, ModelArguments): 46 | raise ValueError('args must be an instance of ModelArguments') 47 | 48 | self._args = args 49 | 50 | @property 51 | def args(self): 52 | """Retrieves the set of arguments specified for training. 53 | """ 54 | return self._args 55 | 56 | def build_graph_interfaces(self, dataset, config): 57 | """Builds graph interfaces for training and evaluating a model, and for predicting using it. 58 | 59 | A graph interface is an object containing a TensorFlow graph member, as well as members 60 | corresponding to various tensors and ops within the graph. 61 | 62 | Arguments: 63 | dataset: The dataset to use during training. 64 | config: The training Configuration object. 65 | Returns: 66 | A tuple consisting of the training, evaluation and prediction interfaces. 67 | """ 68 | with tf.Graph().as_default() as graph: 69 | with tf.device(config.create_device_setter(self._args)): 70 | references = self.build_training_graph(dataset) 71 | training = _create_interface('Training', graph, references) 72 | 73 | with tf.Graph().as_default() as graph: 74 | references = self.build_evaluation_graph(dataset) 75 | evaluation = _create_interface('Evaluation', graph, references) 76 | 77 | with tf.Graph().as_default() as graph: 78 | references = self.build_prediction_graph(dataset) 79 | prediction = _create_interface('Prediction', graph, references) 80 | 81 | return training, evaluation, prediction 82 | 83 | def build_training_graph(self, dataset): 84 | """Builds the graph to use for training a model. 85 | 86 | This operates on the current default graph. 87 | 88 | Args: 89 | dataset: The dataset to use during training. 90 | Returns: 91 | The set of tensors and ops references required for training. 92 | """ 93 | with tf.name_scope('input'): 94 | # For training, ensure the data is shuffled, and don't limit to any fixed number of epochs. 95 | # The datasource to use is the one named as 'train' within the dataset. 96 | inputs = self.build_input(dataset, 'train', 97 | batch=self.args.batch_size, 98 | epochs=self.args.epochs, 99 | shuffle=True) 100 | 101 | with tf.name_scope('inference'): 102 | inferences = self.build_inference(inputs, training=True) 103 | 104 | with tf.name_scope('train'): 105 | # Global steps is marked as trainable (explicitly), so as to have it be saved into checkpoints 106 | # for the purposes of resumed training. 107 | global_steps = tf.Variable(0, name='global_steps', dtype=tf.int64, trainable=True, 108 | collections=[tf.GraphKeys.GLOBAL_VARIABLES, 109 | tf.GraphKeys.GLOBAL_STEP, 110 | tf.GraphKeys.TRAINABLE_VARIABLES]) 111 | loss, train_op = self.build_training(global_steps, inputs, inferences) 112 | 113 | with tf.name_scope('initialization'): 114 | # Create the saver that will be used to save and restore (in cases of resumed training) 115 | # trained variables. 116 | saver = tf.train.Saver(tf.trainable_variables(), sharded=True) 117 | 118 | init_op, local_init_op = self.build_init() 119 | ready_op = tf.report_uninitialized_variables(tf.trainable_variables()) 120 | 121 | # Create the summary op that will merge all summaries across all sub-graphs 122 | summary_op = tf.summary.merge_all() 123 | 124 | scaffold = tf.train.Scaffold(init_op=init_op, 125 | local_init_op=local_init_op, 126 | ready_op=ready_op, 127 | ready_for_local_init_op=ready_op, 128 | summary_op=summary_op, 129 | saver=saver) 130 | scaffold.finalize() 131 | 132 | return { 133 | 'global_steps': global_steps, 134 | 'loss': loss, 135 | 'init_op': init_op, 136 | 'local_init_op': local_init_op, 137 | 'ready_op': ready_op, 138 | 'train_op': train_op, 139 | 'summary_op': summary_op, 140 | 'saver': saver, 141 | 'scaffold': scaffold 142 | } 143 | 144 | def build_evaluation_graph(self, dataset): 145 | """Builds the graph to use for evaluating a model during training. 146 | 147 | Args: 148 | dataset: The dataset to use during training. 149 | Returns: 150 | The set of tensors and ops references required for evaluation. 151 | """ 152 | with tf.name_scope('input'): 153 | # For evaluation, compute the eval metric over a single pass over the evaluation data, 154 | # and avoid any overhead from shuffling. 155 | # The datasource to use is the one named as 'eval' within the dataset. 156 | inputs = self.build_input(dataset, 'eval', batch=1, epochs=1, shuffle=False) 157 | 158 | with tf.name_scope('inference'): 159 | inferences = self.build_inference(inputs, training=False) 160 | 161 | with tf.name_scope('output'): 162 | outputs = self.build_output(inputs, inferences) 163 | 164 | with tf.name_scope('evaluation'): 165 | metric, eval_op = self.build_evaluation(inputs, outputs) 166 | 167 | with tf.name_scope('initialization'): 168 | # Create the saver that will be used to restore trained variables, 169 | saver = tf.train.Saver(tf.trainable_variables(), sharded=True) 170 | 171 | init_op, local_init_op = self.build_init() 172 | 173 | # Create the summary op that will merge all summaries across all sub-graphs 174 | summary_op = tf.summary.merge_all() 175 | 176 | return { 177 | 'metric': metric, 178 | 'init_op': init_op, 179 | 'local_init_op': local_init_op, 180 | 'eval_op': eval_op, 181 | 'summary_op': summary_op, 182 | 'saver': saver 183 | } 184 | 185 | def build_prediction_graph(self, dataset): 186 | """Builds the graph to use for predictions with the trained model. 187 | 188 | Args: 189 | dataset: The dataset to use during training. 190 | Returns: 191 | The set of tensors and ops references required for prediction. 192 | """ 193 | with tf.name_scope('input'): 194 | inputs = self.build_input(dataset, source=None, batch=0, epochs=0, shuffle=False) 195 | 196 | with tf.name_scope('inference'): 197 | inferences = self.build_inference(inputs, training=False) 198 | 199 | with tf.name_scope('output'): 200 | outputs = self.build_output(inputs, inferences) 201 | 202 | with tf.name_scope('initialization'): 203 | # Create the saver that will be used to restore trained variables. 204 | saver = tf.train.Saver(tf.trainable_variables(), sharded=True) 205 | 206 | init_op, local_init_op = self.build_init() 207 | 208 | graph_inputs = tf.get_collection('inputs') 209 | if len(graph_inputs) != 1 or graph_inputs[0].dtype != tf.string: 210 | raise Exception('Invalid prediction graph. Must have a single string input.') 211 | 212 | graph_outputs = tf.get_collection('outputs') 213 | if len(graph_outputs) == 0: 214 | raise Exception('Invalid prediction graph. Must have at least one output.') 215 | 216 | return { 217 | 'init_op': init_op, 218 | 'local_init_op': local_init_op, 219 | 'saver': saver, 220 | 'inputs': graph_inputs, 221 | 'outputs': graph_outputs 222 | } 223 | 224 | def build_init(self): 225 | """Builds the initialization sub-graph. 226 | 227 | The default implementation creates an initialization op that initializes all variables, 228 | locals for initialization, and another for all non-traininable variables and tables for local 229 | initialization. 230 | 231 | Initialization is run when the graph is first created, before training. Local initialization is 232 | performed after a previously trained model is loaded. 233 | 234 | Returns: 235 | A tuple containing the init op and local init op to use to initialize the graph. 236 | """ 237 | init_op = tf.variables_initializer(tf.global_variables(), name='init') 238 | 239 | # For some reason not all local variables are in the local variables collection, but some are in 240 | # the global variables collection (such as those setup by reader ops). 241 | # So in addition to initializing local variables in the local_init_op, we also initialize the 242 | # set of variables in the global variables, that are not trainable. 243 | # Just to add to the mix, tables are neither, and so must be explicitly included as well. 244 | # All of these will be initialized after restoring from a checkpoint. 245 | variables = tf.global_variables() 246 | for trainable in tf.trainable_variables(): 247 | variables.remove(trainable) 248 | 249 | local_init_op = tf.group(tf.variables_initializer(variables), 250 | tf.variables_initializer(tf.local_variables()), 251 | tf.tables_initializer(), 252 | name='local_init_op') 253 | 254 | # Add the local initialization op to the main op collection, which is looked up at model loading 255 | # time, and is automatically invoked after it has been loaded. 256 | tf.add_to_collection('saved_model_main_op', local_init_op) 257 | 258 | return init_op, local_init_op 259 | 260 | def build_input(self, dataset, source, batch, epochs, shuffle): 261 | """Builds the input sub-graph. 262 | 263 | Arguments: 264 | dataset: the dataset representing the inputs to the training. 265 | source: the name of data source to use for input (for training and evaluation). 266 | batch: the number of instances to read per batch. 267 | epochs: the number of passes over the data. 268 | shuffle: whether to shuffle the data. 269 | Returns: 270 | A dictionary of tensors key'ed by feature names. 271 | """ 272 | prediction = False 273 | if source: 274 | with tf.name_scope('read'): 275 | instances = dataset[source].read(batch=batch, shuffle=shuffle, epochs=epochs) 276 | else: 277 | prediction = True 278 | instances = tf.placeholder(dtype=tf.string, shape=(None,), name='instances') 279 | tf.add_to_collection('inputs', instances) 280 | 281 | with tf.name_scope('parse'): 282 | parsed_instances = dataset.parse_instances(instances, prediction) 283 | 284 | if dataset.features: 285 | with tf.name_scope('transform'): 286 | transformer = tfx.data.Transformer(dataset) 287 | return transformer.transform(parsed_instances) 288 | else: 289 | return parsed_instances 290 | 291 | def build_inference(self, inputs, training): 292 | """Builds the inference sub-graph. 293 | 294 | Arguments: 295 | inputs: the dictionary of tensors corresponding to the input. 296 | training: whether the inference sub-graph is being built for the training graph. 297 | Returns: 298 | The inference values. 299 | """ 300 | raise NotImplementedError('build_inference must be implemented in a derived class.') 301 | 302 | def build_training(self, global_steps, inputs, inferences): 303 | """Builds the training sub-graph. 304 | 305 | Arguments: 306 | global_steps: the global steps variable to use. 307 | inputs: the dictionary of tensors corresponding to the input. 308 | inferences: the inference values. 309 | Returns: 310 | The loss tensor, and the training op. 311 | """ 312 | raise NotImplementedError('build_training must be implemented in a derived class.') 313 | 314 | def build_output(self, inputs, inferences): 315 | """Builds the output sub-graph 316 | 317 | Arguments: 318 | inputs: the dictionary of tensors corresponding to the input. 319 | inferences: the inference values. 320 | Returns: 321 | A dictionary consisting of the output prediction tensors. 322 | """ 323 | raise NotImplementedError('build_output must be implemented in a derived class.') 324 | 325 | def build_evaluation(self, inputs, outputs): 326 | """Builds the evaluation graph.abs 327 | 328 | Arguments: 329 | inputs: the dictionary of tensors corresponding to the input. 330 | outputs: the dictionary containing output tensors. 331 | Returns: 332 | The eval metric tensor and the eval op. 333 | """ 334 | raise NotImplementedError('build_evaluation must be implemented in a derived class.') 335 | -------------------------------------------------------------------------------- /src/training/_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLab. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # _trainer.py 14 | # Implements Trainer. 15 | 16 | import tensorflow as tf 17 | import tensorfx as tfx 18 | from _config import Configuration 19 | from _hooks import * 20 | from _job import Job 21 | 22 | 23 | class ModelTrainer(object): 24 | """Provides the functionality to train a model during a training job. 25 | """ 26 | def __init__(self, config=None): 27 | """Initializes a ModelTrainer instance. 28 | 29 | Arguments: 30 | config: an optional configuration providing information about the training job and cluster. 31 | """ 32 | if not config: 33 | # By default, use the configuration specified in the TF_CONFIG environment variable. 34 | config = Configuration.environment() 35 | 36 | self._config = config 37 | 38 | @property 39 | def config(self): 40 | """Retrieves the training configuration. 41 | """ 42 | return self._config 43 | 44 | def train(self, model_builder, inputs, output): 45 | """Runs the training process to train a model. 46 | 47 | Arguments: 48 | model_builder: the ModelBuilder to use to build graphs during training. 49 | inputs: the input dataset for the job. 50 | output: the output path for the job. 51 | Returns: 52 | The trained Model. The resulting value is only relevant for master nodes. 53 | """ 54 | job = Job(model_builder, inputs, output, self._config) 55 | job.configure_logging() 56 | 57 | server = self._config.create_server() 58 | if server and self._config.param_server: 59 | return self._run_ps(server) 60 | 61 | return self._run_training(server, job) 62 | 63 | def _run_ps(self, server): 64 | """Runs the parameter server task. 65 | 66 | A ps task runs forever (until killed) using implementation within TensorFlow runtime. 67 | """ 68 | try: 69 | server.join() 70 | except AbortError: 71 | pass 72 | 73 | def _run_training(self, server, job): 74 | """Runs the worker and master tasks. 75 | 76 | Worker and master tasks create a TensorFlow session, and run the session loop. The session 77 | loop is customized via session hooks. A worker simply runs the training logic, while a master 78 | is also responsible for producing and evaluating checkpoints, as well producing summary event 79 | logs, and finally exporting the trained model. 80 | """ 81 | job.start() 82 | 83 | with job.training.graph.as_default() as graph: 84 | master = server.target if server else '' 85 | config = self._create_session_config(job) 86 | hooks = self._create_session_hooks(job) 87 | 88 | if self._config.master: 89 | session_creator = tf.train.ChiefSessionCreator(job.training.scaffold, 90 | master, config, job.checkpoints_path) 91 | else: 92 | session_creator = tf.train.WorkerSessionCreator(job.training.scaffold, master, config) 93 | 94 | with tf.train.MonitoredSession(session_creator, hooks) as session: 95 | while not session.should_stop(): 96 | # TODO: Add session run timeouts 97 | session.run(job.training.train_op) 98 | 99 | if self._config.master: 100 | return tfx.prediction.Model.load(job.model_path) 101 | else: 102 | return None 103 | 104 | def _create_session_config(self, job): 105 | """Creates the TensorFlow session config object. 106 | """ 107 | if self._config.local: 108 | # Don't have each process (esp. in case of distributed simulation) on the local machine to 109 | # attempt using all CPUs 110 | parallelism = 1 111 | else: 112 | # Use default 113 | parallelism = 0 114 | 115 | # Limit communication to specific devices. Specifically the goal is to disable communications 116 | # across workers, so as to increase performance and reliability. 117 | device_filters = ['/job:ps', self._config.device] 118 | 119 | return tf.ConfigProto(log_device_placement=job.args.log_device_placement, 120 | device_filters=device_filters, 121 | intra_op_parallelism_threads=parallelism, 122 | inter_op_parallelism_threads=parallelism) 123 | 124 | def _create_session_hooks(self, job): 125 | """Creates the TensorFlow session hooks that customize the session loop. 126 | """ 127 | hooks = [] 128 | 129 | hooks.append(LogSessionHook(job)) 130 | if self._config.master: 131 | hooks.append(LogTrainingHook(job)) 132 | hooks.append(SaveCheckpointHook(job)) 133 | hooks.append(StopTrainingHook(job)) 134 | 135 | return hooks 136 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLabs. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # Tests functionality in the tensorfx data module 15 | 16 | -------------------------------------------------------------------------------- /tests/data/dataset_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLabs. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # dataset_tests.py 14 | # Tests Dataset related functionality in tensorfx.data. 15 | 16 | import unittest 17 | import tensorfx as tfx 18 | 19 | 20 | class TestCases(unittest.TestCase): 21 | 22 | def test_empty_dataset(self): 23 | schema = tfx.data.Schema.create(tfx.data.SchemaField.integer('x')) 24 | ds = tfx.data.DataSet({}, schema, None, None) 25 | 26 | self.assertEqual(len(ds), 0) 27 | 28 | def test_create_dataset(self): 29 | schema = tfx.data.Schema.create(tfx.data.SchemaField.integer('x')) 30 | source = tfx.data.DataSource() 31 | ds = tfx.data.DataSet({'foo': source}, schema, None, None) 32 | 33 | self.assertEqual(ds['foo'], source) 34 | 35 | def test_create_multi_source_dataset(self): 36 | schema = tfx.data.Schema.create(tfx.data.SchemaField.integer('x'), 37 | tfx.data.SchemaField.integer('y')) 38 | train = tfx.data.CsvDataSource('...') 39 | eval = tfx.data.CsvDataSource('...') 40 | 41 | ds = tfx.data.CsvDataSet(schema, train=train, eval=eval) 42 | 43 | self.assertEqual(ds['train'], train) 44 | self.assertEqual(ds['eval'], eval) 45 | self.assertEqual(len(ds), 2) 46 | self.assertListEqual(ds.sources, ['train', 'eval']) 47 | -------------------------------------------------------------------------------- /tests/data/features_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLabs. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # schema_tests.py 14 | # Tests FeatureSet related functionality in tensorfx.data 15 | 16 | import unittest 17 | import tensorfx as tfx 18 | 19 | 20 | class TestCases(unittest.TestCase): 21 | 22 | def test_create_featureset(self): 23 | t = tfx.data.Feature.target('t', 't') 24 | x = tfx.data.Feature.identity('x', 'x') 25 | features = tfx.data.FeatureSet.create(t, x) 26 | 27 | self.assertEqual(len(features), 2) 28 | self.assertEqual(features['t'], t) 29 | 30 | def test_parse_featureset(self): 31 | spec = """ 32 | features: 33 | - name: target 34 | type: target 35 | fields: c1 36 | - name: f1 37 | type: identity 38 | fields: c3 39 | """ 40 | features = tfx.data.FeatureSet.parse(spec) 41 | 42 | self.assertEqual(len(features), 2) 43 | self.assertEqual(features['target'].fields[0], 'c1') 44 | self.assertEqual(features['target'].type, tfx.data.FeatureType.target) 45 | self.assertEqual(features['f1'].type, tfx.data.FeatureType.identity) 46 | self.assertEqual(features['f1'].fields, ['c3']) 47 | -------------------------------------------------------------------------------- /tests/data/schema_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLabs. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # schema_tests.py 14 | # Tests Schema related functionality in tensorfx.data 15 | 16 | import unittest 17 | import tensorfx as tfx 18 | 19 | 20 | class TestCases(unittest.TestCase): 21 | 22 | def test_create_single_field_schema(self): 23 | f = tfx.data.SchemaField.integer('n') 24 | schema = tfx.data.Schema.create(f) 25 | 26 | self.assertEqual(len(schema), 1) 27 | self.assertEqual(schema['n'], f) 28 | self.assertEqual(schema[0], f) 29 | 30 | def test_create_multi_field_schema(self): 31 | f1 = tfx.data.SchemaField.integer('n') 32 | f2 = tfx.data.SchemaField.discrete('t') 33 | schema = tfx.data.Schema.create(f1, f2) 34 | 35 | self.assertEqual(len(schema), 2) 36 | self.assertEqual(schema['n'], f1) 37 | self.assertEqual(schema[1], f2) 38 | 39 | def test_parse_schema(self): 40 | spec = """ 41 | fields: 42 | - name: f1 43 | type: integer 44 | - name: f2 45 | type: real 46 | - name: f3 47 | type: discrete 48 | """ 49 | schema = tfx.data.Schema.parse(spec) 50 | 51 | self.assertEqual(len(schema), 3) 52 | self.assertEqual(schema[0].name, 'f1') 53 | self.assertEqual(schema['f1'].type, tfx.data.SchemaFieldType.integer) 54 | self.assertEqual(schema['f2'].type, tfx.data.SchemaFieldType.real) 55 | self.assertEqual(schema['f3'].type, tfx.data.SchemaFieldType.discrete) 56 | self.assertEqual(schema.fields, ['f1', 'f2', 'f3']) 57 | -------------------------------------------------------------------------------- /tests/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLabs. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # main.py 14 | # Entrypoint for tests 15 | 16 | import os 17 | import sys 18 | import unittest 19 | 20 | # Add the library being tested to be on the path and then import it 21 | sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) 22 | 23 | # Load the test modules 24 | import data.dataset_tests 25 | import data.schema_tests 26 | import data.features_tests 27 | import training.config_tests 28 | 29 | _TEST_MODULES = [ 30 | data.dataset_tests, 31 | data.schema_tests, 32 | data.features_tests, 33 | training.config_tests 34 | ] 35 | 36 | 37 | def main(): 38 | suite = unittest.TestSuite() 39 | for m in _TEST_MODULES: 40 | suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(m)) 41 | 42 | runner = unittest.TextTestRunner() 43 | result = runner.run(suite) 44 | 45 | sys.exit(len(result.errors) + len(result.failures)) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | 51 | -------------------------------------------------------------------------------- /tests/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLabs. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # __init__.py 14 | # Tests functionality in the tensorfx training module 15 | 16 | -------------------------------------------------------------------------------- /tests/training/config_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 TensorLabs. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 4 | # in compliance with the License. You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software distributed under the License 9 | # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 | # or implied. See the License for the specific language governing permissions and limitations under 11 | # the License. 12 | 13 | # config_tests.py 14 | # Tests config related functionality in tensorfx. 15 | 16 | import json 17 | import os 18 | import unittest 19 | 20 | import tensorfx as tfx 21 | 22 | class TestCases(unittest.TestCase): 23 | 24 | def test_local_config(self): 25 | config = tfx.training.Configuration.local() 26 | 27 | self.assertFalse(config.distributed) 28 | self.assertIsNone(config.cluster) 29 | self.assertIsNotNone(config.task) 30 | self.assertEqual(config.task.type, 'master') 31 | self.assertTrue(config.master) 32 | 33 | def test_empty_env_config(self): 34 | config = tfx.training.Configuration.environment() 35 | 36 | self.assertFalse(config.distributed) 37 | self.assertIsNone(config.cluster) 38 | self.assertIsNotNone(config.task) 39 | self.assertEqual(config.task.type, 'master') 40 | self.assertTrue(config.master) 41 | 42 | def test_env_config(self): 43 | config = { 44 | 'task': { 45 | 'type': 'master', 46 | 'index': 0 47 | }, 48 | 'cluster': { 49 | 'hosts': [] 50 | } 51 | } 52 | os.environ['TF_CONFIG'] = json.dumps(config) 53 | 54 | config = tfx.training.Configuration.environment() 55 | 56 | self.assertTrue(config.distributed) 57 | self.assertIsNotNone(config.cluster) 58 | self.assertIsNotNone(config.task) 59 | self.assertEqual(config.task.type, 'master') 60 | self.assertTrue(config.master) 61 | 62 | --------------------------------------------------------------------------------