├── examples ├── .gitignore ├── mnist_full │ ├── mnist.py │ ├── main.sh │ └── mnist_full.py ├── mnist_evaluate │ ├── mnist.py │ ├── train.py │ ├── main.sh │ └── evaluate.py ├── mnist_infer │ ├── mnist.py │ ├── train.py │ ├── gt.py │ ├── main.sh │ ├── accuracy.py │ └── infer.py ├── mnist_simple │ ├── mnist.py │ ├── train.py │ └── main.sh ├── mnist_distributed │ ├── mnist.py │ ├── train.py │ └── main.sh └── lib │ ├── mnist_test.py │ ├── train.py │ ├── mnist.sh │ └── mnist.py ├── test ├── empty.py └── oracle.py ├── MANIFEST.in ├── .gitignore ├── img └── logo.png ├── qnd ├── util_test.py ├── infer_test.py ├── config_test.py ├── evaluate_test.py ├── test_test.py ├── flag_test.py ├── train_and_evaluate_test.py ├── __init__.py ├── test.py ├── estimator_test.py ├── util.py ├── experiment_test.py ├── inputs_test.py ├── evaluate.py ├── infer.py ├── estimator.py ├── experiment.py ├── flag.py ├── serve.py ├── train_and_evaluate.py ├── config.py └── inputs.py ├── .gitmodules ├── .travis.yml ├── docs ├── index.html └── qnd │ └── index.html ├── wercker.yml ├── setup.py ├── UNLICENSE ├── Rakefile └── README.md /examples/.gitignore: -------------------------------------------------------------------------------- 1 | var 2 | -------------------------------------------------------------------------------- /test/empty.py: -------------------------------------------------------------------------------- 1 | import qnd 2 | -------------------------------------------------------------------------------- /examples/mnist_full/mnist.py: -------------------------------------------------------------------------------- 1 | ../lib/mnist.py -------------------------------------------------------------------------------- /examples/mnist_evaluate/mnist.py: -------------------------------------------------------------------------------- 1 | ../lib/mnist.py -------------------------------------------------------------------------------- /examples/mnist_evaluate/train.py: -------------------------------------------------------------------------------- 1 | ../lib/train.py -------------------------------------------------------------------------------- /examples/mnist_infer/mnist.py: -------------------------------------------------------------------------------- 1 | ../lib/mnist.py -------------------------------------------------------------------------------- /examples/mnist_infer/train.py: -------------------------------------------------------------------------------- 1 | ../lib/train.py -------------------------------------------------------------------------------- /examples/mnist_simple/mnist.py: -------------------------------------------------------------------------------- 1 | ../lib/mnist.py -------------------------------------------------------------------------------- /examples/mnist_simple/train.py: -------------------------------------------------------------------------------- 1 | ../lib/train.py -------------------------------------------------------------------------------- /examples/mnist_distributed/mnist.py: -------------------------------------------------------------------------------- 1 | ../lib/mnist.py -------------------------------------------------------------------------------- /examples/mnist_distributed/train.py: -------------------------------------------------------------------------------- 1 | ../lib/train.py -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include UNLICENSE 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | build 3 | dist 4 | *.egg-info 5 | .venv 6 | .cache 7 | -------------------------------------------------------------------------------- /img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raviqqe/tensorflow-qnd/HEAD/img/logo.png -------------------------------------------------------------------------------- /examples/mnist_simple/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | . ../lib/mnist.sh && 4 | 5 | fetch_dataset && 6 | train 7 | -------------------------------------------------------------------------------- /qnd/util_test.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | 3 | 4 | def test_func_scope(): 5 | @func_scope 6 | def foo(): 7 | pass 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third/tensorflow-rakefile"] 2 | path = third/tensorflow-rakefile 3 | url = https://github.com/raviqqe/tensorflow-rakefile 4 | -------------------------------------------------------------------------------- /examples/mnist_full/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | . ../lib/mnist.sh && 4 | 5 | 6 | git clean -dfx && 7 | 8 | fetch_dataset && 9 | 10 | python3 mnist_full.py $train_options 11 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | 3 | language: python 4 | 5 | python: 6 | - "3.5" 7 | - "3.6" 8 | 9 | addons: 10 | apt: 11 | packages: 12 | - ruby 13 | 14 | script: 15 | - rake test 16 | -------------------------------------------------------------------------------- /qnd/infer_test.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | from . import infer 4 | from . import test 5 | 6 | 7 | def test_def_infer(): 8 | test.append_argv("--output_dir", "output") 9 | assert isinstance(infer.def_infer(), types.FunctionType) 10 | -------------------------------------------------------------------------------- /qnd/config_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from . import test 4 | from . import config 5 | 6 | 7 | def test_def_config(): 8 | test.append_argv() 9 | assert isinstance(config.def_config()(), tf.contrib.learn.RunConfig) 10 | -------------------------------------------------------------------------------- /qnd/evaluate_test.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | from . import evaluate 4 | from . import test 5 | 6 | 7 | def test_def_evaluate(): 8 | test.append_argv("--output_dir", "output") 9 | assert isinstance(evaluate.def_evaluate(), types.FunctionType) 10 | -------------------------------------------------------------------------------- /qnd/test_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .test import * 4 | 5 | 6 | def test_oracle_model(): 7 | oracle_model(tf.zeros([100]), tf.zeros([100])) 8 | 9 | 10 | def test_user_input_fn(): 11 | user_input_fn(tf.FIFOQueue(64, [tf.string])) 12 | -------------------------------------------------------------------------------- /qnd/flag_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from .flag import FlagAdder 4 | 5 | 6 | def test_flag_adder(): 7 | sys.argv = ["command", "--foo", "baz"] 8 | 9 | adder = FlagAdder() 10 | adder.add_flag("foo", dest="bar") 11 | assert adder.flags["bar"] == "baz" 12 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Redirecting ... 6 | 7 | 8 | Redirecting to qnd directory 9 | 10 | 11 | -------------------------------------------------------------------------------- /qnd/train_and_evaluate_test.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | from .experiment_test import append_argv 4 | from . import train_and_evaluate 5 | 6 | 7 | def test_def_train_and_evaluate(): 8 | append_argv() 9 | assert isinstance(train_and_evaluate.def_train_and_evaluate(), 10 | types.FunctionType) 11 | -------------------------------------------------------------------------------- /examples/mnist_evaluate/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | . ../lib/mnist.sh && 4 | 5 | 6 | fetch_dataset && 7 | 8 | echo Training a MNIST model... && 9 | train && 10 | 11 | echo Evaluating a model with test data... && 12 | python3 evaluate.py \ 13 | --infer_file $data_dir/test.tfrecords \ 14 | --output_dir $var_dir/output 15 | -------------------------------------------------------------------------------- /examples/lib/mnist_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import mnist 4 | 5 | 6 | def test_read_file(): 7 | data = mnist.read_file(tf.train.string_input_producer( 8 | tf.matching_files('examples/var/data/train.tfrecords'))) 9 | 10 | with tf.Session() as session: 11 | tf.train.queue_runner.start_queue_runners(session) 12 | print(session.run(data)) 13 | -------------------------------------------------------------------------------- /examples/mnist_infer/gt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def main(): 7 | for serialized in tf.python_io.tf_record_iterator(sys.argv[1]): 8 | example = tf.train.Example() 9 | example.ParseFromString(serialized) 10 | print(*example.features.feature["label"].int64_list.value) 11 | 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /examples/mnist_infer/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | . ../lib/mnist.sh && 4 | 5 | 6 | fetch_dataset && 7 | 8 | echo Training a MNIST model... && 9 | train && 10 | 11 | echo Infering labels of test data... && 12 | infer > $prediction_file && 13 | 14 | echo Calculating test accuracy... && 15 | python3 gt.py $data_dir/test.tfrecords > $gt_file && 16 | python3 accuracy.py $prediction_file $gt_file 17 | -------------------------------------------------------------------------------- /examples/mnist_infer/accuracy.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def main(): 5 | with open(sys.argv[1]) as f1, open(sys.argv[2]) as f2: 6 | lines1 = f1.readlines() 7 | lines2 = f2.readlines() 8 | 9 | assert len(lines1) == len(lines2) 10 | 11 | print(sum(int(line1 == line2) for line1, line2 in zip(lines1, lines2)) 12 | / len(lines1)) 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /examples/lib/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import qnd 5 | 6 | import mnist 7 | 8 | 9 | train_and_evaluate = qnd.def_train_and_evaluate( 10 | distributed=("distributed" in os.environ)) 11 | 12 | 13 | model = mnist.def_model() 14 | 15 | 16 | def main(): 17 | logging.getLogger().setLevel(logging.INFO) 18 | train_and_evaluate(model, mnist.read_file) 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /qnd/__init__.py: -------------------------------------------------------------------------------- 1 | """Quick and Dirty TensorFlow command framework""" 2 | 3 | from .flag import * 4 | from .infer import def_infer 5 | from .train_and_evaluate import def_train_and_evaluate 6 | from .evaluate import def_evaluate 7 | from .serve import def_serve 8 | 9 | __all__ = ["FLAGS", "add_flag", "add_required_flag", "FlagAdder", 10 | "def_train_and_evaluate", "def_evaluate", "def_infer", "def_serve"] 11 | __version__ = "0.1.11" 12 | -------------------------------------------------------------------------------- /qnd/test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def oracle_model(x, y): 7 | return y, 0.0, tf.no_op() 8 | 9 | 10 | def user_input_fn(filename_queue): 11 | x = filename_queue.dequeue() 12 | return {"x": x}, {"y": x} 13 | 14 | 15 | def append_argv(*args): 16 | command = "THIS_SHOULD_NEVER_MATCH" 17 | 18 | if sys.argv[0] != command: 19 | sys.argv = [command] 20 | 21 | sys.argv += [*args] 22 | -------------------------------------------------------------------------------- /examples/mnist_evaluate/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import qnd 4 | import tensorflow as tf 5 | 6 | import mnist 7 | 8 | 9 | evaluate = qnd.def_evaluate() 10 | 11 | 12 | model = mnist.def_model() 13 | 14 | 15 | def main(): 16 | logging.getLogger().setLevel(logging.INFO) 17 | 18 | try: 19 | print(evaluate(model, mnist.read_file)) 20 | except tf.errors.OutOfRangeError: 21 | pass 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /examples/mnist_infer/infer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import qnd 4 | 5 | import mnist 6 | 7 | 8 | infer = qnd.def_infer() 9 | 10 | 11 | model = mnist.def_model() 12 | 13 | 14 | def read_file(filename_queue): 15 | return mnist.read_file(filename_queue)[0] 16 | 17 | 18 | def main(): 19 | logging.getLogger().setLevel(logging.INFO) 20 | 21 | for label in infer(model, read_file): 22 | print(label) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /wercker.yml: -------------------------------------------------------------------------------- 1 | box: ubuntu:latest 2 | 3 | build: 4 | steps: 5 | - script: 6 | name: install 7 | code: | 8 | apt -y update --fix-missing && apt -y install git python3 python3-pip python3-venv python3-wheel rake 9 | 10 | - script: 11 | name: checkout 12 | code: git submodule update --init --recursive 13 | 14 | - script: 15 | name: test 16 | code: rake test 17 | 18 | - script: 19 | name: clean 20 | code: rake clean 21 | -------------------------------------------------------------------------------- /qnd/estimator_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from . import test 4 | from .estimator import * 5 | 6 | 7 | def test_def_estimator(): 8 | test.append_argv() 9 | assert isinstance(def_estimator()(test.oracle_model, "output"), 10 | tf.contrib.learn.Estimator) 11 | assert isinstance( 12 | def_estimator()( 13 | lambda x, y: tf.contrib.learn.ModelFnOps( 14 | "train", *test.oracle_model(x, y)), 15 | "output"), 16 | tf.contrib.learn.Estimator) 17 | -------------------------------------------------------------------------------- /test/oracle.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import qnd 4 | import tensorflow as tf 5 | 6 | 7 | def model_fn(x, y): 8 | return (y, 9 | 0.0, 10 | tf.contrib.framework.get_or_create_global_step().assign_add()) 11 | 12 | 13 | def input_fn(q): 14 | shape = (100,) 15 | return tf.zeros(shape, tf.float32), tf.ones(shape, tf.int32) 16 | 17 | 18 | train_and_evaluate = qnd.def_train_and_evaluate( 19 | distributed=("distributed" in os.environ)) 20 | 21 | 22 | def main(): 23 | train_and_evaluate(model_fn, input_fn) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /examples/mnist_distributed/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | . ../lib/mnist.sh || exit 1 4 | 5 | workers=localhost:19310,localhost:10019 6 | 7 | 8 | mnist() { 9 | distributed=yes train \ 10 | --master_host localhost:2049 \ 11 | --worker_hosts $workers \ 12 | --ps_hosts localhost:4242 \ 13 | --task_type "$@" 14 | } 15 | 16 | 17 | main() { 18 | fetch_dataset || exit 1 19 | 20 | mnist ps > $var_dir/ps.log 2>&1 & 21 | 22 | worker_id=0 23 | for worker in $(echo $workers | tr , ' ') 24 | do 25 | mnist worker --task_index $worker_id > $var_dir/worker-$worker_id.log 2>&1 & 26 | worker_id=$(($worker_id + 1)) 27 | done && 28 | 29 | mnist master 30 | } 31 | 32 | 33 | main 34 | -------------------------------------------------------------------------------- /qnd/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | 4 | import tensorflow as tf 5 | 6 | 7 | def func_scope(func): 8 | @functools.wraps(func) 9 | def wrapped_func(*args, **kwargs): 10 | with tf.variable_scope(func.__name__): 11 | return func(*args, **kwargs) 12 | 13 | # inspect.getargspec() (used in TensorFlow) cannot be deceived by 14 | # functools.wraps() somehow. So we need to assign a signature of an original 15 | # function to a wrapper. This can be a bug of Python. 16 | wrapped_func.__signature__ = inspect.signature(func) 17 | return wrapped_func 18 | 19 | 20 | def are_instances(objects, klass): 21 | return all(isinstance(obj, klass) for obj in objects) 22 | -------------------------------------------------------------------------------- /qnd/experiment_test.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | import tensorflow as tf 4 | 5 | from . import test 6 | from . import experiment 7 | from . import inputs_test 8 | 9 | 10 | def test_def_experiment(): 11 | append_argv() 12 | 13 | def_experiment_fn = experiment.def_def_experiment_fn() 14 | _assert_is_function(def_experiment_fn) 15 | 16 | experiment_fn = def_experiment_fn(test.oracle_model, test.user_input_fn) 17 | _assert_is_function(experiment_fn) 18 | 19 | assert isinstance(experiment_fn("output"), tf.contrib.learn.Experiment) 20 | 21 | 22 | def _assert_is_function(obj): 23 | assert isinstance(obj, types.FunctionType) 24 | 25 | 26 | def append_argv(): 27 | inputs_test.append_argv() 28 | -------------------------------------------------------------------------------- /examples/lib/mnist.sh: -------------------------------------------------------------------------------- 1 | var_dir=var 2 | data_dir=$var_dir/data 3 | shared_data_dir=../$data_dir 4 | gt_file=$var_dir/gt.csv 5 | prediction_file=$var_dir/predictions.csv 6 | 7 | train_options=\ 8 | '--train_steps 1000 '\ 9 | '--eval_steps 50 '\ 10 | "--train_file $data_dir/train.tfrecords "\ 11 | "--eval_file $data_dir/validation.tfrecords "\ 12 | "--output_dir $var_dir/output" 13 | 14 | 15 | train() { 16 | python3 train.py $train_options 17 | } 18 | 19 | 20 | infer() { 21 | python3 infer.py \ 22 | --infer_file $data_dir/test.tfrecords \ 23 | --output_dir $var_dir/output 24 | } 25 | 26 | 27 | fetch_dataset() { 28 | if [ ! -d $shared_data_dir ] 29 | then 30 | curl -SL https://github.com/tensorflow/tensorflow/raw/master/tensorflow/examples/how_tos/reading_data/convert_to_records.py | 31 | python3 - --directory $shared_data_dir 32 | fi && 33 | 34 | if [ ! -d $data_dir ] 35 | then 36 | mkdir -p $(dirname $data_dir) && 37 | ln -s ../$shared_data_dir $data_dir 38 | fi 39 | } 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import setuptools 3 | import sys 4 | 5 | 6 | if not sys.version_info >= (3, 5): 7 | exit("Sorry, Python must be later than 3.5.") 8 | 9 | 10 | setuptools.setup( 11 | name="tensorflow-qnd", 12 | version=re.search(r'__version__ *= *"([0-9]+\.[0-9]+\.[0-9]+)" *\n', 13 | open("qnd/__init__.py").read()).group(1), 14 | description="Quick and Dirty TensorFlow command framework", 15 | long_description=open("README.md").read(), 16 | license="Public Domain", 17 | author="Yota Toyama", 18 | author_email="raviqqe@gmail.com", 19 | url="https://github.com/raviqqe/tensorflow-qnd/", 20 | packages=["qnd"], 21 | install_requires=["gargparse"], 22 | classifiers=[ 23 | "Development Status :: 3 - Alpha", 24 | "Intended Audience :: Developers", 25 | "License :: Public Domain", 26 | "Programming Language :: Python :: 3.5", 27 | "Programming Language :: Python :: 3.6", 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | "Topic :: System :: Networking", 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /qnd/inputs_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from . import inputs 4 | from . import test 5 | 6 | 7 | _FILE_PATTERN = "*.md" 8 | 9 | 10 | def test_def_input_fn(): 11 | append_argv() 12 | 13 | for def_input_fn in [inputs.def_def_train_input_fn(), 14 | inputs.def_def_eval_input_fn()]: 15 | # Return (tf.Tensor, tf.Tensor) 16 | 17 | features, labels = def_input_fn(lambda queue: (queue.dequeue(),) * 2)() 18 | 19 | assert isinstance(features, tf.Tensor) 20 | assert isinstance(labels, tf.Tensor) 21 | 22 | # Return (dict, dict) 23 | 24 | features, labels = def_input_fn(test.user_input_fn)() 25 | 26 | assert isinstance(features, dict) 27 | assert isinstance(labels, dict) 28 | 29 | _assert_are_instances([*features.keys(), *labels.keys()], str) 30 | _assert_are_instances( 31 | [*features.values(), *labels.values()], tf.Tensor) 32 | 33 | 34 | def _assert_are_instances(objects, klass): 35 | for obj in objects: 36 | assert isinstance(obj, klass) 37 | 38 | 39 | def append_argv(): 40 | test.append_argv("--train_file", _FILE_PATTERN, 41 | "--eval_file", _FILE_PATTERN) 42 | -------------------------------------------------------------------------------- /UNLICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /qnd/evaluate.py: -------------------------------------------------------------------------------- 1 | from .estimator import def_estimator 2 | from .flag import FLAGS, add_output_dir_flag 3 | from .inputs import def_def_infer_input_fn 4 | 5 | 6 | def def_evaluate(batch_inputs=True, prepare_filename_queues=True): 7 | """Define `evaluate()` function. 8 | 9 | See also `help(def_evaluate())`. 10 | 11 | - Args 12 | - `batch_inputs`: Same as `def_train_and_evaluate()`'s. 13 | - `prepare_filename_queues`: Same as `def_train_and_evaluate()`'s. 14 | 15 | - Returns 16 | - `evaluate()` function. 17 | """ 18 | add_output_dir_flag() 19 | 20 | estimator = def_estimator(distributed=False) 21 | def_eval_input_fn = def_def_infer_input_fn(batch_inputs, 22 | prepare_filename_queues) 23 | 24 | def evaluate(model_fn, input_fn): 25 | """Evaluate a model with data sample fed by `input_fn`. 26 | 27 | - Args 28 | - `model_fn`: Same as `train_and_evaluate()`'s. 29 | - `input_fn`: Same as `eval_input_fn` argument of 30 | `train_and_evaluate()`. 31 | 32 | - Returns 33 | - Evaluation results. See `Evaluable` interface in TensorFlow. 34 | """ 35 | return estimator(model_fn, FLAGS.output_dir).evaluate( 36 | input_fn=def_eval_input_fn(input_fn)) 37 | 38 | return evaluate 39 | -------------------------------------------------------------------------------- /qnd/infer.py: -------------------------------------------------------------------------------- 1 | from .estimator import def_estimator 2 | from .flag import FLAGS, add_output_dir_flag 3 | from .inputs import def_def_infer_input_fn 4 | 5 | 6 | def def_infer(batch_inputs=True, prepare_filename_queues=True): 7 | """Define `infer()` function. 8 | 9 | See also `help(def_infer())`. 10 | 11 | - Args 12 | - `batch_inputs`: Same as `def_train_and_evaluate()`'s. 13 | - `prepare_filename_queues`: Same as `def_train_and_evaluate()`'s. 14 | 15 | - Returns 16 | - `infer()` function. 17 | """ 18 | add_output_dir_flag() 19 | 20 | estimator = def_estimator(distributed=False) 21 | def_infer_input_fn = def_def_infer_input_fn(batch_inputs, 22 | prepare_filename_queues) 23 | 24 | def infer(model_fn, input_fn): 25 | """Infer labels or regression values from features of samples fed by 26 | `input_fn`. 27 | 28 | - Args 29 | - `model_fn`: Same as `train_and_evaluate()`'s. 30 | - `input_fn`: Same as `train_input_fn` and `eval_input_fn` 31 | arguments of `train_and_evaluate()` but returns only features. 32 | 33 | - Returns 34 | - Generator of inferred label(s) or regression value(s) for each 35 | sample. 36 | """ 37 | return estimator(model_fn, FLAGS.output_dir).predict( 38 | input_fn=def_infer_input_fn(input_fn)) 39 | 40 | return infer 41 | -------------------------------------------------------------------------------- /qnd/estimator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import typing 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.learn as learn 7 | 8 | from . import util 9 | from .config import def_config 10 | 11 | 12 | def def_estimator(distributed=False): 13 | config = def_config(distributed) 14 | 15 | @util.func_scope 16 | def estimator(model_fn, model_dir): 17 | return tf.contrib.learn.Estimator(_wrap_model_fn(model_fn), 18 | config=config(), 19 | model_dir=model_dir) 20 | 21 | return estimator 22 | 23 | 24 | def _wrap_model_fn(original_model_fn): 25 | @util.func_scope 26 | def model(features, targets, mode): 27 | are_args = functools.partial(util.are_instances, [features, targets]) 28 | def_model_fn = functools.partial(functools.partial, original_model_fn) 29 | 30 | if are_args(tf.Tensor): 31 | model_fn = def_model_fn(features, targets) 32 | elif are_args(dict): 33 | model_fn = def_model_fn(**features, **targets) 34 | elif isinstance(features, tf.Tensor) and targets is None: 35 | model_fn = def_model_fn(features) 36 | elif isinstance(features, dict) and targets is None: 37 | model_fn = def_model_fn(**features) 38 | else: 39 | raise ValueError( 40 | "features and targets should be both tf.Tensor or dict.") 41 | 42 | results = ( 43 | model_fn(mode=mode) 44 | if "mode" in inspect.signature(model_fn).parameters.keys() else 45 | model_fn()) 46 | 47 | return ( 48 | results 49 | if isinstance(results, learn.ModelFnOps) else 50 | learn.ModelFnOps( 51 | mode, 52 | *(results 53 | if isinstance(results, typing.Sequence) else 54 | (results,)))) 55 | 56 | return model 57 | -------------------------------------------------------------------------------- /qnd/experiment.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .flag import FlagAdder 4 | from .estimator import def_estimator 5 | from .inputs import def_def_train_input_fn, def_def_eval_input_fn 6 | 7 | 8 | def def_def_experiment_fn(batch_inputs=True, 9 | prepare_filename_queues=True, 10 | distributed=False): 11 | adder = FlagAdder() 12 | 13 | for mode in [tf.contrib.learn.ModeKeys.TRAIN, 14 | tf.contrib.learn.ModeKeys.EVAL]: 15 | adder.add_flag( 16 | "{}_steps".format(mode), 17 | type=int, 18 | default=(100 if mode == tf.contrib.learn.ModeKeys.EVAL else None), 19 | help="Maximum number of {} steps".format(mode)) 20 | 21 | adder.add_flag( 22 | "min_eval_frequency", type=int, default=1, 23 | help="Minimum evaluation frequency in number of train steps") 24 | 25 | estimator = def_estimator(distributed) 26 | def_train_input_fn = def_def_train_input_fn(batch_inputs, 27 | prepare_filename_queues) 28 | def_eval_input_fn = def_def_eval_input_fn(batch_inputs, 29 | prepare_filename_queues) 30 | 31 | def def_experiment_fn(model_fn, 32 | train_input_fn, 33 | eval_input_fn=None, 34 | serving_input_fn=None): 35 | def experiment_fn(output_dir): 36 | return tf.contrib.learn.Experiment( 37 | estimator(model_fn, output_dir), 38 | def_train_input_fn(train_input_fn), 39 | def_eval_input_fn(eval_input_fn or train_input_fn), 40 | export_strategies=(serving_input_fn and [ 41 | tf.contrib.learn.make_export_strategy(serving_input_fn), 42 | ]), 43 | **adder.flags) 44 | 45 | return experiment_fn 46 | 47 | return def_experiment_fn 48 | -------------------------------------------------------------------------------- /examples/lib/mnist.py: -------------------------------------------------------------------------------- 1 | import qnd 2 | import tensorflow as tf 3 | 4 | 5 | def _preprocess_image(image): 6 | return tf.to_float(image) / 255 - 0.5 7 | 8 | 9 | def read_file(filename_queue): 10 | _, serialized = tf.TFRecordReader().read(filename_queue) 11 | 12 | def scalar_feature(dtype): return tf.FixedLenFeature([], dtype) 13 | 14 | features = tf.parse_single_example(serialized, { 15 | "image_raw": scalar_feature(tf.string), 16 | "label": scalar_feature(tf.int64), 17 | }) 18 | 19 | image = tf.decode_raw(features["image_raw"], tf.uint8) 20 | image.set_shape([28**2]) 21 | 22 | return _preprocess_image(image), features["label"] 23 | 24 | 25 | def serving_input_fn(): 26 | features = { 27 | 'image': _preprocess_image(tf.placeholder(tf.uint8, [None, 28**2])), 28 | } 29 | 30 | return tf.contrib.learn.InputFnOps(features, None, features) 31 | 32 | 33 | def minimize(loss): 34 | return tf.train.AdamOptimizer().minimize( 35 | loss, 36 | tf.contrib.framework.get_global_step()) 37 | 38 | 39 | def def_model(): 40 | qnd.add_flag("hidden_layer_size", type=int, default=64, 41 | help="Hidden layer size") 42 | 43 | def model(image, number=None, mode=None): 44 | h = tf.contrib.layers.fully_connected(image, 45 | qnd.FLAGS.hidden_layer_size) 46 | h = tf.contrib.layers.fully_connected(h, 10, activation_fn=None) 47 | 48 | predictions = tf.argmax(h, axis=1) 49 | 50 | if mode == tf.contrib.learn.ModeKeys.INFER: 51 | return predictions 52 | 53 | loss = tf.reduce_mean( 54 | tf.nn.sparse_softmax_cross_entropy_with_logits(labels=number, 55 | logits=h)) 56 | 57 | return predictions, loss, minimize(loss), { 58 | "accuracy": tf.contrib.metrics.streaming_accuracy(predictions, 59 | number)[1], 60 | } 61 | 62 | return model 63 | -------------------------------------------------------------------------------- /qnd/flag.py: -------------------------------------------------------------------------------- 1 | import gargparse 2 | 3 | 4 | FLAGS = gargparse.ARGS 5 | _FLAG_NAMES = set() 6 | 7 | 8 | def add_flag(name, *args, **kwargs): 9 | """Add a flag. 10 | 11 | Added flags can be accessed by `FLAGS` module variable. 12 | (e.g. `FLAGS.my_flag_name`) 13 | 14 | - Args 15 | - `name`: Flag name. Real flag name will be `"--{}".format(name)`. 16 | - `*args`, `**kwargs`: The rest arguments are the same as 17 | `argparse.ArgumentParser.add_argument()`. 18 | """ 19 | global _FLAG_NAMES 20 | 21 | if 'help' not in kwargs: 22 | kwargs['help'] = '(no description)' 23 | 24 | if name not in _FLAG_NAMES: 25 | _FLAG_NAMES.add(name) 26 | gargparse.add_argument("--" + name, *args, **kwargs) 27 | 28 | 29 | def add_required_flag(name, *args, **kwargs): 30 | """Add a required flag. 31 | 32 | Its interface is the same as `add_flag()` but `required=True` is set by 33 | default. 34 | """ 35 | add_flag(name, *args, required=True, **kwargs) 36 | 37 | 38 | class FlagAdder: 39 | """Manage addition of flags.""" 40 | 41 | def __init__(self): 42 | """Create a `FlagAdder` instance.""" 43 | self._flags = [] 44 | 45 | def add_flag(self, name, *args, **kwargs): 46 | """Add a flag. 47 | 48 | See `add_flag()`. 49 | """ 50 | add_flag(name, *args, **kwargs) 51 | self._flags.append(kwargs.get("dest") or name) 52 | 53 | def add_required_flag(self, name, *args, **kwargs): 54 | """Add a required flag. 55 | 56 | See `add_required_flag()`. 57 | """ 58 | self.add_flag(name, *args, required=True, **kwargs) 59 | 60 | @property 61 | def flags(self): 62 | """Get added flags. 63 | 64 | - Returns 65 | - `dict` of flag names to values added by a `FlagAdder` instance. 66 | """ 67 | return {flag: getattr(FLAGS, flag) for flag in self._flags} 68 | 69 | 70 | def add_output_dir_flag(): 71 | add_flag("output_dir", 72 | default="output", 73 | help="Directory where checkpoint and event files are stored") 74 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require_relative './third/tensorflow-rakefile/tfrake.rb' 2 | include TFRake 3 | 4 | README_FILE = 'README.md'.freeze 5 | 6 | define_tasks('qnd', define_pytest: false) 7 | 8 | task_in_venv :pytest do 9 | vsh 'cd examples/lib && . ./mnist.sh && fetch_dataset' 10 | 11 | Dir.glob(['qnd/**/*_test.py', 'examples/**/*_test.py']).each do |file| 12 | vsh :pytest, file 13 | end 14 | end 15 | 16 | task_in_venv :script_test do 17 | vsh 'python3 test/empty.py' 18 | 19 | Dir.glob('test/*.py').each do |file| 20 | vsh :python3, file, '-h' 21 | end 22 | 23 | distributed_oracle = 'distributed=yes python3 test/oracle.py' 24 | vsh "#{distributed_oracle} -h" 25 | 26 | # Worker hosts should not include a master host. 27 | vsh('!', distributed_oracle.to_s, 28 | '--master_host', 'localhost:4242', 29 | '--worker_hosts', 'localhost:4242', 30 | '--ps_hosts', 'localhost:5151', 31 | '--task_type', 'job', 32 | '--train_file', 'README.md', 33 | '--eval_file', 'setup.py') 34 | end 35 | 36 | %i[mnist_simple mnist_distributed mnist_evaluate mnist_infer].each do |name| 37 | task_in_venv name do 38 | vsh "cd examples/#{name} && ./main.sh" 39 | end 40 | end 41 | 42 | task_in_venv :mnist_full do |t| 43 | [ 44 | nil, 45 | %i[use_eval_input_fn], 46 | # %i[use_serving_input_fn], # TODO: Enable this test when tensorflow/tensorflow #9923 is merged. 47 | %i[use_dict_inputs], 48 | %i[use_model_fn_ops], 49 | %i[self_batch], 50 | %i[self_filename_queue use_eval_input_fn] 51 | ].each do |flags| 52 | vsh( 53 | 'cd', "examples/#{t.name}", '&&', 54 | (flags && flags.map { |flag| "#{flag}=yes" }.join(' ')).to_s, './main.sh' 55 | ) 56 | end 57 | end 58 | 59 | task test: %i[ 60 | pytest 61 | script_test 62 | mnist_simple 63 | mnist_distributed 64 | mnist_evaluate 65 | mnist_infer 66 | mnist_full 67 | ] 68 | 69 | task :readme_examples do 70 | md = File.read(README_FILE) 71 | 72 | command_script = 'train.py' 73 | library_script = 'mnist.py' 74 | 75 | def read_example_file(file) 76 | File.read(File.join('examples/mnist_simple', file)).strip 77 | end 78 | 79 | File.write(README_FILE, %( 80 | #{md.match(/(\A.*## Examples)/m)[0]} 81 | 82 | `#{command_script}` (command script): 83 | 84 | ```python 85 | #{read_example_file command_script} 86 | ``` 87 | 88 | `#{library_script}` (module): 89 | 90 | ```python 91 | #{read_example_file library_script} 92 | ``` 93 | 94 | With the code above, you can create a command with the following interface. 95 | 96 | ``` 97 | #{`#{IN_VENV} cd examples/mnist_simple && python3 #{command_script} -h`.strip} 98 | ``` 99 | 100 | Explore [examples](examples) directory for more information and see how to run 101 | them. 102 | 103 | 104 | #{md.match(/## Caveats.*\Z/m)[0].strip} 105 | ).lstrip) 106 | end 107 | 108 | task doc: %i[pdoc readme_examples] 109 | -------------------------------------------------------------------------------- /examples/mnist_full/mnist_full.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import qnd 5 | import tensorflow as tf 6 | 7 | import mnist 8 | 9 | 10 | BATCH_SIZE = 64 11 | ENV_FLAGS = ['use_eval_input_fn', 'use_serving_input_fn', 'use_dict_inputs', 12 | 'use_model_fn_ops', 'self_batch', 'self_filename_queue'] 13 | 14 | 15 | def env(name): 16 | assert name in ENV_FLAGS 17 | return name in os.environ 18 | 19 | 20 | for name in ENV_FLAGS: 21 | if env(name): 22 | print("Environment variable, {} is set.".format(name)) 23 | 24 | if env("self_filename_queue"): 25 | qnd.add_required_flag("train_file") 26 | qnd.add_required_flag("eval_file") 27 | 28 | 29 | def filename_queue(train=False): 30 | train_filenames = tf.train.match_filenames_once( 31 | qnd.FLAGS.train_file, name="train_filenames") 32 | eval_filenames = tf.train.match_filenames_once( 33 | qnd.FLAGS.eval_file, name="eval_filenames") 34 | 35 | return tf.train.string_input_producer( 36 | train_filenames if train else eval_filenames, 37 | num_epochs=(None if train else 1), 38 | shuffle=train) 39 | 40 | 41 | def train_batch(*tensors): 42 | capacity = BATCH_SIZE * 10 43 | return tf.train.shuffle_batch(tensors, 44 | batch_size=BATCH_SIZE, 45 | capacity=capacity, 46 | min_after_dequeue=capacity // 2) 47 | 48 | 49 | def eval_batch(*tensors): 50 | return tf.train.batch(tensors, batch_size=BATCH_SIZE) 51 | 52 | 53 | def read_file(filename_queue): 54 | image, number = mnist.read_file(filename_queue) 55 | 56 | return (({"image": image}, {"number": number}) 57 | if env("use_dict_inputs") else 58 | (image, number)) 59 | 60 | 61 | mnist_model = mnist.def_model() 62 | 63 | 64 | def model(image, number=None, mode=tf.contrib.learn.ModeKeys.TRAIN): 65 | results = mnist_model(image, number, mode) 66 | 67 | return (tf.contrib.learn.ModelFnOps(mode, *results) 68 | if env("use_model_fn_ops") else 69 | results) 70 | 71 | 72 | train_and_evaluate = qnd.def_train_and_evaluate( 73 | batch_inputs=(not env("self_batch")), 74 | prepare_filename_queues=(not env("self_filename_queue"))) 75 | 76 | 77 | def main(): 78 | logging.getLogger().setLevel(logging.INFO) 79 | 80 | def def_input_fn(batch_fn, filename_queue_fn): 81 | def batch(*tensors): 82 | return batch_fn(*tensors) if env("self_batch") else tensors 83 | 84 | if env("self_filename_queue"): 85 | def input_fn(): 86 | return batch(*read_file(filename_queue_fn())) 87 | else: 88 | def input_fn(filename_queue): 89 | return batch(*read_file(filename_queue)) 90 | 91 | return input_fn 92 | 93 | train_and_evaluate( 94 | model, 95 | def_input_fn(train_batch, lambda: filename_queue(train=True)), 96 | (def_input_fn(eval_batch, lambda: filename_queue()) 97 | if env('use_eval_input_fn') else 98 | None), 99 | (mnist.serving_input_fn if env('use_serving_input_fn') else None)) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /qnd/serve.py: -------------------------------------------------------------------------------- 1 | import http.server 2 | import json 3 | import logging 4 | import queue 5 | import threading 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from .estimator import def_estimator 11 | from .flag import FLAGS, add_flag, add_output_dir_flag 12 | 13 | 14 | def def_serve(): 15 | """Define `serve()` function. 16 | 17 | See also `help(def_serve())`. 18 | 19 | - Returns 20 | - `serve()` function. 21 | """ 22 | add_output_dir_flag() 23 | add_flag('ip_address', default='') 24 | add_flag('port', type=int, default=80) 25 | 26 | estimator = def_estimator(distributed=False) 27 | 28 | def serve(model_fn, preprocess_fn=None, postprocess_fn=None): 29 | """Serve as a HTTP server. 30 | 31 | - Args 32 | - `model_fn`: Same as `train_and_evaluate()`'s. 33 | - `preprocess_fn`: A function to preprocess server request bodies 34 | in JSON. Its first argument is a function which returns the 35 | JSON input. You may need to use `tf.py_func` to create this 36 | function. 37 | - `preprocess_fn`: A function to postprocess server responses of 38 | JSON serializable objects. 39 | """ 40 | server = EstimatorServer( 41 | estimator(model_fn, FLAGS.output_dir), 42 | preprocess_fn, 43 | postprocess_fn) 44 | 45 | class Handler(http.server.BaseHTTPRequestHandler): 46 | def do_POST(self): 47 | self.send_response(200) 48 | self.send_header('Content-type', 'application/json') 49 | self.end_headers() 50 | 51 | inputs = json.loads(self.rfile.read( 52 | int(self.headers['Content-Length']))) 53 | 54 | outputs = server.predict(inputs) 55 | 56 | logging.info('Prediction results: {}'.format(outputs)) 57 | 58 | self.wfile.write(json.dumps(outputs).encode()) 59 | 60 | http.server.HTTPServer((FLAGS.ip_address, FLAGS.port), Handler) \ 61 | .serve_forever() 62 | 63 | return serve 64 | 65 | 66 | def _make_json_serializable(x): 67 | if isinstance(x, np.ndarray): 68 | return x.tolist() 69 | elif isinstance(x, dict): 70 | return {key: _make_json_serializable(value) 71 | for key, value in x.items()} 72 | elif isinstance(x, list): 73 | return [_make_json_serializable(value) for value in x] 74 | 75 | return x 76 | 77 | 78 | class EstimatorServer: 79 | def __init__(self, estimator, preprocess_fn=None, postprocess_fn=None): 80 | self._input_queue = queue.Queue() 81 | self._output_queue = queue.Queue() 82 | 83 | def input_fn(): 84 | return (tf.train.batch(preprocess_fn(self._input_queue.get), 85 | 1, 86 | dynamic_pad=True), 87 | None) 88 | 89 | def target(): 90 | for output in estimator.predict(input_fn=input_fn): 91 | self._output_queue.put(postprocess_fn(output)) 92 | 93 | thread = threading.Thread(target=target, daemon=True) 94 | thread.start() 95 | 96 | def predict(self, inputs): 97 | self._input_queue.put(inputs) 98 | return self._output_queue.get() 99 | -------------------------------------------------------------------------------- /qnd/train_and_evaluate.py: -------------------------------------------------------------------------------- 1 | from tensorflow.contrib.learn.python.learn.learn_runner import run 2 | 3 | from .experiment import def_def_experiment_fn 4 | from .flag import FLAGS, add_output_dir_flag 5 | 6 | 7 | def def_train_and_evaluate(batch_inputs=True, 8 | prepare_filename_queues=True, 9 | distributed=False): 10 | """Define `train_and_evaluate()` function. 11 | 12 | See also `help(def_train_and_evaluate())`. 13 | 14 | - Args 15 | - `batch_inputs`: If `True`, create batches from Tensors returned from 16 | `train_input_fn()` and `eval_input_fn()` and feed them to a model. 17 | - `prepare_filename_queues`: If `True`, create filename queues for 18 | train and eval data based on file paths specified by command line 19 | arguments. 20 | - `distributed`: If `True`, configure command line arguments to train 21 | and evaluate models on a distributed system. 22 | 23 | - Returns 24 | - `train_and_evaluate()` function. 25 | """ 26 | add_output_dir_flag() 27 | 28 | def_experiment_fn = def_def_experiment_fn(batch_inputs, 29 | prepare_filename_queues, 30 | distributed) 31 | 32 | def train_and_evaluate(model_fn, 33 | train_input_fn, 34 | eval_input_fn=None, 35 | serving_input_fn=None): 36 | """Train and evaluate a model with features and targets fed by 37 | `input_fn`s. 38 | 39 | - Args 40 | - `model_fn`: A function to construct a model. 41 | - Types of its arguments must be one of the following: 42 | - `Tensor, ...` 43 | - `Tensor, ..., mode=ModeKeys` 44 | - Types of its return values must be one of the following: 45 | - `Tensor, Tensor, Operation, eval_metric_ops=dict` 46 | (predictions, loss, train_op, and eval_metric_ops (if any)) 47 | - `ModelFnOps` 48 | - `train_input_fn`, `eval_input_fn`: Functions to create input 49 | Tensors fed into the model. If `eval_input_fn` is `None`, 50 | `train_input_fn` will be used instead. 51 | - Types of its arguments must be one of the following: 52 | - `QueueBase` (a filename queue) 53 | - No argument if `prepare_filename_queues` of 54 | `def_train_and_evaluate()` is `False`. 55 | - Types of its return values must be one of the following: 56 | - `Tensor, Tensor` (features and targets) 57 | - `dict, dict` (features and targets) 58 | - The keys in `dict` objects must match with argument 59 | names of `model_fn`. 60 | 61 | - Returns 62 | - Return value of `tf.contrib.learn.python.learn.learn_runner.run()`. 63 | """ 64 | return run(def_experiment_fn(model_fn, 65 | train_input_fn, 66 | eval_input_fn, 67 | serving_input_fn), 68 | FLAGS.output_dir) 69 | 70 | return train_and_evaluate 71 | -------------------------------------------------------------------------------- /qnd/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import tensorflow as tf 6 | 7 | from . import util 8 | from . import flag 9 | from .flag import FLAGS, add_flag, add_required_flag 10 | 11 | 12 | _JOBS = {getattr(tf.contrib.learn.TaskType, name) 13 | for name in ['MASTER', 'PS', 'WORKER']} 14 | 15 | 16 | def def_config(distributed=False): 17 | # ClusterConfig flags 18 | 19 | if distributed: 20 | add_required_flag('master_host', 21 | help='HOSTNAME:PORT pair of a master host') 22 | 23 | def add_hosts_flag(name, **kwargs): 24 | return add_flag( 25 | name, 26 | type=(lambda string: string.split(',')), 27 | default=[], 28 | help='Comma-separated list of $hostname:$port pairs of {}' 29 | .format(name.replace('_', ' ')), 30 | **kwargs) 31 | 32 | add_hosts_flag('ps_hosts', required=True) 33 | add_hosts_flag('worker_hosts') 34 | 35 | add_required_flag('task_type', 36 | help='Must be in {} (aka job)'.format(sorted(_JOBS))) 37 | add_flag('task_index', type=int, default=0, 38 | help='Task index within a job') 39 | 40 | # RunConfig flags 41 | 42 | adder = flag.FlagAdder() 43 | # Default values are based on ones of tf.contrib.learn.RunConfig. 44 | adder.add_flag('num_cores', type=int, default=0, 45 | help='Number of CPU cores used. ' 46 | '0 means use of a default value.') 47 | adder.add_flag('log_device_placement', action='store_true', 48 | help='If specified, log device placement information') 49 | 50 | def saver_help(x): 51 | return 'Number of steps every time of which {} is saved'.format(x) 52 | 53 | adder.add_flag('save_summary_steps', type=int, default=100, 54 | help=saver_help('summary')) 55 | adder.add_flag('save_checkpoints_steps', type=int, 56 | help=saver_help('a model')) 57 | adder.add_flag('keep_checkpoint_max', type=int, default=2049 * 42, 58 | help='Max number of kept checkpoint files') 59 | 60 | @util.func_scope 61 | def config(): 62 | if distributed: 63 | config_env = 'TF_CONFIG' 64 | 65 | if config_env in os.environ and os.environ[config_env]: 66 | logging.warning('A value of the environment variable of ' 67 | 'TensorFlow cluster configuration, {} is ' 68 | 'discarded.' 69 | .format(config_env)) 70 | 71 | if FLAGS.master_host in FLAGS.worker_hosts: 72 | raise ValueError( 73 | 'Master host {} is found in worker hosts {}.' 74 | .format(FLAGS.master_host, FLAGS.worker_hosts)) 75 | 76 | if FLAGS.task_type not in _JOBS: 77 | raise ValueError('Specified task type (job) {} is not in ' 78 | 'available task types {}' 79 | .format(FLAGS.task_type, _JOBS)) 80 | 81 | os.environ[config_env] = json.dumps({ 82 | 'environment': 'cloud', # tf.contrib.learn.Environment.CLOUD 83 | 'cluster': { 84 | 'master': [FLAGS.master_host], 85 | 'ps': FLAGS.ps_hosts, 86 | 'worker': FLAGS.worker_hosts or [FLAGS.master_host], 87 | }, 88 | 'task': { 89 | 'type': FLAGS.task_type, 90 | 'index': FLAGS.task_index, 91 | }, 92 | }) 93 | 94 | return tf.contrib.learn.RunConfig(**adder.flags) 95 | 96 | return config 97 | -------------------------------------------------------------------------------- /qnd/inputs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from . import util 6 | from .flag import FLAGS, add_flag, add_required_flag 7 | 8 | 9 | DEFAULT_BATCH_SIZE = 64 10 | MODES = [tf.contrib.learn.ModeKeys.TRAIN, 11 | tf.contrib.learn.ModeKeys.EVAL, 12 | tf.contrib.learn.ModeKeys.INFER] 13 | 14 | 15 | def _add_file_flag(mode): 16 | assert isinstance(mode, str) 17 | 18 | flag_name = "{}_file".format(mode) 19 | add_required_flag(flag_name, 20 | help="File path of {0} data file(s). " 21 | "A glob is available. (e.g. {0}/*.tfrecords)" 22 | .format(mode)) 23 | return flag_name 24 | 25 | 26 | def def_def_def_input_fn(mode): 27 | assert mode in MODES 28 | 29 | def def_def_input_fn(batch_inputs=True, prepare_filename_queues=True): 30 | if batch_inputs: 31 | add_flag("batch_size", type=int, default=DEFAULT_BATCH_SIZE, 32 | help="Mini-batch size") 33 | add_flag("batch_queue_capacity", 34 | type=int, 35 | # TODO: Set a value for predictable behavior 36 | default=DEFAULT_BATCH_SIZE * 16, 37 | help="Batch queue capacity") 38 | add_flag("num_batch_threads", type=int, default=os.cpu_count(), 39 | help="Number of threads used to create batches") 40 | 41 | if prepare_filename_queues: 42 | _add_file_flag(mode) 43 | filenames_to_queue = def_filenames_to_queue(mode) 44 | 45 | def def_input_fn(user_input_fn): 46 | @util.func_scope 47 | def input_fn(): 48 | inputs = ( 49 | user_input_fn(filenames_to_queue( 50 | tf.matching_files(FLAGS.infer_file) 51 | if mode == tf.contrib.learn.ModeKeys.INFER else 52 | {mode: tf.train.match_filenames_once( 53 | getattr(FLAGS, "{}_file".format(mode)), 54 | name="{}_filenames".format(mode)) 55 | for mode in [tf.contrib.learn.ModeKeys.TRAIN, 56 | tf.contrib.learn.ModeKeys.EVAL]}[mode])) 57 | if prepare_filename_queues else 58 | user_input_fn()) 59 | 60 | inputs = ([inputs] 61 | if type(inputs) in {dict, tf.Tensor} else 62 | inputs) 63 | 64 | _check_inputs(inputs) 65 | 66 | return _batch_inputs(inputs, mode) if batch_inputs else inputs 67 | 68 | return input_fn 69 | 70 | return def_input_fn 71 | 72 | return def_def_input_fn 73 | 74 | 75 | def _batch_inputs(inputs, mode): 76 | input_is_dict = isinstance(inputs[0], dict) 77 | 78 | batched_inputs = _batch_merged_inputs( 79 | _merge_dicts(*inputs) if input_is_dict else inputs, 80 | mode) 81 | 82 | return ([{key: batched_inputs[key] for key in input_.keys()} 83 | for input_ in inputs] 84 | if input_is_dict else 85 | batched_inputs) 86 | 87 | 88 | def _batch_merged_inputs(inputs, mode): 89 | if mode != tf.contrib.learn.ModeKeys.INFER: 90 | inputs = _shuffle(inputs, 91 | capacity=FLAGS.batch_queue_capacity, 92 | num_threads=FLAGS.num_batch_threads, 93 | # TODO: Set a proper value for predictable behavior 94 | min_after_dequeue=FLAGS.batch_queue_capacity // 2) 95 | 96 | return tf.train.batch( 97 | inputs, 98 | batch_size=FLAGS.batch_size, 99 | dynamic_pad=True, 100 | capacity=FLAGS.batch_queue_capacity, 101 | num_threads=FLAGS.num_batch_threads, 102 | allow_smaller_final_batch=(mode != tf.contrib.learn.ModeKeys.TRAIN)) 103 | 104 | 105 | def _shuffle(inputs, capacity, min_after_dequeue, num_threads): 106 | if isinstance(inputs, dict): 107 | names, dtypes = zip(*[(key, input_.dtype) 108 | for key, input_ in inputs.items()]) 109 | else: 110 | dtypes = [input_.dtype for input_ in inputs] 111 | 112 | queue = tf.RandomShuffleQueue( 113 | capacity, 114 | min_after_dequeue, 115 | dtypes, 116 | **({'names': names} if isinstance(inputs, dict) else {})) 117 | 118 | tf.train.add_queue_runner(tf.train.QueueRunner( 119 | queue, 120 | [queue.enqueue(inputs)] * num_threads)) 121 | 122 | shuffled_inputs = queue.dequeue() 123 | 124 | for key, input_ in (inputs.items() 125 | if isinstance(inputs, dict) else 126 | enumerate(inputs)): 127 | shuffled_inputs[key].set_shape(input_.get_shape()) 128 | 129 | return shuffled_inputs 130 | 131 | 132 | def _merge_dicts(*dicts): 133 | return {key: value for dict_ in dicts for key, value in dict_.items()} 134 | 135 | 136 | def _check_inputs(inputs): 137 | if len(inputs) not in {1, 2}: 138 | raise ValueError("Too many return values from input_fn. " 139 | "(returned values: {})" 140 | .format(inputs)) 141 | 142 | if len(inputs) == 2 and not isinstance(inputs[0], type(inputs[1])): 143 | raise ValueError("features and targets should be the same type. " 144 | "(features type: {}, targets type: {})" 145 | .format(*map(type, inputs))) 146 | 147 | if len(inputs) == 2 and isinstance(inputs[0], dict): 148 | duplicate_keys = inputs[0].keys() & inputs[1].keys() 149 | if len(duplicate_keys) != 0: 150 | raise ValueError( 151 | "Some keys of features and targets are duplicate. ({})" 152 | .format(duplicate_keys)) 153 | 154 | 155 | for mode in MODES: 156 | globals()["def_def_{}_input_fn".format(mode)] = def_def_def_input_fn(mode) 157 | 158 | 159 | def def_filenames_to_queue(mode): 160 | assert mode in MODES 161 | 162 | add_flag("filename_queue_capacity", type=int, default=32, 163 | help="Capacity of filename queues of {}, {} and {} data" 164 | .format(*MODES)) 165 | 166 | @util.func_scope 167 | def filenames_to_queue(filenames): 168 | return tf.train.string_input_producer( 169 | filenames, 170 | num_epochs=(None 171 | if mode == tf.contrib.learn.ModeKeys.TRAIN else 172 | 1), 173 | shuffle=(mode != tf.contrib.learn.ModeKeys.INFER), 174 | capacity=FLAGS.filename_queue_capacity) 175 | 176 | return filenames_to_queue 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # tensorflow-qnd 6 | 7 | [![PyPI version](https://badge.fury.io/py/tensorflow-qnd.svg)](https://badge.fury.io/py/tensorflow-qnd) 8 | [![Python versions](https://img.shields.io/pypi/pyversions/tensorflow-qnd.svg)](setup.py) 9 | [![Build Status](https://travis-ci.org/raviqqe/tensorflow-qnd.svg?branch=master)](https://travis-ci.org/raviqqe/tensorflow-qnd) 10 | [![License](https://img.shields.io/badge/license-unlicense-lightgray.svg)](https://unlicense.org) 11 | 12 | Quick and Dirty TensorFlow command framework 13 | 14 | tensorflow-qnd is a TensorFlow framework to create commands to train and 15 | evaluate models and make inference with them. 16 | The framework is built on top of 17 | [tf.contrib.learn module](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/learn/python/learn). 18 | Especially if you are working on research projects using TensorFlow, you can 19 | remove most of boilerplate code with the framework. 20 | All you need to do is to define a model constructor `model_fn` and input 21 | producer(s) `input_fn` to feed a dataset to the model. 22 | 23 | ## Features 24 | 25 | - Command creation for: 26 | - Training and evaluation of models 27 | - Inference of labels or regression values with trained models 28 | - Configuration of command line options to set hyperparameters of models etc. 29 | - [Distributed TensorFlow](https://www.tensorflow.org/how_tos/distributed/) 30 | - Just set an optional argument `distributed` of `def_train_and_evaluate()` 31 | as `True` (i.e. `def_train_and_evaluate(distributed=True)`) to enable it. 32 | - Supports only data parallel training 33 | - Only for training but not for inference 34 | 35 | ## Installation 36 | 37 | Python 3.5+ and TensorFlow 1.1+ are required. 38 | 39 | ``` 40 | pip3 install --user --upgrade tensorflow-qnd 41 | ``` 42 | 43 | ## Usage 44 | 45 | 1. Add command line arguments with `add_flag` and `add_required_flag` functions. 46 | 2. Define a `train_and_evaluate` or `infer` function with 47 | `def_train_and_evaluate` or `def_infer` function 48 | 3. Pass `model_fn` (model constructor) and `input_fn` (input producer) functions 49 | to the defined function. 50 | 4. Run the script with appropriate command line arguments. 51 | 52 | For more information, see [documentation](https://raviqqe.github.io/tensorflow-qnd/qnd). 53 | 54 | ## Examples 55 | 56 | `train.py` (command script): 57 | 58 | ```python 59 | import logging 60 | import os 61 | 62 | import qnd 63 | 64 | import mnist 65 | 66 | 67 | train_and_evaluate = qnd.def_train_and_evaluate( 68 | distributed=("distributed" in os.environ)) 69 | 70 | 71 | model = mnist.def_model() 72 | 73 | 74 | def main(): 75 | logging.getLogger().setLevel(logging.INFO) 76 | train_and_evaluate(model, mnist.read_file) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | ``` 82 | 83 | `mnist.py` (module): 84 | 85 | ```python 86 | import qnd 87 | import tensorflow as tf 88 | 89 | 90 | def _preprocess_image(image): 91 | return tf.to_float(image) / 255 - 0.5 92 | 93 | 94 | def read_file(filename_queue): 95 | _, serialized = tf.TFRecordReader().read(filename_queue) 96 | 97 | def scalar_feature(dtype): return tf.FixedLenFeature([], dtype) 98 | 99 | features = tf.parse_single_example(serialized, { 100 | "image_raw": scalar_feature(tf.string), 101 | "label": scalar_feature(tf.int64), 102 | }) 103 | 104 | image = tf.decode_raw(features["image_raw"], tf.uint8) 105 | image.set_shape([28**2]) 106 | 107 | return _preprocess_image(image), features["label"] 108 | 109 | 110 | def serving_input_fn(): 111 | features = { 112 | 'image': _preprocess_image(tf.placeholder(tf.uint8, [None, 28**2])), 113 | } 114 | 115 | return tf.contrib.learn.InputFnOps(features, None, features) 116 | 117 | 118 | def minimize(loss): 119 | return tf.train.AdamOptimizer().minimize( 120 | loss, 121 | tf.contrib.framework.get_global_step()) 122 | 123 | 124 | def def_model(): 125 | qnd.add_flag("hidden_layer_size", type=int, default=64, 126 | help="Hidden layer size") 127 | 128 | def model(image, number=None, mode=None): 129 | h = tf.contrib.layers.fully_connected(image, 130 | qnd.FLAGS.hidden_layer_size) 131 | h = tf.contrib.layers.fully_connected(h, 10, activation_fn=None) 132 | 133 | predictions = tf.argmax(h, axis=1) 134 | 135 | if mode == tf.contrib.learn.ModeKeys.INFER: 136 | return predictions 137 | 138 | loss = tf.reduce_mean( 139 | tf.nn.sparse_softmax_cross_entropy_with_logits(labels=number, 140 | logits=h)) 141 | 142 | return predictions, loss, minimize(loss), { 143 | "accuracy": tf.contrib.metrics.streaming_accuracy(predictions, 144 | number)[1], 145 | } 146 | 147 | return model 148 | ``` 149 | 150 | With the code above, you can create a command with the following interface. 151 | 152 | ``` 153 | usage: train.py [-h] [--output_dir OUTPUT_DIR] [--train_steps TRAIN_STEPS] 154 | [--eval_steps EVAL_STEPS] 155 | [--min_eval_frequency MIN_EVAL_FREQUENCY] 156 | [--num_cores NUM_CORES] [--log_device_placement] 157 | [--save_summary_steps SAVE_SUMMARY_STEPS] 158 | [--save_checkpoints_steps SAVE_CHECKPOINTS_STEPS] 159 | [--keep_checkpoint_max KEEP_CHECKPOINT_MAX] 160 | [--batch_size BATCH_SIZE] 161 | [--batch_queue_capacity BATCH_QUEUE_CAPACITY] 162 | [--num_batch_threads NUM_BATCH_THREADS] --train_file 163 | TRAIN_FILE [--filename_queue_capacity FILENAME_QUEUE_CAPACITY] 164 | --eval_file EVAL_FILE [--hidden_layer_size HIDDEN_LAYER_SIZE] 165 | 166 | optional arguments: 167 | -h, --help show this help message and exit 168 | --output_dir OUTPUT_DIR 169 | Directory where checkpoint and event files are stored 170 | (default: output) 171 | --train_steps TRAIN_STEPS 172 | Maximum number of train steps (default: None) 173 | --eval_steps EVAL_STEPS 174 | Maximum number of eval steps (default: 100) 175 | --min_eval_frequency MIN_EVAL_FREQUENCY 176 | Minimum evaluation frequency in number of train steps 177 | (default: 1) 178 | --num_cores NUM_CORES 179 | Number of CPU cores used. 0 means use of a default 180 | value. (default: 0) 181 | --log_device_placement 182 | If specified, log device placement information 183 | (default: False) 184 | --save_summary_steps SAVE_SUMMARY_STEPS 185 | Number of steps every time of which summary is saved 186 | (default: 100) 187 | --save_checkpoints_steps SAVE_CHECKPOINTS_STEPS 188 | Number of steps every time of which a model is saved 189 | (default: None) 190 | --keep_checkpoint_max KEEP_CHECKPOINT_MAX 191 | Max number of kept checkpoint files (default: 86058) 192 | --batch_size BATCH_SIZE 193 | Mini-batch size (default: 64) 194 | --batch_queue_capacity BATCH_QUEUE_CAPACITY 195 | Batch queue capacity (default: 1024) 196 | --num_batch_threads NUM_BATCH_THREADS 197 | Number of threads used to create batches (default: 2) 198 | --train_file TRAIN_FILE 199 | File path of train data file(s). A glob is available. 200 | (e.g. train/*.tfrecords) (default: None) 201 | --filename_queue_capacity FILENAME_QUEUE_CAPACITY 202 | Capacity of filename queues of train, eval and infer 203 | data (default: 32) 204 | --eval_file EVAL_FILE 205 | File path of eval data file(s). A glob is available. 206 | (e.g. eval/*.tfrecords) (default: None) 207 | --hidden_layer_size HIDDEN_LAYER_SIZE 208 | Hidden layer size (default: 64) 209 | ``` 210 | 211 | Explore [examples](examples) directory for more information and see how to run 212 | them. 213 | 214 | 215 | ## Caveats 216 | 217 | ### Necessary update of a global step variable 218 | 219 | As done in [examples](examples), you must get a global step variable with 220 | `tf.contrib.framework.get_global_step()` and update (increment) it in each 221 | training step. 222 | 223 | ### Use streaming metrics for `eval_metric_ops` 224 | 225 | When non-streaming metrics such as `tf.contrib.metrics.accuracy` are used in a 226 | return value `eval_metric_ops` of your `model_fn` or as arguments of 227 | `ModelFnOps`, their values will be ones of the last batch in every evaluation 228 | step. 229 | 230 | ## Contributing 231 | 232 | Please send issues about any bugs, feature requests or questions, or pull 233 | requests. 234 | 235 | ## License 236 | 237 | [The Unlicense](https://unlicense.org) 238 | -------------------------------------------------------------------------------- /docs/qnd/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | qnd API documentation 7 | 8 | 9 | 10 | 11 | 551 | 552 | 853 | 854 | 855 | 947 | 948 | 962 | 963 | 964 | Top 965 | 966 |
967 | 968 | 969 | 1011 | 1012 |
1013 | 1014 | 1015 | 1016 | 1017 | 1018 | 1019 |
1020 |

qnd module

1021 |

Quick and Dirty TensorFlow command framework

1022 | 1023 | 1024 |
1025 |
"""Quick and Dirty TensorFlow command framework"""
1026 | 
1027 | from .flag import *
1028 | from .infer import def_infer
1029 | from .train_and_evaluate import def_train_and_evaluate
1030 | from .evaluate import def_evaluate
1031 | from .serve import def_serve
1032 | 
1033 | __all__ = ["FLAGS", "add_flag", "add_required_flag", "FlagAdder",
1034 |            "def_train_and_evaluate", "def_evaluate", "def_infer", "def_serve"]
1035 | __version__ = "0.1.3"
1036 | 
1037 |
1038 | 1039 |
1040 | 1041 |
1042 |

Module variables

1043 |
1044 |

var FLAGS

1045 | 1046 | 1047 |
1048 |
1049 | 1050 |
1051 | 1052 |

Functions

1053 | 1054 |
1055 |
1056 |

def add_flag(

name, *args, **kwargs)

1057 |
1058 | 1059 | 1060 | 1061 | 1062 |

Add a flag.

1063 |

Added flags can be accessed by FLAGS module variable. 1064 | (e.g. FLAGS.my_flag_name)

1065 |
    1066 |
  • Args
      1067 |
    • name: Flag name. Real flag name will be "--{}".format(name).
    • 1068 |
    • *args, **kwargs: The rest arguments are the same as 1069 | argparse.ArgumentParser.add_argument().
    • 1070 |
    1071 |
  • 1072 |
1073 |
1074 | 1075 |
1076 |
def add_flag(name, *args, **kwargs):
1077 |     """Add a flag.
1078 | 
1079 |     Added flags can be accessed by `FLAGS` module variable.
1080 |     (e.g. `FLAGS.my_flag_name`)
1081 | 
1082 |     - Args
1083 |         - `name`: Flag name. Real flag name will be `"--{}".format(name)`.
1084 |         - `*args`, `**kwargs`: The rest arguments are the same as
1085 |             `argparse.ArgumentParser.add_argument()`.
1086 |     """
1087 |     global _FLAG_NAMES
1088 | 
1089 |     if 'help' not in kwargs:
1090 |         kwargs['help'] = '(no description)'
1091 | 
1092 |     if name not in _FLAG_NAMES:
1093 |         _FLAG_NAMES.add(name)
1094 |         gargparse.add_argument("--" + name, *args, **kwargs)
1095 | 
1096 |
1097 |
1098 | 1099 |
1100 | 1101 | 1102 |
1103 |
1104 |

def add_required_flag(

name, *args, **kwargs)

1105 |
1106 | 1107 | 1108 | 1109 | 1110 |

Add a required flag.

1111 |

Its interface is the same as add_flag() but required=True is set by 1112 | default.

1113 |
1114 | 1115 |
1116 |
def add_required_flag(name, *args, **kwargs):
1117 |     """Add a required flag.
1118 | 
1119 |     Its interface is the same as `add_flag()` but `required=True` is set by
1120 |     default.
1121 |     """
1122 |     add_flag(name, *args, required=True, **kwargs)
1123 | 
1124 |
1125 |
1126 | 1127 |
1128 | 1129 | 1130 |
1131 |
1132 |

def def_evaluate(

batch_inputs=True, prepare_filename_queues=True)

1133 |
1134 | 1135 | 1136 | 1137 | 1138 |

Define evaluate() function.

1139 |

See also help(def_evaluate()).

1140 |
    1141 |
  • 1142 |

    Args

    1143 |
      1144 |
    • batch_inputs: Same as def_train_and_evaluate()'s.
    • 1145 |
    • prepare_filename_queues: Same as def_train_and_evaluate()'s.
    • 1146 |
    1147 |
  • 1148 |
  • 1149 |

    Returns

    1150 |
      1151 |
    • evaluate() function.
    • 1152 |
    1153 |
  • 1154 |
1155 |
1156 | 1157 |
1158 |
def def_evaluate(batch_inputs=True, prepare_filename_queues=True):
1159 |     """Define `evaluate()` function.
1160 | 
1161 |     See also `help(def_evaluate())`.
1162 | 
1163 |     - Args
1164 |         - `batch_inputs`: Same as `def_train_and_evaluate()`'s.
1165 |         - `prepare_filename_queues`: Same as `def_train_and_evaluate()`'s.
1166 | 
1167 |     - Returns
1168 |         - `evaluate()` function.
1169 |     """
1170 |     add_output_dir_flag()
1171 | 
1172 |     estimator = def_estimator(distributed=False)
1173 |     def_eval_input_fn = def_def_infer_input_fn(batch_inputs,
1174 |                                                prepare_filename_queues)
1175 | 
1176 |     def evaluate(model_fn, input_fn):
1177 |         """Evaluate a model with data sample fed by `input_fn`.
1178 | 
1179 |         - Args
1180 |             - `model_fn`: Same as `train_and_evaluate()`'s.
1181 |             - `input_fn`: Same as `eval_input_fn` argument of
1182 |                 `train_and_evaluate()`.
1183 | 
1184 |         - Returns
1185 |             - Evaluation results. See `Evaluable` interface in TensorFlow.
1186 |         """
1187 |         return estimator(model_fn, FLAGS.output_dir).evaluate(
1188 |             input_fn=def_eval_input_fn(input_fn))
1189 | 
1190 |     return evaluate
1191 | 
1192 |
1193 |
1194 | 1195 |
1196 | 1197 | 1198 |
1199 |
1200 |

def def_infer(

batch_inputs=True, prepare_filename_queues=True)

1201 |
1202 | 1203 | 1204 | 1205 | 1206 |

Define infer() function.

1207 |

See also help(def_infer()).

1208 |
    1209 |
  • 1210 |

    Args

    1211 |
      1212 |
    • batch_inputs: Same as def_train_and_evaluate()'s.
    • 1213 |
    • prepare_filename_queues: Same as def_train_and_evaluate()'s.
    • 1214 |
    1215 |
  • 1216 |
  • 1217 |

    Returns

    1218 |
      1219 |
    • infer() function.
    • 1220 |
    1221 |
  • 1222 |
1223 |
1224 | 1225 |
1226 |
def def_infer(batch_inputs=True, prepare_filename_queues=True):
1227 |     """Define `infer()` function.
1228 | 
1229 |     See also `help(def_infer())`.
1230 | 
1231 |     - Args
1232 |         - `batch_inputs`: Same as `def_train_and_evaluate()`'s.
1233 |         - `prepare_filename_queues`: Same as `def_train_and_evaluate()`'s.
1234 | 
1235 |     - Returns
1236 |         - `infer()` function.
1237 |     """
1238 |     add_output_dir_flag()
1239 | 
1240 |     estimator = def_estimator(distributed=False)
1241 |     def_infer_input_fn = def_def_infer_input_fn(batch_inputs,
1242 |                                                 prepare_filename_queues)
1243 | 
1244 |     def infer(model_fn, input_fn):
1245 |         """Infer labels or regression values from features of samples fed by
1246 |         `input_fn`.
1247 | 
1248 |         - Args
1249 |             - `model_fn`: Same as `train_and_evaluate()`'s.
1250 |             - `input_fn`: Same as `train_input_fn` and `eval_input_fn`
1251 |                 arguments of `train_and_evaluate()` but returns only features.
1252 | 
1253 |         - Returns
1254 |             - Generator of inferred label(s) or regression value(s) for each
1255 |                 sample.
1256 |         """
1257 |         return estimator(model_fn, FLAGS.output_dir).predict(
1258 |             input_fn=def_infer_input_fn(input_fn))
1259 | 
1260 |     return infer
1261 | 
1262 |
1263 |
1264 | 1265 |
1266 | 1267 | 1268 |
1269 |
1270 |

def def_serve(

)

1271 |
1272 | 1273 | 1274 | 1275 | 1276 |

Define serve() function.

1277 |

See also help(def_serve()).

1278 |
    1279 |
  • Returns
      1280 |
    • serve() function.
    • 1281 |
    1282 |
  • 1283 |
1284 |
1285 | 1286 |
1287 |
def def_serve():
1288 |     """Define `serve()` function.
1289 | 
1290 |     See also `help(def_serve())`.
1291 | 
1292 |     - Returns
1293 |         - `serve()` function.
1294 |     """
1295 |     add_output_dir_flag()
1296 | 
1297 |     create_estimator = def_estimator(distributed=False)
1298 | 
1299 |     def serve(model_fn, preprocess_fn, port=80):
1300 |         """Serve as a HTTP server.
1301 | 
1302 |         - Args
1303 |             - `model_fn`: Same as `train_and_evaluate()`'s.
1304 |             - `preprocess_fn`: A function to preprocess server request bodies
1305 |                 in JSON.
1306 |         """
1307 |         estimator = create_estimator(model_fn, FLAGS.output_dir)
1308 | 
1309 |         class Handler(http.server.BaseHTTPRequestHandler):
1310 |             def do_POST(self):
1311 |                 self.send_response(200)
1312 |                 self.send_header('Content-type', 'application/json')
1313 |                 self.end_headers()
1314 | 
1315 |                 inputs = json.loads(self.rfile.read(
1316 |                     int(self.headers['Content-Length'])))
1317 | 
1318 |                 predictions = _make_json_serializable(
1319 |                     estimator.predict(input_fn=lambda: preprocess_fn(inputs),
1320 |                                       as_iterable=False))
1321 | 
1322 |                 logging.info('Prediction results: {}'.format(predictions))
1323 | 
1324 |                 self.wfile.write(json.dumps(predictions).encode())
1325 | 
1326 |         http.server.HTTPServer(('', port), Handler).serve_forever()
1327 | 
1328 |     return serve
1329 | 
1330 |
1331 |
1332 | 1333 |
1334 | 1335 | 1336 |
1337 |
1338 |

def def_train_and_evaluate(

batch_inputs=True, prepare_filename_queues=True, distributed=False)

1339 |
1340 | 1341 | 1342 | 1343 | 1344 |

Define train_and_evaluate() function.

1345 |

See also help(def_train_and_evaluate()).

1346 |
    1347 |
  • 1348 |

    Args

    1349 |
      1350 |
    • batch_inputs: If True, create batches from Tensors returned from 1351 | train_input_fn() and eval_input_fn() and feed them to a model.
    • 1352 |
    • prepare_filename_queues: If True, create filename queues for 1353 | train and eval data based on file paths specified by command line 1354 | arguments.
    • 1355 |
    • distributed: If True, configure command line arguments to train 1356 | and evaluate models on a distributed system.
    • 1357 |
    1358 |
  • 1359 |
  • 1360 |

    Returns

    1361 |
      1362 |
    • train_and_evaluate() function.
    • 1363 |
    1364 |
  • 1365 |
1366 |
1367 | 1368 |
1369 |
def def_train_and_evaluate(batch_inputs=True,
1370 |                            prepare_filename_queues=True,
1371 |                            distributed=False):
1372 |     """Define `train_and_evaluate()` function.
1373 | 
1374 |     See also `help(def_train_and_evaluate())`.
1375 | 
1376 |     - Args
1377 |         - `batch_inputs`: If `True`, create batches from Tensors returned from
1378 |             `train_input_fn()` and `eval_input_fn()` and feed them to a model.
1379 |         - `prepare_filename_queues`: If `True`, create filename queues for
1380 |             train and eval data based on file paths specified by command line
1381 |             arguments.
1382 |         - `distributed`: If `True`, configure command line arguments to train
1383 |             and evaluate models on a distributed system.
1384 | 
1385 |     - Returns
1386 |         - `train_and_evaluate()` function.
1387 |     """
1388 |     add_output_dir_flag()
1389 | 
1390 |     def_experiment_fn = def_def_experiment_fn(batch_inputs,
1391 |                                               prepare_filename_queues,
1392 |                                               distributed)
1393 | 
1394 |     def train_and_evaluate(model_fn,
1395 |                            train_input_fn,
1396 |                            eval_input_fn=None,
1397 |                            serving_input_fn=None):
1398 |         """Train and evaluate a model with features and targets fed by
1399 |         `input_fn`s.
1400 | 
1401 |         - Args
1402 |             - `model_fn`: A function to construct a model.
1403 |                 - Types of its arguments must be one of the following:
1404 |                     - `Tensor, ...`
1405 |                     - `Tensor, ..., mode=ModeKeys`
1406 |                 - Types of its return values must be one of the following:
1407 |                     - `Tensor, Tensor, Operation, eval_metric_ops=dict`
1408 |                         (predictions, loss, train_op, and eval_metric_ops (if any))
1409 |                     - `ModelFnOps`
1410 |             - `train_input_fn`, `eval_input_fn`: Functions to create input
1411 |                 Tensors fed into the model. If `eval_input_fn` is `None`,
1412 |                 `train_input_fn` will be used instead.
1413 |                 - Types of its arguments must be one of the following:
1414 |                     - `QueueBase` (a filename queue)
1415 |                     - No argument if `prepare_filename_queues` of
1416 |                         `def_train_and_evaluate()` is `False`.
1417 |                 - Types of its return values must be one of the following:
1418 |                     - `Tensor, Tensor` (features and targets)
1419 |                     - `dict, dict` (features and targets)
1420 |                         - The keys in `dict` objects must match with argument
1421 |                             names of `model_fn`.
1422 | 
1423 |         - Returns
1424 |             - Return value of `tf.contrib.learn.python.learn.learn_runner.run()`.
1425 |         """
1426 |         return run(def_experiment_fn(model_fn,
1427 |                                      train_input_fn,
1428 |                                      eval_input_fn,
1429 |                                      serving_input_fn),
1430 |                    FLAGS.output_dir)
1431 | 
1432 |     return train_and_evaluate
1433 | 
1434 |
1435 |
1436 | 1437 |
1438 | 1439 | 1440 |

Classes

1441 | 1442 |
1443 |

class FlagAdder

1444 | 1445 | 1446 |

Manage addition of flags.

1447 |
1448 | 1449 |
1450 |
class FlagAdder:
1451 |     """Manage addition of flags."""
1452 | 
1453 |     def __init__(self):
1454 |         """Create a `FlagAdder` instance."""
1455 |         self._flags = []
1456 | 
1457 |     def add_flag(self, name, *args, **kwargs):
1458 |         """Add a flag.
1459 | 
1460 |         See `add_flag()`.
1461 |         """
1462 |         add_flag(name, *args, **kwargs)
1463 |         self._flags.append(kwargs.get("dest") or name)
1464 | 
1465 |     def add_required_flag(self, name, *args, **kwargs):
1466 |         """Add a required flag.
1467 | 
1468 |         See `add_required_flag()`.
1469 |         """
1470 |         self.add_flag(name, *args, required=True, **kwargs)
1471 | 
1472 |     @property
1473 |     def flags(self):
1474 |         """Get added flags.
1475 | 
1476 |         - Returns
1477 |             - `dict` of flag names to values added by a `FlagAdder` instance.
1478 |         """
1479 |         return {flag: getattr(FLAGS, flag) for flag in self._flags}
1480 | 
1481 |
1482 |
1483 | 1484 | 1485 |
1486 |

Ancestors (in MRO)

1487 |
    1488 |
  • FlagAdder
  • 1489 |
  • builtins.object
  • 1490 |
1491 |

Static methods

1492 | 1493 |
1494 |
1495 |

def __init__(

self)

1496 |
1497 | 1498 | 1499 | 1500 | 1501 |

Create a FlagAdder instance.

1502 |
1503 | 1504 |
1505 |
def __init__(self):
1506 |     """Create a `FlagAdder` instance."""
1507 |     self._flags = []
1508 | 
1509 |
1510 |
1511 | 1512 |
1513 | 1514 | 1515 |
1516 |
1517 |

def add_flag(

self, name, *args, **kwargs)

1518 |
1519 | 1520 | 1521 | 1522 | 1523 |

Add a flag.

1524 |

See add_flag().

1525 |
1526 | 1527 |
1528 |
def add_flag(self, name, *args, **kwargs):
1529 |     """Add a flag.
1530 |     See `add_flag()`.
1531 |     """
1532 |     add_flag(name, *args, **kwargs)
1533 |     self._flags.append(kwargs.get("dest") or name)
1534 | 
1535 |
1536 |
1537 | 1538 |
1539 | 1540 | 1541 |
1542 |
1543 |

def add_required_flag(

self, name, *args, **kwargs)

1544 |
1545 | 1546 | 1547 | 1548 | 1549 |

Add a required flag.

1550 |

See add_required_flag().

1551 |
1552 | 1553 |
1554 |
def add_required_flag(self, name, *args, **kwargs):
1555 |     """Add a required flag.
1556 |     See `add_required_flag()`.
1557 |     """
1558 |     self.add_flag(name, *args, required=True, **kwargs)
1559 | 
1560 |
1561 |
1562 | 1563 |
1564 | 1565 |

Instance variables

1566 |
1567 |

var flags

1568 | 1569 | 1570 | 1571 | 1572 |

Get added flags.

1573 |
    1574 |
  • Returns
      1575 |
    • dict of flag names to values added by a FlagAdder instance.
    • 1576 |
    1577 |
  • 1578 |
1579 |
1580 |
1581 | 1582 |
1583 |
1584 |
1585 | 1586 |
1587 | 1588 |
1589 |
1590 | 1601 |
1602 | 1603 | 1604 | --------------------------------------------------------------------------------