├── third_party └── py │ ├── BUILD │ ├── BUILD.tpl │ └── python_configure.bzl ├── tensorflow_estimator ├── python │ └── estimator │ │ ├── canned │ │ ├── __init__.py │ │ ├── v1 │ │ │ ├── __init__.py │ │ │ ├── linear_estimator_test_v1.py │ │ │ └── dnn_estimator_test_v1.py │ │ ├── testdata │ │ │ └── wire_vocabulary.txt │ │ ├── linear_optimizer │ │ │ ├── __init__.py │ │ │ ├── BUILD │ │ │ └── python │ │ │ │ └── utils │ │ │ │ └── sharded_mutable_dense_hashtable_test.py │ │ ├── prediction_keys.py │ │ ├── metric_keys.py │ │ ├── timeseries │ │ │ ├── feature_keys.py │ │ │ ├── model_utils.py │ │ │ ├── state_management.py │ │ │ ├── math_utils_test.py │ │ │ ├── ar_model_test.py │ │ │ └── BUILD │ │ ├── optimizers_test.py │ │ ├── optimizers_test_v2.py │ │ ├── canned_estimator_ds_integration_test.py │ │ └── optimizers.py │ │ ├── export │ │ ├── __init__.py │ │ ├── export_output.py │ │ └── export_lib.py │ │ ├── head │ │ ├── __init__.py │ │ └── head_utils.py │ │ ├── hooks │ │ ├── __init__.py │ │ ├── basic_session_run_hooks.py │ │ ├── session_run_hook.py │ │ └── fake_summary_writer.py │ │ ├── inputs │ │ ├── __init__.py │ │ ├── queues │ │ │ ├── __init__.py │ │ │ └── feeding_queue_runner_test.py │ │ ├── inputs.py │ │ └── pandas_io.py │ │ ├── tools │ │ ├── __init__.py │ │ └── analytics.py │ │ ├── tpu │ │ ├── __init__.py │ │ ├── error_handling_test.py │ │ ├── spatial_partitioning_api.md │ │ ├── util.py │ │ └── error_handling.py │ │ ├── api │ │ ├── generator_wrapper.py │ │ ├── extractor_wrapper.py │ │ └── BUILD │ │ ├── mode_keys.py │ │ ├── estimator_export_test.py │ │ ├── estimator_export.py │ │ ├── util_test.py │ │ ├── estimator_lib.py │ │ ├── util.py │ │ ├── extenders.py │ │ ├── extenders_test.py │ │ ├── object_checkpointing_test.py │ │ ├── tf_estimator_doctest.py │ │ ├── gc_test.py │ │ └── gc.py ├── estimator.bzl ├── tools │ └── pip_package │ │ ├── BUILD │ │ ├── setup.py │ │ ├── build_pip_package.sh │ │ └── create_pip_helper.py └── BUILD ├── BUILD ├── .bazelrc ├── .gitignore ├── WORKSPACE ├── CONTRIBUTING.md └── README.md /third_party/py/BUILD: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/export/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/head/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/inputs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/v1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/inputs/queues/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/testdata/wire_vocabulary.txt: -------------------------------------------------------------------------------- 1 | omar 2 | stringer 3 | marlo 4 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | # Description: Tensorflow Estimator. 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | 2 | # Default options should come above this line 3 | 4 | # Put user-specific options in .bazelrc.user 5 | try-import %workspace%/.bazelrc.user 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # editor files 2 | *.swp 3 | *~ 4 | .vscode/ 5 | .DS_Store 6 | 7 | # bazel 8 | /.bazelrc.user 9 | /bazel-* 10 | 11 | # python 12 | *.pyc 13 | *.pyo 14 | __pycache__ 15 | *.whl 16 | .ipynb_checkpoints 17 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "org_tensorflow_estimator") 2 | 3 | # Use a custom python toolchain to make sure we always use the python binary 4 | # provided by PYTHON_BIN_PATH. 5 | # This is required due to https://github.com/bazelbuild/bazel/issues/7899, 6 | # because --python_path will not work since Bazel 0.27 7 | load("//third_party/py:python_configure.bzl", "python_configure") 8 | 9 | python_configure(name = "local_config_py_toolchain") 10 | 11 | register_toolchains("@local_config_py_toolchain//:py_toolchain") 12 | -------------------------------------------------------------------------------- /tensorflow_estimator/estimator.bzl: -------------------------------------------------------------------------------- 1 | """Estimator common skylark macros.""" 2 | 3 | # Macro to run Estimator py_tests against pip installation. 4 | def py_test(deps = [], **kwargs): 5 | native.py_test( 6 | deps = select({ 7 | "//conditions:default": deps, 8 | "//tensorflow_estimator:no_estimator_py_deps": [], 9 | }), 10 | **kwargs 11 | ) 12 | 13 | def tpu_py_test(**kwargs): 14 | # Skip the tpu test for Estimator oss. 15 | pass 16 | 17 | # We are never indexing generated code in the OSS build, but still 18 | # return a select() for consistency. 19 | def if_indexing_source_code( 20 | if_true, # @unused 21 | if_false): 22 | """Return a select() on whether or not we are building for source code indexing.""" 23 | return select({ 24 | "//conditions:default": if_false, 25 | }) 26 | -------------------------------------------------------------------------------- /third_party/py/BUILD.tpl: -------------------------------------------------------------------------------- 1 | licenses(["restricted"]) 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | # Point both runtimes to the same python binary to ensure we always 6 | # use the python binary specified by ./configure.py script. 7 | load("@bazel_tools//tools/python:toolchain.bzl", "py_runtime_pair") 8 | 9 | py_runtime( 10 | name = "py2_runtime", 11 | interpreter_path = "%{PYTHON_BIN_PATH}", 12 | python_version = "PY2", 13 | ) 14 | 15 | py_runtime( 16 | name = "py3_runtime", 17 | interpreter_path = "%{PYTHON_BIN_PATH}", 18 | python_version = "PY3", 19 | ) 20 | 21 | py_runtime_pair( 22 | name = "py_runtime_pair", 23 | py2_runtime = ":py2_runtime", 24 | py3_runtime = ":py3_runtime", 25 | ) 26 | 27 | toolchain( 28 | name = "py_toolchain", 29 | toolchain = ":py_runtime_pair", 30 | toolchain_type = "@bazel_tools//tools/python:toolchain_type", 31 | ) 32 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/tpu/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """TPUEstimator.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | -------------------------------------------------------------------------------- /tensorflow_estimator/tools/pip_package/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//tensorflow_estimator:internal"]) 2 | 3 | # Description: 4 | # Tools for building the TensorFlow pip package. 5 | 6 | COMMON_PIP_DEPS = [ 7 | "//tensorflow_estimator", 8 | # Need to include testing libraries in pip package so our pip 9 | # release tests can run. (see py_test rule in estimator.bzl for more context). 10 | # Essentially, everything needed to run the test (except the test file itself) 11 | # must be contained in the pip package since we strip away all deps. 12 | "//tensorflow_estimator/python/estimator:dnn_testing_utils", 13 | "//tensorflow_estimator/python/estimator:dnn_testing_utils_v1", 14 | "//tensorflow_estimator/python/estimator:linear_testing_utils", 15 | "//tensorflow_estimator/python/estimator:linear_testing_utils_v1", 16 | ] 17 | 18 | sh_binary( 19 | name = "build_pip_package", 20 | srcs = ["build_pip_package.sh"], 21 | data = COMMON_PIP_DEPS, 22 | ) 23 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/api/generator_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Thin wrapper to call TensorFlow's API generator script.""" 16 | from absl import app 17 | from tensorflow.python.tools.api.generator2.generator import generator 18 | 19 | if __name__ == "__main__": 20 | app.run(generator.main) 21 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/api/extractor_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Thin wrapper to call TensorFlow's API extractor script.""" 16 | from absl import app 17 | 18 | from tensorflow.python.tools.api.generator2.extractor import extractor 19 | 20 | if __name__ == "__main__": 21 | app.run(extractor.main) 22 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/linear_optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Ops for training linear models. 16 | 17 | ## This package provides optimizers to train linear models. 18 | 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | from tensorflow.python.util.all_util import remove_undocumented 25 | remove_undocumented(__name__) 26 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/mode_keys.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Exporting ModeKeys to tf.estimator namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow.python.saved_model.model_utils.mode_keys import EstimatorModeKeys as ModeKeys 22 | from tensorflow_estimator.python.estimator.estimator_export import estimator_export 23 | 24 | estimator_export('estimator.ModeKeys')(ModeKeys) 25 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/inputs/inputs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utility methods to create simple input_fns.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import,line-too-long 22 | from tensorflow_estimator.python.estimator.inputs.numpy_io import numpy_input_fn 23 | from tensorflow_estimator.python.estimator.inputs.pandas_io import pandas_input_fn 24 | 25 | # pylint: enable=unused-import,line-too-long 26 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/tools/analytics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Analytics helpers library.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | 21 | def track_usage(tool_id, tags): 22 | """No usage tracking for external library. 23 | 24 | Args: 25 | tool_id: A string identifier for tool to be tracked. 26 | tags: list of string tags that will be added to the tracking. 27 | """ 28 | del tool_id, tags # Unused externally. 29 | 30 | 31 | def track_numerical_issues(exc_info): 32 | """No tracking for external library. 33 | 34 | Args: 35 | exc_info: Output from `sys.exc_info` (type, value, traceback) 36 | """ 37 | del exc_info 38 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/prediction_keys.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Enum for model prediction keys.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | class PredictionKeys(object): 23 | """Enum for canonical model prediction keys. 24 | 25 | The following values are defined: 26 | PREDICTIONS: Used by models that predict values, such as regressor models. 27 | """ 28 | 29 | CLASSES = 'classes' 30 | CLASS_IDS = 'class_ids' 31 | ALL_CLASSES = 'all_classes' 32 | ALL_CLASS_IDS = 'all_class_ids' 33 | LOGISTIC = 'logistic' 34 | LOGITS = 'logits' 35 | PREDICTIONS = 'predictions' 36 | PROBABILITIES = 'probabilities' 37 | TOP_K = 'top_k' 38 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at the end). 2 | 3 | ### Before you contribute 4 | 5 | Before we can use your code, you must sign the 6 | [Google Individual Contributor License Agreement] 7 | (https://cla.developers.google.com/about/google-individual) 8 | (CLA), which you can do online. The CLA is necessary mainly because you own the 9 | copyright to your changes, even after your contribution becomes part of our 10 | codebase, so we need your permission to use and distribute your code. We also 11 | need to be sure of various other things—for instance that you'll tell us if you 12 | know that your code infringes on other people's patents. You don't have to sign 13 | the CLA until after you've submitted your code for review and a member has 14 | approved it, but you must do it before we can put your code into our codebase. 15 | Before you start working on a larger contribution, you should get in touch with 16 | us first through the issue tracker with your idea so that we can help out and 17 | possibly guide you. Coordinating up front makes it much easier to avoid 18 | frustration later on. 19 | 20 | ### Code reviews 21 | 22 | All submissions, including submissions by project members, require review. We 23 | use Github pull requests for this purpose. 24 | 25 | ### The small print 26 | 27 | Contributions made by corporations are covered by a different agreement than 28 | the one above, the 29 | [Software Grant and Corporate Contributor License Agreement] 30 | (https://cla.developers.google.com/about/google-corporate). 31 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/tpu/error_handling_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Error Handling tests.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from tensorflow_estimator.python.estimator.tpu import error_handling 23 | 24 | 25 | class ErrorHandlingTest(tf.test.TestCase): 26 | 27 | def catch_and_raise(self, error): 28 | er = error_handling.ErrorRendezvous(1) 29 | with er.catch_errors(source='infeed'): 30 | raise error 31 | er.raise_errors() 32 | 33 | def testInterestingError(self): 34 | with self.assertRaises(tf.errors.InternalError): 35 | self.catch_and_raise(tf.errors.InternalError('message', None, None)) 36 | 37 | def testIgnoredError(self): 38 | """Expect no error to be raised.""" 39 | self.catch_and_raise(tf.errors.AbortedError('message', None, None)) 40 | 41 | 42 | if __name__ == '__main__': 43 | tf.test.main() 44 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/export/export_output.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Classes for different types of export output.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from tensorflow.python.saved_model.model_utils.export_output import _SupervisedOutput 23 | from tensorflow.python.saved_model.model_utils.export_output import ClassificationOutput 24 | from tensorflow.python.saved_model.model_utils.export_output import EvalOutput 25 | from tensorflow.python.saved_model.model_utils.export_output import ExportOutput 26 | from tensorflow.python.saved_model.model_utils.export_output import PredictOutput 27 | from tensorflow.python.saved_model.model_utils.export_output import RegressionOutput 28 | from tensorflow.python.saved_model.model_utils.export_output import TrainOutput 29 | # pylint: enable=unused-import 30 | from tensorflow_estimator.python.estimator.estimator_export import estimator_export 31 | 32 | estimator_export('estimator.export.ExportOutput')(ExportOutput) 33 | estimator_export('estimator.export.ClassificationOutput')(ClassificationOutput) 34 | estimator_export('estimator.export.RegressionOutput')(RegressionOutput) 35 | estimator_export('estimator.export.PredictOutput')(PredictOutput) 36 | estimator_export('estimator.export.EvalOutput')(EvalOutput) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ----------------- 2 | | **`Documentation`** | 3 | |-----------------| 4 | | [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/python/tf/estimator) | 5 | 6 | TensorFlow Estimator is a high-level TensorFlow API that greatly simplifies machine learning programming. 7 | Estimators encapsulate training, evaluation, prediction, and exporting for your model. 8 | 9 | ## Getting Started 10 | 11 | See our Estimator 12 | [getting started guide](https://www.tensorflow.org/guide/estimator) for an 13 | introduction to the Estimator APIs. 14 | 15 | ## Installation 16 | 17 | `tf.estimator` is installed when you install the TensorFlow pip package. See 18 | [Installing TensorFlow](https://www.tensorflow.org/install) for instructions. 19 | 20 | ## Developing 21 | 22 | If you want to build TensorFlow Estimator locally, you will need to 23 | [install Bazel](https://docs.bazel.build/versions/master/install.html) and 24 | [install TensorFlow](https://www.tensorflow.org/install/pip). 25 | 26 | ```sh 27 | # To build TensorFlow Estimator whl file. 28 | bazel build //tensorflow_estimator/tools/pip_package:build_pip_package 29 | bazel-bin/tensorflow_estimator/tools/pip_package/build_pip_package /tmp/estimator_pip 30 | 31 | # To run all Estimator tests 32 | bazel test //tensorflow_estimator/... 33 | ``` 34 | 35 | ## Contribution guidelines 36 | 37 | If you want to contribute to TensorFlow Estimator, be sure to review the [contribution 38 | guidelines](CONTRIBUTING.md). 39 | 40 | **Note that this repository is included as a component of the main TensorFlow 41 | package, and any issues encountered while using Estimators should be filed under 42 | [TensorFlow GitHub Issues](https://github.com/tensorflow/tensorflow/issues), 43 | as we do not separately track issues in this repository. You can link this 44 | repository in any issues created as necessary.** 45 | 46 | Please see 47 | [TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions 48 | and discussion and please direct specific questions to 49 | [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow). 50 | 51 | ## License 52 | 53 | [Apache License 2.0](LICENSE) 54 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/api/BUILD: -------------------------------------------------------------------------------- 1 | # Placeholder: load aliased py_binary 2 | load("//tensorflow_estimator/python/estimator/api:api_gen.bzl", "ESTIMATOR_API_INIT_FILES_V1", "ESTIMATOR_API_INIT_FILES_V2", "generate_apis") 3 | 4 | package(default_visibility = ["//tensorflow_estimator:internal"]) 5 | 6 | licenses(["notice"]) 7 | 8 | # This flag specifies whether Estimator 2.0 API should be built instead 9 | # of 1.* API. Note that Estimator 2.0 API is currently under development. 10 | config_setting( 11 | name = "api_version_2", 12 | define_values = {"estimator_api_version": "2"}, 13 | ) 14 | 15 | py_binary( 16 | name = "extractor_wrapper", 17 | srcs = ["extractor_wrapper.py"], 18 | visibility = ["//visibility:public"], 19 | deps = [ 20 | "//tensorflow_estimator/python/estimator:expect_absl_installed", # absl:app 21 | ], 22 | ) 23 | 24 | py_binary( 25 | name = "generator_wrapper", 26 | srcs = ["generator_wrapper.py"], 27 | visibility = ["//visibility:public"], 28 | deps = [ 29 | "//tensorflow_estimator/python/estimator:expect_absl_installed", # absl:app 30 | ], 31 | ) 32 | 33 | genrule( 34 | name = "estimator_python_api_gen", 35 | srcs = select({ 36 | "api_version_2": ["_v2/v2.py"], 37 | "//conditions:default": ["_v1/v1.py"], 38 | }), 39 | outs = ["__init__.py"], 40 | cmd = select({ 41 | "api_version_2": "cp $(location :_v2/v2.py) $(OUTS)", 42 | "//conditions:default": "cp $(location :_v1/v1.py) $(OUTS)", 43 | }), 44 | ) 45 | 46 | generate_apis( 47 | name = "estimator_python_api_gen_compat_v1", 48 | api_version = 1, 49 | output_dir = "_v1/", 50 | output_files = ESTIMATOR_API_INIT_FILES_V1, 51 | output_package = "tensorflow_estimator.python.estimator.api._v1", 52 | root_file_name = "v1.py", 53 | visibility = ["//visibility:public"], 54 | ) 55 | 56 | generate_apis( 57 | name = "estimator_python_api_gen_compat_v2", 58 | api_version = 2, 59 | output_dir = "_v2/", 60 | output_files = ESTIMATOR_API_INIT_FILES_V2, 61 | output_package = "tensorflow_estimator.python.estimator.api._v2", 62 | root_file_name = "v2.py", 63 | visibility = ["//visibility:public"], 64 | ) 65 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/estimator_export_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """estimator_export tests.""" 16 | 17 | import sys 18 | import tensorflow as tf 19 | 20 | from tensorflow.python.platform import tf_logging as logging 21 | from tensorflow.python.util import tf_export 22 | # pylint: disable=g-deprecated-tf-checker 23 | from tensorflow_estimator.python.estimator import estimator_export 24 | 25 | 26 | class TestClass(object): 27 | pass 28 | 29 | 30 | class ValidateExportTest(tf.test.TestCase): 31 | """Tests for estimator_export class.""" 32 | 33 | def setUp(self): 34 | super().setUp() 35 | self._modules = [] 36 | 37 | def tearDown(self): 38 | super().tearDown() 39 | for name in self._modules: 40 | del sys.modules[name] 41 | self._modules = [] 42 | if hasattr(TestClass, '_estimator_api_names'): 43 | del TestClass._estimator_api_names 44 | if hasattr(TestClass, '_estimator_api_names_v1'): 45 | del TestClass._estimator_api_names_v1 46 | 47 | @tf.compat.v1.test.mock.patch.object( 48 | logging, 'warning', autospec=True 49 | ) 50 | def testExportDeprecated(self, mock_warning): 51 | export_decorator = estimator_export.estimator_export('estimator.TestClass') 52 | export_decorator(TestClass) 53 | 54 | # Deprecation should trigger a runtime warning 55 | TestClass() 56 | self.assertEqual(1, mock_warning.call_count) 57 | # Deprecation should only warn once, upon first call 58 | TestClass() 59 | self.assertEqual(1, mock_warning.call_count) 60 | 61 | 62 | if __name__ == '__main__': 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /third_party/py/python_configure.bzl: -------------------------------------------------------------------------------- 1 | """Repository rule for Python autoconfiguration. 2 | 3 | `python_configure` depends on the following environment variables: 4 | 5 | * `PYTHON_BIN_PATH`: location of python binary. 6 | """ 7 | 8 | _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" 9 | 10 | def _tpl(repository_ctx, tpl, substitutions = {}, out = None): 11 | if not out: 12 | out = tpl 13 | repository_ctx.template( 14 | out, 15 | Label("//third_party/py:%s.tpl" % tpl), 16 | substitutions, 17 | ) 18 | 19 | def _fail(msg): 20 | """Output failure message when auto configuration fails.""" 21 | red = "\033[0;31m" 22 | no_color = "\033[0m" 23 | fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) 24 | 25 | def _get_python_bin(repository_ctx): 26 | """Gets the python bin path.""" 27 | python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) 28 | if python_bin != None: 29 | return python_bin 30 | python_bin_path = repository_ctx.which("python") 31 | if python_bin_path != None: 32 | return str(python_bin_path) 33 | _fail("Cannot find python in PATH, please make sure " + 34 | "python is installed and add its directory in PATH, or --define " + 35 | "%s='/something/else'.\nPATH=%s" % ( 36 | _PYTHON_BIN_PATH, 37 | repository_ctx.os.environ.get("PATH", ""), 38 | )) 39 | 40 | def _create_local_python_repository(repository_ctx): 41 | """Creates the repository containing files set up to build with Python.""" 42 | python_bin = _get_python_bin(repository_ctx) 43 | _tpl(repository_ctx, "BUILD", { 44 | "%{PYTHON_BIN_PATH}": python_bin, 45 | }) 46 | 47 | def _python_autoconf_impl(repository_ctx): 48 | """Implementation of the python_autoconf repository rule.""" 49 | _create_local_python_repository(repository_ctx) 50 | 51 | python_configure = repository_rule( 52 | implementation = _python_autoconf_impl, 53 | environ = [ 54 | _PYTHON_BIN_PATH, 55 | ], 56 | ) 57 | """Detects and configures the local Python toolchain. 58 | 59 | Add the following to your WORKSPACE FILE: 60 | 61 | ```python 62 | load("//third_party/py:python_configure.bzl", "python_configure") 63 | 64 | python_configure(name = "local_config_py_toolchain") 65 | 66 | register_toolchains("@local_config_py_toolchain//:py_toolchain") 67 | ``` 68 | 69 | Args: 70 | name: A unique name for this workspace rule. 71 | """ 72 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/metric_keys.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Enum for model prediction keys.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow_estimator.python.estimator import model_fn 22 | 23 | 24 | class MetricKeys(object): 25 | """Metric key strings.""" 26 | LOSS = model_fn.LOSS_METRIC_KEY 27 | LOSS_MEAN = model_fn.AVERAGE_LOSS_METRIC_KEY 28 | LOSS_REGULARIZATION = 'regularization_loss' 29 | 30 | ACCURACY = 'accuracy' 31 | PRECISION = 'precision' 32 | RECALL = 'recall' 33 | # This is the best the model could do by always predicting one class. 34 | # Should be < ACCURACY in a trained model. 35 | ACCURACY_BASELINE = 'accuracy_baseline' 36 | AUC = 'auc' 37 | AUC_PR = 'auc_precision_recall' 38 | LABEL_MEAN = 'label/mean' 39 | PREDICTION_MEAN = 'prediction/mean' 40 | 41 | # The following require a threshold applied, should be float in range (0, 1). 42 | ACCURACY_AT_THRESHOLD = 'accuracy/positive_threshold_%g' 43 | PRECISION_AT_THRESHOLD = 'precision/positive_threshold_%g' 44 | RECALL_AT_THRESHOLD = 'recall/positive_threshold_%g' 45 | 46 | # The following require a constraint on a competing metric to be applied, 47 | # float in range (0, 1). 48 | PRECISION_AT_RECALL = 'precision_at_recall_%g' 49 | RECALL_AT_PRECISION = 'recall_at_precision_%g' 50 | SENSITIVITY_AT_SPECIFICITY = 'sensitivity_at_specificity_%g' 51 | SPECIFICITY_AT_SENSITIVITY = 'specificity_at_sensitivity_%g' 52 | 53 | # The following require a class id applied. 54 | PROBABILITY_MEAN_AT_CLASS = 'probability_mean/class%d' 55 | AUC_AT_CLASS = 'auc/class%d' 56 | AUC_PR_AT_CLASS = 'auc_precision_recall/class%d' 57 | 58 | # The following require a class name applied. 59 | PROBABILITY_MEAN_AT_NAME = 'probability_mean/%s' 60 | AUC_AT_NAME = 'auc/%s' 61 | AUC_PR_AT_NAME = 'auc_precision_recall/%s' 62 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/timeseries/feature_keys.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Commonly used special feature names for time series models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | class State(object): 25 | """Key formats for accepting/returning state.""" 26 | # The model-dependent state to start from, as a single tuple. 27 | STATE_TUPLE = "start_tuple" 28 | # Same meaning as STATE_TUPLE, but prefixes keys representing flattened model 29 | # state rather than mapping to a nested tuple containing model state, 30 | # primarily for use with export_saved_model. 31 | STATE_PREFIX = "model_state" 32 | 33 | 34 | class Times(object): 35 | """Key formats for accepting/returning times.""" 36 | # An increasing vector of integers. 37 | TIMES = "times" 38 | 39 | 40 | class Values(object): 41 | """Key formats for accepting/returning values.""" 42 | # Floating point, with one or more values corresponding to each time in TIMES. 43 | VALUES = "values" 44 | 45 | 46 | class TrainEvalFeatures(Times, Values): 47 | """Feature names used during training and evaluation.""" 48 | pass 49 | 50 | 51 | class PredictionFeatures(Times, State): 52 | """Feature names used during prediction.""" 53 | pass 54 | 55 | 56 | class FilteringFeatures(Times, Values, State): 57 | """Special feature names for filtering.""" 58 | pass 59 | 60 | 61 | class PredictionResults(Times): 62 | """Keys returned when predicting (not comprehensive).""" 63 | pass 64 | 65 | 66 | class FilteringResults(Times, State): 67 | """Keys returned from evaluation/filtering.""" 68 | pass 69 | 70 | 71 | class SavedModelLabels(object): 72 | """Names of signatures exported with export_saved_model.""" 73 | PREDICT = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY 74 | FILTER = "filter" 75 | COLD_START_FILTER = "cold_start_filter" 76 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/linear_optimizer/BUILD: -------------------------------------------------------------------------------- 1 | # Placeholder: load py_library 2 | load("//tensorflow_estimator:estimator.bzl", "py_test") 3 | 4 | package(default_visibility = ["//tensorflow_estimator:__subpackages__"]) 5 | 6 | licenses(["notice"]) 7 | 8 | py_test( 9 | name = "sdca_test", 10 | size = "medium", 11 | srcs = ["python/sdca_test.py"], 12 | python_version = "PY3", 13 | shard_count = 4, 14 | srcs_version = "PY3", 15 | deps = [ 16 | "//tensorflow_estimator/python/estimator", 17 | "//tensorflow_estimator/python/estimator:expect_proto_cpp_installed", 18 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 19 | "//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed", 20 | "//tensorflow_estimator/python/estimator:linear", 21 | ], 22 | ) 23 | 24 | py_library( 25 | name = "sdca_ops_py", 26 | srcs = [ 27 | "__init__.py", 28 | "python/utils/sdca_ops.py", 29 | ], 30 | srcs_version = "PY3", 31 | deps = [ 32 | ":sharded_mutable_dense_hashtable_py", 33 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 34 | "//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed", 35 | ], 36 | ) 37 | 38 | py_test( 39 | name = "sdca_ops_test", 40 | size = "medium", 41 | srcs = ["python/utils/sdca_ops_test.py"], 42 | python_version = "PY3", 43 | shard_count = 4, 44 | srcs_version = "PY3", 45 | tags = [ 46 | "no_gpu", 47 | "no_pip_gpu", 48 | ], 49 | deps = [ 50 | ":sdca_ops_py", 51 | "//tensorflow_estimator/python/estimator:expect_proto_cpp_installed", 52 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 53 | "//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed", 54 | ], 55 | ) 56 | 57 | py_library( 58 | name = "sharded_mutable_dense_hashtable_py", 59 | srcs = ["python/utils/sharded_mutable_dense_hashtable.py"], 60 | srcs_version = "PY3", 61 | deps = [ 62 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 63 | "//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed", 64 | ], 65 | ) 66 | 67 | py_test( 68 | name = "sharded_mutable_dense_hashtable_test", 69 | size = "small", 70 | srcs = ["python/utils/sharded_mutable_dense_hashtable_test.py"], 71 | python_version = "PY3", 72 | srcs_version = "PY3", 73 | deps = [ 74 | ":sharded_mutable_dense_hashtable_py", 75 | "//tensorflow_estimator/python/estimator:expect_proto_cpp_installed", 76 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 77 | "//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed", 78 | ], 79 | ) 80 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Some common SessionRunHook classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook 22 | from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverListener 23 | from tensorflow.python.training.basic_session_run_hooks import FeedFnHook 24 | from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook 25 | from tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterHook 26 | from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook 27 | from tensorflow.python.training.basic_session_run_hooks import NanLossDuringTrainingError 28 | from tensorflow.python.training.basic_session_run_hooks import NanTensorHook 29 | from tensorflow.python.training.basic_session_run_hooks import ProfilerHook 30 | from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer 31 | from tensorflow.python.training.basic_session_run_hooks import StepCounterHook 32 | from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook 33 | from tensorflow.python.training.basic_session_run_hooks import SummarySaverHook 34 | from tensorflow_estimator.python.estimator.estimator_export import estimator_export 35 | 36 | estimator_export("estimator.SecondOrStepTimer")(SecondOrStepTimer) 37 | estimator_export("estimator.LoggingTensorHook")(LoggingTensorHook) 38 | estimator_export("estimator.StopAtStepHook")(StopAtStepHook) 39 | estimator_export("estimator.CheckpointSaverListener")(CheckpointSaverListener) 40 | estimator_export("estimator.CheckpointSaverHook")(CheckpointSaverHook) 41 | estimator_export("estimator.StepCounterHook")(StepCounterHook) 42 | estimator_export("estimator.NanLossDuringTrainingError")( 43 | NanLossDuringTrainingError) 44 | estimator_export("estimator.NanTensorHook")(NanTensorHook) 45 | estimator_export("estimator.SummarySaverHook")(SummarySaverHook) 46 | estimator_export("estimator.GlobalStepWaiterHook")(GlobalStepWaiterHook) 47 | estimator_export("estimator.FinalOpsHook")(FinalOpsHook) 48 | estimator_export("estimator.FeedFnHook")(FeedFnHook) 49 | estimator_export("estimator.ProfilerHook")(ProfilerHook) 50 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/export/export_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """All public utility methods for exporting Estimator to SavedModel. 16 | 17 | This file includes functions and constants from core (model_utils) and export.py 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | # pylint: disable=unused-import,line-too-long, wildcard-import 25 | from tensorflow.python.saved_model.model_utils import build_all_signature_defs 26 | from tensorflow.python.saved_model.model_utils import export_outputs_for_mode 27 | from tensorflow.python.saved_model.model_utils import EXPORT_TAG_MAP 28 | from tensorflow.python.saved_model.model_utils import get_export_outputs 29 | from tensorflow.python.saved_model.model_utils import get_temp_export_dir 30 | from tensorflow.python.saved_model.model_utils import get_timestamped_export_dir 31 | from tensorflow.python.saved_model.model_utils import SIGNATURE_KEY_MAP 32 | from tensorflow.python.saved_model.model_utils.export_output import _SupervisedOutput 33 | from tensorflow.python.saved_model.model_utils.export_output import ClassificationOutput 34 | from tensorflow.python.saved_model.model_utils.export_output import EvalOutput 35 | from tensorflow.python.saved_model.model_utils.export_output import ExportOutput 36 | from tensorflow.python.saved_model.model_utils.export_output import PredictOutput 37 | from tensorflow.python.saved_model.model_utils.export_output import RegressionOutput 38 | from tensorflow.python.saved_model.model_utils.export_output import TrainOutput 39 | from tensorflow_estimator.python.estimator.export.export import build_parsing_serving_input_receiver_fn 40 | from tensorflow_estimator.python.estimator.export.export import build_raw_serving_input_receiver_fn 41 | from tensorflow_estimator.python.estimator.export.export import build_raw_supervised_input_receiver_fn 42 | from tensorflow_estimator.python.estimator.export.export import build_supervised_input_receiver_fn_from_input_fn 43 | from tensorflow_estimator.python.estimator.export.export import ServingInputReceiver 44 | from tensorflow_estimator.python.estimator.export.export import SupervisedInputReceiver 45 | from tensorflow_estimator.python.estimator.export.export import TensorServingInputReceiver 46 | from tensorflow_estimator.python.estimator.export.export import UnsupervisedInputReceiver 47 | from tensorflow_estimator.python.estimator.export.export import wrap_and_check_input_tensors 48 | # pylint: enable=unused-import,line-too-long, wildcard-import 49 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/estimator_export.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for exporting TensorFlow Estimator symbols to the API. 16 | 17 | Exporting a function or a class: 18 | 19 | To export a function or a class use the estimator_export decorator. For e.g.: 20 | ```python 21 | @estimator_export('foo', 'bar.foo') 22 | def foo(...): 23 | ... 24 | ``` 25 | 26 | If a function is assigned to a variable, you can export it by calling 27 | estimator_export explicitly. For e.g.: 28 | ```python 29 | foo = get_foo(...) 30 | estimator_export('foo', 'bar.foo')(foo) 31 | ``` 32 | 33 | 34 | Exporting a constant 35 | ```python 36 | foo = 1 37 | estimator_export('consts.foo').export_constant(__name__, 'foo') 38 | ``` 39 | """ 40 | from collections.abc import Sequence 41 | from typing import Optional, TypeVar 42 | 43 | from tensorflow.python.util import deprecation 44 | from tensorflow.python.util import tf_export 45 | 46 | T = TypeVar('T') 47 | 48 | ESTIMATOR_API_NAME = 'estimator' 49 | 50 | 51 | # pylint: disable=protected-access 52 | if ESTIMATOR_API_NAME not in tf_export.API_ATTRS: 53 | tf_export.API_ATTRS[ESTIMATOR_API_NAME] = tf_export._Attributes( 54 | '_estimator_api_names', '_estimator_api_constants' 55 | ) 56 | if ESTIMATOR_API_NAME not in tf_export.API_ATTRS_V1: 57 | tf_export.API_ATTRS_V1[ESTIMATOR_API_NAME] = tf_export._Attributes( 58 | '_estimator_api_names_v1', '_estimator_api_constants_v1' 59 | ) 60 | # pylint: enable=protected-access 61 | 62 | 63 | class estimator_export(tf_export.api_export): # pylint: disable=invalid-name 64 | """Provides ways to export symbols to the TensorFlow Estimator API.""" 65 | 66 | def __init__(self, *args: str, v1: Optional[Sequence[str]] = None): 67 | """Export under the names *args (first one is considered canonical). 68 | 69 | All symbols exported by this decorator are exported under the `estimator` 70 | API name. 71 | 72 | Args: 73 | *args: API names in dot delimited format. 74 | v1: Names for the TensorFlow V1 API. If not set, we will use V2 API names 75 | both for TensorFlow V1 and V2 APIs. 76 | """ 77 | super().__init__(*args, api_name=ESTIMATOR_API_NAME, v1=v1) 78 | 79 | def __call__(self, func: T) -> T: 80 | """Calls this decorator. 81 | 82 | Args: 83 | func: decorated symbol (function or class). 84 | 85 | Returns: 86 | The input function with _tf_api_names attribute set and marked as 87 | deprecated. 88 | """ 89 | func = deprecation.deprecated(None, 'Use tf_keras instead.')(func) 90 | return super().__call__(func) 91 | -------------------------------------------------------------------------------- /tensorflow_estimator/tools/pip_package/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TensorFlow Estimator. 16 | 17 | TensorFlow Estimator is a high-level API that encapsulates model training, 18 | evaluation, prediction, and exporting. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import sys 26 | import setuptools 27 | 28 | DOCLINES = __doc__.split('\n') 29 | 30 | # This version string is semver compatible, but incompatible with pip. 31 | # For pip, we will remove all '-' characters from this string, and use the 32 | # result for pip. 33 | _VERSION = '2.16.0' 34 | 35 | REQUIRED_PACKAGES = [ 36 | # We depend on TensorFlow's declared pip dependencies. 37 | # Add a new dep there if one is needed. 38 | ] 39 | 40 | project_name = 'tensorflow_estimator' 41 | if '--project_name' in sys.argv: 42 | project_name_idx = sys.argv.index('--project_name') 43 | project_name = sys.argv[project_name_idx + 1] 44 | sys.argv.remove('--project_name') 45 | sys.argv.pop(project_name_idx) 46 | 47 | setuptools.setup( 48 | name=project_name, 49 | version=_VERSION.replace('-', ''), 50 | description=DOCLINES[0], 51 | long_description='\n'.join(DOCLINES[2:]), 52 | url='https://www.tensorflow.org/', 53 | download_url='https://github.com/tensorflow/estimator/tags', 54 | author='Google Inc.', 55 | packages=setuptools.find_packages(), 56 | install_requires=REQUIRED_PACKAGES, 57 | # PyPI package information. 58 | # Supported Python versions 59 | python_requires='>=3.7', 60 | classifiers=[ 61 | 'Development Status :: 5 - Production/Stable', 62 | 'Intended Audience :: Developers', 63 | 'Intended Audience :: Education', 64 | 'Intended Audience :: Science/Research', 65 | 'License :: OSI Approved :: Apache Software License', 66 | 'Programming Language :: Python :: 3', 67 | 'Programming Language :: Python :: 3.7', 68 | 'Programming Language :: Python :: 3.8', 69 | 'Programming Language :: Python :: 3.9', 70 | 'Programming Language :: Python :: 3.10', 71 | 'Topic :: Scientific/Engineering', 72 | 'Topic :: Scientific/Engineering :: Mathematics', 73 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 74 | 'Topic :: Software Development', 75 | 'Topic :: Software Development :: Libraries', 76 | 'Topic :: Software Development :: Libraries :: Python Modules', 77 | ], 78 | license='Apache 2.0', 79 | keywords='tensorflow estimator tensor machine learning', 80 | ) 81 | -------------------------------------------------------------------------------- /tensorflow_estimator/BUILD: -------------------------------------------------------------------------------- 1 | # Placeholder: load py_library 2 | 3 | # Description: Tensorflow Estimator. 4 | load( 5 | "//tensorflow_estimator/python/estimator/api:api_gen.bzl", 6 | "ESTIMATOR_API_INIT_FILES_V1", 7 | "ESTIMATOR_API_INIT_FILES_V2", 8 | "generate_apis", 9 | ) 10 | 11 | licenses(["notice"]) 12 | 13 | package(default_visibility = ["//tensorflow_estimator:internal"]) 14 | 15 | exports_files(["LICENSE"]) 16 | 17 | # TODO(mikecase): Clean up. Remove all non estimator packages. 18 | package_group( 19 | name = "internal", 20 | packages = [ 21 | "//learning/brain/...", 22 | "//learning/deepmind/research/...", 23 | "//learning/tfx/models/uplift/estimators/...", 24 | "//nlp/nlx/ads/expmatch/model/...", 25 | "//nlp/nlx/common/query_bert/...", 26 | "//nlp/nlx/i18n/pangloss/...", 27 | "//tensorflow_estimator/...", 28 | "//third_party/py/tensorflow_privacy/...", 29 | "//third_party/tensorflow/python/estimator/...", 30 | ], 31 | ) 32 | 33 | # This flag specifies whether Estimator 2.0 API should be built instead 34 | # of 1.* API. Note that Estimator 2.0 API is currently under development. 35 | config_setting( 36 | name = "api_version_2", 37 | define_values = {"estimator_api_version": "2"}, 38 | ) 39 | 40 | config_setting( 41 | name = "no_estimator_py_deps", 42 | define_values = {"no_estimator_py_deps": "true"}, 43 | visibility = ["//visibility:public"], 44 | ) 45 | 46 | py_library( 47 | name = "tensorflow_estimator", 48 | srcs = [ 49 | ":root_init_gen", 50 | ":estimator_python_api_gen_compat_v1", 51 | ":estimator_python_api_gen_compat_v2", 52 | # Old API files. Delete once TensorFlow is updated to import from new location. 53 | "//tensorflow_estimator/python/estimator/api:estimator_python_api_gen", 54 | "//tensorflow_estimator/python/estimator/api:estimator_python_api_gen_compat_v1", 55 | "//tensorflow_estimator/python/estimator/api:estimator_python_api_gen_compat_v2", 56 | ], 57 | srcs_version = "PY3", 58 | visibility = [ 59 | "//tensorflow_estimator:internal", 60 | "//third_party/tensorflow/tools/docs/google:__subpackages__", 61 | ], 62 | deps = [ 63 | "//tensorflow_estimator/python/estimator:estimator_py", 64 | ], 65 | ) 66 | 67 | genrule( 68 | name = "root_init_gen", 69 | srcs = select({ 70 | "api_version_2": ["_api/v2/v2.py"], 71 | "//conditions:default": ["_api/v1/v1.py"], 72 | }), 73 | outs = ["__init__.py"], 74 | cmd = select({ 75 | "api_version_2": "cp $(location :_api/v2/v2.py) $(OUTS)", 76 | "//conditions:default": "cp $(location :_api/v1/v1.py) $(OUTS)", 77 | }), 78 | ) 79 | 80 | generate_apis( 81 | name = "estimator_python_api_gen_compat_v1", 82 | api_version = 1, 83 | output_dir = "_api/v1/", 84 | output_files = ESTIMATOR_API_INIT_FILES_V1, 85 | output_package = "tensorflow_estimator._api.v1", 86 | root_file_name = "v1.py", 87 | ) 88 | 89 | generate_apis( 90 | name = "estimator_python_api_gen_compat_v2", 91 | api_version = 2, 92 | output_dir = "_api/v2/", 93 | output_files = ESTIMATOR_API_INIT_FILES_V2, 94 | output_package = "tensorflow_estimator._api.v2", 95 | root_file_name = "v2.py", 96 | ) 97 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/timeseries/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helper functions for training and constructing time series Models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy 22 | import tensorflow as tf 23 | from tensorflow_estimator.python.estimator.canned.timeseries import feature_keys 24 | 25 | 26 | # TODO(agarwal): Remove and replace with functionality from tf.slim 27 | def fully_connected(inp, 28 | inp_size, 29 | layer_size, 30 | name, 31 | activation=tf.nn.relu, 32 | dtype=tf.dtypes.float32): 33 | """Helper method to create a fully connected hidden layer.""" 34 | wt = tf.compat.v1.get_variable( 35 | name="{}_weight".format(name), shape=[inp_size, layer_size], dtype=dtype) 36 | bias = tf.compat.v1.get_variable( 37 | name="{}_bias".format(name), 38 | shape=[layer_size], 39 | initializer=tf.compat.v1.initializers.zeros()) 40 | output = tf.compat.v1.nn.xw_plus_b(inp, wt, bias) 41 | if activation is not None: 42 | assert callable(activation) 43 | output = activation(output) 44 | return output 45 | 46 | 47 | def canonicalize_times_or_steps_from_output(times, steps, 48 | previous_model_output): 49 | """Canonicalizes either relative or absolute times, with error checking.""" 50 | if steps is not None and times is not None: 51 | raise ValueError("Only one of `steps` and `times` may be specified.") 52 | if steps is None and times is None: 53 | raise ValueError("One of `steps` and `times` must be specified.") 54 | if times is not None: 55 | times = numpy.array(times) 56 | if len(times.shape) != 2: 57 | times = times[None, ...] 58 | if (previous_model_output[feature_keys.FilteringResults.TIMES].shape[0] != 59 | times.shape[0]): 60 | raise ValueError( 61 | ("`times` must have a batch dimension matching" 62 | " the previous model output (got a batch dimension of {} for `times`" 63 | " and {} for the previous model output).").format( 64 | times.shape[0], previous_model_output[ 65 | feature_keys.FilteringResults.TIMES].shape[0])) 66 | if not (previous_model_output[feature_keys.FilteringResults.TIMES][:, -1] < 67 | times[:, 0]).all(): 68 | raise ValueError("Prediction times must be after the corresponding " 69 | "previous model output.") 70 | if steps is not None: 71 | predict_times = ( 72 | previous_model_output[feature_keys.FilteringResults.TIMES][:, -1:] + 1 + 73 | numpy.arange(steps)[None, ...]) 74 | else: 75 | predict_times = times 76 | return predict_times 77 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/tpu/spatial_partitioning_api.md: -------------------------------------------------------------------------------- 1 | # Spatial partitioning 2 | 3 | Spatial partitioning allows us to run models with larger input images. Typically 4 | these models will be too large to fit on a single TPU core. 5 | 6 | Spatial partitioning uses multiple cores to process different parts of the input 7 | tensor. Each core communicates with the other cores when necessary to merge 8 | overlapping parts of the computation. All the complicated merging logic is 9 | implemented in the XLA compiler, therefore you only need to configure how the 10 | inputs to your model are partitioned. 11 | 12 | Note: Spatial partitioning only distributes activations across multiple cores. 13 | Each core still maintains its own copy of the model weights. For most image 14 | model, activations use more memory than the model weights. 15 | 16 | ## Enabling Spatial Partitioning with TPUEstimator 17 | 18 | Spatial partitioning doesn't require any code change in your model. You only 19 | need to specify the spatial partition parameters in your TPUConfig. 20 | 21 | ``` 22 | tpu_config=tpu_config.TPUConfig( 23 | iterations_per_loop=100, 24 | num_cores_per_replica=4, 25 | per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2, 26 | input_partition_dims=[[1, 4, 1, 1], None]] 27 | 28 | ``` 29 | 30 | `per_host_input_for_training` must be set to PER_HOST_V2 for spatial 31 | partitioning: this means you must have a tf.data based input pipeline. 32 | `num_cores_per_replica` determines the maximum number partitions we can split. 33 | `input_partition_dims` is a list with two elements: `feature_partition_dims` and 34 | `label_partition_dims` describes how to partition the input tensors. The 35 | structure of `feature_partition_dims` and `label_partition_dims` must match the 36 | structure of features and labels from input_fn. 37 | 38 | ### Partitioning when features and labels are single tensors 39 | 40 | `features` or `labels` can be a single tensor. In this case, 41 | `feature_partition_dims` or `label_partition_dims` must be a list/tuple of 42 | integers or None. The length of the list/tuple must equal to the number of 43 | dimensions of the tensor. For example, if `features` is an image tensor with 44 | shape [N, H, W, C], the `feature_partition_dims` must be a list/tuple with 4 45 | integers. 46 | 47 | ``` 48 | features = image_tensor # [N, H, W, C] 49 | labels = class_label # [N] 50 | 51 | input_partition_dims = [[1,4,1,1], None] 52 | 53 | ``` 54 | 55 | ### Partitioning when features or labels are a dictionary 56 | 57 | `features` or `labels` can alternatively be a dictionary from `feature_name` to 58 | a `Tensor`. In this case `feature_partition_dims` or `label_partition_dims` must 59 | be a dict with exactly the same keys, and the value is a list/tuple of integers 60 | or None. 61 | 62 | ``` 63 | features = {'image': image_tensor, 'image_mask': mask_tensor} 64 | labels = {'class_label': class_id, 'mask': mask_id} 65 | 66 | input_partition_dims = [ 67 | {'image': [1,4,1,1], 'image_mask': [1, 2, 2,1]}, 68 | {'class_label': [1], mask: None}] 69 | 70 | ``` 71 | 72 | In this example, both `features` and `labels` are dictionaries. Therefore the 73 | `input_partition_dims` contains two dicts with the same structure: the first 74 | dict in `input_partition_dims` has two keys ‘image’ and ‘image_mask’ to match 75 | the tensors in features. The value is a list of integers describes how to 76 | partition the tensor. 'class_label': [1] means we send the class_label tensor to 77 | core 0 only. 78 | 79 | ### Partitioning when features are a dict, labels are a single tensor 80 | 81 | `features` and `labels` could be any of the aforementation’s format. The rule 82 | for `feature_partition_dims` and `label_partition_dims` are applied separately. 83 | 84 | ``` 85 | features = {'image': image_tensor, 'image_mask': mask_tensor} 86 | labels = class_label # [N] 87 | 88 | input_partition_dims = [ 89 | {'image': [1,4,1,1], 'image_mask': [1, 2, 2,1]}, 90 | [1]] 91 | 92 | ``` 93 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/tpu/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # =================================================================== 15 | """Utilities for the functionalities.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import time 24 | import numpy as np 25 | import six 26 | import tensorflow as tf 27 | 28 | _ITERATIONS_PER_LOOP_VALUE_REGEX = re.compile( 29 | r'^(?P[1-9]\d*)((?P[s|m|h])$|$)') 30 | 31 | IterationsPerLoopCounter = collections.namedtuple('IterationsPerLoopCounter', 32 | ['value', 'unit']) 33 | 34 | 35 | def check_positive_integer(value, name): 36 | """Checks whether `value` is a positive integer.""" 37 | if not isinstance(value, (six.integer_types, np.integer)): 38 | raise TypeError('{} must be int, got {}'.format(name, type(value))) 39 | 40 | if value <= 0: 41 | raise ValueError('{} must be positive, got {}'.format(name, value)) 42 | 43 | 44 | def parse_iterations_per_loop(iterations_per_loop): 45 | """Parses the `iterations_per_loop` value. 46 | 47 | The parser expects the value of the `iterations_per_loop` value to be a 48 | positive integer value with unit:`count` or time-based value `` 49 | where is any positive integer and `s`, `m`, `h` are unit of time in 50 | seconds, minutes, hours respectively. Examples of valid values: `3600s`, `60m` 51 | , `1h`. 52 | 53 | Args: 54 | iterations_per_loop: Number of iterations or time alloted to spend on per 55 | device loop. 56 | 57 | Returns: 58 | A dictionary of `value` and `unit`. The `unit` value can be either a raw 59 | `count`, or time in `seconds`. 60 | { 61 | "value": , 62 | "unit": 63 | } 64 | """ 65 | m = _ITERATIONS_PER_LOOP_VALUE_REGEX.match(str(iterations_per_loop)) 66 | if m is None: 67 | raise ValueError( 68 | 'Invalid TPUConfig `iterations_per_loop` value. Value must be positive ' 69 | 'integer value or time-based value `` where is any' 70 | 'positive integer and `s`, `m`, `h` are unit of time in seconds, ' 71 | 'minutes, hours respectively. Examples of valid values: `3600s`, `60m`,' 72 | ' `1h`.') 73 | unit_value = 'seconds' if m.group('suffix') in ['h', 'm', 's'] else 'count' 74 | value = int(m.group('value')) 75 | if m.group('suffix') == 'm': 76 | value *= 60 77 | elif m.group('suffix') == 'h': 78 | value *= 3600 79 | return IterationsPerLoopCounter(value, unit_value) 80 | 81 | 82 | # TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we 83 | # release a tensorflow_estimator with MultiHostDatasetInitializerHook in 84 | # python/estimator/util.py. 85 | class MultiHostDatasetInitializerHook(tf.compat.v1.train.SessionRunHook): 86 | """Creates a SessionRunHook that initializes all passed iterators.""" 87 | 88 | def __init__(self, dataset_initializers): 89 | self._initializers = dataset_initializers 90 | 91 | def after_create_session(self, session, coord): 92 | del coord 93 | start = time.time() 94 | session.run(self._initializers) 95 | tf.compat.v1.logging.info('Initialized dataset iterators in %d seconds', 96 | time.time() - start) 97 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/timeseries/state_management.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Classes for wrapping a model to operate on different data shapes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import abc 22 | from tensorflow_estimator.python.estimator import estimator_lib 23 | from tensorflow_estimator.python.estimator.canned.timeseries import feature_keys 24 | 25 | 26 | class PassthroughStateManager(object): 27 | """A minimal wrapper for models which do not need state management.""" 28 | 29 | def __init__(self): 30 | self._input_statistics = None 31 | self._graph_initialized = False 32 | 33 | def initialize_graph(self, model, input_statistics=None): 34 | """Adds required operations to the graph.""" 35 | del model # unused 36 | self._graph_initialized = True 37 | self._input_statistics = input_statistics 38 | 39 | def define_loss(self, model, features, mode): 40 | """Wrap "model" with StateManager-specific operations. 41 | 42 | Args: 43 | model: The model (inheriting from TimeSeriesModel) to manage state for. 44 | features: A dictionary with the following key/value pairs: 45 | feature_keys.TrainEvalFeatures.TIMES: A [batch size x window size] 46 | Tensor with times for each observation. 47 | feature_keys.TrainEvalFeatures.VALUES: A [batch size x window size x num 48 | features] Tensor with values for each observation. 49 | mode: The tf.estimator.ModeKeys mode to use (TRAIN or EVAL). 50 | 51 | Returns: 52 | A ModelOutputs object. 53 | Raises: 54 | ValueError: If start state was specified. 55 | """ 56 | if feature_keys.State.STATE_TUPLE in features: 57 | raise ValueError( 58 | "Overriding start state is not supported for this model.") 59 | return model.define_loss(features, mode) 60 | 61 | 62 | class _OverridableStateManager(PassthroughStateManager): 63 | """Base class for state managers which support overriding model state.""" 64 | 65 | @abc.abstractmethod 66 | def _define_loss_with_saved_state(self, model, features, mode): 67 | pass 68 | 69 | def define_loss(self, model, features, mode): 70 | """Switches between explicit start state and managed state.""" 71 | if feature_keys.FilteringFeatures.STATE_TUPLE in features: 72 | # Explicit start state has been provided, so we should use that. 73 | if mode == estimator_lib.ModeKeys.TRAIN: 74 | raise ValueError( 75 | "Overriding saved state for training is not supported (but a value " 76 | "for feature {} was specified).".format( 77 | feature_keys.FilteringFeatures.STATE_TUPLE)) 78 | start_state = features[feature_keys.FilteringFeatures.STATE_TUPLE] 79 | del features[feature_keys.FilteringFeatures.STATE_TUPLE] 80 | return model.get_batch_loss( 81 | features=features, mode=mode, state=start_state) 82 | else: 83 | # No explicit start state; use managed state. 84 | return self._define_loss_with_saved_state( 85 | model=model, features=features, mode=mode) 86 | 87 | 88 | class FilteringOnlyStateManager(_OverridableStateManager): 89 | """State manager for models which use state only for filtering. 90 | 91 | Window-based models (ARModel) do not require state to be fed during training 92 | (instead requiring a specific window size). Rather than requiring a minimum 93 | window size for filtering, these models maintain this window in their state, 94 | and so need state to be fed. 95 | """ 96 | 97 | def _define_loss_with_saved_state(self, model, features, mode): 98 | return model.define_loss(features, mode) 99 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/timeseries/math_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for math_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from tensorflow_estimator.python.estimator.canned.timeseries import math_utils 23 | from tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import TrainEvalFeatures 24 | 25 | 26 | class InputStatisticsTests(tf.test.TestCase): 27 | 28 | def _input_statistics_test_template(self, 29 | stat_object, 30 | num_features, 31 | dtype, 32 | warmup_iterations=0, 33 | rtol=1e-6, 34 | data_length=4): 35 | graph = tf.Graph() 36 | with graph.as_default(): 37 | data_length_range = tf.range(data_length, dtype=dtype) 38 | num_features_range = tf.range(num_features, dtype=dtype) 39 | times = 2 * data_length_range[None, :] - 3 40 | values = (data_length_range[:, None] + num_features_range[None, :])[None, 41 | ...] 42 | features = { 43 | TrainEvalFeatures.TIMES: times, 44 | TrainEvalFeatures.VALUES: values, 45 | } 46 | statistics = stat_object.initialize_graph(features=features) 47 | with self.session(graph=graph) as session: 48 | tf.compat.v1.initializers.global_variables().run() 49 | coordinator = tf.train.Coordinator() 50 | tf.compat.v1.train.queue_runner.start_queue_runners( 51 | session, coord=coordinator) 52 | for _ in range(warmup_iterations): 53 | # A control dependency should ensure that, for queue-based statistics, 54 | # a use of any statistic is preceded by an update of all adaptive 55 | # statistics. 56 | self.evaluate(statistics.total_observation_count) 57 | self.assertAllClose( 58 | tf.range(num_features, dtype=dtype) + 59 | tf.math.reduce_mean(data_length_range)[None], 60 | self.evaluate(statistics.series_start_moments.mean), 61 | rtol=rtol) 62 | self.assertAllClose( 63 | tf.tile( 64 | tf.math.reduce_variance(data_length_range)[None], 65 | [num_features]), 66 | self.evaluate(statistics.series_start_moments.variance), 67 | rtol=rtol) 68 | self.assertAllClose( 69 | tf.math.reduce_mean(values[0], axis=0), 70 | self.evaluate(statistics.overall_feature_moments.mean), 71 | rtol=rtol) 72 | self.assertAllClose( 73 | tf.math.reduce_variance(values[0], axis=0), 74 | self.evaluate(statistics.overall_feature_moments.variance), 75 | rtol=rtol) 76 | self.assertAllClose(-3, self.evaluate(statistics.start_time), rtol=rtol) 77 | self.assertAllClose( 78 | data_length, 79 | self.evaluate(statistics.total_observation_count), 80 | rtol=rtol) 81 | coordinator.request_stop() 82 | coordinator.join() 83 | 84 | def test_queue(self): 85 | for dtype in [tf.dtypes.float32, tf.dtypes.float64]: 86 | for num_features in [1, 2, 3]: 87 | self._input_statistics_test_template( 88 | math_utils.InputStatisticsFromMiniBatch( 89 | num_features=num_features, dtype=dtype), 90 | num_features=num_features, 91 | dtype=dtype, 92 | warmup_iterations=1000, 93 | rtol=0.1) 94 | 95 | 96 | if __name__ == "__main__": 97 | tf.test.main() 98 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/head/head_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for heads and unit tests.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from tensorflow_estimator.python.estimator.head import binary_class_head 23 | from tensorflow_estimator.python.estimator.head import multi_class_head 24 | 25 | _DEFAULT_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY 26 | 27 | 28 | def binary_or_multi_class_head(n_classes, weight_column, label_vocabulary, 29 | loss_reduction): 30 | """Creates either binary or multi-class head. 31 | 32 | Args: 33 | n_classes: Number of label classes. 34 | weight_column: A string or a `NumericColumn` created by 35 | `tf.feature_column.numeric_column` defining feature column representing 36 | weights. It is used to down weight or boost examples during training. It 37 | will be multiplied by the loss of the example. If it is a string, it is 38 | used as a key to fetch weight tensor from the `features`. If it is a 39 | `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then 40 | weight_column.normalizer_fn is applied on it to get weight tensor. 41 | label_vocabulary: A list of strings represents possible label values. If 42 | given, labels must be string type and have any value in 43 | `label_vocabulary`. If it is not given, that means labels are already 44 | encoded as integer or float within [0, 1] for `n_classes=2` and encoded as 45 | integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there 46 | will be errors if vocabulary is not provided and labels are string. 47 | loss_reduction: One of `tf.losses.Reduction` except `NONE`. Defines how to 48 | reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`. 49 | 50 | Returns: 51 | A `Head` instance. 52 | """ 53 | if n_classes == 2: 54 | head = binary_class_head.BinaryClassHead( 55 | weight_column=weight_column, 56 | label_vocabulary=label_vocabulary, 57 | loss_reduction=loss_reduction) 58 | else: 59 | head = multi_class_head.MultiClassHead( 60 | n_classes, 61 | weight_column=weight_column, 62 | label_vocabulary=label_vocabulary, 63 | loss_reduction=loss_reduction) 64 | return head 65 | 66 | 67 | def _initialize_variables(test_case, scaffold): 68 | scaffold.finalize() 69 | test_case.assertIsNone(scaffold.init_feed_dict) 70 | test_case.assertIsNone(scaffold.init_fn) 71 | scaffold.init_op.run() 72 | scaffold.ready_for_local_init_op.eval() 73 | scaffold.local_init_op.run() 74 | scaffold.ready_op.eval() 75 | test_case.assertIsNotNone(scaffold.saver) 76 | 77 | 78 | def _assert_simple_summaries(test_case, 79 | expected_summaries, 80 | summary_str, 81 | tol=1e-6): 82 | """Assert summary the specified simple values. 83 | 84 | Args: 85 | test_case: test case. 86 | expected_summaries: Dict of expected tags and simple values. 87 | summary_str: Serialized `summary_pb2.Summary`. 88 | tol: Tolerance for relative and absolute. 89 | """ 90 | summary = tf.compat.v1.summary.Summary() 91 | summary.ParseFromString(summary_str) 92 | test_case.assertAllClose( 93 | expected_summaries, {v.tag: v.simple_value for v in summary.value}, 94 | rtol=tol, 95 | atol=tol) 96 | 97 | 98 | def _assert_no_hooks(test_case, spec): 99 | test_case.assertAllEqual([], spec.training_chief_hooks) 100 | test_case.assertAllEqual([], spec.training_hooks) 101 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for util.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import parameterized 22 | import numpy as np 23 | import tensorflow as tf 24 | from tensorflow.python.framework import test_util 25 | from tensorflow_estimator.python.estimator import util 26 | 27 | 28 | @test_util.deprecated_graph_mode_only 29 | class UtilTest(tf.test.TestCase, parameterized.TestCase): 30 | """Tests for miscellaneous Estimator utils.""" 31 | 32 | def test_parse_input_fn_result_tuple(self): 33 | 34 | def _input_fn(): 35 | features = tf.constant(np.arange(100)) 36 | labels = tf.constant(np.arange(100, 200)) 37 | return features, labels 38 | 39 | features, labels, hooks = util.parse_input_fn_result(_input_fn()) 40 | 41 | with self.cached_session() as sess: 42 | vals = sess.run([features, labels]) 43 | 44 | self.assertAllEqual(vals[0], np.arange(100)) 45 | self.assertAllEqual(vals[1], np.arange(100, 200)) 46 | self.assertEqual(hooks, []) 47 | 48 | @parameterized.named_parameters(('DatasetV1', tf.compat.v1.data.Dataset), 49 | ('DatasetV2', tf.data.Dataset)) 50 | def test_parse_input_fn_result_dataset(self, dataset_class): 51 | 52 | def _input_fn(): 53 | features = np.expand_dims(np.arange(100), 0) 54 | labels = np.expand_dims(np.arange(100, 200), 0) 55 | return dataset_class.from_tensor_slices((features, labels)) 56 | 57 | features, labels, hooks = util.parse_input_fn_result(_input_fn()) 58 | 59 | with tf.compat.v1.train.MonitoredSession(hooks=hooks) as sess: 60 | vals = sess.run([features, labels]) 61 | 62 | self.assertAllEqual(vals[0], np.arange(100)) 63 | self.assertAllEqual(vals[1], np.arange(100, 200)) 64 | self.assertIsInstance(hooks[0], util._DatasetInitializerHook) 65 | 66 | def test_parse_input_fn_result_features_only(self): 67 | 68 | def _input_fn(): 69 | return tf.constant(np.arange(100)) 70 | 71 | features, labels, hooks = util.parse_input_fn_result(_input_fn()) 72 | 73 | with self.cached_session() as sess: 74 | vals = sess.run([features]) 75 | 76 | self.assertAllEqual(vals[0], np.arange(100)) 77 | self.assertEqual(labels, None) 78 | self.assertEqual(hooks, []) 79 | 80 | @parameterized.named_parameters(('DatasetV1', tf.compat.v1.data.Dataset), 81 | ('DatasetV2', tf.data.Dataset)) 82 | def test_parse_input_fn_result_features_only_dataset(self, dataset_class): 83 | 84 | def _input_fn(): 85 | features = np.expand_dims(np.arange(100), 0) 86 | return dataset_class.from_tensor_slices(features) 87 | 88 | features, labels, hooks = util.parse_input_fn_result(_input_fn()) 89 | 90 | with tf.compat.v1.train.MonitoredSession(hooks=hooks) as sess: 91 | vals = sess.run([features]) 92 | 93 | self.assertAllEqual(vals[0], np.arange(100)) 94 | self.assertEqual(labels, None) 95 | self.assertIsInstance(hooks[0], util._DatasetInitializerHook) 96 | 97 | @parameterized.named_parameters(('DatasetV1', tf.compat.v1.data.Dataset), 98 | ('DatasetV2', tf.data.Dataset)) 99 | def test_parse_input_fn_result_invalid(self, dataset_class): 100 | 101 | def _input_fn(): 102 | features = np.expand_dims(np.arange(100), 0) 103 | labels = np.expand_dims(np.arange(100, 200), 0) 104 | return dataset_class.from_tensor_slices((features, labels, labels)) 105 | 106 | with self.assertRaisesRegexp(ValueError, 'input_fn should return'): 107 | util.parse_input_fn_result(_input_fn()) 108 | 109 | 110 | if __name__ == '__main__': 111 | tf.test.main() 112 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/hooks/session_run_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A SessionRunHook extends `session.run()` calls for the `MonitoredSession`. 16 | 17 | SessionRunHooks are useful to track training, report progress, request early 18 | stopping and more. SessionRunHooks use the observer pattern and notify at the 19 | following points: 20 | - when a session starts being used 21 | - before a call to the `session.run()` 22 | - after a call to the `session.run()` 23 | - when the session closed 24 | 25 | A SessionRunHook encapsulates a piece of reusable/composable computation that 26 | can piggyback a call to `MonitoredSession.run()`. A hook can add any 27 | ops-or-tensor/feeds to the run call, and when the run call finishes with success 28 | gets the outputs it requested. Hooks are allowed to add ops to the graph in 29 | `hook.begin()`. The graph is finalized after the `begin()` method is called. 30 | 31 | There are a few pre-defined hooks: 32 | - StopAtStepHook: Request stop based on global_step 33 | - CheckpointSaverHook: saves checkpoint 34 | - LoggingTensorHook: outputs one or more tensor values to log 35 | - NanTensorHook: Request stop if given `Tensor` contains Nans. 36 | - SummarySaverHook: saves summaries to a summary writer 37 | 38 | For more specific needs, you can create custom hooks: 39 | class ExampleHook(SessionRunHook): 40 | def begin(self): 41 | # You can add ops to the graph here. 42 | print('Starting the session.') 43 | self.your_tensor = ... 44 | 45 | def after_create_session(self, session, coord): 46 | # When this is called, the graph is finalized and 47 | # ops can no longer be added to the graph. 48 | print('Session created.') 49 | 50 | def before_run(self, run_context): 51 | print('Before calling session.run().') 52 | return SessionRunArgs(self.your_tensor) 53 | 54 | def after_run(self, run_context, run_values): 55 | print('Done running one step. The value of my tensor: %s', 56 | run_values.results) 57 | if you-need-to-stop-loop: 58 | run_context.request_stop() 59 | 60 | def end(self, session): 61 | print('Done with the session.') 62 | 63 | To understand how hooks interact with calls to `MonitoredSession.run()`, 64 | look at following code: 65 | with MonitoredTrainingSession(hooks=your_hooks, ...) as sess: 66 | while not sess.should_stop(): 67 | sess.run(your_fetches) 68 | 69 | Above user code leads to following execution: 70 | call hooks.begin() 71 | sess = tf.Session() 72 | call hooks.after_create_session() 73 | while not stop is requested: 74 | call hooks.before_run() 75 | try: 76 | results = sess.run(merged_fetches, feed_dict=merged_feeds) 77 | except (errors.OutOfRangeError, StopIteration): 78 | break 79 | call hooks.after_run() 80 | call hooks.end() 81 | sess.close() 82 | 83 | Note that if sess.run() raises OutOfRangeError or StopIteration then 84 | hooks.after_run() will not be called but hooks.end() will still be called. 85 | If sess.run() raises any other exception then neither hooks.after_run() nor 86 | hooks.end() will be called. 87 | """ 88 | 89 | from __future__ import absolute_import 90 | from __future__ import division 91 | from __future__ import print_function 92 | from tensorflow.python.training.session_run_hook import SessionRunArgs 93 | from tensorflow.python.training.session_run_hook import SessionRunContext 94 | from tensorflow.python.training.session_run_hook import SessionRunHook 95 | from tensorflow.python.training.session_run_hook import SessionRunValues 96 | from tensorflow_estimator.python.estimator.estimator_export import estimator_export 97 | 98 | estimator_export("estimator.SessionRunHook")(SessionRunHook) 99 | estimator_export("estimator.SessionRunArgs")(SessionRunArgs) 100 | estimator_export("estimator.SessionRunContext")(SessionRunContext) 101 | estimator_export("estimator.SessionRunValues")(SessionRunValues) 102 | -------------------------------------------------------------------------------- /tensorflow_estimator/tools/pip_package/build_pip_package.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | set -e 17 | 18 | function is_absolute { 19 | [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]] 20 | } 21 | 22 | function real_path() { 23 | is_absolute "$1" && echo "$1" || echo "$PWD/${1#./}" 24 | } 25 | 26 | function prepare_src() { 27 | TMPDIR="$1" 28 | 29 | mkdir -p "$TMPDIR" 30 | echo $(date) : "=== Preparing sources in dir: ${TMPDIR}" 31 | 32 | if [ ! -d bazel-bin/tensorflow_estimator ]; then 33 | echo "Could not find bazel-bin. Did you run from the root of the build tree?" 34 | exit 1 35 | fi 36 | cp -r "bazel-bin/tensorflow_estimator/tools/pip_package/build_pip_package.runfiles/org_tensorflow_estimator/tensorflow_estimator" "$TMPDIR" 37 | cp tensorflow_estimator/tools/pip_package/setup.py "$TMPDIR" 38 | 39 | # Verifies all expected files are in pip. 40 | # Creates init files in all directory in pip. 41 | python tensorflow_estimator/tools/pip_package/create_pip_helper.py --pip-root "${TMPDIR}/tensorflow_estimator/" --bazel-root "./tensorflow_estimator" 42 | } 43 | 44 | function build_wheel() { 45 | if [ $# -lt 2 ] ; then 46 | echo "No src and dest dir provided" 47 | exit 1 48 | fi 49 | 50 | TMPDIR="$1" 51 | DEST="$2" 52 | PROJECT_NAME="$3" 53 | 54 | pushd ${TMPDIR} > /dev/null 55 | echo $(date) : "=== Building wheel" 56 | "${PYTHON_BIN_PATH:-python}" setup.py bdist_wheel --universal --project_name $PROJECT_NAME 57 | mkdir -p ${DEST} 58 | cp dist/* ${DEST} 59 | popd > /dev/null 60 | echo $(date) : "=== Output wheel file is in: ${DEST}" 61 | } 62 | 63 | function usage() { 64 | echo "Usage:" 65 | echo "$0 [--src srcdir] [--dst dstdir] [options]" 66 | echo "$0 dstdir [options]" 67 | echo "" 68 | echo " --src prepare sources in srcdir" 69 | echo " will use temporary dir if not specified" 70 | echo "" 71 | echo " --dst build wheel in dstdir" 72 | echo " if dstdir is not set do not build, only prepare sources" 73 | echo "" 74 | echo " Options:" 75 | echo " --project_name set project name to name" 76 | echo " --nightly build tensorflow_estimator nightly" 77 | echo "" 78 | exit 1 79 | } 80 | 81 | function main() { 82 | NIGHTLY_BUILD=0 83 | PROJECT_NAME="" 84 | SRCDIR="" 85 | DSTDIR="" 86 | CLEANSRC=1 87 | 88 | while true; do 89 | if [[ -z "$1" ]]; then 90 | break 91 | elif [[ "$1" == "--help" ]]; then 92 | usage 93 | exit 1 94 | elif [[ "$1" == "--nightly" ]]; then 95 | NIGHTLY_BUILD=1 96 | elif [[ "$1" == "--project_name" ]]; then 97 | shift 98 | if [[ -z "$1" ]]; then 99 | break 100 | fi 101 | PROJECT_NAME="$1" 102 | elif [[ "$1" == "--src" ]]; then 103 | shift 104 | if [[ -z "$1" ]]; then 105 | break 106 | fi 107 | SRCDIR="$(real_path $1)" 108 | CLEANSRC=0 109 | elif [[ "$1" == "--dst" ]]; then 110 | shift 111 | if [[ -z "$1" ]]; then 112 | break 113 | fi 114 | DSTDIR="$(real_path $1)" 115 | else 116 | DSTDIR="$(real_path $1)" 117 | fi 118 | shift 119 | done 120 | 121 | if [[ -z ${PROJECT_NAME} ]]; then 122 | PROJECT_NAME="tensorflow_estimator" 123 | if [[ ${NIGHTLY_BUILD} == "1" ]]; then 124 | PROJECT_NAME="tf_estimator_nightly" 125 | fi 126 | fi 127 | 128 | if [[ -z "$DSTDIR" ]] && [[ -z "$SRCDIR" ]]; then 129 | echo "No destination dir provided" 130 | usage 131 | exit 1 132 | fi 133 | 134 | if [[ -z "$SRCDIR" ]]; then 135 | # make temp srcdir if none set 136 | SRCDIR="$(mktemp -d -t tmp.XXXXXXXXXX)" 137 | fi 138 | 139 | prepare_src "$SRCDIR" 140 | 141 | if [[ -z "$DSTDIR" ]]; then 142 | # only want to prepare sources 143 | exit 144 | fi 145 | 146 | build_wheel "$SRCDIR" "$DSTDIR" "$PROJECT_NAME" 147 | 148 | if [[ $CLEANSRC -ne 0 ]]; then 149 | rm -rf "${TMPDIR}" 150 | fi 151 | } 152 | 153 | main "$@" 154 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/optimizers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for optimizers.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from tensorflow_estimator.python.estimator.util import tf_keras 23 | from tensorflow_estimator.python.estimator.canned import optimizers 24 | 25 | 26 | class _TestOptimizer(tf.compat.v1.train.Optimizer): 27 | 28 | def __init__(self): 29 | super(_TestOptimizer, self).__init__( 30 | use_locking=False, name='TestOptimizer') 31 | 32 | 33 | class GetOptimizerInstance(tf.test.TestCase): 34 | 35 | def test_unsupported_name(self): 36 | with self.assertRaisesRegex( 37 | ValueError, 'Unsupported optimizer name: unsupported_name'): 38 | optimizers.get_optimizer_instance('unsupported_name', learning_rate=0.1) 39 | 40 | def test_supported_name_but_learning_rate_none(self): 41 | with self.assertRaisesRegex( 42 | ValueError, 'learning_rate must be specified when opt is string'): 43 | optimizers.get_optimizer_instance('Adagrad', learning_rate=None) 44 | 45 | def test_keras_optimizer_after_tf_2_11(self): 46 | new_opt = tf_keras.optimizers.Adagrad() 47 | 48 | # In eager mode it should automatically convert to legacy optimizer. 49 | opt = optimizers.get_optimizer_instance_v2(new_opt, learning_rate=0.1) 50 | self.assertIsInstance(opt, tf_keras.optimizers.legacy.Adagrad) 51 | 52 | # In graph mode errors should be thrown. 53 | @tf.function 54 | def foo(): 55 | with self.assertRaisesRegex( 56 | ValueError, 57 | r'Please set your.*tf_keras\.optimizers\.legacy\.Adagrad.*'): 58 | optimizers.get_optimizer_instance_v2(new_opt, learning_rate=0.1) 59 | foo() 60 | 61 | def test_adagrad(self): 62 | opt = optimizers.get_optimizer_instance('Adagrad', learning_rate=0.1) 63 | self.assertIsInstance(opt, tf.compat.v1.train.AdagradOptimizer) 64 | self.assertAlmostEqual(0.1, opt._learning_rate) 65 | 66 | def test_adam(self): 67 | opt = optimizers.get_optimizer_instance('Adam', learning_rate=0.1) 68 | self.assertIsInstance(opt, tf.compat.v1.train.AdamOptimizer) 69 | self.assertAlmostEqual(0.1, opt._lr) 70 | 71 | def test_ftrl(self): 72 | opt = optimizers.get_optimizer_instance('Ftrl', learning_rate=0.1) 73 | self.assertIsInstance(opt, tf.compat.v1.train.FtrlOptimizer) 74 | self.assertAlmostEqual(0.1, opt._learning_rate) 75 | 76 | def test_rmsprop(self): 77 | opt = optimizers.get_optimizer_instance('RMSProp', learning_rate=0.1) 78 | self.assertIsInstance(opt, tf.compat.v1.train.RMSPropOptimizer) 79 | self.assertAlmostEqual(0.1, opt._learning_rate) 80 | 81 | def test_sgd(self): 82 | opt = optimizers.get_optimizer_instance('SGD', learning_rate=0.1) 83 | self.assertIsInstance(opt, tf.compat.v1.train.GradientDescentOptimizer) 84 | self.assertAlmostEqual(0.1, opt._learning_rate) 85 | 86 | def test_object(self): 87 | opt = optimizers.get_optimizer_instance(_TestOptimizer()) 88 | self.assertIsInstance(opt, _TestOptimizer) 89 | 90 | def test_object_invalid(self): 91 | with self.assertRaisesRegex( 92 | ValueError, 'The given object is not an Optimizer instance'): 93 | optimizers.get_optimizer_instance((1, 2, 3)) 94 | 95 | def test_callable(self): 96 | 97 | def _optimizer_fn(): 98 | return _TestOptimizer() 99 | 100 | opt = optimizers.get_optimizer_instance(_optimizer_fn) 101 | self.assertIsInstance(opt, _TestOptimizer) 102 | 103 | def test_lambda(self): 104 | opt = optimizers.get_optimizer_instance(lambda: _TestOptimizer()) # pylint: disable=unnecessary-lambda 105 | self.assertIsInstance(opt, _TestOptimizer) 106 | 107 | def test_callable_returns_invalid(self): 108 | 109 | def _optimizer_fn(): 110 | return (1, 2, 3) 111 | 112 | with self.assertRaisesRegex( 113 | ValueError, 'The given object is not an Optimizer instance'): 114 | optimizers.get_optimizer_instance(_optimizer_fn) 115 | 116 | 117 | if __name__ == '__main__': 118 | tf.test.main() 119 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/estimator_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Estimator: High level tools for working with models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import,line-too-long,wildcard-import 22 | from tensorflow_estimator.python.estimator.canned.baseline import BaselineClassifier 23 | from tensorflow_estimator.python.estimator.canned.baseline import BaselineEstimator 24 | from tensorflow_estimator.python.estimator.canned.baseline import BaselineRegressor 25 | from tensorflow_estimator.python.estimator.canned.dnn import dnn_logit_fn_builder 26 | from tensorflow_estimator.python.estimator.canned.dnn import DNNClassifier 27 | from tensorflow_estimator.python.estimator.canned.dnn import DNNEstimator 28 | from tensorflow_estimator.python.estimator.canned.dnn import DNNRegressor 29 | from tensorflow_estimator.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedClassifier 30 | from tensorflow_estimator.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedEstimator 31 | from tensorflow_estimator.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedRegressor 32 | from tensorflow_estimator.python.estimator.canned.kmeans import KMeansClustering 33 | from tensorflow_estimator.python.estimator.canned.linear import linear_logit_fn_builder 34 | from tensorflow_estimator.python.estimator.canned.linear import LinearClassifier 35 | from tensorflow_estimator.python.estimator.canned.linear import LinearEstimator 36 | from tensorflow_estimator.python.estimator.canned.linear import LinearRegressor 37 | from tensorflow_estimator.python.estimator.canned.parsing_utils import classifier_parse_example_spec 38 | from tensorflow_estimator.python.estimator.canned.parsing_utils import regressor_parse_example_spec 39 | from tensorflow_estimator.python.estimator.canned.rnn import RNNClassifier 40 | from tensorflow_estimator.python.estimator.canned.rnn import RNNEstimator 41 | from tensorflow_estimator.python.estimator.early_stopping import * 42 | from tensorflow_estimator.python.estimator.estimator import Estimator 43 | from tensorflow_estimator.python.estimator.estimator import VocabInfo 44 | from tensorflow_estimator.python.estimator.estimator import WarmStartSettings 45 | from tensorflow_estimator.python.estimator.export import export_lib as export 46 | from tensorflow_estimator.python.estimator.exporter import Exporter 47 | from tensorflow_estimator.python.estimator.exporter import FinalExporter 48 | from tensorflow_estimator.python.estimator.exporter import LatestExporter 49 | from tensorflow_estimator.python.estimator.extenders import add_metrics 50 | from tensorflow_estimator.python.estimator.head.base_head import Head 51 | from tensorflow_estimator.python.estimator.head.binary_class_head import BinaryClassHead 52 | from tensorflow_estimator.python.estimator.head.multi_class_head import MultiClassHead 53 | from tensorflow_estimator.python.estimator.head.multi_head import MultiHead 54 | from tensorflow_estimator.python.estimator.head.multi_label_head import MultiLabelHead 55 | from tensorflow_estimator.python.estimator.head.regression_head import LogisticRegressionHead 56 | from tensorflow_estimator.python.estimator.head.regression_head import PoissonRegressionHead 57 | from tensorflow_estimator.python.estimator.head.regression_head import RegressionHead 58 | from tensorflow_estimator.python.estimator.hooks import basic_session_run_hooks 59 | from tensorflow_estimator.python.estimator.hooks import hooks 60 | from tensorflow_estimator.python.estimator.hooks import session_run_hook 61 | from tensorflow_estimator.python.estimator.inputs import inputs 62 | from tensorflow_estimator.python.estimator.keras_lib import model_to_estimator 63 | from tensorflow_estimator.python.estimator.mode_keys import ModeKeys 64 | from tensorflow_estimator.python.estimator.model_fn import call_logit_fn 65 | from tensorflow_estimator.python.estimator.model_fn import EstimatorSpec 66 | from tensorflow_estimator.python.estimator.run_config import RunConfig 67 | from tensorflow_estimator.python.estimator.tpu.tpu_estimator import TPUEstimator 68 | from tensorflow_estimator.python.estimator.training import EvalSpec 69 | from tensorflow_estimator.python.estimator.training import train_and_evaluate 70 | from tensorflow_estimator.python.estimator.training import TrainSpec 71 | 72 | # pylint: enable=unused-import,line-too-long,wildcard-import 73 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/linear_optimizer/python/utils/sharded_mutable_dense_hashtable_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for sharded_mutable_dense_hashtable.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from tensorflow.python.platform import googletest 23 | from tensorflow_estimator.python.estimator.canned.linear_optimizer.python.utils.sharded_mutable_dense_hashtable import _ShardedMutableDenseHashTable 24 | 25 | 26 | class _ShardedMutableDenseHashTableTest(tf.test.TestCase): 27 | """Tests for the ShardedMutableHashTable class.""" 28 | 29 | def testShardedMutableHashTable(self): 30 | for num_shards in [1, 3, 10]: 31 | with self.cached_session(): 32 | default_val = -1 33 | empty_key = 0 34 | deleted_key = -1 35 | keys = tf.constant([11, 12, 13], tf.dtypes.int64) 36 | values = tf.constant([0, 1, 2], tf.dtypes.int64) 37 | table = _ShardedMutableDenseHashTable( 38 | tf.dtypes.int64, 39 | tf.dtypes.int64, 40 | default_val, 41 | empty_key, 42 | deleted_key, 43 | num_shards=num_shards) 44 | self.assertAllEqual(0, self.evaluate(table.size())) 45 | 46 | self.evaluate(table.insert(keys, values)) 47 | self.assertAllEqual(3, self.evaluate(table.size())) 48 | 49 | input_string = tf.constant([11, 12, 14], tf.dtypes.int64) 50 | output = table.lookup(input_string) 51 | self.assertAllEqual([3], output.get_shape()) 52 | self.assertAllEqual([0, 1, -1], self.evaluate(output)) 53 | 54 | def testShardedMutableHashTableVectors(self): 55 | for num_shards in [1, 3, 10]: 56 | with self.cached_session(): 57 | default_val = [-0.1, 0.2] 58 | empty_key = [0, 1] 59 | deleted_key = [1, 0] 60 | keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.dtypes.int64) 61 | values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], 62 | tf.dtypes.float32) 63 | table = _ShardedMutableDenseHashTable( 64 | tf.dtypes.int64, 65 | tf.dtypes.float32, 66 | default_val, 67 | empty_key, 68 | deleted_key, 69 | num_shards=num_shards) 70 | self.assertAllEqual(0, self.evaluate(table.size())) 71 | 72 | self.evaluate(table.insert(keys, values)) 73 | self.assertAllEqual(3, self.evaluate(table.size())) 74 | 75 | input_string = tf.constant([[11, 12], [13, 14], [11, 14]], 76 | tf.dtypes.int64) 77 | output = table.lookup(input_string) 78 | self.assertAllEqual([3, 2], output.get_shape()) 79 | self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]], 80 | self.evaluate(output)) 81 | 82 | def testExportSharded(self): 83 | with self.cached_session(): 84 | empty_key = -2 85 | deleted_key = -3 86 | default_val = -1 87 | num_shards = 2 88 | keys = tf.constant([10, 11, 12], tf.dtypes.int64) 89 | values = tf.constant([2, 3, 4], tf.dtypes.int64) 90 | table = _ShardedMutableDenseHashTable( 91 | tf.dtypes.int64, 92 | tf.dtypes.int64, 93 | default_val, 94 | empty_key, 95 | deleted_key, 96 | num_shards=num_shards) 97 | self.assertAllEqual(0, self.evaluate(table.size())) 98 | 99 | self.evaluate(table.insert(keys, values)) 100 | self.assertAllEqual(3, self.evaluate(table.size())) 101 | 102 | keys_list, values_list = table.export_sharded() 103 | self.assertAllEqual(num_shards, len(keys_list)) 104 | self.assertAllEqual(num_shards, len(values_list)) 105 | 106 | # Exported keys include empty key buckets set to the empty_key 107 | self.assertAllEqual( 108 | set([-2, 10, 12]), set(self.evaluate(keys_list[0]).flatten())) 109 | self.assertAllEqual( 110 | set([-2, 11]), set(self.evaluate(keys_list[1]).flatten())) 111 | # Exported values include empty value buckets set to 0 112 | self.assertAllEqual( 113 | set([0, 2, 4]), set(self.evaluate(values_list[0]).flatten())) 114 | self.assertAllEqual( 115 | set([0, 3]), set(self.evaluate(values_list[1]).flatten())) 116 | 117 | 118 | if __name__ == '__main__': 119 | googletest.main() 120 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for Estimators.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import time 22 | import tensorflow as tf 23 | 24 | # import keras 2 25 | version_fn = getattr(tf.keras, 'version', None) 26 | if version_fn and version_fn().startswith('3.'): 27 | import tf_keras # pylint: disable=g-import-not-at-top,unused-import 28 | from tf_keras.api._v1 import keras as tf_keras_v1 # pylint: disable=g-import-not-at-top,unused-import 29 | from tf_keras.api._v2 import keras as tf_keras_v2 # pylint: disable=g-import-not-at-top,unused-import 30 | else: 31 | tf_keras = tf.keras # Keras 2 32 | tf_keras_v1 = tf.compat.v1.keras 33 | tf_keras_v2 = tf.compat.v2.keras 34 | 35 | from tensorflow.python.util import function_utils 36 | 37 | fn_args = function_utils.fn_args 38 | 39 | # When we create a timestamped directory, there is a small chance that the 40 | # directory already exists because another process is also creating these 41 | # directories. In this case we just wait one second to get a new timestamp and 42 | # try again. If this fails several times in a row, then something is seriously 43 | # wrong. 44 | MAX_DIRECTORY_CREATION_ATTEMPTS = 10 45 | 46 | 47 | def parse_input_fn_result(result): 48 | """Gets features, labels, and hooks from the result of an Estimator input_fn. 49 | 50 | Args: 51 | result: output of an input_fn to an estimator, which should be one of: 52 | * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple 53 | (features, labels) with same constraints as below. 54 | * A tuple (features, labels): Where `features` is a `Tensor` or a 55 | dictionary of string feature name to `Tensor` and `labels` is a `Tensor` 56 | or a dictionary of string label name to `Tensor`. Both `features` and 57 | `labels` are consumed by `model_fn`. They should satisfy the expectation 58 | of `model_fn` from inputs. 59 | 60 | Returns: 61 | Tuple of features, labels, and input_hooks, where features are as described 62 | above, labels are as described above or None, and input_hooks are a list 63 | of SessionRunHooks to be included when running. 64 | 65 | Raises: 66 | ValueError: if the result is a list or tuple of length != 2. 67 | """ 68 | input_hooks = [] 69 | if isinstance(result, tf.compat.v2.data.Dataset): 70 | iterator = tf.compat.v1.data.make_initializable_iterator(result) 71 | input_hooks.append(_DatasetInitializerHook(iterator)) 72 | result = iterator.get_next() 73 | return parse_iterator_result(result) + (input_hooks,) 74 | 75 | 76 | def parse_iterator_result(result): 77 | """Gets features, labels from result.""" 78 | if isinstance(result, (list, tuple)): 79 | if len(result) != 2: 80 | raise ValueError( 81 | 'input_fn should return (features, labels) as a len 2 tuple.') 82 | return result[0], result[1] 83 | return result, None 84 | 85 | 86 | class _DatasetInitializerHook(tf.compat.v1.train.SessionRunHook): 87 | """Creates a SessionRunHook that initializes the passed iterator.""" 88 | 89 | def __init__(self, iterator): 90 | self._iterator = iterator 91 | 92 | def begin(self): 93 | self._initializer = self._iterator.initializer 94 | 95 | def after_create_session(self, session, coord): 96 | del coord 97 | session.run(self._initializer) 98 | 99 | 100 | class DistributedIteratorInitializerHook(tf.compat.v1.train.SessionRunHook): 101 | """Creates a SessionRunHook that initializes the passed iterator.""" 102 | 103 | def __init__(self, iterator): 104 | self._iterator = iterator 105 | 106 | def begin(self): 107 | self._initializer = self._iterator.initialize() 108 | 109 | def after_create_session(self, session, coord): 110 | del coord 111 | session.run(self._initializer) 112 | 113 | 114 | class MultiHostDatasetInitializerHook(tf.compat.v1.train.SessionRunHook): 115 | """Creates a SessionRunHook that initializes all passed iterators.""" 116 | 117 | def __init__(self, dataset_initializers): 118 | self._initializers = dataset_initializers 119 | 120 | def after_create_session(self, session, coord): 121 | del coord 122 | start = time.time() 123 | session.run(self._initializers) 124 | tf.compat.v1.logging.info('Initialized dataset iterators in %d seconds', 125 | time.time() - start) 126 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/extenders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Extenders of tf.estimator.Estimator.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow.python.util import function_utils 22 | from tensorflow_estimator.python.estimator import estimator as estimator_lib 23 | from tensorflow_estimator.python.estimator.estimator_export import estimator_export 24 | from tensorflow_estimator.python.estimator.mode_keys import ModeKeys 25 | 26 | _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) 27 | 28 | 29 | @estimator_export('estimator.add_metrics') 30 | def add_metrics(estimator, metric_fn): 31 | """Creates a new `tf.estimator.Estimator` which has given metrics. 32 | 33 | Example: 34 | 35 | ```python 36 | def my_auc(labels, predictions): 37 | auc_metric = tf_keras.metrics.AUC(name="my_auc") 38 | auc_metric.update_state(y_true=labels, y_pred=predictions['logistic']) 39 | return {'auc': auc_metric} 40 | 41 | estimator = tf.estimator.DNNClassifier(...) 42 | estimator = tf.estimator.add_metrics(estimator, my_auc) 43 | estimator.train(...) 44 | estimator.evaluate(...) 45 | ``` 46 | Example usage of custom metric which uses features: 47 | 48 | ```python 49 | def my_auc(labels, predictions, features): 50 | auc_metric = tf_keras.metrics.AUC(name="my_auc") 51 | auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'], 52 | sample_weight=features['weight']) 53 | return {'auc': auc_metric} 54 | 55 | estimator = tf.estimator.DNNClassifier(...) 56 | estimator = tf.estimator.add_metrics(estimator, my_auc) 57 | estimator.train(...) 58 | estimator.evaluate(...) 59 | ``` 60 | 61 | Args: 62 | estimator: A `tf.estimator.Estimator` object. 63 | metric_fn: A function which should obey the following signature: 64 | - Args: can only have following four arguments in any order: 65 | * predictions: Predictions `Tensor` or dict of `Tensor` created by given 66 | `estimator`. 67 | * features: Input `dict` of `Tensor` objects created by `input_fn` which 68 | is given to `estimator.evaluate` as an argument. 69 | * labels: Labels `Tensor` or dict of `Tensor` created by `input_fn` 70 | which is given to `estimator.evaluate` as an argument. 71 | * config: config attribute of the `estimator`. 72 | - Returns: Dict of metric results keyed by name. Final metrics are a 73 | union of this and `estimator's` existing metrics. If there is a name 74 | conflict between this and `estimator`s existing metrics, this will 75 | override the existing one. The values of the dict are the results of 76 | calling a metric function, namely a `(metric_tensor, update_op)` tuple. 77 | 78 | Returns: 79 | A new `tf.estimator.Estimator` which has a union of original metrics with 80 | given ones. 81 | """ 82 | _verify_metric_fn_args(metric_fn) 83 | 84 | def new_model_fn(features, labels, mode, config): 85 | spec = estimator.model_fn(features, labels, mode, config) 86 | if mode != ModeKeys.EVAL: 87 | return spec 88 | new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions, 89 | config) 90 | all_metrics = spec.eval_metric_ops or {} 91 | all_metrics.update(new_metrics) 92 | return spec._replace(eval_metric_ops=all_metrics) 93 | 94 | return estimator_lib.Estimator( 95 | model_fn=new_model_fn, 96 | model_dir=estimator.model_dir, 97 | config=estimator.config, 98 | # pylint: disable=protected-access 99 | warm_start_from=estimator._warm_start_settings) 100 | # pylint: enable=protected-access 101 | 102 | 103 | def _verify_metric_fn_args(metric_fn): 104 | args = set(function_utils.fn_args(metric_fn)) 105 | invalid_args = list(args - _VALID_METRIC_FN_ARGS) 106 | if invalid_args: 107 | raise ValueError('metric_fn (%s) has following not expected args: %s' % 108 | (metric_fn, invalid_args)) 109 | 110 | 111 | def _call_metric_fn(metric_fn, features, labels, predictions, config): 112 | """Calls metric fn with proper arguments.""" 113 | metric_fn_args = function_utils.fn_args(metric_fn) 114 | kwargs = {} 115 | if 'features' in metric_fn_args: 116 | kwargs['features'] = features 117 | if 'labels' in metric_fn_args: 118 | kwargs['labels'] = labels 119 | if 'predictions' in metric_fn_args: 120 | kwargs['predictions'] = predictions 121 | if 'config' in metric_fn_args: 122 | kwargs['config'] = config 123 | return metric_fn(**kwargs) 124 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/timeseries/ar_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for ar_model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import tensorflow as tf 23 | from tensorflow.python.framework import test_util 24 | from tensorflow_estimator.python.estimator import estimator_lib 25 | from tensorflow_estimator.python.estimator.canned.timeseries import ar_model 26 | from tensorflow_estimator.python.estimator.canned.timeseries.estimators import LSTMAutoRegressor 27 | from tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import PredictionFeatures 28 | from tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import TrainEvalFeatures 29 | 30 | 31 | @test_util.run_v1_only("Currently incompatible with ResourceVariable") 32 | class ARModelTest(tf.test.TestCase): 33 | 34 | def test_wrong_window_size(self): 35 | estimator = LSTMAutoRegressor( 36 | periodicities=10, 37 | input_window_size=10, 38 | output_window_size=6, 39 | num_features=1) 40 | 41 | def _bad_window_size_input_fn(): 42 | return ({ 43 | TrainEvalFeatures.TIMES: [[1]], 44 | TrainEvalFeatures.VALUES: [[[1.]]] 45 | }, None) 46 | 47 | def _good_data(): 48 | return ({ 49 | TrainEvalFeatures.TIMES: tf.range(16)[None, :], 50 | TrainEvalFeatures.VALUES: tf.reshape(tf.range(16), [1, 16, 1]) 51 | }, None) 52 | 53 | with self.assertRaisesRegexp(ValueError, "set window_size=16"): 54 | estimator.train(input_fn=_bad_window_size_input_fn, steps=1) 55 | # Get a checkpoint for evaluation 56 | estimator.train(input_fn=_good_data, steps=1) 57 | with self.assertRaisesRegexp(ValueError, "requires a window of at least"): 58 | estimator.evaluate(input_fn=_bad_window_size_input_fn, steps=1) 59 | 60 | def test_predictions_direct_lstm(self): 61 | model = ar_model.ARModel( 62 | periodicities=2, 63 | num_features=1, 64 | num_time_buckets=10, 65 | input_window_size=2, 66 | output_window_size=2, 67 | prediction_model_factory=functools.partial( 68 | ar_model.LSTMPredictionModel, num_units=16)) 69 | with tf.compat.v1.Session(): 70 | predicted_values = model.predict({ 71 | PredictionFeatures.TIMES: [[4, 6, 10]], 72 | PredictionFeatures.STATE_TUPLE: ([[1, 2]], [[[1.], [2.]]], [[[], []]]) 73 | }) 74 | tf.compat.v1.initializers.global_variables().run() 75 | self.assertAllEqual(predicted_values["mean"].eval().shape, [1, 3, 1]) 76 | 77 | def test_long_eval(self): 78 | model = ar_model.ARModel( 79 | periodicities=2, 80 | num_features=1, 81 | num_time_buckets=10, 82 | input_window_size=2, 83 | output_window_size=1) 84 | raw_features = { 85 | TrainEvalFeatures.TIMES: [[1, 3, 5, 7, 11]], 86 | TrainEvalFeatures.VALUES: [[[1.], [2.], [3.], [4.], [5.]]] 87 | } 88 | model.initialize_graph() 89 | with tf.compat.v1.variable_scope("armodel"): 90 | raw_evaluation = model.define_loss( 91 | raw_features, mode=estimator_lib.ModeKeys.EVAL) 92 | with tf.compat.v1.Session() as sess: 93 | tf.compat.v1.initializers.global_variables().run() 94 | raw_evaluation_evaled = sess.run(raw_evaluation) 95 | self.assertAllEqual([[5, 7, 11]], raw_evaluation_evaled.prediction_times) 96 | for feature_name in raw_evaluation.predictions: 97 | self.assertAllEqual( 98 | [1, 3, 1], # batch, window, num_features. The window size has 2 99 | # cut off for the first input_window. 100 | raw_evaluation_evaled.predictions[feature_name].shape) 101 | 102 | def test_long_eval_discard_indivisible(self): 103 | model = ar_model.ARModel( 104 | periodicities=2, 105 | num_features=1, 106 | num_time_buckets=10, 107 | input_window_size=2, 108 | output_window_size=2) 109 | raw_features = { 110 | TrainEvalFeatures.TIMES: [[1, 3, 5, 7, 11]], 111 | TrainEvalFeatures.VALUES: [[[1.], [2.], [3.], [4.], [5.]]] 112 | } 113 | model.initialize_graph() 114 | raw_evaluation = model.define_loss( 115 | raw_features, mode=estimator_lib.ModeKeys.EVAL) 116 | with tf.compat.v1.Session() as sess: 117 | tf.compat.v1.initializers.global_variables().run() 118 | raw_evaluation_evaled = sess.run(raw_evaluation) 119 | self.assertAllEqual([[7, 11]], raw_evaluation_evaled.prediction_times) 120 | for feature_name in raw_evaluation.predictions: 121 | self.assertAllEqual( 122 | [1, 2, 1], # batch, window, num_features. The window has two cut 123 | # off for the first input window and one discarded so 124 | # that the remainder is divisible into output windows. 125 | raw_evaluation_evaled.predictions[feature_name].shape) 126 | 127 | 128 | if __name__ == "__main__": 129 | tf.test.main() 130 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests `FeedingQueueRunner` using arrays and `DataFrames`.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | from tensorflow.python.framework import test_util 24 | from tensorflow_estimator.python.estimator.inputs.queues import feeding_functions as ff 25 | 26 | try: 27 | # pylint: disable=g-import-not-at-top 28 | import pandas as pd 29 | HAS_PANDAS = True 30 | except IOError: 31 | # Pandas writes a temporary file during import. If it fails, don't use pandas. 32 | HAS_PANDAS = False 33 | except ImportError: 34 | HAS_PANDAS = False 35 | 36 | 37 | def get_rows(array, row_indices): 38 | rows = [array[i] for i in row_indices] 39 | return np.vstack(rows) 40 | 41 | 42 | @test_util.deprecated_graph_mode_only 43 | class FeedingQueueRunnerTestCase(tf.test.TestCase): 44 | """Tests for `FeedingQueueRunner`.""" 45 | 46 | def testArrayFeeding(self): 47 | with tf.Graph().as_default(): 48 | array = np.arange(32).reshape([16, 2]) 49 | q = ff._enqueue_data(array, capacity=100) 50 | batch_size = 3 51 | dq_op = q.dequeue_many(batch_size) 52 | with tf.compat.v1.Session() as sess: 53 | coord = tf.train.Coordinator() 54 | threads = tf.compat.v1.train.queue_runner.start_queue_runners( 55 | sess=sess, coord=coord) 56 | for i in range(100): 57 | indices = [ 58 | j % array.shape[0] 59 | for j in range(batch_size * i, batch_size * (i + 1)) 60 | ] 61 | expected_dq = get_rows(array, indices) 62 | dq = sess.run(dq_op) 63 | np.testing.assert_array_equal(indices, dq[0]) 64 | np.testing.assert_array_equal(expected_dq, dq[1]) 65 | coord.request_stop() 66 | coord.join(threads) 67 | 68 | def testArrayFeedingMultiThread(self): 69 | with tf.Graph().as_default(): 70 | array = np.arange(256).reshape([128, 2]) 71 | q = ff._enqueue_data(array, capacity=128, num_threads=8, shuffle=True) 72 | batch_size = 3 73 | dq_op = q.dequeue_many(batch_size) 74 | with tf.compat.v1.Session() as sess: 75 | coord = tf.train.Coordinator() 76 | threads = tf.compat.v1.train.queue_runner.start_queue_runners( 77 | sess=sess, coord=coord) 78 | for _ in range(100): 79 | dq = sess.run(dq_op) 80 | indices = dq[0] 81 | expected_dq = get_rows(array, indices) 82 | np.testing.assert_array_equal(expected_dq, dq[1]) 83 | coord.request_stop() 84 | coord.join(threads) 85 | 86 | def testPandasFeeding(self): 87 | if not HAS_PANDAS: 88 | return 89 | with tf.Graph().as_default(): 90 | array1 = np.arange(32) 91 | array2 = np.arange(32, 64) 92 | df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(64, 96)) 93 | q = ff._enqueue_data(df, capacity=100) 94 | batch_size = 5 95 | dq_op = q.dequeue_many(5) 96 | with tf.compat.v1.Session() as sess: 97 | coord = tf.train.Coordinator() 98 | threads = tf.compat.v1.train.queue_runner.start_queue_runners( 99 | sess=sess, coord=coord) 100 | for i in range(100): 101 | indices = [ 102 | j % array1.shape[0] 103 | for j in range(batch_size * i, batch_size * (i + 1)) 104 | ] 105 | expected_df_indices = df.index[indices] 106 | expected_rows = df.iloc[indices] 107 | dq = sess.run(dq_op) 108 | np.testing.assert_array_equal(expected_df_indices, dq[0]) 109 | for col_num, col in enumerate(df.columns): 110 | np.testing.assert_array_equal(expected_rows[col].values, 111 | dq[col_num + 1]) 112 | coord.request_stop() 113 | coord.join(threads) 114 | 115 | def testPandasFeedingMultiThread(self): 116 | if not HAS_PANDAS: 117 | return 118 | with tf.Graph().as_default(): 119 | array1 = np.arange(128, 256) 120 | array2 = 2 * array1 121 | df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(128)) 122 | q = ff._enqueue_data(df, capacity=128, num_threads=8, shuffle=True) 123 | batch_size = 5 124 | dq_op = q.dequeue_many(batch_size) 125 | with tf.compat.v1.Session() as sess: 126 | coord = tf.train.Coordinator() 127 | threads = tf.compat.v1.train.queue_runner.start_queue_runners( 128 | sess=sess, coord=coord) 129 | for _ in range(100): 130 | dq = sess.run(dq_op) 131 | indices = dq[0] 132 | expected_rows = df.iloc[indices] 133 | for col_num, col in enumerate(df.columns): 134 | np.testing.assert_array_equal(expected_rows[col].values, 135 | dq[col_num + 1]) 136 | coord.request_stop() 137 | coord.join(threads) 138 | 139 | 140 | if __name__ == "__main__": 141 | tf.test.main() 142 | -------------------------------------------------------------------------------- /tensorflow_estimator/tools/pip_package/create_pip_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utils to help build and verify pip package for TensorFlow Estimator.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import fnmatch 23 | import os 24 | 25 | PIP_EXCLUDED_FILES = frozenset([ 26 | 'tensorflow_estimator/python/estimator/canned/optimizers_test_v2.py', 27 | 'tensorflow_estimator/python/estimator/canned/dnn_test_fc_v2.py', 28 | 'tensorflow_estimator/python/estimator/canned/dnn_test_fc_v1.py', 29 | 'tensorflow_estimator/python/estimator/canned/v1/dnn_estimator_test_v1.py', 30 | 'tensorflow_estimator/python/estimator/canned/v1/linear_test_v1.py', 31 | 'tensorflow_estimator/python/estimator/canned/v1/dnn_linear_combined_estimator_test_v1.py', 32 | 'tensorflow_estimator/python/estimator/canned/v1/dnn_linear_combined_test_v1.py', 33 | 'tensorflow_estimator/python/estimator/canned/v1/baseline_estimator_test_v1.py', 34 | 'tensorflow_estimator/python/estimator/canned/v1/linear_estimator_test_v1.py', 35 | 'tensorflow_estimator/python/estimator/canned/v1/baseline_test_v1.py', 36 | 'tensorflow_estimator/python/estimator/canned/v1/dnn_test_fc_v1_v1.py', 37 | 'tensorflow_estimator/python/estimator/canned/v1/dnn_test_fc_v2_v1.py', 38 | 'tensorflow_estimator/python/estimator/api/extractor_wrapper.py', 39 | 'tensorflow_estimator/python/estimator/api/generator_wrapper.py', 40 | 'tensorflow_estimator/tools/pip_package/setup.py', 41 | 'tensorflow_estimator/tools/pip_package/create_pip_helper.py', 42 | ]) 43 | 44 | # Directories that should not have __init__.py files generated within them. 45 | EXCLUDED_INIT_FILE_DIRECTORIES = frozenset(['tensorflow_estimator/tools']) 46 | 47 | 48 | class PipPackagingError(Exception): 49 | pass 50 | 51 | 52 | def create_init_files(pip_root): 53 | """Create __init__.py in pip directory tree. 54 | 55 | These files are auto-generated by Bazel when doing typical build/test, but 56 | do not get auto-generated by the pip build process. Currently, the entire 57 | directory tree is just python files, so its fine to just create all of the 58 | init files. 59 | 60 | Args: 61 | pip_root: Root directory of code being packaged into pip. 62 | 63 | Returns: 64 | True: contrib code is included in pip. 65 | """ 66 | has_contrib = False 67 | for path, subdirs, _ in os.walk(pip_root): 68 | has_contrib = has_contrib or '/contrib/' in path 69 | for subdir in subdirs: 70 | init_file_path = os.path.join(path, subdir, '__init__.py') 71 | if any(excluded_path in init_file_path 72 | for excluded_path in EXCLUDED_INIT_FILE_DIRECTORIES): 73 | continue 74 | if not os.path.exists(init_file_path): 75 | # Create empty file 76 | open(init_file_path, 'w').close() 77 | return has_contrib 78 | 79 | 80 | def verify_python_files_in_pip(pip_root, bazel_root, has_contrib): 81 | """Verifies all expected files are packaged into Pip. 82 | 83 | Args: 84 | pip_root: Root directory of code being packaged into pip. 85 | bazel_root: Root directory of Estimator Bazel workspace. 86 | has_contrib: Code from contrib/ should be included in pip. 87 | 88 | Raises: 89 | PipPackagingError: Missing file in pip. 90 | """ 91 | for path, _, files in os.walk(bazel_root): 92 | if not has_contrib and '/contrib/' in path: 93 | continue 94 | python_files = set(fnmatch.filter(files, '*.py')) 95 | python_test_files = set(fnmatch.filter(files, '*test.py')) 96 | # We only care about python files in the pip package, see create_init_files. 97 | files = python_files - python_test_files 98 | for f in files: 99 | pip_path = os.path.join(pip_root, os.path.relpath(path, bazel_root), f) 100 | file_name = os.path.join(path, f) 101 | path_exists = os.path.exists(pip_path) 102 | file_excluded = file_name.lstrip('./') in PIP_EXCLUDED_FILES 103 | if not path_exists and not file_excluded: 104 | raise PipPackagingError( 105 | ('Pip package missing the file %s. If this is expected, add it ' 106 | 'to PIP_EXCLUDED_FILES in create_pip_helper.py. Otherwise, ' 107 | 'make sure it is a build dependency of the pip package') % 108 | file_name) 109 | if path_exists and file_excluded: 110 | raise PipPackagingError( 111 | ('File in PIP_EXCLUDED_FILES included in pip. %s' % file_name)) 112 | 113 | 114 | def main(): 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument( 117 | '--bazel-root', 118 | type=str, 119 | required=True, 120 | help='Root directory of Estimator Bazel workspace.') 121 | parser.add_argument( 122 | '--pip-root', 123 | type=str, 124 | required=True, 125 | help='Root directory of code being packaged into pip.') 126 | 127 | args = parser.parse_args() 128 | has_contrib = create_init_files(args.pip_root) 129 | verify_python_files_in_pip(args.pip_root, args.bazel_root, has_contrib) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/extenders_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """extenders tests.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | from tensorflow_estimator.python.estimator import extenders 24 | from tensorflow_estimator.python.estimator.util import tf_keras 25 | from tensorflow_estimator.python.estimator import run_config 26 | from tensorflow_estimator.python.estimator.canned import linear 27 | 28 | 29 | def get_input_fn(x, y): 30 | 31 | def input_fn(): 32 | dataset = tf.compat.v1.data.Dataset.from_tensor_slices({'x': x, 'y': y}) 33 | iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) 34 | features = iterator.get_next() 35 | labels = features.pop('y') 36 | return features, labels 37 | 38 | return input_fn 39 | 40 | 41 | class AddMetricsTest(tf.test.TestCase): 42 | 43 | def test_should_add_metrics(self): 44 | 45 | def _test_metric_fn(metric_fn): 46 | input_fn = get_input_fn( 47 | x=np.arange(4)[:, None, None], y=np.ones(4)[:, None]) 48 | config = run_config.RunConfig(log_step_count_steps=1) 49 | estimator = linear.LinearClassifierV2( 50 | [tf.feature_column.numeric_column('x')], config=config) 51 | 52 | estimator = extenders.add_metrics(estimator, metric_fn) 53 | 54 | estimator.train(input_fn=input_fn) 55 | metrics = estimator.evaluate(input_fn=input_fn) 56 | self.assertIn('mean_x', metrics) 57 | self.assertEqual(1.5, metrics['mean_x']) 58 | # assert that it keeps original estimators metrics 59 | self.assertIn('auc', metrics) 60 | 61 | def metric_fn(features): 62 | metric = tf_keras.metrics.Mean() 63 | metric.update_state(features['x']) 64 | return {'mean_x': metric} 65 | 66 | _test_metric_fn(metric_fn) 67 | 68 | def test_should_error_out_for_not_recognized_args(self): 69 | estimator = linear.LinearClassifierV2( 70 | [tf.feature_column.numeric_column('x')]) 71 | 72 | def metric_fn(features, not_recognized): 73 | _, _ = features, not_recognized 74 | return {} 75 | 76 | with self.assertRaisesRegexp(ValueError, 'not_recognized'): 77 | estimator = extenders.add_metrics(estimator, metric_fn) 78 | 79 | def test_all_supported_args(self): 80 | input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) 81 | estimator = linear.LinearClassifierV2( 82 | [tf.feature_column.numeric_column('x')]) 83 | 84 | def metric_fn(features, predictions, labels, config): 85 | self.assertIn('x', features) 86 | self.assertIsNotNone(labels) 87 | self.assertIn('logistic', predictions) 88 | self.assertTrue(isinstance(config, run_config.RunConfig)) 89 | return {} 90 | 91 | estimator = extenders.add_metrics(estimator, metric_fn) 92 | 93 | estimator.train(input_fn=input_fn) 94 | estimator.evaluate(input_fn=input_fn) 95 | 96 | def test_all_supported_args_in_different_order(self): 97 | input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) 98 | estimator = linear.LinearClassifierV2( 99 | [tf.feature_column.numeric_column('x')]) 100 | 101 | def metric_fn(labels, config, features, predictions): 102 | self.assertIn('x', features) 103 | self.assertIsNotNone(labels) 104 | self.assertIn('logistic', predictions) 105 | self.assertTrue(isinstance(config, run_config.RunConfig)) 106 | return {} 107 | 108 | estimator = extenders.add_metrics(estimator, metric_fn) 109 | 110 | estimator.train(input_fn=input_fn) 111 | estimator.evaluate(input_fn=input_fn) 112 | 113 | def test_all_args_are_optional(self): 114 | 115 | def _test_metric_fn(metric_fn): 116 | input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) 117 | estimator = linear.LinearClassifierV2( 118 | [tf.feature_column.numeric_column('x')]) 119 | estimator = extenders.add_metrics(estimator, metric_fn) 120 | 121 | estimator.train(input_fn=input_fn) 122 | metrics = estimator.evaluate(input_fn=input_fn) 123 | self.assertEqual(2., metrics['two']) 124 | 125 | def metric_fn(): 126 | metric = tf_keras.metrics.Mean() 127 | metric.update_state(tf.constant([2.])) 128 | return {'two': metric} 129 | 130 | _test_metric_fn(metric_fn) 131 | 132 | def test_overrides_existing_metrics(self): 133 | 134 | def _test_metric_fn(metric_fn): 135 | input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) 136 | estimator = linear.LinearClassifierV2( 137 | [tf.feature_column.numeric_column('x')]) 138 | estimator.train(input_fn=input_fn) 139 | metrics = estimator.evaluate(input_fn=input_fn) 140 | self.assertNotEqual(2., metrics['auc']) 141 | 142 | estimator = extenders.add_metrics(estimator, metric_fn) 143 | metrics = estimator.evaluate(input_fn=input_fn) 144 | self.assertEqual(2., metrics['auc']) 145 | 146 | def metric_fn(): 147 | metric = tf_keras.metrics.Mean() 148 | metric.update_state(tf.constant([2.])) 149 | return {'auc': metric} 150 | 151 | _test_metric_fn(metric_fn) 152 | 153 | 154 | if __name__ == '__main__': 155 | tf.test.main() 156 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/tpu/error_handling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # =================================================================== 15 | """ErrorRendezvous handler for collecting errors from multiple threads.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import contextlib 22 | import sys 23 | import threading 24 | import time 25 | 26 | import six 27 | import tensorflow as tf 28 | from tensorflow_estimator.python.estimator.tools import analytics 29 | 30 | _UNINTERESTING_ERRORS = (tf.errors.CancelledError,) 31 | _IGNORED_ERRORS = ( 32 | tf.errors.AbortedError, 33 | tf.errors.UnavailableError, 34 | ) 35 | 36 | _CHECK_NUMERIC_OP_NAME = 'CheckNumerics' 37 | 38 | 39 | class ErrorRendezvous(object): 40 | """Resolve errors from multiple threads during TPU execution. 41 | 42 | TPU errors can occur on the infeed or outfeed threads as well as the main 43 | training thread. 44 | 45 | Depending on which thread "wins" and receives the session error first, we may 46 | end up showing users a confusing and non-actionable error message (session 47 | cancelled) instead of a root cause (e.g. a bad filename). 48 | 49 | The rendezvous object provides a location to capture these errors until all 50 | threads terminate. At that point we can choose the most informative error 51 | to report. 52 | """ 53 | 54 | def __init__(self, num_sources): 55 | # string -> (message, traceback) 56 | self._errors = {} 57 | self._num_sources = num_sources 58 | self._session_cancel_timer = None 59 | 60 | def record_error(self, source, exc_info, session=None): 61 | """Report an exception from the given source. 62 | 63 | If a session is passed, a timer will be registered to close it after a few 64 | seconds. This is necessary to ensure the main training loop does not hang 65 | if an infeed/oufeed error occurs. We sleep a few seconds to allow a more 66 | interesting error from another thread to propagate. 67 | 68 | Args: 69 | source: string, source of the error 70 | exc_info: Output from `sys.exc_info` (type, value, traceback) 71 | session: Session to close after delay. 72 | """ 73 | _, value, _ = exc_info 74 | # Ignore errors already handled by MonitoredSession 75 | if isinstance(value, _IGNORED_ERRORS): 76 | return 77 | 78 | self._errors[source] = exc_info 79 | 80 | # If the error is a numeric type, e.g., NaN error, we can assume that the 81 | # loop execution completed successfully. In this case, we can skip the 82 | # `session.close()` logic and wait for the infeed/outfeed threads to 83 | # complete as normal. 84 | try: 85 | if value.op.type == _CHECK_NUMERIC_OP_NAME: 86 | analytics.track_numerical_issues(exc_info) 87 | return 88 | except AttributeError as _: 89 | pass 90 | 91 | if session is not None and self._session_cancel_timer is None: 92 | 93 | def _cancel_session(): 94 | time.sleep(5) 95 | tf.compat.v1.logging.error('Closing session due to error %s' % value) 96 | try: 97 | session.close() 98 | except: # pylint: disable=bare-except 99 | tf.compat.v1.logging.error( 100 | '\n\n\nFailed to close session after error.' 101 | 'Other threads may hang.\n\n\n') 102 | 103 | self._session_cancel_timer = threading.Thread(target=_cancel_session,) 104 | self._session_cancel_timer.daemon = True 105 | self._session_cancel_timer.start() 106 | 107 | def record_done(self, source): 108 | """Mark execution source `source` as done. 109 | 110 | If an error was originally reported from `source` it is left intact. 111 | 112 | Args: 113 | source: `str`, source being recorded 114 | """ 115 | tf.compat.v1.logging.info('%s marked as finished', source) 116 | if source not in self._errors: 117 | self._errors[source] = None 118 | 119 | @contextlib.contextmanager 120 | def catch_errors(self, source, session=None): 121 | """Context manager to report any errors within a block.""" 122 | try: 123 | yield 124 | except Exception: # pylint: disable=broad-except 125 | self.record_error(source, sys.exc_info(), session) 126 | 127 | def raise_errors(self, timeout_sec=0): 128 | """Wait for up to `timeout` seconds for all error sources to finish. 129 | 130 | Preferentially raise "interesting" errors (errors not in the 131 | _UNINTERESTING_ERRORS) set. 132 | 133 | Args: 134 | timeout_sec: Seconds to wait for other error sources. 135 | """ 136 | for _ in range(timeout_sec): 137 | if len(self._errors) == self._num_sources: 138 | break 139 | time.sleep(1) 140 | 141 | kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None] 142 | 143 | # First check for any interesting errors, then fall back on the session 144 | # cancelled errors etc. 145 | for k, (typ, value, traceback) in kept_errors: 146 | if isinstance(value, _UNINTERESTING_ERRORS): 147 | continue 148 | else: 149 | tf.compat.v1.logging.warn('Reraising captured error') 150 | six.reraise(typ, value, traceback) 151 | 152 | for k, (typ, value, traceback) in kept_errors: 153 | tf.compat.v1.logging.warn('Reraising captured error') 154 | six.reraise(typ, value, traceback) 155 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/hooks/fake_summary_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Fake summary writer for unit tests.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from tensorflow.core.framework import summary_pb2 21 | from tensorflow.python.framework import test_util 22 | from tensorflow.python.summary.writer import writer 23 | from tensorflow.python.summary.writer import writer_cache 24 | 25 | 26 | # TODO(ptucker): Replace with mock framework. 27 | class FakeSummaryWriter(object): 28 | """Fake summary writer.""" 29 | 30 | _replaced_summary_writer = None 31 | 32 | @classmethod 33 | def install(cls): 34 | if cls._replaced_summary_writer: 35 | raise ValueError('FakeSummaryWriter already installed.') 36 | cls._replaced_summary_writer = writer.FileWriter 37 | writer.FileWriter = FakeSummaryWriter 38 | writer_cache.FileWriter = FakeSummaryWriter 39 | 40 | @classmethod 41 | def uninstall(cls): 42 | if not cls._replaced_summary_writer: 43 | raise ValueError('FakeSummaryWriter not installed.') 44 | writer.FileWriter = cls._replaced_summary_writer 45 | writer_cache.FileWriter = cls._replaced_summary_writer 46 | cls._replaced_summary_writer = None 47 | 48 | def __init__(self, logdir, graph=None): 49 | self._logdir = logdir 50 | self._graph = graph 51 | self._summaries = {} 52 | self._added_graphs = [] 53 | self._added_meta_graphs = [] 54 | self._added_session_logs = [] 55 | self._added_run_metadata = {} 56 | 57 | @property 58 | def summaries(self): 59 | return self._summaries 60 | 61 | def assert_summaries(self, 62 | test_case, 63 | expected_logdir=None, 64 | expected_graph=None, 65 | expected_summaries=None, 66 | expected_added_graphs=None, 67 | expected_added_meta_graphs=None, 68 | expected_session_logs=None): 69 | """Assert expected items have been added to summary writer.""" 70 | if expected_logdir is not None: 71 | test_case.assertEqual(expected_logdir, self._logdir) 72 | if expected_graph is not None: 73 | test_case.assertTrue(expected_graph is self._graph) 74 | expected_summaries = expected_summaries or {} 75 | for step in expected_summaries: 76 | test_case.assertTrue( 77 | step in self._summaries, 78 | msg='Missing step %s from %s.' % (step, self._summaries.keys())) 79 | actual_simple_values = {} 80 | for step_summary in self._summaries[step]: 81 | for v in step_summary.value: 82 | # Ignore global_step/sec since it's written by Supervisor in a 83 | # separate thread, so it's non-deterministic how many get written. 84 | if 'global_step/sec' != v.tag: 85 | actual_simple_values[v.tag] = v.simple_value 86 | test_case.assertEqual(expected_summaries[step], actual_simple_values) 87 | if expected_added_graphs is not None: 88 | test_case.assertEqual(expected_added_graphs, self._added_graphs) 89 | if expected_added_meta_graphs is not None: 90 | test_case.assertEqual( 91 | len(expected_added_meta_graphs), len(self._added_meta_graphs)) 92 | for expected, actual in zip(expected_added_meta_graphs, 93 | self._added_meta_graphs): 94 | test_util.assert_meta_graph_protos_equal(test_case, expected, actual) 95 | if expected_session_logs is not None: 96 | test_case.assertEqual(expected_session_logs, self._added_session_logs) 97 | 98 | def add_summary(self, summ, current_global_step): 99 | """Add summary.""" 100 | if isinstance(summ, bytes): 101 | summary_proto = summary_pb2.Summary() 102 | summary_proto.ParseFromString(summ) 103 | summ = summary_proto 104 | if current_global_step in self._summaries: 105 | step_summaries = self._summaries[current_global_step] 106 | else: 107 | step_summaries = [] 108 | self._summaries[current_global_step] = step_summaries 109 | step_summaries.append(summ) 110 | 111 | # NOTE: Ignore global_step since its value is non-deterministic. 112 | def add_graph(self, graph, global_step=None, graph_def=None): 113 | """Add graph.""" 114 | if (global_step is not None) and (global_step < 0): 115 | raise ValueError('Invalid global_step %s.' % global_step) 116 | if graph_def is not None: 117 | raise ValueError('Unexpected graph_def %s.' % graph_def) 118 | self._added_graphs.append(graph) 119 | 120 | def add_meta_graph(self, meta_graph_def, global_step=None): 121 | """Add metagraph.""" 122 | if (global_step is not None) and (global_step < 0): 123 | raise ValueError('Invalid global_step %s.' % global_step) 124 | self._added_meta_graphs.append(meta_graph_def) 125 | 126 | # NOTE: Ignore global_step since its value is non-deterministic. 127 | def add_session_log(self, session_log, global_step=None): 128 | # pylint: disable=unused-argument 129 | self._added_session_logs.append(session_log) 130 | 131 | def add_run_metadata(self, run_metadata, tag, global_step=None): 132 | if (global_step is not None) and (global_step < 0): 133 | raise ValueError('Invalid global_step %s.' % global_step) 134 | self._added_run_metadata[tag] = run_metadata 135 | 136 | def flush(self): 137 | pass 138 | 139 | def reopen(self): 140 | pass 141 | 142 | def close(self): 143 | pass 144 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/v1/linear_estimator_test_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for LinearEstimator.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import shutil 22 | import tempfile 23 | 24 | import numpy as np 25 | import six 26 | import tensorflow as tf 27 | from tensorflow.python.framework import test_util 28 | from tensorflow_estimator.python.estimator.canned import head as head_lib 29 | from tensorflow_estimator.python.estimator.canned import linear 30 | from tensorflow_estimator.python.estimator.canned import prediction_keys 31 | from tensorflow_estimator.python.estimator.canned.v1 import linear_testing_utils_v1 32 | from tensorflow_estimator.python.estimator.export import export 33 | from tensorflow_estimator.python.estimator.inputs import numpy_io 34 | 35 | 36 | def _linear_estimator_fn(weight_column=None, label_dimension=1, **kwargs): 37 | """Returns a LinearEstimator that uses regression_head.""" 38 | return linear.LinearEstimator( 39 | head=head_lib._regression_head( 40 | weight_column=weight_column, 41 | label_dimension=label_dimension, 42 | # Tests in core (from which this test inherits) test the sum loss. 43 | loss_reduction=tf.compat.v1.losses.Reduction.SUM), 44 | **kwargs) 45 | 46 | 47 | @test_util.run_v1_only('Tests v1 only symbols') 48 | class LinearEstimatorEvaluateTest( 49 | linear_testing_utils_v1.BaseLinearRegressorEvaluationTest, 50 | tf.test.TestCase): 51 | 52 | def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 53 | tf.test.TestCase.__init__(self, methodName) 54 | linear_testing_utils_v1.BaseLinearRegressorEvaluationTest.__init__( 55 | self, _linear_estimator_fn) 56 | 57 | 58 | @test_util.run_v1_only('Tests v1 only symbols') 59 | class LinearEstimatorPredictTest( 60 | linear_testing_utils_v1.BaseLinearRegressorPredictTest, tf.test.TestCase): 61 | 62 | def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 63 | tf.test.TestCase.__init__(self, methodName) 64 | linear_testing_utils_v1.BaseLinearRegressorPredictTest.__init__( 65 | self, _linear_estimator_fn) 66 | 67 | 68 | @test_util.run_v1_only('Tests v1 only symbols') 69 | class LinearEstimatorTrainTest( 70 | linear_testing_utils_v1.BaseLinearRegressorTrainingTest, tf.test.TestCase): 71 | 72 | def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 73 | tf.test.TestCase.__init__(self, methodName) 74 | linear_testing_utils_v1.BaseLinearRegressorTrainingTest.__init__( 75 | self, _linear_estimator_fn) 76 | 77 | 78 | @test_util.run_v1_only('Tests v1 only symbols') 79 | class LinearEstimatorIntegrationTest(tf.test.TestCase): 80 | 81 | def setUp(self): 82 | self._model_dir = tempfile.mkdtemp() 83 | 84 | def tearDown(self): 85 | if self._model_dir: 86 | tf.compat.v1.summary.FileWriterCache.clear() 87 | shutil.rmtree(self._model_dir) 88 | 89 | def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, 90 | input_dimension, label_dimension, batch_size): 91 | feature_columns = [ 92 | tf.feature_column.numeric_column('x', shape=(input_dimension,)) 93 | ] 94 | est = linear.LinearEstimator( 95 | head=head_lib._regression_head(label_dimension=label_dimension), 96 | feature_columns=feature_columns, 97 | model_dir=self._model_dir) 98 | 99 | # Train 100 | num_steps = 10 101 | est.train(train_input_fn, steps=num_steps) 102 | 103 | # Evaluate 104 | scores = est.evaluate(eval_input_fn) 105 | self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP]) 106 | self.assertIn('loss', six.iterkeys(scores)) 107 | 108 | # Predict 109 | predictions = np.array([ 110 | x[prediction_keys.PredictionKeys.PREDICTIONS] 111 | for x in est.predict(predict_input_fn) 112 | ]) 113 | self.assertAllEqual((batch_size, label_dimension), predictions.shape) 114 | 115 | # Export 116 | feature_spec = tf.compat.v1.feature_column.make_parse_example_spec( 117 | feature_columns) 118 | serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( 119 | feature_spec) 120 | export_dir = est.export_saved_model(tempfile.mkdtemp(), 121 | serving_input_receiver_fn) 122 | self.assertTrue(tf.compat.v1.gfile.Exists(export_dir)) 123 | 124 | def test_numpy_input_fn(self): 125 | """Tests complete flow with numpy_input_fn.""" 126 | label_dimension = 2 127 | batch_size = 10 128 | data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) 129 | data = data.reshape(batch_size, label_dimension) 130 | # learn y = x 131 | train_input_fn = numpy_io.numpy_input_fn( 132 | x={'x': data}, 133 | y=data, 134 | batch_size=batch_size, 135 | num_epochs=None, 136 | shuffle=True) 137 | eval_input_fn = numpy_io.numpy_input_fn( 138 | x={'x': data}, y=data, batch_size=batch_size, shuffle=False) 139 | predict_input_fn = numpy_io.numpy_input_fn( 140 | x={'x': data}, batch_size=batch_size, shuffle=False) 141 | 142 | self._test_complete_flow( 143 | train_input_fn=train_input_fn, 144 | eval_input_fn=eval_input_fn, 145 | predict_input_fn=predict_input_fn, 146 | input_dimension=label_dimension, 147 | label_dimension=label_dimension, 148 | batch_size=batch_size) 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.test.main() 153 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/timeseries/BUILD: -------------------------------------------------------------------------------- 1 | # Placeholder: load py_library 2 | load("//tensorflow_estimator:estimator.bzl", "py_test") 3 | 4 | package(default_visibility = ["//tensorflow_estimator:__subpackages__"]) 5 | 6 | licenses(["notice"]) 7 | 8 | py_library( 9 | name = "feature_keys", 10 | srcs = [ 11 | "feature_keys.py", 12 | ], 13 | srcs_version = "PY3", 14 | deps = ["//tensorflow_estimator/python/estimator:expect_tensorflow_installed"], 15 | ) 16 | 17 | py_library( 18 | name = "saved_model_utils", 19 | srcs = [ 20 | "saved_model_utils.py", 21 | ], 22 | srcs_version = "PY3", 23 | deps = [ 24 | ":feature_keys", 25 | ":head", 26 | ":model_utils", 27 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 28 | ], 29 | ) 30 | 31 | py_library( 32 | name = "model", 33 | srcs = [ 34 | "model.py", 35 | ], 36 | srcs_version = "PY3", 37 | deps = [ 38 | ":feature_keys", 39 | ":math_utils", 40 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 41 | ], 42 | ) 43 | 44 | py_library( 45 | name = "estimators", 46 | srcs = [ 47 | "estimators.py", 48 | ], 49 | srcs_version = "PY3", 50 | deps = [ 51 | ":ar_model", 52 | ":feature_keys", 53 | ":head", 54 | ":math_utils", 55 | ":saved_model_utils", 56 | ":state_management", 57 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 58 | ], 59 | ) 60 | 61 | py_test( 62 | name = "estimators_test", 63 | srcs = [ 64 | "estimators_test.py", 65 | ], 66 | python_version = "PY3", 67 | srcs_version = "PY3", 68 | tags = [ 69 | "notap", # TODO(b/132129465): Re-enable. 70 | ], 71 | deps = [ 72 | ":ar_model", 73 | ":estimators", 74 | ":feature_keys", 75 | ":saved_model_utils", 76 | "//tensorflow_estimator/python/estimator:estimator_py", 77 | "//tensorflow_estimator/python/estimator:expect_numpy_installed", 78 | "//tensorflow_estimator/python/estimator:expect_proto_cpp_installed", 79 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 80 | ], 81 | ) 82 | 83 | py_library( 84 | name = "head", 85 | srcs = [ 86 | "head.py", 87 | ], 88 | srcs_version = "PY3", 89 | deps = [ 90 | ":feature_keys", 91 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 92 | ], 93 | ) 94 | 95 | py_test( 96 | name = "head_test", 97 | srcs = [ 98 | "head_test.py", 99 | ], 100 | python_version = "PY3", 101 | shard_count = 4, 102 | srcs_version = "PY3", 103 | deps = [ 104 | ":estimators", 105 | ":feature_keys", 106 | ":head", 107 | ":model", 108 | ":state_management", 109 | "//tensorflow_estimator/python/estimator:estimator_py", 110 | "//tensorflow_estimator/python/estimator:expect_absl_installed", 111 | "//tensorflow_estimator/python/estimator:expect_numpy_installed", 112 | "//tensorflow_estimator/python/estimator:expect_proto_cpp_installed", 113 | "//tensorflow_estimator/python/estimator:expect_six_installed", 114 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 115 | ], 116 | ) 117 | 118 | py_library( 119 | name = "model_utils", 120 | srcs = [ 121 | "model_utils.py", 122 | ], 123 | srcs_version = "PY3", 124 | deps = [ 125 | ":feature_keys", 126 | "//tensorflow_estimator/python/estimator:expect_numpy_installed", 127 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 128 | ], 129 | ) 130 | 131 | py_library( 132 | name = "state_management", 133 | srcs = [ 134 | "state_management.py", 135 | ], 136 | srcs_version = "PY3", 137 | deps = [ 138 | ":feature_keys", 139 | ":math_utils", 140 | ":model", 141 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 142 | ], 143 | ) 144 | 145 | py_library( 146 | name = "ar_model", 147 | srcs = [ 148 | "ar_model.py", 149 | ], 150 | srcs_version = "PY3", 151 | deps = [ 152 | ":feature_keys", 153 | ":model", 154 | ":model_utils", 155 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 156 | ], 157 | ) 158 | 159 | py_test( 160 | name = "ar_model_test", 161 | srcs = [ 162 | "ar_model_test.py", 163 | ], 164 | python_version = "PY3", 165 | shard_count = 4, 166 | srcs_version = "PY3", 167 | deps = [ 168 | ":ar_model", 169 | ":estimators", 170 | ":feature_keys", 171 | "//tensorflow_estimator/python/estimator:estimator_py", 172 | "//tensorflow_estimator/python/estimator:expect_proto_cpp_installed", 173 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 174 | ], 175 | ) 176 | 177 | py_test( 178 | name = "ar_model_training_test", 179 | srcs = [ 180 | "ar_model_training_test.py", 181 | ], 182 | python_version = "PY3", 183 | shard_count = 4, 184 | srcs_version = "PY3", 185 | deps = [ 186 | ":ar_model", 187 | ":estimators", 188 | ":feature_keys", 189 | "//tensorflow_estimator/python/estimator:estimator_py", 190 | "//tensorflow_estimator/python/estimator:expect_proto_cpp_installed", 191 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 192 | ], 193 | ) 194 | 195 | py_library( 196 | name = "math_utils", 197 | srcs = [ 198 | "math_utils.py", 199 | ], 200 | srcs_version = "PY3", 201 | deps = [ 202 | ":feature_keys", 203 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 204 | ], 205 | ) 206 | 207 | py_test( 208 | name = "math_utils_test", 209 | srcs = [ 210 | "math_utils_test.py", 211 | ], 212 | python_version = "PY3", 213 | srcs_version = "PY3", 214 | deps = [ 215 | ":feature_keys", 216 | ":math_utils", 217 | "//tensorflow_estimator/python/estimator:expect_proto_cpp_installed", 218 | "//tensorflow_estimator/python/estimator:expect_tensorflow_installed", 219 | ], 220 | ) 221 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/object_checkpointing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Integration tests for Estimator + object checkpointing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | import os 24 | # pylint: disable=g-import-not-at-top 25 | try: 26 | from tensorflow.python.checkpoint import checkpoint as util 27 | except ImportError: 28 | # TODO(allenl): Remove this after cl/229814711 syncs 29 | from tensorflow.python.training.checkpointable import util 30 | 31 | from tensorflow_estimator.python.estimator import estimator as estimator_lib 32 | from tensorflow_estimator.python.estimator import model_fn as model_fn_lib 33 | from tensorflow_estimator.python.estimator.util import tf_keras 34 | from tensorflow_estimator.python.estimator.export import export_lib 35 | 36 | 37 | class SubclassedModel(tf_keras.models.Model): 38 | 39 | def __init__(self): 40 | super(SubclassedModel, self).__init__() 41 | self.dense_one = tf_keras.layers.Dense(5) 42 | self.dense_two = tf_keras.layers.Dense(1) 43 | 44 | def call(self, inputs): 45 | return self.dense_two(self.dense_one(inputs)) 46 | 47 | 48 | def _serving_input_receiver_fn(): 49 | receiver = tf.compat.v1.placeholder( 50 | tf.dtypes.float32, shape=[None, 1], name='input') 51 | return export_lib.ServingInputReceiver( 52 | features={'feature': receiver}, receiver_tensors=receiver) 53 | 54 | 55 | class ObjectCheckpointingTest(tf.test.TestCase): 56 | 57 | def _make_estimator(self, model_dir): 58 | 59 | def _model_fn(features, labels, mode): 60 | del labels 61 | model = SubclassedModel() 62 | optimizer = tf_keras.optimizers.Adam(0.01) 63 | checkpoint = util.Checkpoint( 64 | step=tf.compat.v1.train.get_or_create_global_step(), 65 | optimizer=optimizer, 66 | model=model) 67 | # Make the save counter to satisfy the assert_consumed() assertion later 68 | checkpoint.save_counter # pylint: disable=pointless-statement 69 | with tf.GradientTape() as tape: 70 | output = model(features['feature']) 71 | loss = tf.math.reduce_sum(output) 72 | variables = model.trainable_variables 73 | gradients = tape.gradient(loss, variables) 74 | train_op = tf.group( 75 | optimizer.apply_gradients(zip(gradients, variables)), 76 | checkpoint.step.assign_add(1)) 77 | return model_fn_lib.EstimatorSpec( 78 | mode, 79 | loss=loss, 80 | train_op=train_op, 81 | predictions=dict( 82 | output=output, 83 | bias=tf.tile(model.dense_two.bias[None, :], 84 | [tf.compat.v1.shape(output)[0], 1]), 85 | step=tf.tile(checkpoint.step[None], 86 | [tf.compat.v1.shape(output)[0]])), 87 | scaffold=tf.compat.v1.train.Scaffold(saver=checkpoint)) 88 | 89 | est = estimator_lib.EstimatorV2(model_fn=_model_fn, model_dir=model_dir) 90 | 91 | def _input_map_fn(tensor): 92 | """Converts a tensor into `features, labels` format used by Estimator.""" 93 | return {'feature': tensor}, tensor 94 | 95 | def _input_fn(): 96 | return tf.compat.v1.data.Dataset.from_tensors( 97 | [1.]).repeat().batch(10).map(_input_map_fn) 98 | 99 | return est, _input_fn 100 | 101 | def testTwoWayCompatibility(self): 102 | save_model_dir = os.path.join(self.get_temp_dir(), 'model_dir') 103 | save_est, input_fn = self._make_estimator(save_model_dir) 104 | 105 | save_est.train(input_fn, steps=3) 106 | 107 | model = SubclassedModel() 108 | optimizer = tf_keras.optimizers.Adam(0.01) 109 | checkpoint = util.Checkpoint( 110 | step=tf.Variable(0, dtype=tf.dtypes.int64), 111 | optimizer=optimizer, 112 | model=model) 113 | status = checkpoint.restore(tf.train.latest_checkpoint(save_model_dir)) 114 | self.assertEqual(3, self.evaluate(checkpoint.step)) 115 | with tf.GradientTape() as tape: 116 | output = model(tf.constant([[1.]])) 117 | loss = tf.math.reduce_sum(output) 118 | variables = model.trainable_variables 119 | gradients = tape.gradient(loss, variables) 120 | optimizer.apply_gradients(zip(gradients, variables)) 121 | status.assert_consumed() 122 | 123 | # The optimizer uses this for some reason... 124 | tf_keras.backend.clear_session() 125 | 126 | load_model_dir = os.path.join(self.get_temp_dir(), 'load_model_dir/') 127 | checkpoint.step.assign(40) 128 | checkpoint.model.dense_two.bias.assign([13.]) 129 | checkpoint.save(load_model_dir) 130 | load_est, input_fn = self._make_estimator(load_model_dir) 131 | predictions = load_est.predict(input_fn) 132 | predictions = next(predictions) 133 | self.assertAllClose([13.], predictions['bias']) 134 | self.assertEqual(40, predictions['step']) 135 | 136 | def testSavedModelExport(self): 137 | model_dir = os.path.join(self.get_temp_dir(), 'estimator_train_dir') 138 | estimator, input_fn = self._make_estimator(model_dir) 139 | estimator.train(input_fn, steps=1) # Train to generate a checkpoint. 140 | 141 | export_dir_base = os.path.join(self.get_temp_dir(), 'estimator_export_dir') 142 | export_dir = estimator.export_saved_model(export_dir_base, 143 | _serving_input_receiver_fn) 144 | 145 | # Check the saved model loads and simple inference runs. 146 | model = tf.compat.v2.saved_model.load(export_dir) 147 | model.signatures['serving_default'](tf.constant([[1.]])) 148 | 149 | 150 | if __name__ == '__main__': 151 | tf.compat.v1.enable_eager_execution() 152 | tf.test.main() 153 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/optimizers_test_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for optimizers.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from tensorflow_estimator.python.estimator.util import tf_keras 23 | from tensorflow_estimator.python.estimator.canned import optimizers 24 | 25 | 26 | class _TestOptimizerV2(tf_keras.optimizers.legacy.Optimizer): 27 | 28 | def __init__(self): 29 | super(_TestOptimizerV2, self).__init__(name='TestOptimizer') 30 | 31 | def get_config(self): 32 | pass 33 | 34 | 35 | class GetOptimizerInstanceV2(tf.test.TestCase): 36 | """Tests for Optimizer V2.""" 37 | 38 | def test_unsupported_name(self): 39 | with self.assertRaisesRegexp( 40 | ValueError, 'Unsupported optimizer name: unsupported_name'): 41 | optimizers.get_optimizer_instance_v2( 42 | 'unsupported_name', learning_rate=0.1) 43 | 44 | def test_adagrad_but_no_learning_rate(self): 45 | with self.cached_session(): 46 | opt = optimizers.get_optimizer_instance_v2('Adagrad') 47 | # The creation of variables in optimizer_v2 is deferred to when it's 48 | # called, so we need to manually create it here. Same for all other tests. 49 | self.assertIsInstance(opt.learning_rate, tf.Variable) 50 | self.evaluate(tf.compat.v1.initializers.global_variables()) 51 | self.assertIsInstance( 52 | opt, 53 | (tf_keras.optimizers.Adagrad, tf_keras.optimizers.legacy.Adagrad)) 54 | self.assertAlmostEqual(0.001, self.evaluate(opt.learning_rate)) 55 | 56 | def test_adam_but_no_learning_rate(self): 57 | with self.cached_session(): 58 | opt = optimizers.get_optimizer_instance_v2('Adam') 59 | self.assertIsInstance(opt.learning_rate, tf.Variable) 60 | self.evaluate(tf.compat.v1.initializers.global_variables()) 61 | self.assertIsInstance( 62 | opt, (tf_keras.optimizers.Adam, tf_keras.optimizers.legacy.Adam)) 63 | self.assertAlmostEqual(0.001, self.evaluate(opt.learning_rate)) 64 | 65 | def test_adagrad(self): 66 | with self.cached_session(): 67 | opt = optimizers.get_optimizer_instance_v2('Adagrad', learning_rate=0.1) 68 | self.assertIsInstance(opt.learning_rate, tf.Variable) 69 | self.evaluate(tf.compat.v1.initializers.global_variables()) 70 | self.assertIsInstance( 71 | opt, 72 | (tf_keras.optimizers.Adagrad, tf_keras.optimizers.legacy.Adagrad)) 73 | self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate)) 74 | 75 | def test_adam(self): 76 | with self.cached_session(): 77 | opt = optimizers.get_optimizer_instance_v2('Adam', learning_rate=0.1) 78 | self.assertIsInstance(opt.learning_rate, tf.Variable) 79 | self.evaluate(tf.compat.v1.initializers.global_variables()) 80 | self.assertIsInstance( 81 | opt, (tf_keras.optimizers.Adam, tf_keras.optimizers.legacy.Adam)) 82 | self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate)) 83 | 84 | def test_ftrl(self): 85 | with self.cached_session(): 86 | opt = optimizers.get_optimizer_instance_v2('Ftrl', learning_rate=0.1) 87 | self.assertIsInstance(opt.learning_rate, tf.Variable) 88 | self.evaluate(tf.compat.v1.initializers.global_variables()) 89 | self.assertIsInstance( 90 | opt, (tf_keras.optimizers.Ftrl, tf_keras.optimizers.legacy.Ftrl)) 91 | self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate)) 92 | 93 | def test_rmsprop(self): 94 | with self.cached_session(): 95 | opt = optimizers.get_optimizer_instance_v2('RMSProp', learning_rate=0.1) 96 | self.assertIsInstance(opt.learning_rate, tf.Variable) 97 | self.evaluate(tf.compat.v1.initializers.global_variables()) 98 | self.assertIsInstance( 99 | opt, 100 | (tf_keras.optimizers.RMSprop, tf_keras.optimizers.legacy.RMSprop)) 101 | self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate)) 102 | 103 | def test_sgd(self): 104 | with self.cached_session(): 105 | opt = optimizers.get_optimizer_instance_v2('SGD', learning_rate=0.1) 106 | self.assertIsInstance(opt.learning_rate, tf.Variable) 107 | self.evaluate(tf.compat.v1.initializers.global_variables()) 108 | self.assertIsInstance( 109 | opt, (tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD)) 110 | self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate)) 111 | 112 | def test_object(self): 113 | opt = optimizers.get_optimizer_instance_v2(_TestOptimizerV2()) 114 | self.assertIsInstance(opt, _TestOptimizerV2) 115 | 116 | def test_object_invalid(self): 117 | with self.assertRaisesRegexp( 118 | ValueError, 119 | 'The given object is not a tf_keras.optimizers.Optimizer instance'): 120 | optimizers.get_optimizer_instance_v2((1, 2, 3)) 121 | 122 | def test_callable(self): 123 | 124 | def _optimizer_fn(): 125 | return _TestOptimizerV2() 126 | 127 | opt = optimizers.get_optimizer_instance_v2(_optimizer_fn) 128 | self.assertIsInstance(opt, _TestOptimizerV2) 129 | 130 | def test_lambda(self): 131 | opt = optimizers.get_optimizer_instance_v2(lambda: _TestOptimizerV2()) # pylint: disable=unnecessary-lambda 132 | self.assertIsInstance(opt, _TestOptimizerV2) 133 | 134 | def test_callable_returns_invalid(self): 135 | 136 | def _optimizer_fn(): 137 | return (1, 2, 3) 138 | 139 | with self.assertRaisesRegexp( 140 | ValueError, 141 | 'The given object is not a tf_keras.optimizers.Optimizer instance'): 142 | optimizers.get_optimizer_instance_v2(_optimizer_fn) 143 | 144 | 145 | if __name__ == '__main__': 146 | tf.test.main() 147 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/canned_estimator_ds_integration_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests canned estimators with distribution strategy.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import inspect 22 | import tempfile 23 | 24 | from absl.testing import parameterized 25 | import numpy as np 26 | import tensorflow as tf 27 | from tensorflow_estimator.python.estimator import run_config 28 | from tensorflow_estimator.python.estimator.util import tf_keras 29 | from tensorflow_estimator.python.estimator.canned import dnn 30 | from tensorflow_estimator.python.estimator.canned import dnn_linear_combined 31 | from tensorflow_estimator.python.estimator.canned import linear 32 | from tensorflow_estimator.python.estimator.extenders import add_metrics 33 | 34 | 35 | class CannedEstimatorDistributionStrategyTest(tf.test.TestCase, 36 | parameterized.TestCase): 37 | 38 | def setUp(self): 39 | super(CannedEstimatorDistributionStrategyTest, self).setUp() 40 | np.random.seed(1337) 41 | tf.compat.v1.random.set_random_seed(1337) 42 | 43 | self._model_dir = tempfile.mkdtemp() 44 | 45 | def dataset_input_fn(self, x, y, batch_size, shuffle): 46 | 47 | def input_fn(): 48 | dataset = tf.compat.v1.data.Dataset.from_tensor_slices((x, y)) 49 | if shuffle: 50 | dataset = dataset.shuffle(batch_size) 51 | dataset = dataset.repeat(10).batch(batch_size) 52 | return dataset 53 | 54 | return input_fn 55 | 56 | @tf.compat.v2.__internal__.distribute.combinations.generate( 57 | tf.compat.v2.__internal__.test.combinations.combine( 58 | mode=['graph', 'eager'], 59 | distribution=[ 60 | tf.compat.v2.__internal__.distribute.combinations.one_device_strategy, 61 | tf.compat.v2.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, 62 | tf.compat.v2.__internal__.distribute.combinations.mirrored_strategy_with_two_gpus, 63 | ], 64 | estimator_cls=[ 65 | dnn_linear_combined.DNNLinearCombinedRegressorV2, 66 | dnn.DNNRegressorV2, 67 | linear.LinearRegressorV2, 68 | ])) 69 | def test_canned_estimator(self, distribution, estimator_cls): 70 | label_dimension = 2 71 | batch_size = 10 72 | # Adding one extra row (+ label_dimension) to test the last partial batch 73 | # use case. 74 | data = np.linspace( 75 | 0., 76 | 2., 77 | batch_size * label_dimension + label_dimension, 78 | dtype=np.float32) 79 | data = data.reshape(batch_size + 1, label_dimension) 80 | fc = tf.feature_column.numeric_column('x', shape=(2,)) 81 | 82 | # Set kwargs based on the current canned estimator class. 83 | estimator_kw_args = { 84 | 'model_dir': self._model_dir, 85 | 'label_dimension': 2, 86 | } 87 | 88 | cls_args = inspect.getargspec(estimator_cls.__init__).args 89 | if 'hidden_units' in cls_args: 90 | estimator_kw_args['hidden_units'] = [2, 2] 91 | elif 'dnn_hidden_units' in cls_args: 92 | estimator_kw_args['dnn_hidden_units'] = [2, 2] 93 | 94 | if 'optimizer' in cls_args: 95 | estimator_kw_args['optimizer'] = 'SGD' 96 | else: 97 | estimator_kw_args['linear_optimizer'] = 'SGD' 98 | estimator_kw_args['dnn_optimizer'] = 'SGD' 99 | 100 | if 'feature_columns' in cls_args: 101 | estimator_kw_args['feature_columns'] = [fc] 102 | else: 103 | estimator_kw_args['linear_feature_columns'] = [fc] 104 | estimator_kw_args['dnn_feature_columns'] = [fc] 105 | 106 | def my_metrics(features): 107 | metric = tf_keras.metrics.Mean() 108 | metric.update_state(features['x']) 109 | return {'mean_x': metric} 110 | 111 | # Create a canned estimator and train to save a checkpoint. 112 | input_fn = self.dataset_input_fn( 113 | x={'x': data}, y=data, batch_size=batch_size, shuffle=False) 114 | canned_est = estimator_cls(**estimator_kw_args) 115 | canned_est.train(input_fn=input_fn) 116 | 117 | # Create a second canned estimator, warm-started from the first. 118 | del estimator_kw_args['model_dir'] 119 | estimator_kw_args['warm_start_from'] = canned_est.model_dir 120 | warm_started_canned_est = estimator_cls(**estimator_kw_args) 121 | warm_started_canned_est.train(input_fn=input_fn) 122 | 123 | # Create a third canned estimator, warm-started from the first. 124 | input_fn = self.dataset_input_fn( 125 | x={'x': data}, 126 | y=data, 127 | batch_size=batch_size // distribution.num_replicas_in_sync, 128 | shuffle=False) 129 | estimator_kw_args['config'] = run_config.RunConfig( 130 | train_distribute=distribution, eval_distribute=distribution) 131 | warm_started_canned_est_with_ds = estimator_cls(**estimator_kw_args) 132 | warm_started_canned_est_with_ds.train(input_fn=input_fn) 133 | 134 | for variable_name in warm_started_canned_est.get_variable_names(): 135 | self.assertAllClose( 136 | warm_started_canned_est_with_ds.get_variable_value(variable_name), 137 | warm_started_canned_est.get_variable_value(variable_name)) 138 | 139 | warm_started_canned_est = add_metrics(warm_started_canned_est, my_metrics) 140 | warm_started_canned_est_with_ds = add_metrics( 141 | warm_started_canned_est_with_ds, my_metrics) 142 | 143 | scores = warm_started_canned_est.evaluate(input_fn) 144 | scores_with_ds = warm_started_canned_est_with_ds.evaluate(input_fn) 145 | self.assertAlmostEqual(scores['loss'], scores_with_ds['loss'], 5) 146 | self.assertAlmostEqual(scores['mean_x'], scores_with_ds['mean_x'], 5) 147 | 148 | 149 | if __name__ == '__main__': 150 | tf.test.main() 151 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/inputs/pandas_io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Methods to allow pandas.DataFrame.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import uuid 22 | import numpy as np 23 | import six 24 | from tensorflow_estimator.python.estimator.estimator_export import estimator_export 25 | from tensorflow_estimator.python.estimator.inputs.queues import feeding_functions 26 | 27 | try: 28 | # pylint: disable=g-import-not-at-top 29 | # pylint: disable=unused-import 30 | import pandas as pd 31 | HAS_PANDAS = True 32 | except IOError: 33 | # Pandas writes a temporary file during import. If it fails, don't use pandas. 34 | HAS_PANDAS = False 35 | except ImportError: 36 | HAS_PANDAS = False 37 | 38 | 39 | def _get_unique_target_key(features, target_column_name): 40 | """Returns a key that does not exist in the input DataFrame `features`. 41 | 42 | Args: 43 | features: DataFrame 44 | target_column_name: Name of the target column as a `str` 45 | 46 | Returns: 47 | A unique key that can be used to insert the target into 48 | features. 49 | """ 50 | if target_column_name in features: 51 | target_column_name += '_' + str(uuid.uuid4()) 52 | return target_column_name 53 | 54 | 55 | @estimator_export(v1=['estimator.inputs.pandas_input_fn']) 56 | def pandas_input_fn(x, 57 | y=None, 58 | batch_size=128, 59 | num_epochs=1, 60 | shuffle=None, 61 | queue_capacity=1000, 62 | num_threads=1, 63 | target_column='target'): 64 | """Returns input function that would feed Pandas DataFrame into the model. 65 | 66 | Note: `y`'s index must match `x`'s index. 67 | 68 | Args: 69 | x: pandas `DataFrame` object. 70 | y: pandas `Series` object or `DataFrame`. `None` if absent. 71 | batch_size: int, size of batches to return. 72 | num_epochs: int, number of epochs to iterate over data. If not `None`, read 73 | attempts that would exceed this value will raise `OutOfRangeError`. 74 | shuffle: bool, whether to read the records in random order. 75 | queue_capacity: int, size of the read queue. If `None`, it will be set 76 | roughly to the size of `x`. 77 | num_threads: Integer, number of threads used for reading and enqueueing. In 78 | order to have predicted and repeatable order of reading and enqueueing, 79 | such as in prediction and evaluation mode, `num_threads` should be 1. 80 | target_column: str, name to give the target column `y`. This parameter is 81 | not used when `y` is a `DataFrame`. 82 | 83 | Returns: 84 | Function, that has signature of ()->(dict of `features`, `target`) 85 | 86 | Raises: 87 | ValueError: if `x` already contains a column with the same name as `y`, or 88 | if the indexes of `x` and `y` don't match. 89 | ValueError: if 'shuffle' is not provided or a bool. 90 | """ 91 | if not HAS_PANDAS: 92 | raise TypeError( 93 | 'pandas_input_fn should not be called without pandas installed') 94 | 95 | if not isinstance(shuffle, bool): 96 | raise ValueError('shuffle must be provided and explicitly set as boolean ' 97 | '(it is recommended to set it as True for training); ' 98 | 'got {}'.format(shuffle)) 99 | 100 | if not isinstance(target_column, six.string_types): 101 | raise TypeError('target_column must be a string type') 102 | 103 | x = x.copy() 104 | if y is not None: 105 | if target_column in x: 106 | raise ValueError( 107 | 'Cannot use name %s for target column: DataFrame already has a ' 108 | 'column with that name: %s' % (target_column, x.columns)) 109 | if not np.array_equal(x.index, y.index): 110 | raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n' 111 | 'Index for y: %s\n' % (x.index, y.index)) 112 | if isinstance(y, pd.DataFrame): 113 | y_columns = [ 114 | (column, _get_unique_target_key(x, column)) for column in list(y) 115 | ] 116 | target_column = [v for _, v in y_columns] 117 | x[target_column] = y 118 | else: 119 | x[target_column] = y 120 | 121 | # TODO(mdan): These are memory copies. We probably don't need 4x slack space. 122 | # The sizes below are consistent with what I've seen elsewhere. 123 | if queue_capacity is None: 124 | if shuffle: 125 | queue_capacity = 4 * len(x) 126 | else: 127 | queue_capacity = len(x) 128 | min_after_dequeue = max(queue_capacity / 4, 1) 129 | 130 | def input_fn(): 131 | """Pandas input function.""" 132 | queue = feeding_functions._enqueue_data( # pylint: disable=protected-access 133 | x, 134 | queue_capacity, 135 | shuffle=shuffle, 136 | min_after_dequeue=min_after_dequeue, 137 | num_threads=num_threads, 138 | enqueue_size=batch_size, 139 | num_epochs=num_epochs) 140 | if num_epochs is None: 141 | features = queue.dequeue_many(batch_size) 142 | else: 143 | features = queue.dequeue_up_to(batch_size) 144 | assert len(features) == len(x.columns) + 1, ('Features should have one ' 145 | 'extra element for the index.') 146 | features = features[1:] 147 | features = dict(zip(list(x.columns), features)) 148 | if y is not None: 149 | if isinstance(target_column, list): 150 | keys = [k for k, _ in y_columns] 151 | values = [features.pop(column) for column in target_column] 152 | target = {k: v for k, v in zip(keys, values)} 153 | else: 154 | target = features.pop(target_column) 155 | return features, target 156 | return features 157 | 158 | return input_fn 159 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/tf_estimator_doctest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Run doctests for tensorflow.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import sys 24 | import textwrap 25 | import tensorflow as tf 26 | import numpy as np 27 | 28 | from absl import flags 29 | from absl.testing import absltest 30 | 31 | import tensorflow_estimator.python.estimator.estimator_lib as tfe 32 | 33 | import tensorflow.compat.v2 as tf 34 | tf.estimator = tfe 35 | tf.compat.v1.enable_v2_behavior() 36 | 37 | # We put doctest after absltest so that it picks up the unittest monkeypatch. 38 | # Otherwise doctest tests aren't runnable at all. 39 | import doctest # pylint: disable=g-import-not-at-top, g-bad-import-order 40 | 41 | FLAGS = flags.FLAGS 42 | 43 | flags.DEFINE_string('module', None, 'A specific module to run doctest on.') 44 | flags.DEFINE_boolean('list', None, 45 | 'List all the modules in the core package imported.') 46 | flags.DEFINE_string('file', None, 'A specific file to run doctest on.') 47 | 48 | flags.mark_flags_as_mutual_exclusive(['module', 'file']) 49 | flags.mark_flags_as_mutual_exclusive(['list', 'file']) 50 | 51 | PACKAGE = 'tensorflow_estimator.python.' 52 | 53 | 54 | def find_modules(): 55 | """Finds all the modules in the core package imported. 56 | 57 | Returns: 58 | A list containing all the modules in tensorflow.python. 59 | """ 60 | 61 | tf_modules = [] 62 | for name, module in sys.modules.items(): 63 | if name.startswith(PACKAGE): 64 | tf_modules.append(module) 65 | 66 | return tf_modules 67 | 68 | 69 | def filter_on_submodules(all_modules, submodule): 70 | """Filters all the modules based on the module flag. 71 | 72 | The module flag has to be relative to the core package imported. 73 | For example, if `submodule=keras.layers` then, this function will return 74 | all the modules in the submodule. 75 | 76 | Args: 77 | all_modules: All the modules in the core package. 78 | submodule: Submodule to filter from all the modules. 79 | 80 | Returns: 81 | All the modules in the submodule. 82 | """ 83 | 84 | filtered_modules = [ 85 | mod for mod in all_modules if PACKAGE + submodule in mod.__name__ 86 | ] 87 | return filtered_modules 88 | 89 | 90 | def get_module_and_inject_docstring(file_path): 91 | """Replaces the docstring of the module with the changed file's content. 92 | 93 | Args: 94 | file_path: Path to the file 95 | 96 | Returns: 97 | A list containing the module changed by the file. 98 | """ 99 | 100 | file_path = os.path.abspath(file_path) 101 | mod_index = file_path.find(PACKAGE.replace('.', os.sep)) 102 | file_mod_name, _ = os.path.splitext(file_path[mod_index:]) 103 | file_module = sys.modules[file_mod_name.replace(os.sep, '.')] 104 | 105 | with open(file_path, 'r') as f: 106 | content = f.read() 107 | 108 | file_module.__doc__ = content 109 | 110 | return [file_module] 111 | 112 | 113 | class TfTestCase(tf.test.TestCase): 114 | 115 | def set_up(self, test): 116 | self.setUp() 117 | 118 | def tear_down(self, test): 119 | self.tearDown() 120 | 121 | 122 | class CustomOutputChecker(doctest.OutputChecker): 123 | """Changes the `want` and `got` strings. 124 | 125 | This allows it to be customized before they are compared. 126 | """ 127 | ID_RE = re.compile(r'\bid=(\d+)\b') 128 | ADDRESS_RE = re.compile(r'\bat 0x[0-9a-f]*?>') 129 | 130 | def check_output(self, want, got, optionflags): 131 | # Replace tf.Tensor's id with ellipsis(...) because tensor's id can change 132 | # on each execution. Users may forget to use ellipsis while writing 133 | # examples in docstrings, so replacing the id with `...` makes it safe. 134 | want = self.ID_RE.sub('id=...', want) 135 | want = self.ADDRESS_RE.sub('at ...>', want) 136 | return doctest.OutputChecker.check_output(self, want, got, optionflags) 137 | 138 | _MESSAGE = textwrap.dedent("""\n 139 | ############################################################# 140 | Check the documentation 141 | (go/testable-docstrings) on how to write testable docstrings. 142 | #############################################################""") 143 | 144 | def output_difference(self, example, got, optionflags): 145 | got = got + self._MESSAGE 146 | return doctest.OutputChecker.output_difference(self, example, got, 147 | optionflags) 148 | 149 | 150 | def load_tests(unused_loader, tests, unused_ignore): 151 | """Loads all the tests in the docstrings and runs them.""" 152 | 153 | tf_modules = find_modules() 154 | 155 | if FLAGS.module: 156 | tf_modules = filter_on_submodules(tf_modules, FLAGS.module) 157 | 158 | if FLAGS.list: 159 | print('**************************************************') 160 | for mod in tf_modules: 161 | print(mod.__name__) 162 | print('**************************************************') 163 | return tests 164 | 165 | if FLAGS.file: 166 | tf_modules = get_module_and_inject_docstring(FLAGS.file) 167 | 168 | for module in tf_modules: 169 | testcase = TfTestCase() 170 | tests.addTests( 171 | doctest.DocTestSuite( 172 | module, 173 | test_finder=doctest.DocTestFinder(exclude_empty=False), 174 | extraglobs={ 175 | 'tf': tf, 176 | 'np': np, 177 | 'os': os 178 | }, 179 | setUp=testcase.set_up, 180 | tearDown=testcase.tear_down, 181 | checker=CustomOutputChecker(), 182 | optionflags=(doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE 183 | | doctest.IGNORE_EXCEPTION_DETAIL 184 | | doctest.DONT_ACCEPT_BLANKLINE), 185 | )) 186 | return tests 187 | 188 | 189 | if __name__ == '__main__': 190 | absltest.main() 191 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/gc_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for garbage collection utilities.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | 24 | from six.moves import xrange # pylint: disable=redefined-builtin 25 | import tensorflow as tf 26 | from tensorflow.python.platform import gfile 27 | from tensorflow_estimator.python.estimator import gc 28 | 29 | 30 | def _create_parser(base_dir): 31 | # create a simple parser that pulls the export_version from the directory. 32 | def parser(path): 33 | # Modify the path object for RegEx match for Windows Paths 34 | if os.name == "nt": 35 | match = re.match( 36 | "^" + tf.compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$", 37 | tf.compat.as_str_any(path.path).replace("\\", "/")) 38 | else: 39 | match = re.match("^" + tf.compat.as_str_any(base_dir) + "/(\\d+)$", 40 | tf.compat.as_str_any(path.path)) 41 | if not match: 42 | return None 43 | return path._replace(export_version=int(match.group(1))) 44 | 45 | return parser 46 | 47 | 48 | class GcTest(tf.test.TestCase): 49 | 50 | def testLargestExportVersions(self): 51 | paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] 52 | newest = gc._largest_export_versions(2) 53 | n = newest(paths) 54 | self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) 55 | 56 | def testLargestExportVersionsDoesNotDeleteZeroFolder(self): 57 | paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] 58 | newest = gc._largest_export_versions(2) 59 | n = newest(paths) 60 | self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)]) 61 | 62 | def testModExportVersion(self): 63 | paths = [ 64 | gc.Path("/foo", 4), 65 | gc.Path("/foo", 5), 66 | gc.Path("/foo", 6), 67 | gc.Path("/foo", 9) 68 | ] 69 | mod = gc._mod_export_version(2) 70 | self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) 71 | mod = gc._mod_export_version(3) 72 | self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) 73 | 74 | def testOneOfEveryNExportVersions(self): 75 | paths = [ 76 | gc.Path("/foo", 0), 77 | gc.Path("/foo", 1), 78 | gc.Path("/foo", 3), 79 | gc.Path("/foo", 5), 80 | gc.Path("/foo", 6), 81 | gc.Path("/foo", 7), 82 | gc.Path("/foo", 8), 83 | gc.Path("/foo", 33) 84 | ] 85 | one_of = gc._one_of_every_n_export_versions(3) 86 | self.assertEqual( 87 | one_of(paths), [ 88 | gc.Path("/foo", 3), 89 | gc.Path("/foo", 6), 90 | gc.Path("/foo", 8), 91 | gc.Path("/foo", 33) 92 | ]) 93 | 94 | def testOneOfEveryNExportVersionsZero(self): 95 | # Zero is a special case since it gets rolled into the first interval. 96 | # Test that here. 97 | paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] 98 | one_of = gc._one_of_every_n_export_versions(3) 99 | self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)]) 100 | 101 | def testUnion(self): 102 | paths = [] 103 | for i in xrange(10): 104 | paths.append(gc.Path("/foo", i)) 105 | f = gc._union(gc._largest_export_versions(3), gc._mod_export_version(3)) 106 | self.assertEqual( 107 | f(paths), [ 108 | gc.Path("/foo", 0), 109 | gc.Path("/foo", 3), 110 | gc.Path("/foo", 6), 111 | gc.Path("/foo", 7), 112 | gc.Path("/foo", 8), 113 | gc.Path("/foo", 9) 114 | ]) 115 | 116 | def testNegation(self): 117 | paths = [ 118 | gc.Path("/foo", 4), 119 | gc.Path("/foo", 5), 120 | gc.Path("/foo", 6), 121 | gc.Path("/foo", 9) 122 | ] 123 | mod = gc._negation(gc._mod_export_version(2)) 124 | self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) 125 | mod = gc._negation(gc._mod_export_version(3)) 126 | self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) 127 | 128 | def testPathsWithParse(self): 129 | base_dir = os.path.join(tf.compat.v1.test.get_temp_dir(), "paths_parse") 130 | self.assertFalse(tf.compat.v1.gfile.Exists(base_dir)) 131 | for p in xrange(3): 132 | tf.compat.v1.gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) 133 | # add a base_directory to ignore 134 | tf.compat.v1.gfile.MakeDirs(os.path.join(base_dir, "ignore")) 135 | 136 | self.assertEqual( 137 | gc._get_paths(base_dir, _create_parser(base_dir)), [ 138 | gc.Path(os.path.join(base_dir, "0"), 0), 139 | gc.Path(os.path.join(base_dir, "1"), 1), 140 | gc.Path(os.path.join(base_dir, "2"), 2) 141 | ]) 142 | tf.compat.v1.gfile.DeleteRecursively(base_dir) 143 | 144 | def testMixedStrTypes(self): 145 | temp_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir()) 146 | 147 | for sub_dir in ["str", b"bytes", u"unicode"]: 148 | base_dir = os.path.join( 149 | (temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()), 150 | sub_dir) 151 | self.assertFalse(tf.compat.v1.gfile.Exists(base_dir)) 152 | tf.compat.v1.gfile.MakeDirs( 153 | os.path.join(tf.compat.as_str_any(base_dir), "42")) 154 | gc._get_paths(base_dir, _create_parser(base_dir)) 155 | tf.compat.v1.gfile.DeleteRecursively(base_dir) 156 | 157 | def testGcsDirWithSeparator(self): 158 | base_dir = "gs://bucket/foo" 159 | with tf.compat.v1.test.mock.patch.object( 160 | gfile, "ListDirectory") as mock_list_directory: 161 | # gfile.ListDirectory returns directory names with separator '/' 162 | mock_list_directory.return_value = ["0/", "1/"] 163 | self.assertEqual( 164 | gc._get_paths(base_dir, _create_parser(base_dir)), [ 165 | gc.Path(os.path.join(base_dir, "0"), 0), 166 | gc.Path(os.path.join(base_dir, "1"), 1) 167 | ]) 168 | 169 | 170 | if __name__ == "__main__": 171 | tf.test.main() 172 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/v1/dnn_estimator_test_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for DNNEstimator.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import shutil 22 | import tempfile 23 | 24 | import numpy as np 25 | import six 26 | import tensorflow as tf 27 | from tensorflow.python.framework import test_util 28 | from tensorflow_estimator.python.estimator.canned import dnn 29 | from tensorflow_estimator.python.estimator.canned import head as head_lib 30 | from tensorflow_estimator.python.estimator.canned import prediction_keys 31 | from tensorflow_estimator.python.estimator.canned.v1 import dnn_testing_utils_v1 32 | from tensorflow_estimator.python.estimator.export import export 33 | from tensorflow_estimator.python.estimator.inputs import numpy_io 34 | 35 | 36 | def _dnn_estimator_fn(weight_column=None, label_dimension=1, **kwargs): 37 | """Returns a DNNEstimator that uses regression_head.""" 38 | return dnn.DNNEstimator( 39 | head=head_lib._regression_head( 40 | weight_column=weight_column, 41 | label_dimension=label_dimension, 42 | # Tests in core (from which this test inherits) test the sum loss. 43 | loss_reduction=tf.compat.v1.losses.Reduction.SUM), 44 | **kwargs) 45 | 46 | 47 | def _dnn_estimator_classifier_fn(n_classes=3, **kwargs): 48 | return dnn.DNNEstimator( 49 | head=head_lib._multi_class_head_with_softmax_cross_entropy_loss( 50 | n_classes=n_classes), 51 | **kwargs) 52 | 53 | 54 | @test_util.run_v1_only('Tests v1 only symbols') 55 | class DNNEstimatorEvaluateTest( 56 | dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest, tf.test.TestCase): 57 | 58 | def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 59 | tf.test.TestCase.__init__(self, methodName) 60 | dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__( 61 | self, _dnn_estimator_fn) 62 | 63 | 64 | @test_util.run_v1_only('Tests v1 only symbols') 65 | class DNNEstimatorPredictTest(dnn_testing_utils_v1.BaseDNNRegressorPredictTest, 66 | tf.test.TestCase): 67 | 68 | def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 69 | tf.test.TestCase.__init__(self, methodName) 70 | dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__( 71 | self, _dnn_estimator_fn) 72 | 73 | 74 | @test_util.run_v1_only('Tests v1 only symbols') 75 | class DNNEstimatorTrainTest(dnn_testing_utils_v1.BaseDNNRegressorTrainTest, 76 | tf.test.TestCase): 77 | 78 | def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 79 | tf.test.TestCase.__init__(self, methodName) 80 | dnn_testing_utils_v1.BaseDNNRegressorTrainTest.__init__( 81 | self, _dnn_estimator_fn) 82 | 83 | 84 | @test_util.run_v1_only('Tests v1 only symbols') 85 | class DNNEstimatorWarmStartingTest(dnn_testing_utils_v1.BaseDNNWarmStartingTest, 86 | tf.test.TestCase): 87 | 88 | def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 89 | tf.test.TestCase.__init__(self, methodName) 90 | dnn_testing_utils_v1.BaseDNNWarmStartingTest.__init__( 91 | self, _dnn_estimator_classifier_fn, _dnn_estimator_fn) 92 | 93 | 94 | @test_util.run_v1_only('Tests v1 only symbols') 95 | class DNNEstimatorIntegrationTest(tf.test.TestCase): 96 | 97 | def setUp(self): 98 | self._model_dir = tempfile.mkdtemp() 99 | 100 | def tearDown(self): 101 | if self._model_dir: 102 | tf.compat.v1.summary.FileWriterCache.clear() 103 | shutil.rmtree(self._model_dir) 104 | 105 | def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, 106 | input_dimension, label_dimension, batch_size): 107 | feature_columns = [ 108 | tf.feature_column.numeric_column('x', shape=(input_dimension,)) 109 | ] 110 | est = dnn.DNNEstimator( 111 | head=head_lib._regression_head(label_dimension=label_dimension), 112 | hidden_units=(2, 2), 113 | feature_columns=feature_columns, 114 | model_dir=self._model_dir) 115 | 116 | # Train 117 | num_steps = 10 118 | est.train(train_input_fn, steps=num_steps) 119 | 120 | # Evaluate 121 | scores = est.evaluate(eval_input_fn) 122 | self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP]) 123 | self.assertIn('loss', six.iterkeys(scores)) 124 | 125 | # Predict 126 | predictions = np.array([ 127 | x[prediction_keys.PredictionKeys.PREDICTIONS] 128 | for x in est.predict(predict_input_fn) 129 | ]) 130 | self.assertAllEqual((batch_size, label_dimension), predictions.shape) 131 | 132 | # Export 133 | feature_spec = tf.compat.v1.feature_column.make_parse_example_spec( 134 | feature_columns) 135 | serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( 136 | feature_spec) 137 | export_dir = est.export_saved_model(tempfile.mkdtemp(), 138 | serving_input_receiver_fn) 139 | self.assertTrue(tf.compat.v1.gfile.Exists(export_dir)) 140 | 141 | def test_numpy_input_fn(self): 142 | """Tests complete flow with numpy_input_fn.""" 143 | label_dimension = 2 144 | batch_size = 10 145 | data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) 146 | data = data.reshape(batch_size, label_dimension) 147 | # learn y = x 148 | train_input_fn = numpy_io.numpy_input_fn( 149 | x={'x': data}, 150 | y=data, 151 | batch_size=batch_size, 152 | num_epochs=None, 153 | shuffle=True) 154 | eval_input_fn = numpy_io.numpy_input_fn( 155 | x={'x': data}, y=data, batch_size=batch_size, shuffle=False) 156 | predict_input_fn = numpy_io.numpy_input_fn( 157 | x={'x': data}, batch_size=batch_size, shuffle=False) 158 | 159 | self._test_complete_flow( 160 | train_input_fn=train_input_fn, 161 | eval_input_fn=eval_input_fn, 162 | predict_input_fn=predict_input_fn, 163 | input_dimension=label_dimension, 164 | label_dimension=label_dimension, 165 | batch_size=batch_size) 166 | 167 | 168 | if __name__ == '__main__': 169 | tf.test.main() 170 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/canned/optimizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Methods related to optimizers used in canned_estimators.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import inspect 22 | from absl import logging 23 | import six 24 | import tensorflow as tf 25 | from tensorflow_estimator.python.estimator.util import tf_keras 26 | 27 | _OPTIMIZER_CLS_NAMES = { 28 | 'Adagrad': tf.compat.v1.train.AdagradOptimizer, 29 | 'Adam': tf.compat.v1.train.AdamOptimizer, 30 | 'Ftrl': tf.compat.v1.train.FtrlOptimizer, 31 | 'RMSProp': tf.compat.v1.train.RMSPropOptimizer, 32 | 'SGD': tf.compat.v1.train.GradientDescentOptimizer, 33 | } 34 | 35 | _OPTIMIZER_CLS_NAMES_V2 = { 36 | 'Adagrad': tf_keras.optimizers.legacy.Adagrad, 37 | 'Adam': tf_keras.optimizers.legacy.Adam, 38 | 'Ftrl': tf_keras.optimizers.legacy.Ftrl, 39 | 'RMSProp': tf_keras.optimizers.legacy.RMSprop, 40 | 'SGD': tf_keras.optimizers.legacy.SGD, 41 | } 42 | 43 | # The default learning rate of 0.05 is a historical artifact of the initial 44 | # implementation, but seems a reasonable choice. 45 | _LEARNING_RATE = 0.05 46 | 47 | 48 | def get_optimizer_instance(opt, learning_rate=None): 49 | """Returns an optimizer instance. 50 | 51 | Supports the following types for the given `opt`: 52 | * An `Optimizer` instance: Returns the given `opt`. 53 | * A string: Creates an `Optimizer` subclass with the given `learning_rate`. 54 | Supported strings: 55 | * 'Adagrad': Returns an `AdagradOptimizer`. 56 | * 'Adam': Returns an `AdamOptimizer`. 57 | * 'Ftrl': Returns an `FtrlOptimizer`. 58 | * 'RMSProp': Returns an `RMSPropOptimizer`. 59 | * 'SGD': Returns a `GradientDescentOptimizer`. 60 | 61 | Args: 62 | opt: An `Optimizer` instance, or string, as discussed above. 63 | learning_rate: A float. Only used if `opt` is a string. 64 | 65 | Returns: 66 | An `Optimizer` instance. 67 | 68 | Raises: 69 | ValueError: If `opt` is an unsupported string. 70 | ValueError: If `opt` is a supported string but `learning_rate` was not 71 | specified. 72 | ValueError: If `opt` is none of the above types. 73 | """ 74 | if isinstance(opt, six.string_types): 75 | if opt in six.iterkeys(_OPTIMIZER_CLS_NAMES): 76 | if not learning_rate: 77 | raise ValueError('learning_rate must be specified when opt is string.') 78 | return _OPTIMIZER_CLS_NAMES[opt](learning_rate=learning_rate) 79 | raise ValueError( 80 | 'Unsupported optimizer name: {}. Supported names are: {}'.format( 81 | opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES))))) 82 | if callable(opt): 83 | opt = opt() 84 | if not isinstance(opt, tf.compat.v1.train.Optimizer): 85 | raise ValueError( 86 | 'The given object is not an Optimizer instance. Given: {}'.format(opt)) 87 | return opt 88 | 89 | 90 | def _optimizer_has_default_learning_rate(opt): 91 | signature = inspect.getfullargspec(opt.__init__) 92 | default_name_to_value = dict(zip(signature.args[::-1], signature.defaults)) 93 | for name in signature.kwonlyargs: 94 | if name in signature.kwonlydefaults: 95 | default_name_to_value[name] = signature.kwonlydefaults[name] 96 | return 'learning_rate' in default_name_to_value 97 | 98 | 99 | def get_optimizer_instance_v2(opt, learning_rate=None): 100 | """Returns an optimizer_v2.OptimizerV2 instance. 101 | 102 | Supports the following types for the given `opt`: 103 | * An `optimizer_v2.OptimizerV2` instance: Returns the given `opt`. 104 | * A string: Creates an `optimizer_v2.OptimizerV2` subclass with the given 105 | `learning_rate`. 106 | Supported strings: 107 | * 'Adagrad': Returns an tf_keras.optimizers.Adagrad. 108 | * 'Adam': Returns an tf_keras.optimizers.Adam. 109 | * 'Ftrl': Returns an tf_keras.optimizers.Ftrl. 110 | * 'RMSProp': Returns an tf_keras.optimizers.RMSProp. 111 | * 'SGD': Returns a tf_keras.optimizers.SGD. 112 | 113 | Args: 114 | opt: An `tf_keras.optimizers.Optimizer` instance, or string, as discussed 115 | above. 116 | learning_rate: A float. Only used if `opt` is a string. If None, and opt is 117 | string, it will use the default learning_rate of the optimizer. 118 | 119 | Returns: 120 | An `tf_keras.optimizers.Optimizer` instance. 121 | 122 | Raises: 123 | ValueError: If `opt` is an unsupported string. 124 | ValueError: If `opt` is a supported string but `learning_rate` was not 125 | specified. 126 | ValueError: If `opt` is none of the above types. 127 | """ 128 | if isinstance(opt, six.string_types): 129 | if opt in six.iterkeys(_OPTIMIZER_CLS_NAMES_V2): 130 | if not learning_rate: 131 | if _optimizer_has_default_learning_rate(_OPTIMIZER_CLS_NAMES_V2[opt]): 132 | return _OPTIMIZER_CLS_NAMES_V2[opt]() 133 | else: 134 | return _OPTIMIZER_CLS_NAMES_V2[opt](learning_rate=_LEARNING_RATE) 135 | return _OPTIMIZER_CLS_NAMES_V2[opt](learning_rate=learning_rate) 136 | raise ValueError( 137 | 'Unsupported optimizer name: {}. Supported names are: {}'.format( 138 | opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES_V2))))) 139 | if callable(opt): 140 | opt = opt() 141 | if isinstance(opt, tf_keras.optimizers.experimental.Optimizer): 142 | if tf.executing_eagerly(): 143 | logging.warning( 144 | 'You are using `tf_keras.optimizers.experimental.Optimizer` in TF ' 145 | 'estimator, which only supports ' 146 | '`tf_keras.optimizers.legacy.Optimizer`. Automatically converting ' 147 | 'your optimizer to `tf_keras.optimizers.legacy.Optimizer`.') 148 | opt = tf_keras.__internal__.optimizers.convert_to_legacy_optimizer(opt) 149 | else: 150 | raise ValueError('Please set your optimizer as an instance of ' 151 | '`tf_keras.optimizers.legacy.Optimizer`, e.g., ' 152 | f'`tf_keras.optimizers.legacy.{opt.__class__.__name__}`.' 153 | f'Received optimizer type: {type(opt)}.') 154 | if not isinstance( 155 | opt, 156 | (tf_keras.optimizers.legacy.Optimizer, tf_keras.optimizers.Optimizer)): 157 | raise ValueError( 158 | 'The given object is not a tf_keras.optimizers.Optimizer instance.' 159 | ' Given: {}'.format(opt)) 160 | return opt 161 | -------------------------------------------------------------------------------- /tensorflow_estimator/python/estimator/gc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""System for specifying garbage collection (GC) of path based data. 16 | 17 | This framework allows for GC of data specified by path names, for example files 18 | on disk. gc.Path objects each represent a single item stored at a path and may 19 | be a base directory, 20 | /tmp/exports/0/... 21 | /tmp/exports/1/... 22 | ... 23 | or a fully qualified file, 24 | /tmp/train-1.ckpt 25 | /tmp/train-2.ckpt 26 | ... 27 | 28 | A gc filter function takes and returns a list of gc.Path items. Filter 29 | functions are responsible for selecting Path items for preservation or deletion. 30 | Note that functions should always return a sorted list. 31 | 32 | For example, 33 | base_dir = "/tmp" 34 | # Create the directories. 35 | for e in xrange(10): 36 | os.mkdir("%s/%d" % (base_dir, e), 0o755) 37 | 38 | # Create a simple parser that pulls the export_version from the directory. 39 | path_regex = "^" + re.escape(base_dir) + "/(\\d+)$" 40 | def parser(path): 41 | match = re.match(path_regex, path.path) 42 | if not match: 43 | return None 44 | return path._replace(export_version=int(match.group(1))) 45 | 46 | path_list = gc._get_paths("/tmp", parser) # contains all ten Paths 47 | 48 | every_fifth = gc._mod_export_version(5) 49 | print(every_fifth(path_list)) # shows ["/tmp/0", "/tmp/5"] 50 | 51 | largest_three = gc.largest_export_versions(3) 52 | print(largest_three(all_paths)) # shows ["/tmp/7", "/tmp/8", "/tmp/9"] 53 | 54 | both = gc._union(every_fifth, largest_three) 55 | print(both(all_paths)) # shows ["/tmp/0", "/tmp/5", 56 | # "/tmp/7", "/tmp/8", "/tmp/9"] 57 | # Delete everything not in 'both'. 58 | to_delete = gc._negation(both) 59 | for p in to_delete(all_paths): 60 | gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2", 61 | # "/tmp/3", "/tmp/4", "/tmp/6", 62 | """ 63 | 64 | from __future__ import absolute_import 65 | from __future__ import division 66 | from __future__ import print_function 67 | import collections 68 | import heapq 69 | import math 70 | import os 71 | import tensorflow as tf 72 | from tensorflow.python.platform import gfile 73 | 74 | Path = collections.namedtuple('Path', 'path export_version') 75 | 76 | 77 | def _largest_export_versions(n): 78 | """Creates a filter that keeps the largest n export versions. 79 | 80 | Args: 81 | n: number of versions to keep. 82 | 83 | Returns: 84 | A filter function that keeps the n largest paths. 85 | """ 86 | 87 | def keep(paths): 88 | heap = [] 89 | for idx, path in enumerate(paths): 90 | if path.export_version is not None: 91 | heapq.heappush(heap, (path.export_version, idx)) 92 | keepers = [paths[i] for _, i in heapq.nlargest(n, heap)] 93 | return sorted(keepers) 94 | 95 | return keep 96 | 97 | 98 | def _one_of_every_n_export_versions(n): 99 | """Creates a filter that keeps one of every n export versions. 100 | 101 | Args: 102 | n: interval size. 103 | 104 | Returns: 105 | A filter function that keeps exactly one path from each interval 106 | [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an 107 | interval the largest is kept. 108 | """ 109 | 110 | def keep(paths): 111 | """A filter function that keeps exactly one out of every n paths.""" 112 | 113 | keeper_map = {} # map from interval to largest path seen in that interval 114 | for p in paths: 115 | if p.export_version is None: 116 | # Skip missing export_versions. 117 | continue 118 | # Find the interval (with a special case to map export_version = 0 to 119 | # interval 0. 120 | interval = math.floor( 121 | (p.export_version - 1) / n) if p.export_version else 0 122 | existing = keeper_map.get(interval, None) 123 | if (not existing) or (existing.export_version < p.export_version): 124 | keeper_map[interval] = p 125 | return sorted(keeper_map.values()) 126 | 127 | return keep 128 | 129 | 130 | def _mod_export_version(n): 131 | """Creates a filter that keeps every export that is a multiple of n. 132 | 133 | Args: 134 | n: step size. 135 | 136 | Returns: 137 | A filter function that keeps paths where export_version % n == 0. 138 | """ 139 | 140 | def keep(paths): 141 | keepers = [] 142 | for p in paths: 143 | if p.export_version % n == 0: 144 | keepers.append(p) 145 | return sorted(keepers) 146 | 147 | return keep 148 | 149 | 150 | def _union(lf, rf): 151 | """Creates a filter that keeps the union of two filters. 152 | 153 | Args: 154 | lf: first filter 155 | rf: second filter 156 | 157 | Returns: 158 | A filter function that keeps the n largest paths. 159 | """ 160 | 161 | def keep(paths): 162 | l = set(lf(paths)) 163 | r = set(rf(paths)) 164 | return sorted(list(l | r)) 165 | 166 | return keep 167 | 168 | 169 | def _negation(f): 170 | """Negate a filter. 171 | 172 | Args: 173 | f: filter function to invert 174 | 175 | Returns: 176 | A filter function that returns the negation of f. 177 | """ 178 | 179 | def keep(paths): 180 | l = set(paths) 181 | r = set(f(paths)) 182 | return sorted(list(l - r)) 183 | 184 | return keep 185 | 186 | 187 | def _get_paths(base_dir, parser): 188 | """Gets a list of Paths in a given directory. 189 | 190 | Args: 191 | base_dir: directory. 192 | parser: a function which gets the raw Path and can augment it with 193 | information such as the export_version, or ignore the path by returning 194 | None. An example parser may extract the export version from a path such 195 | as "/tmp/exports/100" an another may extract from a full file name such as 196 | "/tmp/checkpoint-99.out". 197 | 198 | Returns: 199 | A list of Paths contained in the base directory with the parsing function 200 | applied. 201 | By default the following fields are populated, 202 | - Path.path 203 | The parsing function is responsible for populating, 204 | - Path.export_version 205 | """ 206 | # We are mocking this in the test, hence we should not use public API 207 | raw_paths = gfile.ListDirectory(base_dir) 208 | paths = [] 209 | for r in raw_paths: 210 | # ListDirectory() return paths with "/" at the last if base_dir was GCS URL 211 | r = tf.compat.as_str_any(r) 212 | if r[-1] == '/': 213 | r = r[0:len(r) - 1] 214 | p = parser(Path(os.path.join(tf.compat.as_str_any(base_dir), r), None)) 215 | if p: 216 | paths.append(p) 217 | return sorted(paths) 218 | --------------------------------------------------------------------------------