├── g3doc ├── images │ ├── stats.png │ ├── anomaly.png │ ├── schema.png │ ├── skew_anomaly.png │ └── serving_anomaly.png ├── _toc.yaml └── index.md ├── tensorflow_data_validation ├── BUILD ├── anomalies │ ├── proto │ │ ├── BUILD │ │ ├── validation_config.proto │ │ └── feature_statistics_to_proto.proto │ ├── __init__.py │ ├── test_schema_protos.h │ ├── float_domain_util.h │ ├── int_domain_util.h │ ├── test_util_test.cc │ ├── metrics.h │ ├── bool_domain_util.h │ ├── metrics.cc │ ├── string_domain_util.h │ ├── internal_types.h │ ├── validation_api.i │ ├── test_util.cc │ ├── map_util.h │ ├── statistics_view_test_util.h │ ├── metrics_test.cc │ ├── map_util.cc │ ├── float_domain_test.cc │ ├── path.h │ ├── feature_util.h │ ├── test_util.h │ ├── statistics_view_test_util.cc │ ├── path_test.cc │ ├── schema_anomalies.h │ ├── float_domain_util.cc │ ├── int_domain_util.cc │ ├── map_util_test.cc │ └── path.cc ├── api │ └── __init__.py ├── coders │ ├── __init__.py │ ├── tf_example_decoder.py │ └── tf_example_decoder_test.py ├── utils │ ├── __init__.py │ ├── stats_util.py │ ├── batch_util.py │ ├── batch_util_test.py │ ├── profile_util_test.py │ ├── stats_util_test.py │ ├── schema_util.py │ ├── profile_util.py │ ├── test_util.py │ ├── quantiles_util_test.py │ └── schema_util_test.py ├── statistics │ ├── __init__.py │ ├── generators │ │ ├── __init__.py │ │ ├── stats_generator.py │ │ ├── uniques_stats_generator.py │ │ └── string_stats_generator.py │ ├── stats_impl_test.py │ └── stats_impl.py ├── workspace.bzl ├── types.py ├── version.py ├── types_compat.py ├── build_pip_package.sh ├── repo.bzl ├── data_validation.bzl └── __init__.py ├── RELEASE.md ├── .gitignore ├── CONTRIBUTING.md ├── WORKSPACE ├── setup.py └── README.md /g3doc/images/stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/data-validation/master/g3doc/images/stats.png -------------------------------------------------------------------------------- /g3doc/images/anomaly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/data-validation/master/g3doc/images/anomaly.png -------------------------------------------------------------------------------- /g3doc/images/schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/data-validation/master/g3doc/images/schema.png -------------------------------------------------------------------------------- /g3doc/images/skew_anomaly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/data-validation/master/g3doc/images/skew_anomaly.png -------------------------------------------------------------------------------- /g3doc/images/serving_anomaly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/data-validation/master/g3doc/images/serving_anomaly.png -------------------------------------------------------------------------------- /tensorflow_data_validation/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | sh_binary( 4 | name = "build_pip_package", 5 | srcs = ["build_pip_package.sh"], 6 | data = [ 7 | "//tensorflow_data_validation/anomalies:_pywrap_tensorflow_data_validation.so", 8 | "//tensorflow_data_validation/anomalies:pywrap_tensorflow_data_validation.py", 9 | ], 10 | ) 11 | -------------------------------------------------------------------------------- /g3doc/_toc.yaml: -------------------------------------------------------------------------------- 1 | toc: 2 | - title: Get Started 3 | path: /tfx/data_validation/get_started 4 | 5 | - heading: Examples 6 | - title: Chicago Taxi 7 | path: https://github.com/tensorflow/data-validation/blob/master/examples/chicago_taxi/chicago_taxi_tfdv.ipynb 8 | status: external 9 | - title: Chicago Taxi (end-to-end) 10 | path: https://github.com/tensorflow/model-analysis/tree/master/examples/chicago_taxi 11 | status: external 12 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Current version (not yet released; still in development) 2 | 3 | ## Major Features and Improvements 4 | 5 | * Add support for computing weighted common statistics. 6 | 7 | ## Bug Fixes and Other Changes 8 | 9 | * Fix bug in clearing oneof domain\_info field in Feature proto. 10 | * Fix overflow error for large integers by casting them to STRING type. 11 | 12 | ## Breaking changes 13 | 14 | ## Deprecations 15 | 16 | # Release 0.9.0 17 | 18 | * Initial release of TensorFlow Data Validation. 19 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/proto/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//tensorflow_data_validation:__subpackages__"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | load("//tensorflow_data_validation:data_validation.bzl", "tfdv_proto_library") 8 | 9 | tfdv_proto_library( 10 | name = "feature_statistics_to_proto_proto", 11 | srcs = ["feature_statistics_to_proto.proto"], 12 | cc_api_version = 2, 13 | ) 14 | 15 | tfdv_proto_library( 16 | name = "validation_config_proto", 17 | srcs = ["validation_config.proto"], 18 | cc_api_version = 2, 19 | ) 20 | -------------------------------------------------------------------------------- /tensorflow_data_validation/api/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_data_validation/coders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_data_validation/statistics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_data_validation/statistics/generators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .ipynb_checkpoints 3 | node_modules 4 | /.bazelrc 5 | /.tf_configure.bazelrc 6 | /bazel-* 7 | /bazel_pip 8 | /tools/python_bin_path.sh 9 | /pip_test 10 | /_python_build 11 | *.pyc 12 | __pycache__ 13 | *.swp 14 | .vscode/ 15 | cmake_build/ 16 | .idea/** 17 | /build/ 18 | [Bb]uild/ 19 | Pods 20 | Podfile.lock 21 | *.pbxproj 22 | *.xcworkspacedata 23 | xcuserdata/** 24 | dist/ 25 | tensorflow_data_validation.egg-info/ 26 | tensorflow_data_validation/anomalies/_pywrap_tensorflow_data_validation.so 27 | tensorflow_data_validation/anomalies/pywrap_tensorflow_data_validation.py 28 | 29 | # Android 30 | .gradle 31 | .idea 32 | .project 33 | *.iml 34 | local.properties 35 | gradleBuild 36 | -------------------------------------------------------------------------------- /tensorflow_data_validation/workspace.bzl: -------------------------------------------------------------------------------- 1 | """TensorFlow Data Validation external dependencies that can be loaded in WORKSPACE files. 2 | """ 3 | 4 | load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") 5 | 6 | def tf_data_validation_workspace(): 7 | """All TensorFlow Data Validation external dependencies.""" 8 | tf_workspace( 9 | path_prefix = "", 10 | tf_repo_name = "org_tensorflow", 11 | ) 12 | 13 | # Fetch tf.Metadata repo from GitHub. 14 | native.git_repository( 15 | name = "com_github_tensorflow_metadata", 16 | # v0.9.0dev 17 | commit = "223923d04c75de71ae782c51872d0e14ce7e657d", 18 | remote = "https://github.com/tensorflow/metadata.git", 19 | ) 20 | -------------------------------------------------------------------------------- /tensorflow_data_validation/types.py: -------------------------------------------------------------------------------- 1 | """Types.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | 6 | from __future__ import print_function 7 | 8 | import apache_beam as beam 9 | import numpy as np 10 | 11 | from tensorflow_data_validation.types_compat import Dict, Text, Union 12 | 13 | FeatureName = Union[bytes, Text] 14 | 15 | # Feature type enum value. 16 | FeatureNameStatisticsType = int 17 | 18 | # Type of the input batch. 19 | ExampleBatch = Dict[FeatureName, np.ndarray] 20 | 21 | # For use in Beam type annotations, because Beam's support for Python types 22 | # in Beam type annotations is not complete. 23 | BeamFeatureName = beam.typehints.Union[bytes, Text] 24 | # pylint: enable=invalid-name 25 | -------------------------------------------------------------------------------- /tensorflow_data_validation/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains the version string of TFDV.""" 16 | 17 | # Note that setup.py uses this version. 18 | __version__ = '0.9.0' 19 | -------------------------------------------------------------------------------- /tensorflow_data_validation/types_compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Types for backwards compatibility with versions that don't support typing.""" 15 | 16 | 17 | from apache_beam.typehints import Any, Dict, Generator, List, Optional, Set, Tuple, Union # pylint: disable=unused-import,g-multiple-import 18 | 19 | # pylint: disable=invalid-name 20 | Callable = None 21 | Text = Any 22 | TypeVar = None 23 | 24 | # pylint: enable=invalid-name 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/proto/validation_config.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | syntax = "proto3"; 17 | 18 | package tensorflow.data_validation; 19 | 20 | // Configuration for example statistics validation. 21 | message ValidationConfig { 22 | // If true then validation will mark new features (i.e., those that are not 23 | // covered in the schema) as warnings instead of errors. The distinction is 24 | // that warnings do not cause alerts to fire. 25 | bool new_features_are_warnings = 1; 26 | } 27 | -------------------------------------------------------------------------------- /tensorflow_data_validation/build_pip_package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2018 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Convenience binary to build TFDV from source. 17 | 18 | # Put wrapped c++ files in place 19 | 20 | set -u -x 21 | 22 | cp -f tensorflow_data_validation/anomalies/pywrap_tensorflow_data_validation.py \ 23 | ${BUILD_WORKSPACE_DIRECTORY}/tensorflow_data_validation/anomalies 24 | cp -f tensorflow_data_validation/anomalies/_pywrap_tensorflow_data_validation.so \ 25 | ${BUILD_WORKSPACE_DIRECTORY}/tensorflow_data_validation/anomalies 26 | 27 | # Create the wheel 28 | cd ${BUILD_WORKSPACE_DIRECTORY} 29 | 30 | python setup.py bdist_wheel 31 | 32 | # Cleanup 33 | cd - 34 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/test_schema_protos.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_TEST_SCHEMA_PROTOS_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_TEST_SCHEMA_PROTOS_H_ 18 | 19 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 20 | 21 | namespace tensorflow { 22 | namespace data_validation { 23 | namespace testing { 24 | 25 | tensorflow::metadata::v0::Schema GetTestAllTypesMessage(); 26 | tensorflow::metadata::v0::Schema GetAnnotatedFieldsMessage(); 27 | tensorflow::metadata::v0::Schema GetTestSchemaAlone(); 28 | 29 | } // namespace testing 30 | } // namespace data_validation 31 | } // namespace tensorflow 32 | 33 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_TEST_SCHEMA_PROTOS_H_ 34 | -------------------------------------------------------------------------------- /tensorflow_data_validation/repo.bzl: -------------------------------------------------------------------------------- 1 | """ TensorFlow Http Archive 2 | 3 | Modified http_arhive that allows us to override the TensorFlow commit that is 4 | downloaded by setting an environment variable. This override is to be used for 5 | testing purposes. 6 | 7 | Add the following to your Bazel build command in order to override the 8 | TensorFlow revision. 9 | 10 | build: --action_env TF_REVISION="" 11 | 12 | * `TF_REVISION`: tensorflow revision override (git commit hash) 13 | """ 14 | 15 | _TF_REVISION = "TF_REVISION" 16 | 17 | def _tensorflow_http_archive(ctx): 18 | git_commit = ctx.attr.git_commit 19 | sha256 = ctx.attr.sha256 20 | 21 | override_git_commit = ctx.os.environ.get(_TF_REVISION) 22 | if override_git_commit: 23 | sha256 = "" 24 | git_commit = override_git_commit 25 | 26 | strip_prefix = "tensorflow-%s" % git_commit 27 | urls = [ 28 | "https://mirror.bazel.build/github.com/tensorflow/tensorflow/archive/%s.tar.gz" % git_commit, 29 | "https://github.com/tensorflow/tensorflow/archive/%s.tar.gz" % git_commit, 30 | ] 31 | ctx.download_and_extract( 32 | urls, 33 | "", 34 | sha256, 35 | "", 36 | strip_prefix) 37 | 38 | tensorflow_http_archive = repository_rule( 39 | implementation=_tensorflow_http_archive, 40 | attrs={ 41 | "git_commit": attr.string(mandatory=True), 42 | "sha256": attr.string(mandatory=True), 43 | }) 44 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/float_domain_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_FLOAT_DOMAIN_UTIL_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_FLOAT_DOMAIN_UTIL_H_ 18 | 19 | #include "tensorflow_data_validation/anomalies/internal_types.h" 20 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 21 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 22 | 23 | namespace tensorflow { 24 | namespace data_validation { 25 | 26 | // Updates the float_domain based upon the range of values in , be they 27 | // STRING or FLOAT. 28 | // Will recommend the field be cleared if the type is STRING or BYTES but 29 | // the strings do not represent floats. Undefined behavior if the data is INT. 30 | UpdateSummary UpdateFloatDomain( 31 | const FeatureStatsView& stats, 32 | tensorflow::metadata::v0::FloatDomain* float_domain); 33 | 34 | // Returns true if feature_stats is a STRING field has only floats and no 35 | // non-UTF8 strings. 36 | bool IsFloatDomainCandidate(const FeatureStatsView& feature_stats); 37 | 38 | } // namespace data_validation 39 | } // namespace tensorflow 40 | 41 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_FLOAT_DOMAIN_UTIL_H_ 42 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/int_domain_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_INT_DOMAIN_UTIL_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_INT_DOMAIN_UTIL_H_ 18 | 19 | #include "tensorflow_data_validation/anomalies/internal_types.h" 20 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 21 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 22 | 23 | namespace tensorflow { 24 | namespace data_validation { 25 | 26 | // Updates the float_domain based upon the range of values in , be they 27 | // STRING or INT. 28 | // Will recommend the field be cleared if the type is STRING or BYTES but 29 | // the strings do not represent floats. Undefined behavior if the data is FLOAT. 30 | UpdateSummary UpdateIntDomain(const FeatureStatsView& feature_stats, 31 | tensorflow::metadata::v0::IntDomain* int_domain); 32 | 33 | // Returns true if feature_stats is a STRING field has only floats and no 34 | // non-UTF8 strings. 35 | bool IsIntDomainCandidate(const FeatureStatsView& feature_stats); 36 | 37 | } // namespace data_validation 38 | } // namespace tensorflow 39 | 40 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_INT_DOMAIN_UTIL_H_ 41 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/test_util_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/test_util.h" 17 | 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | #include "absl/strings/str_split.h" 24 | #include "tensorflow/core/lib/core/status.h" 25 | #include "tensorflow/core/lib/core/status_test_util.h" 26 | #include "tensorflow/core/platform/logging.h" 27 | #include "tensorflow/core/platform/types.h" 28 | 29 | namespace tensorflow { 30 | namespace data_validation { 31 | namespace testing { 32 | namespace { 33 | 34 | 35 | TEST(TestAnomalies, Basic) { 36 | const tensorflow::metadata::v0::Schema original = 37 | ParseTextProtoOrDie(R"( 38 | feature { 39 | name: "feature_name" 40 | type: INT 41 | skew_comparator: { infinity_norm: { threshold: 0.1 } } 42 | })"); 43 | 44 | tensorflow::metadata::v0::Anomalies result; 45 | *result.mutable_baseline() = original; 46 | TestAnomalies(result, original, std::map()); 47 | } 48 | 49 | } // namespace 50 | } // namespace testing 51 | } // namespace data_validation 52 | } // namespace tensorflow 53 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/metrics.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_METRICS_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_METRICS_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 24 | #include "tensorflow/core/platform/types.h" 25 | 26 | namespace tensorflow { 27 | namespace data_validation { 28 | 29 | // Computes the L-infinity distance between the (weighted) histograms of the 30 | // features. 31 | // Only takes into account how many times the feature are present, 32 | // and scales the histograms so that they sum to 1. 33 | // The first value returned is the element with highest deviation, and 34 | // the second value returned is the L infinity distance itself. 35 | std::pair LInftyDistance(const FeatureStatsView& a, 36 | const FeatureStatsView& b); 37 | 38 | std::pair LInftyDistance( 39 | const std::map& counts_a, 40 | const std::map& counts_b); 41 | 42 | } // namespace data_validation 43 | } // namespace tensorflow 44 | 45 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_METRICS_H_ 46 | -------------------------------------------------------------------------------- /tensorflow_data_validation/data_validation.bzl: -------------------------------------------------------------------------------- 1 | load("@protobuf_archive//:protobuf.bzl", "cc_proto_library") 2 | load("@protobuf_archive//:protobuf.bzl", "py_proto_library") 3 | 4 | def tfdv_proto_library(name, srcs=[], has_services=False, 5 | deps=[], visibility=None, testonly=0, 6 | cc_grpc_version = None, 7 | cc_api_version=2, go_api_version=2, 8 | java_api_version=2, js_api_version=2, 9 | py_api_version=2): 10 | """Opensource cc_proto_library.""" 11 | _ignore = [has_services, cc_api_version, go_api_version, java_api_version, js_api_version, py_api_version] 12 | native.filegroup(name=name + "_proto_srcs", 13 | srcs=srcs, 14 | testonly=testonly,) 15 | 16 | use_grpc_plugin = None 17 | if cc_grpc_version: 18 | use_grpc_plugin = True 19 | cc_proto_library(name=name, 20 | srcs=srcs, 21 | deps=deps, 22 | cc_libs = ["@protobuf_archive//:protobuf"], 23 | protoc="@protobuf_archive//:protoc", 24 | default_runtime="@protobuf_archive//:protobuf", 25 | use_grpc_plugin=use_grpc_plugin, 26 | testonly=testonly, 27 | visibility=visibility,) 28 | 29 | def tfdv_proto_library_py(name, proto_library, srcs=[], deps=[], visibility=None, testonly=0): 30 | """Opensource py_proto_library.""" 31 | _ignore = [proto_library] 32 | py_proto_library(name=name, 33 | srcs=srcs, 34 | srcs_version = "PY2AND3", 35 | deps=["@protobuf_archive//:protobuf_python"] + deps, 36 | default_runtime="@protobuf_archive//:protobuf_python", 37 | protoc="@protobuf_archive//:protoc", 38 | visibility=visibility, 39 | testonly=testonly,) 40 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "tensorflow_data_validation") 2 | 3 | # To update TensorFlow to a new revision. 4 | # 1. Update the 'git_commit' args below to include the new git hash. 5 | # 2. Get the sha256 hash of the archive with a command such as... 6 | # curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | sha256sum 7 | # and update the 'sha256' arg with the result. 8 | # 3. Request the new archive to be mirrored on mirror.bazel.build for more 9 | # reliable downloads. 10 | load("//tensorflow_data_validation:repo.bzl", "tensorflow_http_archive") 11 | 12 | tensorflow_http_archive( 13 | name = "org_tensorflow", 14 | sha256 = "696c4906d6536ed8d9f8f13c4927d3ccf36dcf3e13bb352ab80cba6b1b9038d4", 15 | git_commit = "25c197e02393bd44f50079945409009dd4d434f8", 16 | ) 17 | 18 | # TensorFlow depends on "io_bazel_rules_closure" so we need this here. 19 | # Needs to be kept in sync with the same target in TensorFlow's WORKSPACE file. 20 | http_archive( 21 | name = "io_bazel_rules_closure", 22 | sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", 23 | strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", 24 | urls = [ 25 | "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", 26 | "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 27 | ], 28 | ) 29 | 30 | # Required by tf.Metadata. 31 | git_repository( 32 | name = "protobuf_bzl", 33 | # v3.4.0 34 | commit = "80a37e0782d2d702d52234b62dd4b9ec74fd2c95", 35 | remote = "https://github.com/google/protobuf.git", 36 | ) 37 | 38 | # Please add all new TensorFlow Data Validation dependencies in workspace.bzl. 39 | load("//tensorflow_data_validation:workspace.bzl", "tf_data_validation_workspace") 40 | 41 | tf_data_validation_workspace() 42 | 43 | # Specify the minimum required bazel version. 44 | load("@org_tensorflow//tensorflow:version_check.bzl", "check_bazel_version_at_least") 45 | 46 | check_bazel_version_at_least("0.15.0") 47 | -------------------------------------------------------------------------------- /tensorflow_data_validation/coders/tf_example_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Decode TF Examples into in-memory representation for tf data validation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | from tensorflow_data_validation import types 25 | 26 | 27 | def _convert_to_numpy_array(feature): 28 | """Converts a single TF feature to its numpy array representation.""" 29 | kind = feature.WhichOneof('kind') 30 | if kind == 'int64_list': 31 | return np.asarray(feature.int64_list.value, dtype=np.integer) 32 | elif kind == 'float_list': 33 | return np.asarray(feature.float_list.value, dtype=np.floating) 34 | elif kind == 'bytes_list': 35 | return np.asarray(feature.bytes_list.value, dtype=np.object) 36 | else: 37 | # Return an empty array for feature with no value list. In numpy, an empty 38 | # array has a dtype of float, thus we explicitly set it to np.object here. 39 | return np.array([], dtype=np.object) 40 | 41 | 42 | class TFExampleDecoder(object): 43 | """A decoder for decoding TF examples into tf data validation datasets. 44 | """ 45 | 46 | def decode(self, serialized_example_proto): 47 | """Decodes serialized tf.Example to tf data validation input dict.""" 48 | example = tf.train.Example() 49 | example.ParseFromString(serialized_example_proto) 50 | feature_map = example.features.feature 51 | return { 52 | feature_name: _convert_to_numpy_array(feature_map[feature_name]) 53 | for feature_name in feature_map 54 | } 55 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/bool_domain_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_BOOL_DOMAIN_UTIL_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_BOOL_DOMAIN_UTIL_H_ 18 | 19 | #include 20 | 21 | #include "tensorflow_data_validation/anomalies/internal_types.h" 22 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 23 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 24 | 25 | namespace tensorflow { 26 | namespace data_validation { 27 | 28 | // Update a BoolDomain by itself. Namely, if the string values corresponding to 29 | // true and false in the domain are the same, clear the value for false. 30 | std::vector UpdateBoolDomainSelf( 31 | tensorflow::metadata::v0::BoolDomain* bool_domain); 32 | 33 | // This updates bool_domain. Should only be called if bool_domain is set. 34 | // If the type is INT and the min and max are out of the range {0,1}, 35 | // this will set int_domain. 36 | std::vector UpdateBoolDomain( 37 | const FeatureStatsView& feature_stats, 38 | tensorflow::metadata::v0::Feature* feature); 39 | 40 | // Determine if this could be a BoolDomain. 41 | // Note this takes precedence over IntDomain and StringDomain. 42 | bool IsBoolDomainCandidate(const FeatureStatsView& feature_stats); 43 | 44 | // Generate a BoolDomain from the stats. 45 | // The behavior is undefined if IsBoolDomainCandidate(stats) is false. 46 | tensorflow::metadata::v0::BoolDomain BoolDomainFromStats( 47 | const FeatureStatsView& stats); 48 | 49 | } // namespace data_validation 50 | } // namespace tensorflow 51 | 52 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_BOOL_DOMAIN_UTIL_H_ 53 | -------------------------------------------------------------------------------- /tensorflow_data_validation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Init module for TensorFlow Data Validation.""" 16 | 17 | # Import stats API. 18 | from tensorflow_data_validation.api.stats_api import GenerateStatistics 19 | from tensorflow_data_validation.api.stats_api import StatsOptions 20 | 21 | # Import validation API. 22 | from tensorflow_data_validation.api.validation_api import infer_schema 23 | from tensorflow_data_validation.api.validation_api import validate_statistics 24 | 25 | # Import coders. 26 | from tensorflow_data_validation.coders.csv_decoder import DecodeCSV 27 | from tensorflow_data_validation.coders.tf_example_decoder import TFExampleDecoder 28 | 29 | # Import stats generators. 30 | from tensorflow_data_validation.statistics.generators.stats_generator import CombinerStatsGenerator 31 | from tensorflow_data_validation.statistics.generators.stats_generator import TransformStatsGenerator 32 | 33 | # Import display utilities. 34 | from tensorflow_data_validation.utils.display_util import display_anomalies 35 | from tensorflow_data_validation.utils.display_util import display_schema 36 | from tensorflow_data_validation.utils.display_util import visualize_statistics 37 | 38 | # Import schema utilities. 39 | from tensorflow_data_validation.utils.schema_util import get_domain 40 | from tensorflow_data_validation.utils.schema_util import get_feature 41 | 42 | # Import stats lib. 43 | from tensorflow_data_validation.utils.stats_gen_lib import generate_statistics_from_csv 44 | from tensorflow_data_validation.utils.stats_gen_lib import generate_statistics_from_tfrecord 45 | from tensorflow_data_validation.utils.stats_gen_lib import load_statistics 46 | 47 | # Import version string. 48 | from tensorflow_data_validation.version import __version__ 49 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/metrics.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/metrics.h" 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "tensorflow_data_validation/anomalies/map_util.h" 25 | #include "tensorflow/core/platform/types.h" 26 | 27 | using std::map; 28 | 29 | namespace tensorflow { 30 | namespace data_validation { 31 | 32 | namespace { 33 | 34 | // Gets the L-infty norm of a vector, represented as a map. 35 | // This is the largest absolute value of any value. 36 | // For convenience, the associated key is also returned. 37 | std::pair GetLInftyNorm(const map& vec) { 38 | std::pair best_so_far; 39 | for (const auto& pair : vec) { 40 | const string& key = pair.first; 41 | const double value = std::abs(pair.second); 42 | if (value >= best_so_far.second) { 43 | best_so_far = {key, value}; 44 | } 45 | } 46 | return best_so_far; 47 | } 48 | 49 | } // namespace 50 | 51 | std::pair LInftyDistance(const map& counts_a, 52 | const map& counts_b) { 53 | return GetLInftyNorm(GetDifference(Normalize(counts_a), Normalize(counts_b))); 54 | } 55 | 56 | std::pair LInftyDistance(const FeatureStatsView& a, 57 | const FeatureStatsView& b) { 58 | const map prob_a = Normalize(a.GetStringValuesWithCounts()); 59 | const map prob_b = Normalize(b.GetStringValuesWithCounts()); 60 | 61 | return GetLInftyNorm(GetDifference(prob_a, prob_b)); 62 | } 63 | 64 | } // namespace data_validation 65 | } // namespace tensorflow 66 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/proto/feature_statistics_to_proto.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | syntax = "proto2"; 17 | 18 | package tensorflow.data_validation; 19 | 20 | // Manual constraints on the automatic generation of a schema. 21 | message ColumnConstraint { 22 | // A column constraint can apply to multiple columns. 23 | repeated string column_name = 1; 24 | // The name of the enum representing all the columns, if present. 25 | optional string enum_name = 2; 26 | } 27 | 28 | // See EnumType::IsSimilar(...). 29 | message EnumsSimilarConfig { 30 | // Equal or below this count, two enums must be identical to be considered 31 | // "similar". 32 | optional int32 min_count = 1 [default = 10]; 33 | // Jaccard similarity is the ratio of the intersection to the union. 34 | // The enum types are viewed as sets, then two enums are similar if they both 35 | // have more than min_similar_count elements and a Jaccard similarity higher 36 | // than min_jaccard_similarity. 37 | optional double min_jaccard_similarity = 2 [default = 0.5]; 38 | } 39 | 40 | // Configuration for creating the first version of a schema or a new field 41 | // within a schema during validation. 42 | message FeatureStatisticsToProtoConfig { 43 | // Deleted fields. 44 | reserved 3, 4; 45 | 46 | // If a string field has less than this number of entries, it will be 47 | // interpreted as an enum. 48 | optional int32 enum_threshold = 1; 49 | optional EnumsSimilarConfig enums_similar_config = 2; 50 | // Constraints on various columns. 51 | repeated ColumnConstraint column_constraint = 5; 52 | // Ignore the following columns. 53 | repeated string column_to_ignore = 6; 54 | // Sets the severity of an anomaly which indicates a new feature. 55 | optional bool new_features_are_warnings = 7; 56 | } 57 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/stats_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for stats generators. 15 | """ 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | from tensorflow_data_validation import types 24 | from tensorflow_data_validation.types_compat import List, Optional 25 | from tensorflow_metadata.proto.v0 import schema_pb2 26 | from tensorflow_metadata.proto.v0 import statistics_pb2 27 | 28 | 29 | def get_categorical_numeric_features( 30 | schema): 31 | """Get the list of numeric features that should be treated as categorical. 32 | 33 | Args: 34 | schema: The schema for the data. 35 | 36 | Returns: 37 | A list of int features that should be considered categorical. 38 | """ 39 | categorical_features = [] 40 | for feature in schema.feature: 41 | if (feature.type == schema_pb2.INT and feature.HasField('int_domain') and 42 | feature.int_domain.is_categorical): 43 | categorical_features.append(feature.name) 44 | return categorical_features 45 | 46 | 47 | def make_feature_type(dtype 48 | ): 49 | """Get feature type from numpy dtype. 50 | 51 | Args: 52 | dtype: Numpy dtype. 53 | 54 | Returns: 55 | A statistics_pb2.FeatureNameStatistics.Type value. 56 | """ 57 | if not isinstance(dtype, np.dtype): 58 | raise TypeError( 59 | 'dtype is of type %s, should be a numpy.dtype' % type(dtype).__name__) 60 | 61 | if np.issubdtype(dtype, np.integer): 62 | return statistics_pb2.FeatureNameStatistics.INT 63 | elif np.issubdtype(dtype, np.floating): 64 | return statistics_pb2.FeatureNameStatistics.FLOAT 65 | # The numpy dtype for strings is variable length. 66 | # We need to compare the dtype.type to be sure it's a string type. 67 | elif (dtype == np.object or dtype.type == np.string_ or 68 | dtype.type == np.unicode_): 69 | return statistics_pb2.FeatureNameStatistics.STRING 70 | return None 71 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/string_domain_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_STRING_DOMAIN_UTIL_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_STRING_DOMAIN_UTIL_H_ 18 | 19 | #include 20 | 21 | #include "tensorflow_data_validation/anomalies/internal_types.h" 22 | #include "tensorflow_data_validation/anomalies/proto/feature_statistics_to_proto.pb.h" 23 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 24 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 25 | 26 | namespace tensorflow { 27 | namespace data_validation { 28 | 29 | // True if two domains are similar. If they are "small" according to the 30 | // config.min_count, then they must be identical. Otherwise, they must 31 | // have a large jaccard similarity. 32 | bool IsSimilarStringDomain(const tensorflow::metadata::v0::StringDomain& a, 33 | const tensorflow::metadata::v0::StringDomain& b, 34 | const EnumsSimilarConfig& config); 35 | 36 | // Returns true if this feature_stats has less than enum_threshold number of 37 | // unique string values. 38 | bool IsStringDomainCandidate(const FeatureStatsView& feature_stats, 39 | const int enum_threshold); 40 | 41 | 42 | // If there are any values that are repeated, remove them. 43 | std::vector UpdateStringDomainSelf( 44 | tensorflow::metadata::v0::StringDomain* string_domain); 45 | 46 | // Update a string domain. 47 | // stats: the statistics of the string domain. 48 | // max_off_domain: the maximum fraction of mass allowed to be off the domain. 49 | // string_domain: string_domain to be modified. 50 | UpdateSummary UpdateStringDomain( 51 | const FeatureStatsView& stats, double max_off_domain, 52 | tensorflow::metadata::v0::StringDomain* string_domain); 53 | 54 | } // namespace data_validation 55 | } // namespace tensorflow 56 | 57 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_STRING_DOMAIN_UTIL_H_ 58 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/internal_types.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_INTERNAL_TYPES_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_INTERNAL_TYPES_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "tensorflow/core/platform/types.h" 23 | #include "tensorflow_metadata/proto/v0/anomalies.pb.h" 24 | 25 | namespace tensorflow { 26 | namespace data_validation { 27 | 28 | // Represents the description of an anomaly, in short and long form. 29 | struct Description { 30 | tensorflow::metadata::v0::AnomalyInfo::Type type; 31 | string short_description, long_description; 32 | 33 | friend bool operator==(const Description& a, const Description& b) { 34 | return (a.type == b.type && a.short_description == b.short_description && 35 | a.long_description == b.long_description); 36 | } 37 | 38 | friend std::ostream& operator<<(std::ostream& strm, const Description& a) { 39 | return (strm << "{" << a.type << ", " << a.short_description << ", " << 40 | a.long_description << "}"); 41 | } 42 | }; 43 | 44 | // UpdateSummary for a field. 45 | struct UpdateSummary { 46 | // Clear the field in question. If this is a ``shared'' enum, 47 | // then the field is dropped. 48 | UpdateSummary() { clear_field = false; } 49 | bool clear_field; 50 | std::vector descriptions; 51 | }; 52 | 53 | enum class ComparatorType { SKEW, DRIFT }; 54 | 55 | // The context for a tensorflow::metadata::v0::FeatureComparator. 56 | // In tensorflow::metadata::v0::Feature, there are two comparisons: 57 | // skew_comparator (that compares serving and training) and 58 | // drift_comparator (that compares previous and current). This struct 59 | // allows us to annotate the objects based upon this information. 60 | struct ComparatorContext { 61 | string control_name; 62 | string treatment_name; 63 | }; 64 | 65 | } // namespace data_validation 66 | } // namespace tensorflow 67 | 68 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_INTERNAL_TYPES_H_ 69 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/batch_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for batching input examples.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | import apache_beam as beam 23 | import numpy as np 24 | from tensorflow_data_validation import types 25 | from tensorflow_data_validation.types_compat import List, Optional 26 | 27 | 28 | def _merge_single_batch(batch): 29 | """Merges batched input examples to proper batch format.""" 30 | batch_size = len(batch) 31 | result = {} 32 | for idx, example in enumerate(batch): 33 | for feature in example.keys(): 34 | if feature not in result.keys(): 35 | # New feature. Initialize the list with None. 36 | result[feature] = np.empty(batch_size, dtype=np.object) 37 | result[feature][idx] = example[feature] 38 | return result 39 | 40 | 41 | @beam.typehints.with_input_types(types.ExampleBatch) 42 | @beam.typehints.with_output_types(types.ExampleBatch) 43 | @beam.ptransform_fn 44 | def BatchExamples( # pylint: disable=invalid-name 45 | examples, 46 | desired_batch_size = None): 47 | """Batches input examples to proper batch format. 48 | 49 | Each input example is a dict of feature name to np.ndarray of feature values. 50 | The output batched example format is also a dict of feature name to a 51 | np.ndarray. However, this np.ndarray contains either np.ndarray of feature 52 | values (if one example have this feature), or np.NaN (if one example is 53 | missing this feature). 54 | 55 | For example, if two input examples are 56 | { 57 | 'a': [1, 2, 3], 58 | 'b': ['a', 'b', 'c'] 59 | }, 60 | { 61 | 'a': [4, 5, 6], 62 | } 63 | 64 | Then the output batched examples will be 65 | { 66 | 'a': [[1, 2, 3], [4, 5, 6]], 67 | 'b': [['a', 'b', 'c'], np.NaN] 68 | } 69 | 70 | Args: 71 | examples: PCollection of examples. Each example should be a dict of 72 | feature name to a numpy array of values (OK to be empty). 73 | desired_batch_size: Optional batch size for batching examples when 74 | computing data statistics. 75 | 76 | Returns: 77 | PCollection of batched examples. 78 | """ 79 | batch_args = {} 80 | if desired_batch_size: 81 | batch_args = dict( 82 | min_batch_size=desired_batch_size, max_batch_size=desired_batch_size) 83 | return (examples 84 | | 'BatchExamples' >> beam.BatchElements(**batch_args) 85 | | 'MergeBatch' >> beam.Map(_merge_single_batch)) 86 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/validation_api.i: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | %{ 17 | #include "tensorflow/core/lib/core/status.h" 18 | #include "tensorflow_data_validation/anomalies/feature_statistics_validator.h" 19 | 20 | #ifdef HAS_GLOBAL_STRING 21 | using ::string; 22 | #else 23 | using std::string; 24 | #endif 25 | 26 | PyObject* ConvertToPythonString(const string& input_str) { 27 | return PyBytes_FromStringAndSize(input_str.data(), input_str.size()); 28 | } 29 | %} 30 | 31 | %{ 32 | PyObject* InferSchema(const string& statistics_proto_string, 33 | int max_string_domain_size) { 34 | string schema_proto_string; 35 | const tensorflow::Status status = tensorflow::data_validation::InferSchema( 36 | statistics_proto_string, max_string_domain_size, &schema_proto_string); 37 | if (!status.ok()) { 38 | PyErr_SetString(PyExc_RuntimeError, status.error_message().c_str()); 39 | return NULL; 40 | } 41 | return ConvertToPythonString(schema_proto_string); 42 | } 43 | 44 | 45 | PyObject* ValidateFeatureStatistics( 46 | const string& statistics_proto_string, 47 | const string& schema_proto_string, 48 | const string& environment, 49 | const string& previous_statistics_proto_string, 50 | const string& serving_statistics_proto_string) { 51 | string anomalies_proto_string; 52 | const tensorflow::Status status = tensorflow::data_validation::ValidateFeatureStatistics( 53 | statistics_proto_string, schema_proto_string, environment, 54 | previous_statistics_proto_string, serving_statistics_proto_string, 55 | &anomalies_proto_string); 56 | if (!status.ok()) { 57 | PyErr_SetString(PyExc_RuntimeError, status.error_message().c_str()); 58 | return NULL; 59 | } 60 | return ConvertToPythonString(anomalies_proto_string); 61 | } 62 | %} 63 | 64 | // Typemap to convert an input argument from Python object to C++ string. 65 | %typemap(in) const string& (string temp) { 66 | char *buf; 67 | Py_ssize_t len; 68 | if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) SWIG_fail; 69 | temp.assign(buf, len); 70 | $1 = &temp; 71 | } 72 | 73 | PyObject* InferSchema(const string& statistics_proto_string, 74 | int max_string_domain_size); 75 | 76 | PyObject* ValidateFeatureStatistics( 77 | const string& statistics_proto_string, 78 | const string& schema_proto_string, 79 | const string& environment, 80 | const string& previous_statistics_proto_string, 81 | const string& serving_statistics_proto_string); 82 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Package Setup script for TensorFlow Data Validation.""" 15 | 16 | from setuptools import find_packages 17 | from setuptools import setup 18 | from setuptools.dist import Distribution 19 | 20 | 21 | class BinaryDistribution(Distribution): 22 | """This class is needed in order to create OS specific wheels.""" 23 | 24 | def has_ext_modules(self): 25 | return True 26 | 27 | # Get version from version module. 28 | with open('tensorflow_data_validation/version.py') as fp: 29 | globals_dict = {} 30 | exec (fp.read(), globals_dict) # pylint: disable=exec-used 31 | __version__ = globals_dict['__version__'] 32 | 33 | setup( 34 | name='tensorflow-data-validation', 35 | version=__version__, 36 | author='Google LLC', 37 | author_email='tensorflow-extended-dev@googlegroups.com', 38 | license='Apache 2.0', 39 | classifiers=[ 40 | 'Development Status :: 4 - Beta', 41 | 'Intended Audience :: Developers', 42 | 'Intended Audience :: Education', 43 | 'Intended Audience :: Science/Research', 44 | 'License :: OSI Approved :: Apache Software License', 45 | 'Operating System :: OS Independent', 46 | 'Programming Language :: Python', 47 | 'Programming Language :: Python :: 2', 48 | 'Programming Language :: Python :: 2.7', 49 | 'Programming Language :: Python :: 2 :: Only', 50 | 'Topic :: Scientific/Engineering', 51 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 52 | 'Topic :: Scientific/Engineering :: Mathematics', 53 | 'Topic :: Software Development', 54 | 'Topic :: Software Development :: Libraries', 55 | 'Topic :: Software Development :: Libraries :: Python Modules', 56 | ], 57 | namespace_packages=[], 58 | install_requires=[ 59 | 'absl-py>=0.1.6', 60 | 'apache-beam[gcp]>=2.6,<3', 61 | 'numpy>=1.13.3,<2', 62 | 63 | # TF now requires protobuf>=3.6.0. 64 | 'protobuf>=3.6.0,<4', 65 | 66 | 'six>=1.10,<2', 67 | 68 | 69 | 'tensorflow-metadata>=0.9,<1', 70 | 'tensorflow-transform>=0.9,<1', 71 | 72 | # Dependencies needed for visualization. 73 | 'IPython>=5.0,<6', 74 | 'pandas>=0.18,<1', 75 | ], 76 | python_requires='>=2.7,<3', 77 | packages=find_packages(), 78 | include_package_data=True, 79 | package_data={'': ['*.so']}, 80 | zip_safe=False, 81 | distclass=BinaryDistribution, 82 | description='A library for exploring and validating machine learning data.', 83 | requires=[]) 84 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/batch_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for example batching utilities.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | import apache_beam as beam 23 | from apache_beam.testing import util 24 | import numpy as np 25 | from tensorflow_data_validation.utils import batch_util 26 | 27 | 28 | class BatchUtilTest(absltest.TestCase): 29 | 30 | def test_batch_examples(self): 31 | examples = [{ 32 | 'a': np.array([1.0, 2.0], dtype=np.floating), 33 | 'b': np.array(['a', 'b', 'c', 'e'], dtype=np.object) 34 | }, { 35 | 'a': np.array([3.0, 4.0, np.NaN, 5.0], dtype=np.floating), 36 | }, { 37 | 'b': np.array(['d', 'e', 'f'], dtype=np.object), 38 | 'd': np.array([10, 20, 30], dtype=np.integer), 39 | }, { 40 | 'b': np.array(['a', 'b', 'c'], dtype=np.object) 41 | }, { 42 | 'c': np.array(['d', 'e', 'f'], dtype=np.object) 43 | }] 44 | 45 | expected_batched_examples = [{ 46 | 'a': np.array([np.array([1.0, 2.0]), np.array([3.0, 4.0, np.NaN, 5.0]), 47 | None], dtype=np.object), 48 | 'b': np.array([np.array(['a', 'b', 'c', 'e']), None, 49 | np.array(['d', 'e', 'f'])], dtype=np.object), 50 | 'd': np.array([np.NaN, np.NaN, np.array([10, 20, 30])], dtype=np.object) 51 | }, { 52 | 'b': np.array([np.array(['a', 'b', 'c']), None], dtype=np.object), 53 | 'c': np.array([None, np.array(['d', 'e', 'f'])], dtype=np.object) 54 | }] 55 | 56 | def _batched_example_equal_fn(expected_batched_examples): 57 | """Makes a matcher function for comparing batched examples.""" 58 | def _matcher(actual_batched_examples): 59 | sorted_result = sorted(actual_batched_examples) 60 | sorted_expected_result = sorted(expected_batched_examples) 61 | self.assertEqual(len(sorted_result), len(sorted_expected_result)) 62 | for idx, batched_example in enumerate(sorted_result): 63 | self.assertEqual(sorted(batched_example), 64 | sorted(sorted_expected_result[idx])) 65 | return _matcher 66 | 67 | with beam.Pipeline() as p: 68 | result = (p 69 | | beam.Create(examples) 70 | | batch_util.BatchExamples(desired_batch_size=3)) 71 | util.assert_that( 72 | result, _batched_example_equal_fn(expected_batched_examples)) 73 | 74 | 75 | if __name__ == '__main__': 76 | absltest.main() 77 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/test_util.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/test_util.h" 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | #include "absl/strings/str_cat.h" 25 | #include "tensorflow_data_validation/anomalies/map_util.h" 26 | #include "tensorflow/core/lib/core/status.h" 27 | #include "tensorflow/core/lib/core/status_test_util.h" 28 | #include "tensorflow/core/platform/logging.h" 29 | #include "tensorflow/core/platform/types.h" 30 | #include "tensorflow_metadata/proto/v0/anomalies.pb.h" 31 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 32 | 33 | namespace tensorflow { 34 | namespace data_validation { 35 | namespace testing { 36 | 37 | ProtoStringMatcher::ProtoStringMatcher(const string& expected) 38 | : expected_(expected) {} 39 | ProtoStringMatcher::ProtoStringMatcher( 40 | const ::tensorflow::protobuf::Message& expected) 41 | : expected_(expected.DebugString()) {} 42 | 43 | 44 | void TestAnomalies( 45 | const tensorflow::metadata::v0::Anomalies& actual, 46 | const tensorflow::metadata::v0::Schema& old_schema, 47 | const std::map& expected_anomalies) { 48 | EXPECT_THAT(actual.baseline(), EqualsProto(old_schema)); 49 | for (const auto& pair : expected_anomalies) { 50 | const string& name = pair.first; 51 | const ExpectedAnomalyInfo& expected = pair.second; 52 | ASSERT_TRUE(ContainsKey(actual.anomaly_info(), name)) 53 | << "Expected anomaly for feature name: " << name 54 | << " not found in Anomalies: " << actual.DebugString(); 55 | TestAnomalyInfo(actual.anomaly_info().at(name), old_schema, expected, 56 | absl::StrCat(" column: ", name)); 57 | } 58 | for (const auto& pair : actual.anomaly_info()) { 59 | const string& name = pair.first; 60 | EXPECT_TRUE(ContainsKey(expected_anomalies, name)) 61 | << "Unexpected anomaly: " << name << " " 62 | << pair.second.DebugString(); 63 | } 64 | } 65 | 66 | void TestAnomalyInfo(const tensorflow::metadata::v0::AnomalyInfo& actual, 67 | const tensorflow::metadata::v0::Schema& baseline, 68 | const ExpectedAnomalyInfo& expected, 69 | const string& comment) { 70 | tensorflow::metadata::v0::AnomalyInfo actual_info = actual; 71 | EXPECT_THAT(actual_info, EqualsProto(expected.expected_info_without_diff)) 72 | << comment; 73 | } 74 | 75 | } // namespace testing 76 | } // namespace data_validation 77 | } // namespace tensorflow 78 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/map_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_MAP_UTIL_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_MAP_UTIL_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "tensorflow/core/platform/types.h" 25 | 26 | namespace tensorflow { 27 | namespace data_validation { 28 | 29 | // Returns true if and only if the given container contains the given key. 30 | template 31 | bool ContainsKey(const Container& container, const Key& key) { 32 | return container.find(key) != container.end(); 33 | } 34 | 35 | // Adds the values in the map. 36 | double SumValues(const std::map& input); 37 | 38 | // Gets the keys from the map. The order of the keys is the same as in 39 | // the map. 40 | std::vector GetKeysFromMap(const std::map& input); 41 | 42 | // Gets the values from the map. The order of the values is the same as in 43 | // the map. 44 | std::vector GetValuesFromMap(const std::map& input); 45 | 46 | // Normalizes the values, such that the sum of the values are 1. 47 | // If the values sum to zero, return the input map. 48 | std::map Normalize(const std::map& input); 49 | 50 | // Gets the difference of the values of two maps. Values that are not 51 | // present are treated as zero. 52 | std::map GetDifference(const std::map& a, 53 | const std::map& b); 54 | 55 | // Gets the sum of the values of two maps. Values that are not 56 | // present are treated as zero. 57 | std::map GetSum(const std::map& a, 58 | const std::map& b); 59 | 60 | // Increments one map by another. Values that are not 61 | // present are treated as zero. 62 | void IncrementMap(const std::map& a, 63 | std::map* b); 64 | 65 | // Applies a function to all the values in the map. 66 | std::map MapValues(const std::map& input, 67 | const std::function& mapFn); 68 | 69 | // Cast the values from int64 to double. Notice that this might lose some 70 | // information. 71 | std::map IntMapToDoubleMap( 72 | const std::map& int_map); 73 | 74 | } // namespace data_validation 75 | } // namespace tensorflow 76 | 77 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_MAP_UTIL_H_ 78 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/profile_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for profile utilities.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl.testing import absltest 21 | import apache_beam as beam 22 | import numpy as np 23 | import six 24 | from tensorflow_data_validation.utils import profile_util 25 | 26 | 27 | class ProfileUtilTest(absltest.TestCase): 28 | 29 | def test_profile_input(self): 30 | examples = [ 31 | { 32 | 'a': np.array([1.0, 2.0], dtype=np.floating), 33 | 'b': np.array(['a', 'b', 'c', 'e'], dtype=np.object), 34 | }, 35 | { 36 | 'a': np.array([3.0, 4.0, np.NaN, 5.0], dtype=np.floating), 37 | }, 38 | { 39 | 'b': np.array(['d', 'e', 'f'], dtype=np.object), 40 | 'd': np.array([10, 20, 30], dtype=np.integer), 41 | }, 42 | { 43 | 'b': np.array(['a', 'b', 'c'], dtype=np.object), 44 | }, 45 | { 46 | 'c': np.array(['d', 'e', 'f'], dtype=np.object), 47 | }, 48 | ] 49 | 50 | expected_distributions = { 51 | 'int_feature_values_count': [3, 3, 3, 1], 52 | 'float_feature_values_count': [2, 4, 6, 2], 53 | 'string_feature_values_count': [3, 4, 13, 4], 54 | } 55 | p = beam.Pipeline() 56 | _ = ( 57 | p 58 | | 'Create' >> beam.Create(examples) 59 | | 'Profile' >> profile_util.Profile()) 60 | 61 | runner = p.run() 62 | runner.wait_until_finish() 63 | result_metrics = runner.metrics() 64 | 65 | num_metrics = len( 66 | result_metrics.query(beam.metrics.metric.MetricsFilter().with_namespace( 67 | profile_util.METRICS_NAMESPACE))['counters']) 68 | self.assertEqual(num_metrics, 1) 69 | 70 | counter = result_metrics.query(beam.metrics.metric.MetricsFilter() 71 | .with_name('num_instances'))['counters'] 72 | self.assertEqual(len(counter), 1) 73 | self.assertEqual(counter[0].committed, 5) 74 | 75 | for distribution_name, expected_value in six.iteritems( 76 | expected_distributions): 77 | metric_filter = beam.metrics.metric.MetricsFilter().with_name( 78 | distribution_name) 79 | distribution = result_metrics.query(metric_filter)['distributions'] 80 | self.assertEqual(len(distribution), 1) 81 | self.assertEqual([ 82 | distribution[0].committed.min, distribution[0].committed.max, 83 | distribution[0].committed.sum, distribution[0].committed.count 84 | ], expected_value) 85 | 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/stats_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utilities.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | import numpy as np 23 | from tensorflow_data_validation.utils import stats_util 24 | 25 | from google.protobuf import text_format 26 | from tensorflow_metadata.proto.v0 import schema_pb2 27 | from tensorflow_metadata.proto.v0 import statistics_pb2 28 | 29 | 30 | class StatsUtilTest(absltest.TestCase): 31 | 32 | def test_make_feature_type_int(self): 33 | self.assertEqual(stats_util.make_feature_type(np.dtype('int8')), 34 | statistics_pb2.FeatureNameStatistics.INT) 35 | self.assertEqual(stats_util.make_feature_type(np.dtype('int16')), 36 | statistics_pb2.FeatureNameStatistics.INT) 37 | self.assertEqual(stats_util.make_feature_type(np.dtype('int32')), 38 | statistics_pb2.FeatureNameStatistics.INT) 39 | self.assertEqual(stats_util.make_feature_type(np.dtype('int64')), 40 | statistics_pb2.FeatureNameStatistics.INT) 41 | 42 | def test_make_feature_type_float(self): 43 | self.assertEqual(stats_util.make_feature_type(np.dtype('float16')), 44 | statistics_pb2.FeatureNameStatistics.FLOAT) 45 | self.assertEqual(stats_util.make_feature_type(np.dtype('float32')), 46 | statistics_pb2.FeatureNameStatistics.FLOAT) 47 | self.assertEqual(stats_util.make_feature_type(np.dtype('float64')), 48 | statistics_pb2.FeatureNameStatistics.FLOAT) 49 | 50 | def test_make_feature_type_string(self): 51 | self.assertEqual(stats_util.make_feature_type(np.dtype('S')), 52 | statistics_pb2.FeatureNameStatistics.STRING) 53 | self.assertEqual(stats_util.make_feature_type(np.dtype('U')), 54 | statistics_pb2.FeatureNameStatistics.STRING) 55 | 56 | def test_make_feature_type_none(self): 57 | self.assertIsNone(stats_util.make_feature_type(np.dtype('complex64'))) 58 | 59 | def test_make_feature_type_invalid_dtype(self): 60 | with self.assertRaises(TypeError): 61 | stats_util.make_feature_type(int) 62 | 63 | def test_get_categorical_numeric_features(self): 64 | schema = text_format.Parse( 65 | """ 66 | feature { 67 | name: "fa" 68 | type: INT 69 | int_domain { 70 | is_categorical: true 71 | } 72 | } 73 | feature { 74 | name: "fb" 75 | type: BYTES 76 | } 77 | feature { 78 | name: "fc" 79 | type: FLOAT 80 | } 81 | """, schema_pb2.Schema()) 82 | self.assertEqual( 83 | stats_util.get_categorical_numeric_features(schema), ['fa']) 84 | 85 | 86 | if __name__ == '__main__': 87 | absltest.main() 88 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/schema_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for manipulating the schema.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | from tensorflow_data_validation import types 23 | from tensorflow_data_validation.types_compat import Union 24 | from tensorflow_metadata.proto.v0 import schema_pb2 25 | 26 | 27 | def get_feature(schema, 28 | feature_name): 29 | """Get a feature from the schema. 30 | 31 | Args: 32 | schema: A Schema protocol buffer. 33 | feature_name: The name of the feature to obtain from the schema. 34 | 35 | Returns: 36 | A Feature protocol buffer. 37 | 38 | Raises: 39 | TypeError: If the input schema is not of the expected type. 40 | ValueError: If the input feature is not found in the schema. 41 | """ 42 | if not isinstance(schema, schema_pb2.Schema): 43 | raise TypeError('schema is of type %s, should be a Schema proto.' % 44 | type(schema).__name__) 45 | 46 | for feature in schema.feature: 47 | if feature.name == feature_name: 48 | return feature 49 | 50 | raise ValueError('Feature %s not found in the schema.' % feature_name) 51 | 52 | 53 | 54 | 55 | def get_domain(schema, feature_name 56 | ): 57 | """Get the domain associated with the input feature from the schema. 58 | 59 | Args: 60 | schema: A Schema protocol buffer. 61 | feature_name: The name of the feature whose domain needs to be found. 62 | 63 | Returns: 64 | The domain protocol buffer (one of IntDomain, FloatDomain, StringDomain or 65 | BoolDomain) associated with the input feature. 66 | 67 | Raises: 68 | TypeError: If the input schema is not of the expected type. 69 | ValueError: If the input feature is not found in the schema or there is 70 | no domain associated with the feature. 71 | """ 72 | if not isinstance(schema, schema_pb2.Schema): 73 | raise TypeError('schema is of type %s, should be a Schema proto.' % 74 | type(schema).__name__) 75 | 76 | feature = get_feature(schema, feature_name) 77 | domain_info = feature.WhichOneof('domain_info') 78 | 79 | if domain_info is None: 80 | raise ValueError('Feature %s has no domain associated with it.' 81 | % feature_name) 82 | 83 | if domain_info == 'int_domain': 84 | return feature.int_domain 85 | elif domain_info == 'float_domain': 86 | return feature.float_domain 87 | elif domain_info == 'string_domain': 88 | return feature.string_domain 89 | elif domain_info == 'domain': 90 | for domain in schema.string_domain: 91 | if domain.name == feature.domain: 92 | return domain 93 | elif domain_info == 'bool_domain': 94 | return feature.bool_domain 95 | 96 | raise ValueError('Feature %s has an unsupported domain %s.' 97 | % (feature_name, domain_info)) 98 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/statistics_view_test_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_STATISTICS_VIEW_TEST_UTIL_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_STATISTICS_VIEW_TEST_UTIL_H_ 18 | 19 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 20 | #include "tensorflow_metadata/proto/v0/statistics.pb.h" 21 | 22 | namespace tensorflow { 23 | namespace data_validation { 24 | namespace testing { 25 | 26 | // Makes a dataset with one feature. Assumes global counts match the 27 | // count for the feature. 28 | tensorflow::metadata::v0::DatasetFeatureStatistics 29 | GetDatasetFeatureStatisticsForTesting( 30 | const tensorflow::metadata::v0::FeatureNameStatistics& feature_name_stats); 31 | 32 | // For testing, we often just have information for one feature. 33 | // However, DatasetStatsView and FeatureStatsView point to other objects. 34 | // This structure allows us to set all that up in one call. 35 | // Here is a pattern: 36 | // FuncToTest(DatasetForTesting(stats).feature_stats_view()) 37 | // Here is an anti-pattern. It will make the resulting object point to a 38 | // destroyed object (very bad). 39 | // const FeatureStatsView& MyShortcut( 40 | // const tensorflow::metadata::v0::FeatureNameStatistics& stats) { 41 | // return DatasetForTesting(stats).feature_stats_view(); 42 | // } 43 | class DatasetForTesting { 44 | public: 45 | explicit DatasetForTesting( 46 | const tensorflow::metadata::v0::FeatureNameStatistics& 47 | feature_name_stats); 48 | DatasetForTesting( 49 | const tensorflow::metadata::v0::FeatureNameStatistics& feature_name_stats, 50 | bool by_weight); 51 | 52 | DatasetForTesting(const tensorflow::metadata::v0::DatasetFeatureStatistics& 53 | dataset_feature_stats, 54 | bool by_weight); 55 | 56 | // DatasetForTesting is neither copyable nor movable, as DatasetStatsView 57 | // is neither copyable nor movable. 58 | DatasetForTesting(const DatasetForTesting&) = delete; 59 | DatasetForTesting& operator=(const DatasetForTesting&) = delete; 60 | 61 | const DatasetStatsView& dataset_stats_view() const { 62 | return dataset_stats_view_; 63 | } 64 | 65 | const FeatureStatsView& feature_stats_view() const { 66 | return feature_stats_view_; 67 | } 68 | 69 | private: 70 | // Notice that the destructor will destroy the objects from bottom to top, 71 | // respecting the proper order of destruction. 72 | const tensorflow::metadata::v0::DatasetFeatureStatistics 73 | dataset_feature_statistics_; 74 | const DatasetStatsView dataset_stats_view_; 75 | const FeatureStatsView feature_stats_view_; 76 | }; 77 | 78 | DatasetForTesting GetDatasetForTesting( 79 | const tensorflow::metadata::v0::FeatureNameStatistics& feature_name_stats); 80 | 81 | tensorflow::metadata::v0::FeatureNameStatistics AddWeightedStats( 82 | const tensorflow::metadata::v0::FeatureNameStatistics& original); 83 | 84 | } // namespace testing 85 | } // namespace data_validation 86 | } // namespace tensorflow 87 | 88 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_STATISTICS_VIEW_TEST_UTIL_H_ 89 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/metrics_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/metrics.h" 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | #include "tensorflow_data_validation/anomalies/statistics_view_test_util.h" 24 | #include "tensorflow_data_validation/anomalies/test_util.h" 25 | #include "tensorflow/core/platform/types.h" 26 | #include "tensorflow_metadata/proto/v0/statistics.pb.h" 27 | 28 | namespace tensorflow { 29 | namespace data_validation { 30 | 31 | namespace { 32 | 33 | using testing::DatasetForTesting; 34 | using testing::ParseTextProtoOrDie; 35 | 36 | tensorflow::metadata::v0::FeatureNameStatistics 37 | GetFeatureNameStatisticsWithTokens(const std::map& tokens) { 38 | tensorflow::metadata::v0::FeatureNameStatistics result = 39 | ParseTextProtoOrDie( 40 | R"(name: 'bar' 41 | type: STRING 42 | string_stats: { 43 | common_stats: { 44 | num_missing: 0 45 | max_num_values: 1 46 | }})"); 47 | tensorflow::metadata::v0::StringStatistics* string_stats = 48 | result.mutable_string_stats(); 49 | string_stats->set_unique(tokens.size()); 50 | tensorflow::metadata::v0::RankHistogram* histogram = 51 | string_stats->mutable_rank_histogram(); 52 | for (const auto& pair : tokens) { 53 | const string& feature_value = pair.first; 54 | const double feature_occurrences = pair.second; 55 | tensorflow::metadata::v0::RankHistogram::Bucket* bucket = 56 | histogram->add_buckets(); 57 | *bucket->mutable_label() = feature_value; 58 | bucket->set_sample_count(feature_occurrences); 59 | } 60 | return result; 61 | } 62 | 63 | struct LInftyDistanceExample { 64 | string name; 65 | std::map training; 66 | std::map serving; 67 | double expected; 68 | }; 69 | 70 | std::vector GetLInftyDistanceTests() { 71 | return {{"Two empty maps", {}, {}, 0.0}, 72 | {"Normal distribution.", 73 | {{"hello", 0.1}, {"world", 0.9}}, 74 | {{"hello", 0.3}, {"world", 0.7}}, 75 | 0.2}, 76 | {"Missing value in both.", 77 | {{"b", 0.9}, {"c", 0.1}}, 78 | {{"a", 0.3}, {"b", 0.7}}, 79 | 0.3}, 80 | {"Missing value in both, flipped.", 81 | {{"a", 0.3}, {"b", 0.7}}, 82 | {{"b", 0.9}, {"c", 0.1}}, 83 | 0.3}}; 84 | } 85 | 86 | TEST(LInftyDistanceTest, All) { 87 | for (const auto& test : GetLInftyDistanceTests()) { 88 | const DatasetForTesting training( 89 | GetFeatureNameStatisticsWithTokens(test.training)); 90 | const DatasetForTesting serving( 91 | GetFeatureNameStatisticsWithTokens(test.serving)); 92 | const double result = LInftyDistance(training.feature_stats_view(), 93 | serving.feature_stats_view()) 94 | .second; 95 | EXPECT_NEAR(result, test.expected, 1e-5) << test.name; 96 | } 97 | } 98 | 99 | } // namespace 100 | 101 | } // namespace data_validation 102 | 103 | } // namespace tensorflow 104 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/map_util.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/map_util.h" 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "tensorflow/core/platform/types.h" 25 | 26 | namespace tensorflow { 27 | namespace data_validation { 28 | 29 | double SumValues(const std::map& input) { 30 | std::vector values = GetValuesFromMap(input); 31 | return std::accumulate(values.begin(), values.end(), 0.0); 32 | } 33 | 34 | std::vector GetValuesFromMap(const std::map& input) { 35 | std::vector values; 36 | values.reserve(input.size()); 37 | for (const auto& pair : input) { 38 | values.push_back(pair.second); 39 | } 40 | return values; 41 | } 42 | 43 | std::vector GetKeysFromMap(const std::map& input) { 44 | std::vector keys; 45 | keys.reserve(input.size()); 46 | for (const auto& pair : input) { 47 | keys.push_back(pair.first); 48 | } 49 | return keys; 50 | } 51 | 52 | std::map Normalize(const std::map& input) { 53 | double sum = SumValues(input); 54 | if (sum == 0.0) { 55 | return input; 56 | } 57 | std::map result; 58 | for (const auto& pair : input) { 59 | const string& key = pair.first; 60 | const double value = pair.second; 61 | result[key] = value / sum; 62 | } 63 | return result; 64 | } 65 | 66 | std::map GetDifference(const std::map& a, 67 | const std::map& b) { 68 | std::map result = a; 69 | for (const auto& pair_b : b) { 70 | const string& key_b = pair_b.first; 71 | const double value_b = pair_b.second; 72 | // If the key is not present, this will initialize it to zero. 73 | result[key_b] -= value_b; 74 | } 75 | return result; 76 | } 77 | 78 | void IncrementMap(const std::map& a, 79 | std::map* b) { 80 | for (const auto& pair_a : a) { 81 | const string& key_a = pair_a.first; 82 | const double value_a = pair_a.second; 83 | // If the key is not present, this will initialize it to zero. 84 | (*b)[key_a] += value_a; 85 | } 86 | } 87 | 88 | std::map GetSum(const std::map& a, 89 | const std::map& b) { 90 | std::map result = a; 91 | IncrementMap(b, &result); 92 | return result; 93 | } 94 | 95 | std::map MapValues(const std::map& input, 96 | const std::function& mapFn) { 97 | std::map result; 98 | for (const auto& pair : input) { 99 | result[pair.first] = mapFn(pair.second); 100 | } 101 | return result; 102 | } 103 | 104 | std::map IntMapToDoubleMap( 105 | const std::map& int_map) { 106 | std::map result; 107 | for (const auto& pair : int_map) { 108 | result[pair.first] = pair.second; 109 | } 110 | return result; 111 | } 112 | 113 | } // namespace data_validation 114 | } // namespace tensorflow 115 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/profile_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for profiling data flowing through TF.DV.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | 19 | from __future__ import print_function 20 | 21 | import apache_beam as beam 22 | import numpy as np 23 | import six 24 | from tensorflow_data_validation import types 25 | from tensorflow_data_validation.utils import stats_util 26 | from tensorflow_data_validation.types_compat import Generator 27 | 28 | from tensorflow_metadata.proto.v0 import statistics_pb2 29 | 30 | # Namespace for all TFDV metrics. 31 | METRICS_NAMESPACE = 'tfx.DataValidation' 32 | 33 | 34 | @beam.typehints.with_input_types(types.ExampleBatch) 35 | @beam.typehints.with_output_types(types.ExampleBatch) 36 | class _ProfileFn(beam.DoFn): 37 | """See docstring of Profile for details.""" 38 | 39 | def __init__(self): 40 | # Counter for number of examples processed. 41 | self._num_instances = beam.metrics.Metrics.counter(METRICS_NAMESPACE, 42 | 'num_instances') 43 | 44 | # Distribution metrics to track the distribution of feature value lengths 45 | # for each type. 46 | self._int_feature_values_count = beam.metrics.Metrics.distribution( 47 | METRICS_NAMESPACE, 'int_feature_values_count') 48 | self._float_feature_values_count = beam.metrics.Metrics.distribution( 49 | METRICS_NAMESPACE, 'float_feature_values_count') 50 | self._string_feature_values_count = beam.metrics.Metrics.distribution( 51 | METRICS_NAMESPACE, 'string_feature_values_count') 52 | self._unknown_feature_values_count = beam.metrics.Metrics.distribution( 53 | METRICS_NAMESPACE, 'unknown_feature_values_count') 54 | 55 | def process(self, 56 | element 57 | ): 58 | self._num_instances.inc(1) 59 | for _, value in six.iteritems(element): 60 | if not isinstance(value, np.ndarray): 61 | self._unknown_feature_values_count.update(1) 62 | continue 63 | feature_type = stats_util.make_feature_type(value.dtype) 64 | if feature_type == statistics_pb2.FeatureNameStatistics.INT: 65 | self._int_feature_values_count.update(len(value)) 66 | elif feature_type == statistics_pb2.FeatureNameStatistics.FLOAT: 67 | self._float_feature_values_count.update(len(value)) 68 | elif feature_type == statistics_pb2.FeatureNameStatistics.STRING: 69 | self._string_feature_values_count.update(len(value)) 70 | else: 71 | self._unknown_feature_values_count.update(len(value)) 72 | yield element 73 | 74 | 75 | @beam.typehints.with_input_types(types.ExampleBatch) 76 | @beam.typehints.with_output_types(types.ExampleBatch) 77 | class Profile(beam.PTransform): 78 | """Profiles the input examples by emitting counters for different properties. 79 | 80 | Each input example is a dict of feature name to np.ndarray of feature values. 81 | The functor passes through each example to the output while emitting certain 82 | counters that track the number of examples and number of features of each 83 | type. 84 | 85 | Args: 86 | examples: PCollection of examples. Each example should be a dict of feature 87 | name to a numpy array of values (OK to be empty). 88 | 89 | Returns: 90 | PCollection of examples (same as input). 91 | """ 92 | 93 | def expand(self, pcoll): 94 | return pcoll | beam.ParDo(_ProfileFn()) 95 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/float_domain_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/float_domain_util.h" 17 | 18 | #include 19 | #include 20 | 21 | #include 22 | #include "tensorflow_data_validation/anomalies/internal_types.h" 23 | #include "tensorflow_data_validation/anomalies/statistics_view_test_util.h" 24 | #include "tensorflow_data_validation/anomalies/test_util.h" 25 | #include "tensorflow/core/platform/types.h" 26 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 27 | #include "tensorflow_metadata/proto/v0/statistics.pb.h" 28 | 29 | namespace tensorflow { 30 | namespace data_validation { 31 | namespace { 32 | using ::tensorflow::metadata::v0::FeatureNameStatistics; 33 | using ::tensorflow::metadata::v0::FloatDomain; 34 | using testing::EqualsProto; 35 | using testing::ParseTextProtoOrDie; 36 | 37 | struct UpdateFloatDomainTest { 38 | const string name; 39 | FloatDomain float_domain; 40 | const FeatureNameStatistics input; 41 | const bool clear_field; 42 | FloatDomain expected; 43 | }; 44 | 45 | std::vector GetUpdateFloatDomainTests() { 46 | return { 47 | {"transform_as_string", FloatDomain(), 48 | ParseTextProtoOrDie(R"( 49 | name: "transform_as_string" 50 | type: STRING 51 | string_stats: { 52 | common_stats: { 53 | num_missing: 3 54 | max_num_values: 2 55 | } 56 | unique: 3 57 | rank_histogram: { 58 | buckets: { 59 | label: "1.5" 60 | } 61 | buckets: { 62 | label: "0.25" 63 | }}})"), 64 | false, FloatDomain()}, 65 | {"float_value_in_range", 66 | ParseTextProtoOrDie("min: 3.0 max: 5.0"), 67 | ParseTextProtoOrDie(R"( 68 | name: 'bar' 69 | type: FLOAT 70 | num_stats: { 71 | common_stats: { 72 | num_missing: 3 73 | max_num_values: 2 74 | } 75 | min: 3.5 76 | max: 4.5})"), 77 | false, ParseTextProtoOrDie("min: 3.0 max: 5.0")}, 78 | {"low_value", ParseTextProtoOrDie("min: 3.0 max: 5.0"), 79 | ParseTextProtoOrDie(R"( 80 | name: 'bar' 81 | type: FLOAT 82 | num_stats: { 83 | common_stats: { 84 | num_missing: 3 85 | max_num_values: 2 86 | } 87 | min: 2.5 88 | max: 4.5})"), 89 | false, ParseTextProtoOrDie("min: 2.5 max: 5")}}; 90 | } 91 | 92 | TEST(FloatDomainTest, UpdateFloatDomain) { 93 | for (const auto& test : GetUpdateFloatDomainTests()) { 94 | const testing::DatasetForTesting dataset(test.input); 95 | FloatDomain to_modify = test.float_domain; 96 | UpdateSummary summary = 97 | UpdateFloatDomain(dataset.feature_stats_view(), &to_modify); 98 | if (summary.descriptions.empty()) { 99 | // If there are no descriptions, then there should be no changes. 100 | EXPECT_FALSE(summary.clear_field) << test.name; 101 | EXPECT_THAT(to_modify, EqualsProto(test.float_domain)) << test.name; 102 | } 103 | 104 | EXPECT_EQ(summary.clear_field, test.clear_field) << test.name; 105 | EXPECT_THAT(to_modify, EqualsProto(test.expected)) << test.name; 106 | } 107 | } 108 | 109 | } // namespace 110 | } // namespace data_validation 111 | } // namespace tensorflow 112 | -------------------------------------------------------------------------------- /tensorflow_data_validation/coders/tf_example_decoder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for TFExampleDecoder.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from absl.testing import absltest 20 | import numpy as np 21 | import tensorflow as tf 22 | from tensorflow_data_validation.coders import tf_example_decoder 23 | 24 | from google.protobuf import text_format 25 | 26 | 27 | class TFExampleDecoderTest(absltest.TestCase): 28 | """Tests for TFExampleDecoder.""" 29 | 30 | def _check_decoding_results(self, actual, expected): 31 | # Check that the numpy array dtypes match. 32 | self.assertEqual(len(actual), len(expected)) 33 | for key in actual: 34 | self.assertEqual(actual[key].dtype, expected[key].dtype) 35 | np.testing.assert_equal(actual, expected) 36 | 37 | def test_decode_example_empty_input(self): 38 | example = tf.train.Example() 39 | decoder = tf_example_decoder.TFExampleDecoder() 40 | self._check_decoding_results( 41 | decoder.decode(example.SerializeToString()), {}) 42 | 43 | def test_decode_example(self): 44 | example_proto_text = """ 45 | features { 46 | feature { key: "int_feature_1" 47 | value { int64_list { value: [ 0 ] } } } 48 | feature { key: "int_feature_2" 49 | value { int64_list { value: [ 1, 2, 3 ] } } } 50 | feature { key: "float_feature_1" 51 | value { float_list { value: [ 4.0 ] } } } 52 | feature { key: "float_feature_2" 53 | value { float_list { value: [ 5.0, 6.0 ] } } } 54 | feature { key: "str_feature_1" 55 | value { bytes_list { value: [ 'female' ] } } } 56 | feature { key: "str_feature_2" 57 | value { bytes_list { value: [ 'string', 'list' ] } } } 58 | } 59 | """ 60 | expected_decoded = { 61 | 'int_feature_1': np.array([0], dtype=np.integer), 62 | 'int_feature_2': np.array([1, 2, 3], dtype=np.integer), 63 | 'float_feature_1': np.array([4.0], dtype=np.floating), 64 | 'float_feature_2': np.array([5.0, 6.0], dtype=np.floating), 65 | 'str_feature_1': np.array([b'female'], dtype=np.object), 66 | 'str_feature_2': np.array([b'string', b'list'], dtype=np.object), 67 | } 68 | example = tf.train.Example() 69 | text_format.Merge(example_proto_text, example) 70 | 71 | decoder = tf_example_decoder.TFExampleDecoder() 72 | self._check_decoding_results( 73 | decoder.decode(example.SerializeToString()), expected_decoded) 74 | 75 | def test_decode_example_empty_feature(self): 76 | example_proto_text = """ 77 | features { 78 | feature { key: "int_feature" value { int64_list { value: [ 0 ] } } } 79 | feature { key: "int_feature_empty" value { } } 80 | feature { key: "float_feature" value { float_list { value: [ 4.0 ] } } } 81 | feature { key: "float_feature_empty" value { } } 82 | feature { key: "str_feature" value { bytes_list { value: [ 'male' ] } } } 83 | feature { key: "str_feature_empty" value { } } 84 | } 85 | """ 86 | expected_decoded = { 87 | 'int_feature': np.array([0], dtype=np.integer), 88 | 'int_feature_empty': np.array([], dtype=np.object), 89 | 'float_feature': np.array([4.0], dtype=np.floating), 90 | 'float_feature_empty': np.array([], dtype=np.object), 91 | 'str_feature': np.array([b'male'], dtype=np.object), 92 | 'str_feature_empty': np.array([], dtype=np.object), 93 | } 94 | example = tf.train.Example() 95 | text_format.Merge(example_proto_text, example) 96 | 97 | decoder = tf_example_decoder.TFExampleDecoder() 98 | self._check_decoding_results( 99 | decoder.decode(example.SerializeToString()), expected_decoded) 100 | 101 | 102 | if __name__ == '__main__': 103 | absltest.main() 104 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/path.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_PATH_H_ 2 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_PATH_H_ 3 | 4 | #include 5 | #include 6 | #include "absl/strings/string_view.h" 7 | #include "tensorflow/core/lib/core/status.h" 8 | #include "tensorflow_metadata/proto/v0/path.pb.h" 9 | 10 | namespace tensorflow { 11 | namespace data_validation { 12 | 13 | // This represents a sequence of steps (i.e. strings) in a structured example. 14 | // The main functionality here is the ability Serialize and Deserialize a 15 | // list of paths in a 'human-readable' way. 16 | // Specifically, individual steps in a path can be arbitrary strings (in the 17 | // byte sense). As we cannot assume that the strings are valid unicode, we 18 | // should not try to parse them as such. 19 | // However, many steps will look like your conventional alphanumeric variable 20 | // names in a host of languages, i.e.: 21 | // [A-Za-z_][A-Za-z0-9]* 22 | // Others will have the proto2 extension format, (roughly) 23 | // \([A-Za-z0-9_.]*\) 24 | // This is a superset, but it captures the basic idea. In both these cases, it 25 | // is sufficient to simply use dot separators to create the path. 26 | // More generally, for steps of the form: 27 | // ([^.()']+)| (\([^()]\)) 28 | // Serialize will leave them untouched. 29 | // For example: 30 | // foo 31 | // bar 32 | // (foo.bar) 33 | // (foo.'bar) 34 | // Other steps will be encapsulated by single quotes and any internal single 35 | // quotes will be doubled. 36 | // For example: 37 | // ((c) becomes '((c)' 38 | // Marty's becomes 'Marty''s' 39 | // Steps, once serialized, will be concatenated with dots. E.g.: 40 | // {foo, bar, baz} becomes foo.bar.baz 41 | // {foo, ((c), Marty's} becomes foo.'((c)'.'Marty''s' 42 | // Importantly, note that Serialize is an injection (1-1). For any string 43 | // generated by Serialize(), Deserialize() will invert the process. 44 | class Path { 45 | public: 46 | Path() = default; 47 | explicit Path(std::vector step) : step_(std::move(step)) {} 48 | explicit Path(const tensorflow::metadata::v0::Path& p); 49 | Path(const Path& p) = default; 50 | Path(Path&& p) = default; 51 | Path& operator=(const Path& p) = default; 52 | Path& operator=(Path&& p) = default; 53 | 54 | // Returns -1, 0, 1 if *this is greater than, less than or equal to p. 55 | int Compare(const Path& p) const; 56 | 57 | // Number of steps in a path. 58 | size_t size() const { return step_.size(); } 59 | 60 | // Since we store the steps with the separators, sometimes we need to remove 61 | // the separator. 62 | const string& last_step() const { return step_.back(); } 63 | 64 | // Serialize a path into a string that can be Deserialized. 65 | // Intended to be as human-readable as possible. 66 | // See class-level comments for the style of the string. 67 | string Serialize() const; 68 | 69 | // Serialize the path to a proto. 70 | tensorflow::metadata::v0::Path AsProto() const; 71 | 72 | // Deserializes a string created with Serialize(). 73 | // Note: for any path p (i.e. arbitrary steps): 74 | // Path p2; 75 | // TF_CHECK_OK(Path::Deserialize(p.Serialize(), &p2)); 76 | // EXPECT_EQ(p, p2); 77 | static tensorflow::Status Deserialize(absl::string_view str, Path* result); 78 | 79 | // True if there are no steps. 80 | bool empty() const { return step_.empty(); } 81 | 82 | // Get the parent path. 83 | Path GetParent() const { 84 | return Path(std::vector(step_.begin(), step_.end() - 1)); 85 | } 86 | 87 | Path GetChild(absl::string_view last_step) const; 88 | 89 | private: 90 | // Returns true iff this is equal to p. 91 | // Part of the implementation of Compare(). 92 | bool Equals(const Path& p) const; 93 | 94 | // Returns true iff this is less than p. 95 | // Part of the implementation of Compare(). 96 | bool Less(const Path& p) const; 97 | 98 | // The raw steps of a path. 99 | std::vector step_; 100 | }; 101 | 102 | // Lexicographical ordering on steps. 103 | // Needed for std::less (for std::set). 104 | bool operator<(const Path& a, const Path& b); 105 | bool operator>(const Path& a, const Path& b); 106 | bool operator==(const Path& a, const Path& b); 107 | bool operator!=(const Path& a, const Path& b); 108 | bool operator>=(const Path& a, const Path& b); 109 | bool operator<=(const Path& a, const Path& b); 110 | 111 | } // namespace data_validation 112 | } // namespace tensorflow 113 | 114 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_PATH_H_ 115 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/feature_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Utilities to modify a feature in the schema. 17 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_FEATURE_UTIL_H_ 18 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_FEATURE_UTIL_H_ 19 | 20 | #include 21 | 22 | #include "tensorflow_data_validation/anomalies/internal_types.h" 23 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 24 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 25 | 26 | namespace tensorflow { 27 | namespace data_validation { 28 | 29 | // If the value count constraints are not satisfied, adjust them. 30 | std::vector UpdateValueCount( 31 | const FeatureStatsView& feature_stats_view, 32 | tensorflow::metadata::v0::ValueCount* value_count); 33 | 34 | // If a feature occurs in too few examples, or a feature occurs in too small 35 | // a fraction of the examples, adjust the presence constraints to account for 36 | // this. 37 | std::vector UpdatePresence( 38 | const FeatureStatsView& feature_stats_view, 39 | tensorflow::metadata::v0::FeaturePresence* presence); 40 | 41 | bool FeatureHasComparator(const tensorflow::metadata::v0::Feature& feature, 42 | ComparatorType comparator_type); 43 | 44 | // Gets the feature comparator, creating it if it does not exist. 45 | tensorflow::metadata::v0::FeatureComparator* GetFeatureComparator( 46 | tensorflow::metadata::v0::Feature* feature, ComparatorType comparator_type); 47 | 48 | // Updates comparator from the feature stats. 49 | // Note that if the "control" was missing, we have deprecated the column. 50 | std::vector UpdateFeatureComparatorDirect( 51 | const FeatureStatsView& stats, const ComparatorType comparator_type, 52 | tensorflow::metadata::v0::FeatureComparator* comparator); 53 | 54 | // Initializes the value count and presence given a feature_stats_view. 55 | // This is called when a Feature is first created from a FeatureStatsView. 56 | // It infers OPTIONAL, REPEATED, REQUIRED (in the proto sense), 57 | // and REPEATED_REQUIRED (a repeated field that is always present), and 58 | // sets value count and presence analogously. 59 | void InitValueCountAndPresence(const FeatureStatsView& feature_stats_view, 60 | tensorflow::metadata::v0::Feature* feature); 61 | 62 | // Deprecate a feature. Currently sets deprecated==true, but later will 63 | // set the lifecycle_stage==DEPRECATED. The contract of this method is that 64 | // FeatureIsDeprecated is set to true after it is called. 65 | void DeprecateFeature(tensorflow::metadata::v0::Feature* feature); 66 | 67 | // Same as above for SparseFeature. 68 | void DeprecateSparseFeature( 69 | tensorflow::metadata::v0::SparseFeature* sparse_feature); 70 | 71 | // Tell if a feature is deprecated (i.e., ignored for data validation). 72 | // Note that a deprecated feature is a more relaxed constraint than a feature 73 | // not being present in the schema, as it also suppresses the unexpected column 74 | // anomaly. 75 | // If neither deprecated is set nor lifecycle_stage is set, it is not 76 | // deprecated. 77 | // If deprecated==true, it is deprecated. 78 | // Otherwise, if lifecycle_stage is in {ALPHA, PLANNED, DEPRECATED, DEBUG_ONLY} 79 | // it is deprecated. 80 | // If lifecycle_stage is in {UNKNOWN_STAGE, BETA, PRODUCTION} 81 | // it is not deprecated. 82 | // Setting deprecated==false has no effect. 83 | bool FeatureIsDeprecated(const tensorflow::metadata::v0::Feature& feature); 84 | 85 | // Same as above for SparseFeature. 86 | bool SparseFeatureIsDeprecated( 87 | const tensorflow::metadata::v0::SparseFeature& sparse_feature); 88 | 89 | // Get the maximum allowed off the domain. 90 | double GetMaxOffDomain(const tensorflow::metadata::v0::DistributionConstraints& 91 | distribution_constraints); 92 | 93 | // Clear the domain of the feature. 94 | void ClearDomain(tensorflow::metadata::v0::Feature* feature); 95 | } // namespace data_validation 96 | } // namespace tensorflow 97 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_FEATURE_UTIL_H_ 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # TensorFlow Data Validation [![PyPI](https://img.shields.io/pypi/pyversions/tensorflow-data-validation.svg?style=plastic)](https://github.com/tensorflow/data-validation) 4 | 5 | *TensorFlow Data Validation* (TFDV) is a library for exploring and validating 6 | machine learning data. It is designed to be highly scalable 7 | and to work well with TensorFlow and [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx). 8 | 9 | TF Data Validation includes: 10 | 11 | * Scalable calculation of summary statistics of training and test data. 12 | * Integration with a viewer for data distributions and statistics, as well 13 | as faceted comparison of pairs of features ([Facets](https://github.com/PAIR-code/facets)) 14 | * Automated [data-schema](https://github.com/tensorflow/metadata/blob/master/tensorflow_metadata/proto/v0/schema.proto) 15 | generation to describe expectations about data 16 | like required values, ranges, and vocabularies 17 | * A schema viewer to help you inspect the schema. 18 | * Anomaly detection to identify anomalies, such as missing features, 19 | out-of-range values, or wrong feature types, to name a few. 20 | * An anomalies viewer so that you can see what features have anomalies and 21 | learn more in order to correct them. 22 | 23 | For instructions on using TFDV, see the [get started guide](g3doc/get_started.md) 24 | and try out the [example notebook](https://nbviewer.jupyter.org/github/tensorflow/data-validation/blob/master/examples/chicago_taxi/chicago_taxi_tfdv.ipynb). 25 | 26 | Caution: TFDV may be backwards incompatible before version 1.0. 27 | 28 | ## Installing from PyPI 29 | 30 | The recommended way to install TFDV is using the 31 | [PyPI package](https://pypi.org/project/tensorflow-data-validation/): 32 | 33 | ```bash 34 | pip install tensorflow-data-validation 35 | ``` 36 | 37 | ## Installing from source 38 | 39 | ### 1. Prerequisites 40 | 41 | To compile and use TFDV, you need to set up some prerequisites. 42 | 43 | #### Install NumPy 44 | 45 | If NumPy is not installed on your system, install it now by following [these 46 | directions](https://www.scipy.org/scipylib/download.html). 47 | 48 | #### Install Bazel 49 | 50 | If bazel is not installed on your system, install it now by following [these 51 | directions](https://bazel.build/versions/master/docs/install.html). 52 | 53 | ### 2. Clone the TFDV repository 54 | 55 | ```shell 56 | git clone https://github.com/tensorflow/data-validation 57 | cd data-validation 58 | ``` 59 | 60 | Note that these instructions will install the latest master branch of TensorFlow 61 | Data Validation. If you want to install a specific branch (such as a release branch), 62 | pass `-b ` to the `git clone` command. 63 | 64 | ### 3. Build the pip package 65 | 66 | TFDV uses Bazel to build the pip package from source: 67 | 68 | ```shell 69 | bazel run -c opt tensorflow_data_validation:build_pip_package 70 | ``` 71 | 72 | You can find the generated `.whl` file in the `dist` subdirectory. 73 | 74 | ### 4. Install the pip package 75 | 76 | ```shell 77 | pip install dist/*.whl 78 | ``` 79 | 80 | ## Supported platforms 81 | 82 | Note: TFDV currently requires Python 2.7. Support for Python 3 is coming 83 | very soon (tracked [here](https://github.com/tensorflow/data-validation/issues/10)). 84 | 85 | TFDV is built and tested on the following 64-bit operating systems: 86 | 87 | * macOS 10.12.6 (Sierra) or later. 88 | * Ubuntu 14.04 or later. 89 | 90 | ## Dependencies 91 | 92 | TFDV requires TensorFlow but does not depend on the `tensorflow` 93 | [PyPI package](https://pypi.org/project/tensorflow/). See the[TensorFlow install guides](https://www.tensorflow.org/install/) 94 | for instructions on how to get started with TensorFlow. 95 | 96 | [Apache Beam](https://beam.apache.org/) is required; it's the way that efficient 97 | distributed computation is supported. By default, Apache Beam runs in local 98 | mode but can also run in distributed mode using 99 | [Google Cloud Dataflow](https://cloud.google.com/dataflow/). 100 | TFDV is designed to be extensible for other Apache Beam runners. 101 | 102 | ## Compatible versions 103 | 104 | The following table shows the package versions that are 105 | compatible with each other. This is determined by our testing framework, but 106 | other *untested* combinations may also work. 107 | 108 | |tensorflow-data-validation |tensorflow |apache-beam[gcp]| 109 | |---------------------------|--------------|----------------| 110 | |GitHub master |nightly (1.x) |2.6.0 | 111 | |0.9.0 |1.9 |2.6.0 | 112 | 113 | ## Questions 114 | 115 | Please direct any questions about working with TF Data Validation to 116 | [Stack Overflow](https://stackoverflow.com) using the 117 | [tensorflow-data-validation](https://stackoverflow.com/questions/tagged/tensorflow-data-validation) 118 | tag. 119 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/test_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Provides a variety of tools for evaluating methods that output Anomalies. 17 | // In particular, allows for tests written for schema version 0 to apply to 18 | // schema version 1. 19 | // Also, allows us to have expected schema protos instead of 20 | // expected diff regions. 21 | 22 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_TEST_UTIL_H_ 23 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_TEST_UTIL_H_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #include 31 | #include "absl/strings/str_join.h" 32 | #include "absl/strings/str_split.h" 33 | #include "absl/strings/string_view.h" 34 | #include "tensorflow/core/lib/core/status.h" 35 | #include "tensorflow/core/platform/logging.h" 36 | #include "tensorflow/core/platform/protobuf.h" 37 | #include "tensorflow/core/platform/types.h" 38 | #include "tensorflow_metadata/proto/v0/anomalies.pb.h" 39 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 40 | 41 | namespace tensorflow { 42 | namespace data_validation { 43 | namespace testing { 44 | 45 | using tensorflow::protobuf::TextFormat; 46 | 47 | // Simple implementation of a proto matcher comparing string representations. 48 | // 49 | // IMPORTANT: Only use this for protos whose textual representation is 50 | // deterministic (that may not be the case for the map collection type). 51 | 52 | class ProtoStringMatcher { 53 | public: 54 | explicit ProtoStringMatcher(const string& expected); 55 | explicit ProtoStringMatcher(const ::tensorflow::protobuf::Message& expected); 56 | 57 | template 58 | bool MatchAndExplain(const Message& p, 59 | ::testing::MatchResultListener* /* listener */) const; 60 | 61 | void DescribeTo(::std::ostream* os) const { *os << expected_; } 62 | void DescribeNegationTo(::std::ostream* os) const { 63 | *os << "not equal to expected message: " << expected_; 64 | } 65 | 66 | private: 67 | const string expected_; 68 | }; 69 | 70 | template 71 | T CreateProto(const string& textual_proto) { 72 | T proto; 73 | CHECK(TextFormat::ParseFromString(textual_proto, &proto)); 74 | return proto; 75 | } 76 | 77 | template 78 | bool ProtoStringMatcher::MatchAndExplain( 79 | const Message& p, ::testing::MatchResultListener* /* listener */) const { 80 | // Need to CreateProto and then print as string so that the formatting 81 | // matches exactly. 82 | return p.SerializeAsString() == 83 | CreateProto(expected_).SerializeAsString(); 84 | } 85 | 86 | // Polymorphic matcher to compare any two protos. 87 | inline ::testing::PolymorphicMatcher EqualsProto( 88 | const string& x) { 89 | return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); 90 | } 91 | 92 | // Polymorphic matcher to compare any two protos. 93 | inline ::testing::PolymorphicMatcher EqualsProto( 94 | const ::tensorflow::protobuf::Message& x) { 95 | return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); 96 | } 97 | 98 | // Parse input string as a protocol buffer. 99 | template 100 | T ParseTextProtoOrDie(const string& input) { 101 | T result; 102 | CHECK(TextFormat::ParseFromString(input, &result)) 103 | << "Failed to parse: " << input; 104 | return result; 105 | } 106 | 107 | 108 | // Store this as a proto, to make it easier to understand and update tests. 109 | struct ExpectedAnomalyInfo { 110 | tensorflow::metadata::v0::AnomalyInfo expected_info_without_diff; 111 | tensorflow::metadata::v0::Schema new_schema; 112 | }; 113 | 114 | // Test if anomalies is as expected. 115 | void TestAnomalies( 116 | const tensorflow::metadata::v0::Anomalies& actual, 117 | const tensorflow::metadata::v0::Schema& old_schema, 118 | const std::map& expected_anomalies); 119 | 120 | void TestAnomalyInfo(const tensorflow::metadata::v0::AnomalyInfo& actual, 121 | const tensorflow::metadata::v0::Schema& baseline, 122 | const ExpectedAnomalyInfo& expected, 123 | const string& comment); 124 | 125 | } // namespace testing 126 | } // namespace data_validation 127 | } // namespace tensorflow 128 | 129 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_TEST_UTIL_H_ 130 | -------------------------------------------------------------------------------- /g3doc/index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | {% setvar github_path %}tensorflow/data-validation{% endsetvar %} 4 | {% include "_templates/github-bug.html" %} 5 | 6 | # TensorFlow Data Validation [![PyPI](https://img.shields.io/pypi/pyversions/tensorflow-data-validation.svg?style=plastic)](https://github.com/tensorflow/data-validation) 7 | 8 | *TensorFlow Data Validation* (TFDV) is a library for exploring and validating 9 | machine learning data. It is designed to be highly scalable 10 | and to work well with TensorFlow and [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx). 11 | 12 | TF Data Validation includes: 13 | 14 | * Scalable calculation of summary statistics of training and test data. 15 | * Integration with a viewer for data distributions and statistics, as well 16 | as faceted comparison of pairs of features ([Facets](https://github.com/PAIR-code/facets)) 17 | * Automated [data-schema](https://github.com/tensorflow/metadata/blob/master/tensorflow_metadata/proto/v0/schema.proto) 18 | generation to describe expectations about data 19 | like required values, ranges, and vocabularies 20 | * A schema viewer to help you inspect the schema. 21 | * Anomaly detection to identify anomalies, such as missing features, 22 | out-of-range values, or wrong feature types, to name a few. 23 | * An anomalies viewer so that you can see what features have anomalies and 24 | learn more in order to correct them. 25 | 26 | For instructions on using TFDV, see the [get started guide](get_started.md) 27 | and try out the [example notebook](https://nbviewer.jupyter.org/github/tensorflow/data-validation/blob/master/examples/chicago_taxi/chicago_taxi_tfdv.ipynb). 28 | 29 | Caution: TFDV may be backwards incompatible before version 1.0. 30 | 31 | ## Installing from PyPI 32 | 33 | The recommended way to install TFDV is using the 34 | [PyPI package](https://pypi.org/project/tensorflow-data-validation/): 35 | 36 | ```bash 37 | pip install tensorflow-data-validation 38 | ``` 39 | 40 | ## Installing from source 41 | 42 | ### 1. Prerequisites 43 | 44 | To compile and use TFDV, you need to set up some prerequisites. 45 | 46 | #### Install NumPy 47 | 48 | If NumPy is not installed on your system, install it now by following [these 49 | directions](https://www.scipy.org/scipylib/download.html). 50 | 51 | #### Install Bazel 52 | 53 | If bazel is not installed on your system, install it now by following [these 54 | directions](https://bazel.build/versions/master/docs/install.html). 55 | 56 | ### 2. Clone the TFDV repository 57 | 58 | ```shell 59 | git clone https://github.com/tensorflow/data-validation 60 | cd data-validation 61 | ``` 62 | 63 | Note that these instructions will install the latest master branch of TensorFlow 64 | Data Validation. If you want to install a specific branch (such as a release branch), 65 | pass `-b ` to the `git clone` command. 66 | 67 | ### 3. Build the pip package 68 | 69 | TFDV uses Bazel to build the pip package from source: 70 | 71 | ```shell 72 | bazel run -c opt tensorflow_data_validation:build_pip_package 73 | ``` 74 | 75 | You can find the generated `.whl` file in the `dist` subdirectory. 76 | 77 | ### 4. Install the pip package 78 | 79 | ```shell 80 | pip install dist/*.whl 81 | ``` 82 | 83 | ## Supported platforms 84 | 85 | Note: TFDV currently requires Python 2.7. Support for Python 3 is coming 86 | very soon (tracked [here](https://github.com/tensorflow/data-validation/issues/10)). 87 | 88 | TFDV is built and tested on the following 64-bit operating systems: 89 | 90 | * macOS 10.12.6 (Sierra) or later. 91 | * Ubuntu 14.04 or later. 92 | 93 | ## Dependencies 94 | 95 | TFDV requires TensorFlow but does not depend on the `tensorflow` 96 | [PyPI package](https://pypi.org/project/tensorflow/). See the[TensorFlow install guides](https://www.tensorflow.org/install/) 97 | for instructions on how to get started with TensorFlow. 98 | 99 | [Apache Beam](https://beam.apache.org/) is required; it's the way that efficient 100 | distributed computation is supported. By default, Apache Beam runs in local 101 | mode but can also run in distributed mode using 102 | [Google Cloud Dataflow](https://cloud.google.com/dataflow/). 103 | TFDV is designed to be extensible for other Apache Beam runners. 104 | 105 | ## Compatible versions 106 | 107 | The following table shows the package versions that are 108 | compatible with each other. This is determined by our testing framework, but 109 | other *untested* combinations may also work. 110 | 111 | |tensorflow-data-validation |tensorflow |apache-beam[gcp]| 112 | |---------------------------|--------------|----------------| 113 | |GitHub master |nightly (1.x) |2.6.0 | 114 | |0.9.0 |1.9 |2.6.0 | 115 | 116 | ## Questions 117 | 118 | Please direct any questions about working with TF Data Validation to 119 | [Stack Overflow](https://stackoverflow.com) using the 120 | [tensorflow-data-validation](https://stackoverflow.com/questions/tagged/tensorflow-data-validation) 121 | tag. 122 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/test_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for writing statistics generator tests.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | import apache_beam as beam 23 | from apache_beam.testing import util 24 | from tensorflow_data_validation import types 25 | from tensorflow_data_validation.statistics.generators import stats_generator 26 | from tensorflow_data_validation.types_compat import Callable, Dict, List 27 | 28 | from tensorflow.python.util.protobuf import compare 29 | from tensorflow_metadata.proto.v0 import statistics_pb2 30 | 31 | 32 | def make_dataset_feature_stats_list_proto_equal_fn( 33 | test, 34 | expected_result 35 | ): 36 | """Makes a matcher function for comparing DatasetFeatureStatisticsList proto. 37 | 38 | Args: 39 | test: test case object 40 | expected_result: the expected DatasetFeatureStatisticsList proto. 41 | 42 | Returns: 43 | A matcher function for comparing DatasetFeatureStatisticsList proto. 44 | """ 45 | 46 | def _matcher(actual): 47 | """Matcher function for comparing DatasetFeatureStatisticsList proto.""" 48 | try: 49 | test.assertEqual(len(actual), 1) 50 | # Get the dataset stats from DatasetFeatureStatisticsList proto. 51 | actual_stats = actual[0].datasets[0] 52 | expected_stats = expected_result.datasets[0] 53 | 54 | test.assertEqual(actual_stats.num_examples, expected_stats.num_examples) 55 | test.assertEqual(len(actual_stats.features), len(expected_stats.features)) 56 | 57 | expected_features = {} 58 | for feature in expected_stats.features: 59 | expected_features[feature.name] = feature 60 | 61 | for feature in actual_stats.features: 62 | compare.assertProtoEqual( 63 | test, 64 | feature, 65 | expected_features[feature.name], 66 | normalize_numbers=True) 67 | except AssertionError as e: 68 | raise util.BeamAssertException('Failed assert: ' + str(e)) 69 | 70 | return _matcher 71 | 72 | 73 | class CombinerStatsGeneratorTest(absltest.TestCase): 74 | """Test class with extra combiner stats generator related functionality.""" 75 | 76 | # Runs the provided combiner statistics generator and tests if the output 77 | # matches the expected result. 78 | def assertCombinerOutputEqual( 79 | self, batches, 80 | generator, expected_result): 81 | """Tests a combiner statistics generator.""" 82 | accumulators = [ 83 | generator.add_input(generator.create_accumulator(), batch) 84 | for batch in batches 85 | ] 86 | result = generator.extract_output( 87 | generator.merge_accumulators(accumulators)) 88 | self.assertEqual(len(result.features), len(expected_result)) 89 | for actual_feature_stats in result.features: 90 | compare.assertProtoEqual( 91 | self, 92 | actual_feature_stats, 93 | expected_result[actual_feature_stats.name], 94 | normalize_numbers=True) 95 | 96 | 97 | class TransformStatsGeneratorTest(absltest.TestCase): 98 | """Test class with extra transform stats generator related functionality.""" 99 | 100 | # Runs the provided transform statistics generator and tests if the output 101 | # matches the expected result. 102 | def assertTransformOutputEqual( 103 | self, batches, 104 | generator, 105 | expected_results): 106 | """Tests a transform statistics generator.""" 107 | 108 | def _make_result_matcher( 109 | test, 110 | expected_results): 111 | """Makes matcher for a list of DatasetFeatureStatistics protos.""" 112 | 113 | def _equal(actual_results): 114 | """Matcher for comparing a list of DatasetFeatureStatistics protos.""" 115 | test.assertEquals(len(expected_results), len(actual_results)) 116 | # Sort both list of protos based on their string presentation to make 117 | # sure the sort is stable. 118 | sorted_expected_results = sorted(expected_results, key=str) 119 | sorted_actual_results = sorted(actual_results, key=str) 120 | for index, actual in enumerate(sorted_actual_results): 121 | compare.assertProtoEqual( 122 | test, 123 | actual, 124 | sorted_expected_results[index], 125 | normalize_numbers=True) 126 | 127 | return _equal 128 | 129 | with beam.Pipeline() as p: 130 | result = p | beam.Create(batches) | generator.ptransform 131 | util.assert_that(result, _make_result_matcher(self, expected_results)) 132 | -------------------------------------------------------------------------------- /tensorflow_data_validation/statistics/generators/stats_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base class for statistics generators. 15 | 16 | A statistics generator is used to compute the statistics of features in 17 | parallel. We support two types of generators: 18 | 19 | 1) CombinerStatsGenerator 20 | This generator computes statistics using a combiner function. It emits 21 | partial states processing a batch of examples at a time, 22 | merges the partial states, and finally computes the statistics from the 23 | merged partial state at the end. Specifically, the generator 24 | must implement the following four methods, 25 | 26 | Initializes an accumulator to store the partial state and returns it. 27 | create_accumulator() 28 | 29 | Incorporates a batch of input examples into the current accumulator 30 | and returns the updated accumulator. 31 | add_input(accumulator, input_batch) 32 | 33 | Merge the partial states in the accumulators and returns the accumulator 34 | containing the merged state. 35 | merge_accumulators(accumulators) 36 | 37 | Compute statistics from the partial state in the accumulator and 38 | return the result as a DatasetFeatureStatistics proto. 39 | extract_output(accumulator) 40 | 41 | 2) TransformStatsGenerator 42 | This generator computes statistics using a user-provided Beam PTransform. 43 | The PTransform must accept a Beam PCollection where each element is a Python 44 | dict whose keys are feature names and values are numpy arrays representing a 45 | batch of examples. It must return a PCollection containing a single element 46 | which is a DatasetFeatureStatistics proto. 47 | """ 48 | 49 | from __future__ import absolute_import 50 | from __future__ import division 51 | 52 | from __future__ import print_function 53 | 54 | import apache_beam as beam 55 | from tensorflow_data_validation import types 56 | from tensorflow_data_validation.types_compat import List, Optional, TypeVar 57 | from tensorflow_metadata.proto.v0 import schema_pb2 58 | from tensorflow_metadata.proto.v0 import statistics_pb2 59 | 60 | 61 | class StatsGenerator(object): 62 | """Generate statistics.""" 63 | 64 | def __init__(self, name, 65 | schema = None): 66 | """Initializes a statistics generator. 67 | 68 | Args: 69 | name: A unique name associated with the statistics generator. 70 | schema: An optional schema for the dataset. 71 | """ 72 | self._name = name 73 | self._schema = schema 74 | 75 | @property 76 | def name(self): 77 | return self._name 78 | 79 | @property 80 | def schema(self): 81 | return self._schema 82 | 83 | 84 | 85 | class CombinerStatsGenerator(StatsGenerator): 86 | """Generate statistics using combiner function. 87 | 88 | This object mirrors a beam.CombineFn. 89 | """ 90 | 91 | def create_accumulator(self): # pytype: disable=invalid-annotation 92 | """Return a fresh, empty accumulator. 93 | 94 | Returns: 95 | An empty accumulator. 96 | """ 97 | raise NotImplementedError 98 | 99 | def add_input(self, accumulator, 100 | input_batch): 101 | """Return result of folding a batch of inputs into accumulator. 102 | 103 | Args: 104 | accumulator: The current accumulator. 105 | input_batch: A Python dict whose keys are strings denoting feature 106 | names and values are numpy arrays representing a batch of examples, 107 | which should be added to the accumulator. 108 | 109 | Returns: 110 | The accumulator after updating the statistics for the batch of inputs. 111 | """ 112 | raise NotImplementedError 113 | 114 | def merge_accumulators(self, accumulators): 115 | """Merges several accumulators to a single accumulator value. 116 | 117 | Args: 118 | accumulators: The accumulators to merge. 119 | 120 | Returns: 121 | The merged accumulator. 122 | """ 123 | raise NotImplementedError 124 | 125 | def extract_output( 126 | self, accumulator 127 | ): # pytype: disable=invalid-annotation 128 | """Return result of converting accumulator into the output value. 129 | 130 | Args: 131 | accumulator: The final accumulator value. 132 | 133 | Returns: 134 | A proto representing the result of this stats generator. 135 | """ 136 | raise NotImplementedError 137 | 138 | 139 | class TransformStatsGenerator(StatsGenerator): 140 | """Generate statistics using a Beam PTransform.""" 141 | 142 | def __init__(self, 143 | name, 144 | ptransform, 145 | schema = None): 146 | self._ptransform = ptransform 147 | super(TransformStatsGenerator, self).__init__(name, schema) 148 | 149 | @property 150 | def ptransform(self): 151 | return self._ptransform 152 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/statistics_view_test_util.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/statistics_view_test_util.h" 17 | 18 | #include 19 | #include 20 | 21 | #include "absl/types/optional.h" 22 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 23 | #include "tensorflow/core/platform/logging.h" 24 | #include "tensorflow/core/platform/types.h" 25 | 26 | namespace tensorflow { 27 | namespace data_validation { 28 | namespace testing { 29 | namespace { 30 | using tensorflow::metadata::v0::CommonStatistics; 31 | using tensorflow::metadata::v0::DatasetFeatureStatistics; 32 | using tensorflow::metadata::v0::FeatureNameStatistics; 33 | using tensorflow::metadata::v0::WeightedCommonStatistics; 34 | 35 | CommonStatistics* GetCommonStatisticsPtr(FeatureNameStatistics* feature_stats) { 36 | if (feature_stats->has_num_stats()) { 37 | return feature_stats->mutable_num_stats()->mutable_common_stats(); 38 | } else if (feature_stats->has_string_stats()) { 39 | return feature_stats->mutable_string_stats()->mutable_common_stats(); 40 | } else if (feature_stats->has_bytes_stats()) { 41 | return feature_stats->mutable_bytes_stats()->mutable_common_stats(); 42 | } else if (feature_stats->has_struct_stats()) { 43 | return feature_stats->mutable_struct_stats()->mutable_common_stats(); 44 | } 45 | LOG(FATAL) << "Unknown statistics: " << feature_stats->DebugString(); 46 | } 47 | 48 | FeatureStatsView GetByNameOrDie(const DatasetStatsView& dataset, 49 | const string& name) { 50 | absl::optional result = dataset.GetByName(name); 51 | CHECK(absl::nullopt != result) << "Unknown name: " << name; 52 | return *result; 53 | } 54 | 55 | FeatureStatsView GetFirstOrDie(const DatasetStatsView& dataset) { 56 | CHECK(!dataset.features().empty()) << "Must have a feature name statistics"; 57 | return dataset.features()[0]; 58 | } 59 | 60 | } // namespace 61 | 62 | FeatureNameStatistics AddWeightedStats(const FeatureNameStatistics& original) { 63 | FeatureNameStatistics result = original; 64 | CommonStatistics& common_stats = *GetCommonStatisticsPtr(&result); 65 | WeightedCommonStatistics& weighted_common_stats = 66 | *common_stats.mutable_weighted_common_stats(); 67 | weighted_common_stats.set_num_non_missing(common_stats.num_non_missing()); 68 | weighted_common_stats.set_num_missing(common_stats.num_missing()); 69 | weighted_common_stats.set_avg_num_values(common_stats.avg_num_values()); 70 | weighted_common_stats.set_tot_num_values(common_stats.tot_num_values()); 71 | if (result.has_string_stats()) { 72 | *result.mutable_string_stats() 73 | ->mutable_weighted_string_stats() 74 | ->mutable_rank_histogram() = result.string_stats().rank_histogram(); 75 | } 76 | return result; 77 | } 78 | 79 | DatasetFeatureStatistics GetDatasetFeatureStatisticsForTesting( 80 | const FeatureNameStatistics& feature_name_stats) { 81 | DatasetFeatureStatistics result; 82 | FeatureNameStatistics& new_stats = *result.add_features(); 83 | new_stats = feature_name_stats; 84 | const CommonStatistics& common_stats = *GetCommonStatisticsPtr(&new_stats); 85 | result.set_num_examples(common_stats.num_missing() + 86 | common_stats.num_non_missing()); 87 | const WeightedCommonStatistics& weighted_common_stats = 88 | common_stats.weighted_common_stats(); 89 | result.set_weighted_num_examples(weighted_common_stats.num_non_missing() + 90 | weighted_common_stats.num_missing()); 91 | return result; 92 | } 93 | 94 | DatasetForTesting::DatasetForTesting( 95 | const FeatureNameStatistics& feature_name_stats) 96 | : dataset_feature_statistics_( 97 | GetDatasetFeatureStatisticsForTesting(feature_name_stats)), 98 | dataset_stats_view_(dataset_feature_statistics_), 99 | feature_stats_view_( 100 | GetByNameOrDie(dataset_stats_view_, feature_name_stats.name())) {} 101 | 102 | DatasetForTesting::DatasetForTesting( 103 | const FeatureNameStatistics& feature_name_stats, bool by_weight) 104 | : dataset_feature_statistics_( 105 | GetDatasetFeatureStatisticsForTesting(feature_name_stats)), 106 | dataset_stats_view_(dataset_feature_statistics_, by_weight), 107 | feature_stats_view_( 108 | GetByNameOrDie(dataset_stats_view_, feature_name_stats.name())) {} 109 | 110 | DatasetForTesting::DatasetForTesting( 111 | const DatasetFeatureStatistics& dataset_feature_stats, bool by_weight) 112 | : dataset_feature_statistics_(dataset_feature_stats), 113 | dataset_stats_view_(dataset_feature_statistics_, by_weight), 114 | feature_stats_view_(GetFirstOrDie(dataset_stats_view_)) {} 115 | 116 | } // namespace testing 117 | } // namespace data_validation 118 | } // namespace tensorflow 119 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/path_test.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow_data_validation/anomalies/path.h" 2 | #include 3 | #include 4 | #include "tensorflow_data_validation/anomalies/test_util.h" 5 | #include "tensorflow/core/lib/core/errors.h" 6 | #include "tensorflow/core/lib/core/status.h" 7 | #include "tensorflow/core/lib/core/status_test_util.h" 8 | #include "tensorflow_metadata/proto/v0/path.pb.h" 9 | 10 | namespace tensorflow { 11 | namespace data_validation { 12 | namespace { 13 | using testing::ParseTextProtoOrDie; 14 | 15 | MATCHER_P(EqualsPath, path, 16 | absl::StrCat((negation ? "doesn't equal" : "equals"), 17 | path.Serialize())) { 18 | return path.Compare(arg) == 0; 19 | } 20 | 21 | TEST(Path, Constructor) { 22 | EXPECT_EQ("a.b.c", Path({"a", "b", "c"}).Serialize()); 23 | EXPECT_EQ("", Path().Serialize()); 24 | EXPECT_EQ("a.b.c", Path(ParseTextProtoOrDie( 25 | R"(step: [ "a", "b", "c" ])")) 26 | .Serialize()); 27 | } 28 | 29 | TEST(Path, AsProto) { 30 | const Path p({"a", "b", "c"}); 31 | Path result(p.AsProto()); 32 | EXPECT_THAT(result, EqualsPath(p)) << "result: " << result.Serialize(); 33 | } 34 | 35 | // If compare identifies things as -1, 0, and 1 correctly, the rest of the 36 | // methods can be tested with three examples. 37 | TEST(Path, Compare) { 38 | EXPECT_EQ(1, Path({"a", "b", "c"}).Compare(Path({"a", "b"}))); 39 | EXPECT_EQ(-1, Path({"a", "b"}).Compare(Path({"a", "b", "c"}))); 40 | EXPECT_EQ(0, Path({"a", "b", "c"}).Compare(Path({"a", "b", "c"}))); 41 | EXPECT_EQ(-1, Path({"a", "b", "c"}).Compare(Path({"a", "d", "c"}))); 42 | EXPECT_EQ(1, Path({"a", "d", "c"}).Compare(Path({"a", "b", "c"}))); 43 | EXPECT_EQ(0, Path().Compare(Path())); 44 | } 45 | 46 | // See TEST(Path, Compare) above. 47 | TEST(Path, Less) { 48 | EXPECT_FALSE(Path({"a", "b", "c"}) < Path({"a", "b"})); 49 | EXPECT_TRUE(Path({"a", "b"}) < Path({"a", "b", "c"})); 50 | EXPECT_FALSE(Path({"a", "b", "c"}) < Path({"a", "b", "c"})); 51 | } 52 | 53 | // See TEST(Path, Compare) above. 54 | TEST(Path, GreaterOrEqual) { 55 | EXPECT_TRUE(Path({"a", "b", "c"}) >= Path({"a", "b"})); 56 | EXPECT_FALSE(Path({"a", "b"}) >= Path({"a", "b", "c"})); 57 | EXPECT_TRUE(Path({"a", "b", "c"}) >= Path({"a", "b", "c"})); 58 | } 59 | 60 | // See TEST(Path, Compare) above. 61 | TEST(Path, Greater) { 62 | EXPECT_TRUE(Path({"a", "b", "c"}) > Path({"a", "b"})); 63 | EXPECT_FALSE(Path({"a", "b"}) > Path({"a", "b", "c"})); 64 | EXPECT_FALSE(Path({"a", "b", "c"}) > Path({"a", "b", "c"})); 65 | } 66 | 67 | // See TEST(Path, Compare) above. 68 | TEST(Path, LessOrEqual) { 69 | EXPECT_FALSE(Path({"a", "b", "c"}) <= Path({"a", "b"})); 70 | EXPECT_TRUE(Path({"a", "b"}) <= Path({"a", "b", "c"})); 71 | EXPECT_TRUE(Path({"a", "b", "c"}) <= Path({"a", "b", "c"})); 72 | } 73 | 74 | // See TEST(Path, Compare) above. 75 | TEST(Path, Equal) { 76 | EXPECT_FALSE(Path({"a", "b", "c"}) == Path({"a", "b"})); 77 | EXPECT_FALSE(Path({"a", "b"}) == Path({"a", "b", "c"})); 78 | EXPECT_TRUE(Path({"a", "b", "c"}) == Path({"a", "b", "c"})); 79 | } 80 | 81 | // See TEST(Path, Compare) above. 82 | TEST(Path, NotEqual) { 83 | EXPECT_TRUE(Path({"a", "b", "c"}) != Path({"a", "b"})); 84 | EXPECT_TRUE(Path({"a", "b"}) != Path({"a", "b", "c"})); 85 | EXPECT_FALSE(Path({"a", "b", "c"}) != Path({"a", "b", "c"})); 86 | } 87 | 88 | TEST(Path, Serialize) { 89 | EXPECT_EQ("a.'.b'.'''c'''", Path({"a", ".b", "'c'"}).Serialize()); 90 | EXPECT_EQ("a.(b'.d).'((c)'", Path({"a", "(b'.d)", "((c)"}).Serialize()); 91 | EXPECT_EQ("''", Path({""}).Serialize()); 92 | EXPECT_EQ("", Path().Serialize()); 93 | } 94 | 95 | TEST(Path, Deserialize) { 96 | std::vector paths_to_check = {Path({"a", ".b", "'c'"}), 97 | Path({"a", "(b'.d)", "((c)"}), Path({""}), 98 | Path()}; 99 | for (const Path& path : paths_to_check) { 100 | Path result; 101 | TF_ASSERT_OK(Path::Deserialize(path.Serialize(), &result)) 102 | << "Failed on " << path.Serialize() << "!=" << result.Serialize(); 103 | EXPECT_THAT(result, EqualsPath(path)) << "result: " << result.Serialize(); 104 | } 105 | } 106 | 107 | // If a path has steps that have quotes that didn't need to be quoted, 108 | // Deserialize works anyway. 109 | TEST(Path, DeserializeSillyQuotes) { 110 | Path no_silly_quotes; 111 | TF_ASSERT_OK(Path::Deserialize("'a'.'b'.'(c'')'", &no_silly_quotes)); 112 | EXPECT_EQ("a.b.(c')", no_silly_quotes.Serialize()); 113 | } 114 | 115 | // If a path has steps that have quotes that didn't need to be quoted, 116 | // Deserialize works anyway. 117 | TEST(Path, DeserializeBad) { 118 | const std::vector bad_serializations = { 119 | "a'", "'a", "(b", "c'd", "'c'd'", "''cd'", "'c'''d'"}; 120 | for (const string& bad : bad_serializations) { 121 | Path dummy; 122 | tensorflow::Status status = Path::Deserialize(bad, &dummy); 123 | EXPECT_EQ(status.code(), tensorflow::error::INVALID_ARGUMENT) 124 | << "Deserialize did not fail on " << bad; 125 | } 126 | } 127 | 128 | TEST(Path, GetParent) { 129 | EXPECT_EQ("a.b", Path({"a", "b", "c"}).GetParent().Serialize()); 130 | EXPECT_EQ("a", Path({"a", "b"}).GetParent().Serialize()); 131 | } 132 | TEST(Path, GetChild) { 133 | EXPECT_EQ("a.b", Path({"a"}).GetChild("b").Serialize()); 134 | EXPECT_EQ("a", Path().GetChild("a").Serialize()); 135 | } 136 | 137 | TEST(Path, size) { 138 | EXPECT_EQ(3, Path({"a", "b", "c"}).size()); 139 | EXPECT_EQ(0, Path().size()); 140 | } 141 | TEST(Path, empty) { 142 | EXPECT_EQ(false, Path({"a", "b", "c"}).empty()); 143 | EXPECT_EQ(true, Path().empty()); 144 | } 145 | 146 | TEST(Path, GetLastStep) { 147 | EXPECT_EQ("c", Path({"a", "b", "c"}).last_step()); 148 | EXPECT_EQ("a", Path({"a"}).last_step()); 149 | } 150 | 151 | } // namespace 152 | } // namespace data_validation 153 | } // namespace tensorflow 154 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/schema_anomalies.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_DATA_VALIDATION_ANOMALIES_SCHEMA_ANOMALIES_H_ 17 | #define TENSORFLOW_DATA_VALIDATION_ANOMALIES_SCHEMA_ANOMALIES_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "tensorflow_data_validation/anomalies/internal_types.h" 26 | #include "tensorflow_data_validation/anomalies/proto/feature_statistics_to_proto.pb.h" 27 | #include "tensorflow_data_validation/anomalies/schema.h" 28 | #include "tensorflow_data_validation/anomalies/statistics_view.h" 29 | #include "tensorflow/core/lib/core/status.h" 30 | #include "tensorflow/core/platform/types.h" 31 | #include "tensorflow_metadata/proto/v0/anomalies.pb.h" 32 | #include "tensorflow_metadata/proto/v0/schema.pb.h" 33 | 34 | namespace tensorflow { 35 | namespace data_validation { 36 | 37 | // SchemaAnomaly represents all the issues related to a single column. 38 | class SchemaAnomaly { 39 | public: 40 | SchemaAnomaly(); 41 | 42 | SchemaAnomaly(SchemaAnomaly&& schema_anomaly); 43 | 44 | SchemaAnomaly& operator=(SchemaAnomaly&& schema_anomaly); 45 | 46 | // Initializes schema_. 47 | tensorflow::Status InitSchema( 48 | const tensorflow::metadata::v0::Schema& schema); 49 | 50 | // Updates based upon the relevant current feature statistics. 51 | tensorflow::Status Update(const Schema::Updater& updater, 52 | const FeatureStatsView& feature_stats_view); 53 | 54 | // Update the skew. 55 | void UpdateSkewComparator(const FeatureStatsView& feature_stats_view); 56 | 57 | // Makes a note that the feature is missing. Deprecates the feature, 58 | // and leaves a description. 59 | void ObserveMissing(); 60 | 61 | // If new_severity is more severe that current severity, increases 62 | // severity. Otherwise, does nothing. 63 | void UpgradeSeverity( 64 | tensorflow::metadata::v0::AnomalyInfo::Severity new_severity); 65 | 66 | // Returns an AnomalyInfo representing the change. 67 | // baseline is the original schema. 68 | tensorflow::metadata::v0::AnomalyInfo GetAnomalyInfo( 69 | const tensorflow::metadata::v0::Schema& baseline) const; 70 | 71 | // Identifies if there is an issue. 72 | bool is_problem() { 73 | return severity_ != tensorflow::metadata::v0::AnomalyInfo::UNKNOWN; 74 | } 75 | void set_feature_name(const string& feature_name) { 76 | feature_name_ = feature_name; 77 | } 78 | 79 | private: 80 | // Returns an AnomalyInfo representing the change. Takes as an input the 81 | // text version of the existing schema and the new schema. 82 | // Called as part of GetAnomalyInfoV0(...) and GetAnomalyInfoV1(...) to do 83 | // the part of the work that is common between them. 84 | tensorflow::metadata::v0::AnomalyInfo GetAnomalyInfoCommon( 85 | const string& existing_schema, const string& new_schema) const; 86 | // A new schema that will make the anomaly go away. 87 | std::unique_ptr schema_; 88 | // The name of the feature being fixed. 89 | string feature_name_; 90 | // Descriptions of what caused the anomaly. 91 | std::vector descriptions_; 92 | // The severity of the anomaly 93 | tensorflow::metadata::v0::AnomalyInfo::Severity severity_; 94 | }; 95 | 96 | // A class for tracking all anomalies that occur based upon the feature that 97 | // created the anomaly. 98 | class SchemaAnomalies { 99 | public: 100 | explicit SchemaAnomalies( 101 | const tensorflow::metadata::v0::Schema& schema) 102 | : serialized_baseline_(schema) {} 103 | 104 | // Finds any columns that have issues, and creates a new Schema proto 105 | // involving only the changes for that column. Returns a map where the key is 106 | // the key of the column with an anomaly, and the Schema proto is a changed 107 | // schema that would allow the column to be valid. 108 | tensorflow::Status FindChanges( 109 | const DatasetStatsView& statistics, 110 | const FeatureStatisticsToProtoConfig& feature_statistics_to_proto_config); 111 | 112 | tensorflow::Status FindSkew(const DatasetStatsView& dataset_stats_view); 113 | 114 | // Records current anomalies as a schema diff. 115 | tensorflow::metadata::v0::Anomalies GetSchemaDiff() const; 116 | 117 | private: 118 | // 1. If there is a SchemaAnomaly for feature_name, applies update, 119 | // 2. otherwise, creates a new SchemaAnomaly for the feature_name and 120 | // initializes it using the serialized_baseline_. Then, it tries the 121 | // update(...) function. If there is a problem, then the new SchemaAnomaly 122 | // gets added. 123 | tensorflow::Status GenericUpdate( 124 | const std::function& update, 125 | const string& feature_name); 126 | 127 | // Initialize a schema from the serialized_baseline_. 128 | tensorflow::Status InitSchema(Schema* schema) const; 129 | 130 | // A map from feature columns to anomalies in that column. 131 | std::map anomalies_; 132 | 133 | // The initial schema. Each SchemaAnomaly is initialized from this. 134 | tensorflow::metadata::v0::Schema serialized_baseline_; 135 | }; 136 | 137 | } // namespace data_validation 138 | } // namespace tensorflow 139 | 140 | #endif // TENSORFLOW_DATA_VALIDATION_ANOMALIES_SCHEMA_ANOMALIES_H_ 141 | -------------------------------------------------------------------------------- /tensorflow_data_validation/statistics/generators/uniques_stats_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for computing number of unique values per string feature.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | import apache_beam as beam 23 | import numpy as np 24 | import six 25 | from tensorflow_data_validation import types 26 | from tensorflow_data_validation.statistics.generators import stats_generator 27 | from tensorflow_data_validation.utils import stats_util 28 | from tensorflow_data_validation.types_compat import Generator, Optional, Set, Tuple 29 | from tensorflow_metadata.proto.v0 import schema_pb2 30 | from tensorflow_metadata.proto.v0 import statistics_pb2 31 | 32 | 33 | def _make_feature_stats_proto( 34 | feature_name, count, 35 | is_categorical): 36 | """Makes a FeatureNameStatistics proto containing the uniques stats.""" 37 | result = statistics_pb2.FeatureNameStatistics() 38 | result.name = feature_name 39 | # If we have a categorical feature, we preserve the type to be the original 40 | # INT type. 41 | result.type = (statistics_pb2.FeatureNameStatistics.INT if is_categorical 42 | else statistics_pb2.FeatureNameStatistics.STRING) 43 | result.string_stats.unique = count 44 | return result 45 | 46 | 47 | def _make_dataset_feature_stats_proto_with_single_feature( 48 | feature_name_to_value_count, 49 | categorical_features 50 | ): 51 | """Generates a DatasetFeatureStatistics proto containing a single feature.""" 52 | result = statistics_pb2.DatasetFeatureStatistics() 53 | result.features.add().CopyFrom( 54 | _make_feature_stats_proto( 55 | feature_name_to_value_count[0], 56 | feature_name_to_value_count[1], 57 | feature_name_to_value_count[0] in categorical_features)) 58 | return result 59 | 60 | 61 | # Input type check is commented out, as beam python will fail the type check 62 | # when input is an empty dict. 63 | # @beam.typehints.with_input_types(types.ExampleBatch) 64 | @beam.typehints.with_output_types(statistics_pb2.DatasetFeatureStatistics) 65 | class _UniquesStatsGeneratorImpl(beam.PTransform): 66 | """A PTransform that computes the number of unique values 67 | for string features. 68 | """ 69 | 70 | def __init__(self, schema): 71 | """Initializes unique stats generator ptransform. 72 | 73 | Args: 74 | schema: An schema for the dataset. None if no schema is available. 75 | """ 76 | self._categorical_features = set( 77 | stats_util.get_categorical_numeric_features(schema) if schema else []) 78 | 79 | def _filter_irrelevant_features( 80 | self, input_batch 81 | ): 82 | """Filters out non-string features.""" 83 | for feature_name, values_batch in six.iteritems(input_batch): 84 | is_categorical = feature_name in self._categorical_features 85 | for values in values_batch: 86 | # Check if we have a numpy array with at least one value. 87 | if not isinstance(values, np.ndarray) or values.size == 0: 88 | continue 89 | # If the feature is neither categorical nor of string type, then 90 | # skip the feature. 91 | if not (is_categorical or 92 | stats_util.make_feature_type(values.dtype) == 93 | statistics_pb2.FeatureNameStatistics.STRING): 94 | continue 95 | 96 | yield (feature_name, values.astype(str) if is_categorical else values) 97 | 98 | def expand(self, pcoll): 99 | """Computes number of unique values for string features.""" 100 | # Count the number of appearance of each feature_value. Output is a 101 | # pcollection of DatasetFeatureStatistics protos 102 | return ( 103 | pcoll 104 | | 'Uniques_FilterIrrelevantFeatures' >> 105 | (beam.FlatMap(self._filter_irrelevant_features).with_output_types( 106 | beam.typehints.KV[types.BeamFeatureName, np.ndarray])) 107 | | 'Uniques_FlattenToFeatureNameValueTuples' >> 108 | beam.FlatMap(lambda name_and_value_list: # pylint: disable=g-long-lambda 109 | [(name_and_value_list[0], value) 110 | for value in name_and_value_list[1]]) 111 | | 'Uniques_CountFeatureNameValueTuple' >> 112 | beam.combiners.Count().PerElement() 113 | # Drop the values to only have the feature_name with each repeated the 114 | # number of unique values times. 115 | | 'Uniques_DropValues' >> beam.Map(lambda x: x[0][0]) 116 | | 'Uniques_CountPerFeatureName' >> beam.combiners.Count().PerElement() 117 | | 'Uniques_ConvertToSingleFeatureStats' >> beam.Map( 118 | _make_dataset_feature_stats_proto_with_single_feature, 119 | categorical_features=self._categorical_features)) 120 | 121 | 122 | class UniquesStatsGenerator(stats_generator.TransformStatsGenerator): 123 | """A transform statistics generator that computes the number of unique values 124 | for string features.""" 125 | 126 | def __init__(self, 127 | name = 'UniquesStatsGenerator', 128 | schema = None): 129 | """Initializes unique stats generator. 130 | 131 | Args: 132 | name: An optional unique name associated with the statistics generator. 133 | schema: An optional schema for the dataset. 134 | """ 135 | super(UniquesStatsGenerator, self).__init__( 136 | name, schema=schema, ptransform=_UniquesStatsGeneratorImpl(schema)) 137 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/quantiles_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for quantile utilities.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | from absl.testing import absltest 23 | import numpy as np 24 | from tensorflow_data_validation.utils import quantiles_util 25 | from tensorflow_data_validation.types_compat import List, Tuple 26 | 27 | from google.protobuf import text_format 28 | from tensorflow_metadata.proto.v0 import statistics_pb2 29 | 30 | 31 | def _run_quantiles_combiner_test(test, 32 | q_combiner, 33 | batches, 34 | expected_result): 35 | """Tests quantiles combiner.""" 36 | summaries = [q_combiner.add_input(q_combiner.create_accumulator(), 37 | batch) for batch in batches] 38 | result = q_combiner.extract_output(q_combiner.merge_accumulators(summaries)) 39 | test.assertEqual(result.dtype, expected_result.dtype) 40 | test.assertEqual(result.size, expected_result.size) 41 | for i in range(expected_result.size): 42 | test.assertAlmostEqual(result[i], expected_result[i]) 43 | 44 | 45 | def _assert_buckets_almost_equal(test, 46 | a, 47 | b): 48 | """Check if the histogram buckets are almost equal.""" 49 | test.assertEqual(len(a), len(b)) 50 | for i in range(len(a)): 51 | test.assertAlmostEqual(a[i].low_value, b[i].low_value) 52 | test.assertAlmostEqual(a[i].high_value, b[i].high_value) 53 | test.assertAlmostEqual(a[i].sample_count, b[i].sample_count) 54 | 55 | 56 | class QuantilesUtilTest(absltest.TestCase): 57 | 58 | def test_quantiles_combiner(self): 59 | batches = [[np.linspace(1, 100, 100)], 60 | [np.linspace(101, 200, 100)], 61 | [np.linspace(201, 300, 100)]] 62 | expected_result = np.array([61.0, 121.0, 181.0, 241.0], dtype=np.float32) 63 | q_combiner = quantiles_util.QuantilesCombiner(5, 0.00001) 64 | _run_quantiles_combiner_test(self, q_combiner, batches, expected_result) 65 | 66 | def test_generate_quantiles_histogram(self): 67 | result = quantiles_util.generate_quantiles_histogram( 68 | quantiles=np.array([61.0, 121.0, 181.0, 241.0], dtype=np.float32), 69 | min_val=1.0, max_val=300.0, total_count=300.0) 70 | expected_result = text_format.Parse( 71 | """ 72 | buckets { 73 | low_value: 1.0 74 | high_value: 61.0 75 | sample_count: 60.0 76 | } 77 | buckets { 78 | low_value: 61.0 79 | high_value: 121.0 80 | sample_count: 60.0 81 | } 82 | buckets { 83 | low_value: 121.0 84 | high_value: 181.0 85 | sample_count: 60.0 86 | } 87 | buckets { 88 | low_value: 181.0 89 | high_value: 241.0 90 | sample_count: 60.0 91 | } 92 | buckets { 93 | low_value: 241.0 94 | high_value: 300.0 95 | sample_count: 60.0 96 | } 97 | type: QUANTILES 98 | """, statistics_pb2.Histogram()) 99 | self.assertEqual(result, expected_result) 100 | 101 | def test_generate_equi_width_histogram(self): 102 | result = quantiles_util.generate_equi_width_histogram( 103 | quantiles=np.array([1, 5, 10, 15, 20], dtype=np.float32), 104 | min_val=0, max_val=24.0, total_count=18, num_buckets=3) 105 | expected_result = text_format.Parse( 106 | """ 107 | buckets { 108 | low_value: 0 109 | high_value: 8.0 110 | sample_count: 7.8 111 | } 112 | buckets { 113 | low_value: 8.0 114 | high_value: 16.0 115 | sample_count: 4.8 116 | } 117 | buckets { 118 | low_value: 16.0 119 | high_value: 24.0 120 | sample_count: 5.4 121 | } 122 | type: STANDARD 123 | """, statistics_pb2.Histogram()) 124 | self.assertEqual(result, expected_result) 125 | 126 | def test_generate_equi_width_buckets(self): 127 | _assert_buckets_almost_equal( 128 | self, quantiles_util.generate_equi_width_buckets( 129 | quantiles=[1.0, 5.0, 10.0, 15.0, 20.0], 130 | min_val=0, max_val=24.0, total_count=18, num_buckets=3), 131 | [quantiles_util.Bucket(0, 8.0, 7.8), 132 | quantiles_util.Bucket(8.0, 16.0, 4.8), 133 | quantiles_util.Bucket(16.0, 24.0, 5.4)]) 134 | 135 | _assert_buckets_almost_equal( 136 | self, quantiles_util.generate_equi_width_buckets( 137 | quantiles=[1.0, 2.0, 3.0, 4.0, 5.0], 138 | min_val=1.0, max_val=5.0, total_count=6, num_buckets=3), 139 | [quantiles_util.Bucket(1.0, 2.33333333, 2.33333333), 140 | quantiles_util.Bucket(2.33333333, 3.66666666, 1.33333333), 141 | quantiles_util.Bucket(3.66666666, 5, 2.33333333)]) 142 | 143 | _assert_buckets_almost_equal( 144 | self, quantiles_util.generate_equi_width_buckets( 145 | quantiles=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 146 | min_val=1.0, max_val=1.0, total_count=100, num_buckets=3), 147 | [quantiles_util.Bucket(1.0, 1.0, 100.0)]) 148 | 149 | def test_find_median(self): 150 | self.assertEqual(quantiles_util.find_median([5.0]), 5.0) 151 | self.assertEqual(quantiles_util.find_median([3.0, 5.0]), 4.0) 152 | self.assertEqual(quantiles_util.find_median([3.0, 4.0, 5.0]), 4.0) 153 | self.assertEqual(quantiles_util.find_median([3.0, 4.0, 5.0, 6.0]), 4.5) 154 | 155 | 156 | if __name__ == '__main__': 157 | absltest.main() 158 | -------------------------------------------------------------------------------- /tensorflow_data_validation/utils/schema_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for schema utilities.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl.testing import absltest 21 | from tensorflow_data_validation.utils import schema_util 22 | from google.protobuf import text_format 23 | from tensorflow_metadata.proto.v0 import schema_pb2 24 | 25 | 26 | class SchemaUtilTest(absltest.TestCase): 27 | 28 | def test_get_feature(self): 29 | schema = text_format.Parse( 30 | """ 31 | feature { 32 | name: "feature1" 33 | } 34 | feature { 35 | name: "feature2" 36 | } 37 | """, schema_pb2.Schema()) 38 | 39 | feature2 = schema_util.get_feature(schema, 'feature2') 40 | self.assertEqual(feature2.name, 'feature2') 41 | # Check to verify that we are operating on the same feature object. 42 | self.assertIs(feature2, schema_util.get_feature(schema, 'feature2')) 43 | 44 | def test_get_feature_not_present(self): 45 | schema = text_format.Parse( 46 | """ 47 | feature { 48 | name: "feature1" 49 | } 50 | """, schema_pb2.Schema()) 51 | 52 | with self.assertRaisesRegexp(ValueError, 53 | 'Feature.*not found in the schema.*'): 54 | _ = schema_util.get_feature(schema, 'feature2') 55 | 56 | def test_get_feature_invalid_schema_input(self): 57 | with self.assertRaisesRegexp(TypeError, '.*should be a Schema proto.*'): 58 | _ = schema_util.get_feature({}, 'feature') 59 | 60 | def test_get_string_domain_schema_level_domain(self): 61 | schema = text_format.Parse( 62 | """ 63 | string_domain { 64 | name: "domain1" 65 | } 66 | string_domain { 67 | name: "domain2" 68 | } 69 | feature { 70 | name: "feature1" 71 | domain: "domain2" 72 | } 73 | """, schema_pb2.Schema()) 74 | 75 | domain2 = schema_util.get_domain(schema, 'feature1') 76 | self.assertIsInstance(domain2, schema_pb2.StringDomain) 77 | self.assertEqual(domain2.name, 'domain2') 78 | # Check to verify that we are operating on the same domain object. 79 | self.assertIs(domain2, schema_util.get_domain(schema, 'feature1')) 80 | 81 | def test_get_string_domain_feature_level_domain(self): 82 | schema = text_format.Parse( 83 | """ 84 | string_domain { 85 | name: "domain2" 86 | } 87 | feature { 88 | name: "feature1" 89 | string_domain { 90 | name: "domain1" 91 | } 92 | } 93 | """, schema_pb2.Schema()) 94 | 95 | domain1 = schema_util.get_domain(schema, 'feature1') 96 | self.assertIsInstance(domain1, schema_pb2.StringDomain) 97 | self.assertEqual(domain1.name, 'domain1') 98 | # Check to verify that we are operating on the same domain object. 99 | self.assertIs(domain1, schema_util.get_domain(schema, 'feature1')) 100 | 101 | def test_get_int_domain_feature_level_domain(self): 102 | schema = text_format.Parse( 103 | """ 104 | feature { 105 | name: "feature1" 106 | int_domain { 107 | name: "domain1" 108 | } 109 | } 110 | """, schema_pb2.Schema()) 111 | 112 | domain1 = schema_util.get_domain(schema, 'feature1') 113 | self.assertIsInstance(domain1, schema_pb2.IntDomain) 114 | self.assertEqual(domain1.name, 'domain1') 115 | # Check to verify that we are operating on the same domain object. 116 | self.assertIs(domain1, schema_util.get_domain(schema, 'feature1')) 117 | 118 | def test_get_float_domain_feature_level_domain(self): 119 | schema = text_format.Parse( 120 | """ 121 | feature { 122 | name: "feature1" 123 | float_domain { 124 | name: "domain1" 125 | } 126 | } 127 | """, schema_pb2.Schema()) 128 | 129 | domain1 = schema_util.get_domain(schema, 'feature1') 130 | self.assertIsInstance(domain1, schema_pb2.FloatDomain) 131 | self.assertEqual(domain1.name, 'domain1') 132 | # Check to verify that we are operating on the same domain object. 133 | self.assertIs(domain1, schema_util.get_domain(schema, 'feature1')) 134 | 135 | def test_get_bool_domain_feature_level_domain(self): 136 | schema = text_format.Parse( 137 | """ 138 | feature { 139 | name: "feature1" 140 | bool_domain { 141 | name: "domain1" 142 | } 143 | } 144 | """, schema_pb2.Schema()) 145 | 146 | domain1 = schema_util.get_domain(schema, 'feature1') 147 | self.assertIsInstance(domain1, schema_pb2.BoolDomain) 148 | self.assertEqual(domain1.name, 'domain1') 149 | # Check to verify that we are operating on the same domain object. 150 | self.assertIs(domain1, schema_util.get_domain(schema, 'feature1')) 151 | 152 | def test_get_domain_not_present(self): 153 | schema = text_format.Parse( 154 | """ 155 | string_domain { 156 | name: "domain1" 157 | } 158 | feature { 159 | name: "feature1" 160 | } 161 | """, schema_pb2.Schema()) 162 | 163 | with self.assertRaisesRegexp(ValueError, 164 | '.*has no domain associated.*'): 165 | _ = schema_util.get_domain(schema, 'feature1') 166 | 167 | def test_get_domain_invalid_schema_input(self): 168 | with self.assertRaisesRegexp(TypeError, '.*should be a Schema proto.*'): 169 | _ = schema_util.get_domain({}, 'feature') 170 | 171 | 172 | if __name__ == '__main__': 173 | absltest.main() 174 | -------------------------------------------------------------------------------- /tensorflow_data_validation/statistics/stats_impl_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for the statistics generation implementation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | from absl.testing import absltest 23 | import apache_beam as beam 24 | from apache_beam.testing import util 25 | import numpy as np 26 | from tensorflow_data_validation.statistics import stats_impl 27 | from tensorflow_data_validation.statistics.generators import string_stats_generator 28 | from tensorflow_data_validation.statistics.generators import uniques_stats_generator 29 | from tensorflow_data_validation.utils import test_util 30 | 31 | from google.protobuf import text_format 32 | from tensorflow_metadata.proto.v0 import statistics_pb2 33 | 34 | 35 | class StatsImplTest(absltest.TestCase): 36 | 37 | def test_generate_stats_impl(self): 38 | # input with two batches: first batch has two examples and second batch 39 | # has a single example. 40 | batches = [{'a': np.array([np.array(['xyz']), np.array(['qwe'])])}, 41 | {'a': np.array([np.array(['ab'])])}] 42 | 43 | generator1 = string_stats_generator.StringStatsGenerator() 44 | generator2 = uniques_stats_generator.UniquesStatsGenerator() 45 | 46 | expected_result = text_format.Parse( 47 | """ 48 | datasets { 49 | features { 50 | name: 'a' 51 | type: STRING 52 | string_stats { 53 | avg_length: 2.66666666 54 | unique: 3 55 | } 56 | 57 | } 58 | } 59 | """, statistics_pb2.DatasetFeatureStatisticsList()) 60 | 61 | with beam.Pipeline() as p: 62 | result = (p | beam.Create(batches) | 63 | stats_impl.GenerateStatisticsImpl( 64 | generators=[generator1, generator2])) 65 | util.assert_that( 66 | result, 67 | test_util.make_dataset_feature_stats_list_proto_equal_fn( 68 | self, expected_result)) 69 | 70 | def test_merge_dataset_feature_stats_protos(self): 71 | proto1 = text_format.Parse( 72 | """ 73 | num_examples: 7 74 | features: { 75 | name: 'feature1' 76 | type: STRING 77 | string_stats: { 78 | common_stats: { 79 | num_missing: 3 80 | num_non_missing: 4 81 | min_num_values: 1 82 | max_num_values: 1 83 | } 84 | } 85 | } 86 | """, statistics_pb2.DatasetFeatureStatistics()) 87 | 88 | proto2 = text_format.Parse( 89 | """ 90 | features: { 91 | name: 'feature1' 92 | type: STRING 93 | string_stats: { 94 | unique: 3 95 | } 96 | } 97 | """, statistics_pb2.DatasetFeatureStatistics()) 98 | 99 | expected = text_format.Parse( 100 | """ 101 | num_examples: 7 102 | features: { 103 | name: 'feature1' 104 | type: STRING 105 | string_stats: { 106 | common_stats: { 107 | num_missing: 3 108 | num_non_missing: 4 109 | min_num_values: 1 110 | max_num_values: 1 111 | } 112 | unique: 3 113 | } 114 | } 115 | """, statistics_pb2.DatasetFeatureStatistics()) 116 | 117 | actual = stats_impl._merge_dataset_feature_stats_protos([proto1, proto2]) 118 | self.assertEqual(actual, expected) 119 | 120 | def test_merge_dataset_feature_stats_protos_single_proto(self): 121 | proto1 = text_format.Parse( 122 | """ 123 | num_examples: 7 124 | features: { 125 | name: 'feature1' 126 | type: STRING 127 | string_stats: { 128 | common_stats: { 129 | num_missing: 3 130 | num_non_missing: 4 131 | min_num_values: 1 132 | max_num_values: 1 133 | } 134 | } 135 | } 136 | """, statistics_pb2.DatasetFeatureStatistics()) 137 | 138 | expected = text_format.Parse( 139 | """ 140 | num_examples: 7 141 | features: { 142 | name: 'feature1' 143 | type: STRING 144 | string_stats: { 145 | common_stats: { 146 | num_missing: 3 147 | num_non_missing: 4 148 | min_num_values: 1 149 | max_num_values: 1 150 | } 151 | } 152 | } 153 | """, statistics_pb2.DatasetFeatureStatistics()) 154 | 155 | actual = stats_impl._merge_dataset_feature_stats_protos([proto1]) 156 | self.assertEqual(actual, expected) 157 | 158 | def test_merge_dataset_feature_stats_protos_empty(self): 159 | self.assertEqual(stats_impl._merge_dataset_feature_stats_protos([]), 160 | statistics_pb2.DatasetFeatureStatistics()) 161 | 162 | def test_make_dataset_feature_statistics_list_proto(self): 163 | input_proto = text_format.Parse( 164 | """ 165 | num_examples: 7 166 | features: { 167 | name: 'feature1' 168 | type: STRING 169 | } 170 | """, statistics_pb2.DatasetFeatureStatistics()) 171 | 172 | expected = text_format.Parse( 173 | """ 174 | datasets { 175 | num_examples: 7 176 | features: { 177 | name: 'feature1' 178 | type: STRING 179 | } 180 | } 181 | """, statistics_pb2.DatasetFeatureStatisticsList()) 182 | 183 | self.assertEqual( 184 | stats_impl._make_dataset_feature_statistics_list_proto(input_proto), 185 | expected) 186 | 187 | 188 | if __name__ == '__main__': 189 | absltest.main() 190 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/float_domain_util.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/float_domain_util.h" 17 | 18 | #include 19 | #include 20 | 21 | #include "absl/strings/numbers.h" 22 | #include "absl/strings/str_cat.h" 23 | #include "absl/types/optional.h" 24 | #include "absl/types/variant.h" 25 | #include "tensorflow_data_validation/anomalies/internal_types.h" 26 | #include "tensorflow/core/platform/logging.h" 27 | #include "tensorflow/core/platform/types.h" 28 | #include "tensorflow_metadata/proto/v0/anomalies.pb.h" 29 | #include "tensorflow_metadata/proto/v0/statistics.pb.h" 30 | 31 | namespace tensorflow { 32 | namespace data_validation { 33 | namespace { 34 | 35 | constexpr char kOutOfRangeValues[] = "Out-of-range values"; 36 | constexpr char kInvalidValues[] = "Invalid values"; 37 | 38 | using ::absl::get_if; 39 | using ::absl::holds_alternative; 40 | using ::absl::optional; 41 | using ::absl::variant; 42 | using ::tensorflow::metadata::v0::FeatureNameStatistics; 43 | using ::tensorflow::metadata::v0::FloatDomain; 44 | 45 | // A FloatIntervalResult where a byte is not a float. 46 | typedef string ExampleStringNotFloat; 47 | 48 | // An interval of floats. 49 | struct FloatInterval { 50 | // Min and max values of the interval. 51 | float min, max; 52 | }; 53 | 54 | // See GetFloatInterval 55 | using FloatIntervalResult = 56 | absl::optional>; 57 | 58 | // Determines the range of floats represented by the feature_stats, whether 59 | // the data is floats or strings. 60 | // Returns nullopt if there is no data in the field or it is INT. 61 | // Returns ExampleStringNotFloat if there is at least one string that does not 62 | // represent a float. 63 | // Otherwise, returns the interval. 64 | FloatIntervalResult GetFloatInterval(const FeatureStatsView& feature_stats) { 65 | switch (feature_stats.type()) { 66 | case FeatureNameStatistics::FLOAT: 67 | return FloatInterval{static_cast(feature_stats.num_stats().min()), 68 | static_cast(feature_stats.num_stats().max())}; 69 | case FeatureNameStatistics::BYTES: 70 | case FeatureNameStatistics::STRING: { 71 | absl::optional interval; 72 | for (const string& str : feature_stats.GetStringValues()) { 73 | float value; 74 | if (!absl::SimpleAtof(str, &value)) { 75 | return str; 76 | } 77 | if (!interval) { 78 | interval = FloatInterval{value, value}; 79 | } 80 | if (interval->min > value) { 81 | interval->min = value; 82 | } 83 | if (interval->max < value) { 84 | interval->max = value; 85 | } 86 | } 87 | if (interval) { 88 | return *interval; 89 | } 90 | return absl::nullopt; 91 | } 92 | case FeatureNameStatistics::INT: 93 | return absl::nullopt; 94 | default: 95 | LOG(FATAL) << "Unknown type: " << feature_stats.type(); 96 | } 97 | } 98 | 99 | } // namespace 100 | 101 | UpdateSummary UpdateFloatDomain( 102 | const FeatureStatsView& stats, 103 | tensorflow::metadata::v0::FloatDomain* float_domain) { 104 | UpdateSummary update_summary; 105 | 106 | const FloatIntervalResult result = GetFloatInterval(stats); 107 | if (result) { 108 | const variant actual_result = 109 | *result; 110 | if (holds_alternative(actual_result)) { 111 | update_summary.descriptions.push_back( 112 | {tensorflow::metadata::v0::AnomalyInfo::FLOAT_TYPE_STRING_NOT_FLOAT, 113 | kInvalidValues, 114 | absl::StrCat( 115 | "String values that were not floats were found, such as \"", 116 | *absl::get_if(&actual_result), "\".")}); 117 | update_summary.clear_field = true; 118 | return update_summary; 119 | } 120 | if (holds_alternative(actual_result)) { 121 | const FloatInterval range = 122 | *absl::get_if(&actual_result); 123 | if (float_domain->has_min() && range.min < float_domain->min()) { 124 | float_domain->set_min(range.min); 125 | update_summary.descriptions.push_back( 126 | {tensorflow::metadata::v0::AnomalyInfo::FLOAT_TYPE_SMALL_FLOAT, 127 | kOutOfRangeValues, 128 | absl::StrCat( 129 | "Unexpectedly low values: ", absl::SixDigits(range.min), "<", 130 | absl::SixDigits(float_domain->min()), 131 | "(upto six significant digits)")}); 132 | } 133 | 134 | if (float_domain->has_max() && range.max > float_domain->max()) { 135 | update_summary.descriptions.push_back( 136 | {tensorflow::metadata::v0::AnomalyInfo::FLOAT_TYPE_BIG_FLOAT, 137 | kOutOfRangeValues, 138 | absl::StrCat( 139 | "Unexpectedly high value: ", absl::SixDigits(range.max), ">", 140 | absl::SixDigits(float_domain->max()), 141 | "(upto six significant digits)")}); 142 | float_domain->set_max(range.max); 143 | } 144 | } 145 | } 146 | // If no interval is found, then assume everything is OK. 147 | return update_summary; 148 | } 149 | 150 | bool IsFloatDomainCandidate(const FeatureStatsView& feature_stats) { 151 | // We don't set float_domain by default unless we are trying to indicate that 152 | // strings are actually floats. 153 | if (feature_stats.type() != FeatureNameStatistics::STRING || 154 | feature_stats.HasInvalidUTF8Strings()) { 155 | return false; 156 | } 157 | const FloatIntervalResult result = GetFloatInterval(feature_stats); 158 | if (result) { 159 | // If all the examples are floats, then maybe we can make this a 160 | // FloatDomain. 161 | return holds_alternative(*result); 162 | } 163 | return false; 164 | } 165 | 166 | } // namespace data_validation 167 | } // namespace tensorflow 168 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/int_domain_util.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/int_domain_util.h" 17 | 18 | #include 19 | #include 20 | 21 | #include "absl/strings/numbers.h" 22 | #include "absl/strings/str_cat.h" 23 | #include "absl/types/optional.h" 24 | #include "absl/types/variant.h" 25 | #include "tensorflow_data_validation/anomalies/internal_types.h" 26 | #include "tensorflow/core/platform/logging.h" 27 | #include "tensorflow/core/platform/types.h" 28 | #include "tensorflow_metadata/proto/v0/anomalies.pb.h" 29 | #include "tensorflow_metadata/proto/v0/statistics.pb.h" 30 | 31 | namespace tensorflow { 32 | namespace data_validation { 33 | namespace { 34 | using ::absl::variant; 35 | using ::tensorflow::metadata::v0::FeatureNameStatistics; 36 | using ::tensorflow::metadata::v0::IntDomain; 37 | 38 | constexpr char kOutOfRangeValues[] = "Out-of-range values"; 39 | constexpr char kInvalidValues[] = "Invalid values"; 40 | 41 | // A IntIntervalResult where a byte is not a float. 42 | typedef string ExampleStringNotInt; 43 | 44 | // An interval of ints. 45 | struct IntInterval { 46 | // Min and max values of the interval. 47 | int64 min, max; 48 | }; 49 | 50 | // See GetIntInterval 51 | using IntIntervalResult = 52 | absl::optional>; 53 | 54 | // Returns an Interval, if the values are all integers (either int64_list or a 55 | // bytes_list where every value is a decimal representation of an integer). 56 | // Returns UNDEFINED if the statistics are FLOAT or BYTES. 57 | // Returns UNDEFINED if the statistics are STRING but do not represent int64 58 | // numbers. 59 | // Returns EMPTY if the statistics are STRING but there are no common values 60 | // (i.e., the statistics (not the range) are empty). 61 | // Returns NONEMPTY otherwise, with valid min and max set in 62 | // result. 63 | // NOTE: if GetIntInterval returns anything but NONEMPTY, result is always 64 | // [0,0]. 65 | IntIntervalResult GetIntInterval(const FeatureStatsView& feature_stats_view) { 66 | // Extract string values upfront as it can be useful for categorical INT 67 | // features. 68 | const std::vector string_values = 69 | feature_stats_view.GetStringValues(); 70 | switch (feature_stats_view.type()) { 71 | case FeatureNameStatistics::FLOAT: 72 | return absl::nullopt; 73 | case FeatureNameStatistics::INT: { 74 | if (string_values.empty()) { 75 | return IntInterval{ 76 | static_cast(feature_stats_view.num_stats().min()), 77 | static_cast(feature_stats_view.num_stats().max())}; 78 | } 79 | // Intentionally fall through BYTES, STRING case for categorical integer 80 | // features. 81 | ABSL_FALLTHROUGH_INTENDED; 82 | } 83 | case FeatureNameStatistics::BYTES: 84 | case FeatureNameStatistics::STRING: { 85 | absl::optional interval; 86 | for (const string& str : string_values) { 87 | int64 value; 88 | if (!absl::SimpleAtoi(str, &value)) { 89 | return str; 90 | } 91 | if (!interval) { 92 | interval = IntInterval{value, value}; 93 | } 94 | if (interval->min > value) { 95 | interval->min = value; 96 | } 97 | if (interval->max < value) { 98 | interval->max = value; 99 | } 100 | } 101 | if (interval) { 102 | return *interval; 103 | } 104 | return absl::nullopt; 105 | } 106 | default: 107 | LOG(FATAL) << "Unknown type: " << feature_stats_view.type(); 108 | } 109 | } 110 | 111 | } // namespace 112 | 113 | bool IsIntDomainCandidate(const FeatureStatsView& feature_stats) { 114 | // We are not getting bounds here: we are just identifying that it is a string 115 | // encoded as an int. 116 | if (feature_stats.type() != FeatureNameStatistics::STRING || 117 | feature_stats.HasInvalidUTF8Strings()) { 118 | return false; 119 | } 120 | 121 | const IntIntervalResult result = GetIntInterval(feature_stats); 122 | if (result) { 123 | return absl::holds_alternative(*result); 124 | } 125 | return false; 126 | } 127 | 128 | UpdateSummary UpdateIntDomain(const FeatureStatsView& feature_stats, 129 | tensorflow::metadata::v0::IntDomain* int_domain) { 130 | UpdateSummary update_summary; 131 | const IntIntervalResult result = GetIntInterval(feature_stats); 132 | if (result) { 133 | const variant actual_result = *result; 134 | if (absl::holds_alternative(actual_result)) { 135 | update_summary.descriptions.push_back( 136 | {tensorflow::metadata::v0::AnomalyInfo::INT_TYPE_NOT_INT_STRING, 137 | kInvalidValues, 138 | absl::StrCat( 139 | "String values that were not ints were found, such as \"", 140 | *absl::get_if(&actual_result), "\".")}); 141 | update_summary.clear_field = true; 142 | return update_summary; 143 | } 144 | if (absl::holds_alternative(actual_result)) { 145 | const IntInterval interval = 146 | *absl::get_if(&actual_result); 147 | if (int_domain->has_min() && int_domain->min() > interval.min) { 148 | update_summary.descriptions.push_back( 149 | {tensorflow::metadata::v0::AnomalyInfo::INT_TYPE_SMALL_INT, 150 | kOutOfRangeValues, 151 | absl::StrCat("Unexpectedly small value: ", interval.min, ".")}); 152 | int_domain->set_min(interval.min); 153 | } 154 | if (int_domain->has_max() && int_domain->max() < interval.max) { 155 | update_summary.descriptions.push_back( 156 | {tensorflow::metadata::v0::AnomalyInfo::INT_TYPE_BIG_INT, 157 | kOutOfRangeValues, 158 | absl::StrCat("Unexpectedly large value: ", interval.max, ".")}); 159 | int_domain->set_max(interval.max); 160 | } 161 | 162 | return update_summary; 163 | } 164 | } 165 | return update_summary; 166 | } 167 | 168 | } // namespace data_validation 169 | } // namespace tensorflow 170 | -------------------------------------------------------------------------------- /tensorflow_data_validation/statistics/generators/string_stats_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module that computes statistics for features of string type. 16 | 17 | Specifically, we compute the following statistics for each string feature: 18 | - Average length of the values for this feature. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | 24 | from __future__ import print_function 25 | 26 | import numpy as np 27 | import six 28 | from tensorflow_data_validation import types 29 | from tensorflow_data_validation.statistics.generators import stats_generator 30 | from tensorflow_data_validation.utils import stats_util 31 | from tensorflow_data_validation.types_compat import Dict, List, Optional 32 | from tensorflow_metadata.proto.v0 import schema_pb2 33 | from tensorflow_metadata.proto.v0 import statistics_pb2 34 | 35 | 36 | class _PartialStringStats(object): 37 | """Holds partial statistics needed to compute the string statistics 38 | for a single feature.""" 39 | 40 | def __init__(self): 41 | # The total length of all the values for this feature. 42 | self.total_bytes_length = 0 43 | # The total number of values for this feature. 44 | self.total_num_values = 0 45 | 46 | 47 | def _merge_string_stats(left, 48 | right): 49 | """Merge two partial string statistics and return the merged statistics.""" 50 | result = _PartialStringStats() 51 | result.total_bytes_length = (left.total_bytes_length + 52 | right.total_bytes_length) 53 | result.total_num_values = (left.total_num_values + 54 | right.total_num_values) 55 | return result 56 | 57 | 58 | def _make_feature_stats_proto( 59 | string_stats, feature_name, 60 | is_categorical): 61 | """Convert the partial string statistics into FeatureNameStatistics proto.""" 62 | result = statistics_pb2.FeatureNameStatistics() 63 | result.name = feature_name 64 | # If we have a categorical feature, we preserve the type to be the original 65 | # INT type. 66 | result.type = (statistics_pb2.FeatureNameStatistics.INT if is_categorical 67 | else statistics_pb2.FeatureNameStatistics.STRING) 68 | result.string_stats.avg_length = (string_stats.total_bytes_length / 69 | string_stats.total_num_values) 70 | return result 71 | 72 | 73 | class StringStatsGenerator(stats_generator.CombinerStatsGenerator): 74 | """A combiner statistics generator that computes the statistics 75 | for features of string type.""" 76 | 77 | def __init__( 78 | self, # pylint: disable=useless-super-delegation 79 | name = 'StringStatsGenerator', 80 | schema = None): 81 | """Initializes a string statistics generator. 82 | 83 | Args: 84 | name: An optional unique name associated with the statistics generator. 85 | schema: An optional schema for the dataset. 86 | """ 87 | super(StringStatsGenerator, self).__init__(name, schema) 88 | self._categorical_features = set( 89 | stats_util.get_categorical_numeric_features(schema) if schema else []) 90 | 91 | # Create an accumulator, which maps feature name to the partial stats 92 | # associated with the feature. 93 | def create_accumulator(self): 94 | return {} 95 | 96 | # Incorporates the input (a Python dict whose keys are feature names and 97 | # values are numpy arrays representing a batch of examples) into the 98 | # accumulator. 99 | def add_input(self, accumulator, 100 | input_batch 101 | ): 102 | # Iterate through each feature and update the partial string stats. 103 | for feature_name, values in six.iteritems(input_batch): 104 | # Update the string statistics for every example in the batch. 105 | for value in values: 106 | # Check if we have a numpy array with at least one value. 107 | if not isinstance(value, np.ndarray) or value.size == 0: 108 | continue 109 | 110 | # If the feature is neither categorical nor of string type, then 111 | # skip the feature. 112 | if not (feature_name in self._categorical_features or 113 | stats_util.make_feature_type(value.dtype) == 114 | statistics_pb2.FeatureNameStatistics.STRING): 115 | continue 116 | 117 | # If we encounter this feature for the first time, create a 118 | # new partial string stats. 119 | if feature_name not in accumulator: 120 | accumulator[feature_name] = _PartialStringStats() 121 | 122 | # If we have a categorical feature, convert the value to string type. 123 | if feature_name in self._categorical_features: 124 | value = value.astype(str) 125 | 126 | # Update the partial string stats. 127 | for v in value: 128 | accumulator[feature_name].total_bytes_length += len(v) 129 | accumulator[feature_name].total_num_values += len(value) 130 | 131 | return accumulator 132 | 133 | # Merge together a list of partial string statistics. 134 | def merge_accumulators( 135 | self, accumulators 136 | ): 137 | result = {} 138 | 139 | for accumulator in accumulators: 140 | for feature_name, string_stats in accumulator.items(): 141 | if feature_name not in result: 142 | result[feature_name] = string_stats 143 | else: 144 | result[feature_name] = _merge_string_stats( 145 | result[feature_name], string_stats) 146 | return result 147 | 148 | # Return final stats as a DatasetFeatureStatistics proto. 149 | def extract_output(self, 150 | accumulator 151 | ): 152 | # Create a new DatasetFeatureStatistics proto. 153 | result = statistics_pb2.DatasetFeatureStatistics() 154 | 155 | for feature_name, string_stats in accumulator.items(): 156 | # Construct the FeatureNameStatistics proto from the partial 157 | # string stats. 158 | feature_stats_proto = _make_feature_stats_proto( 159 | string_stats, feature_name, 160 | feature_name in self._categorical_features) 161 | # Copy the constructed FeatureNameStatistics proto into the 162 | # DatasetFeatureStatistics proto. 163 | new_feature_stats_proto = result.features.add() 164 | new_feature_stats_proto.CopyFrom(feature_stats_proto) 165 | return result 166 | -------------------------------------------------------------------------------- /tensorflow_data_validation/statistics/stats_impl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Beam implementation of statistics generators.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | import apache_beam as beam 23 | from tensorflow_data_validation import types 24 | from tensorflow_data_validation.statistics.generators import stats_generator 25 | from tensorflow_data_validation.types_compat import List, TypeVar 26 | 27 | from tensorflow_metadata.proto.v0 import statistics_pb2 28 | 29 | 30 | @beam.typehints.with_input_types(types.ExampleBatch) 31 | @beam.typehints.with_output_types(statistics_pb2.DatasetFeatureStatisticsList) 32 | class GenerateStatisticsImpl(beam.PTransform): 33 | """PTransform that applies a set of generators.""" 34 | 35 | def __init__( 36 | self, 37 | generators): 38 | self._generators = generators 39 | 40 | def expand(self, pcoll): 41 | result_protos = [] 42 | # Iterate over the stats generators. For each generator, 43 | # a) if it is a CombinerStatsGenerator, wrap it as a beam.CombineFn 44 | # and run it. 45 | # b) if it is a TransformStatsGenerator, wrap it as a beam.PTransform 46 | # and run it. 47 | for generator in self._generators: 48 | if isinstance(generator, stats_generator.CombinerStatsGenerator): 49 | result_protos.append( 50 | pcoll | 51 | generator.name >> beam.CombineGlobally( 52 | _CombineFnWrapper(generator))) 53 | elif isinstance(generator, stats_generator.TransformStatsGenerator): 54 | result_protos.append( 55 | pcoll | 56 | generator.name >> generator.ptransform) 57 | else: 58 | raise TypeError('Statistics generator must extend one of ' 59 | 'CombinerStatsGenerator or TransformStatsGenerator, ' 60 | 'found object of type %s' % 61 | type(generator).__class__.__name__) 62 | 63 | # Each stats generator will output a PCollection of DatasetFeatureStatistics 64 | # protos. We now flatten the list of PCollections into a single PCollection, 65 | # then merge the DatasetFeatureStatistics protos in the PCollection into a 66 | # single DatasetFeatureStatisticsList proto. 67 | return (result_protos | 'FlattenFeatureStatistics' >> beam.Flatten() 68 | | 'MergeDatasetFeatureStatisticsProtos' >> 69 | beam.CombineGlobally(_merge_dataset_feature_stats_protos) 70 | | 'MakeDatasetFeatureStatisticsListProto' >> 71 | beam.Map(_make_dataset_feature_statistics_list_proto)) 72 | 73 | 74 | def _merge_dataset_feature_stats_protos( 75 | stats_protos 76 | ): 77 | """Merge together a list of DatasetFeatureStatistics protos. 78 | 79 | Args: 80 | stats_protos: A list of DatasetFeatureStatistics protos to merge. 81 | 82 | Returns: 83 | The merged DatasetFeatureStatistics proto. 84 | """ 85 | stats_per_feature = {} 86 | # Iterate over each DatasetFeatureStatistics proto and merge the 87 | # FeatureNameStatistics protos per feature. 88 | for stats_proto in stats_protos: 89 | for feature_stats_proto in stats_proto.features: 90 | if feature_stats_proto.name not in stats_per_feature: 91 | stats_per_feature[feature_stats_proto.name] = feature_stats_proto 92 | else: 93 | stats_per_feature[feature_stats_proto.name].MergeFrom( 94 | feature_stats_proto) 95 | 96 | # Create a new DatasetFeatureStatistics proto. 97 | result = statistics_pb2.DatasetFeatureStatistics() 98 | num_examples = None 99 | for feature_stats_proto in stats_per_feature.values(): 100 | # Add the merged FeatureNameStatistics proto for the feature 101 | # into the DatasetFeatureStatistics proto. 102 | new_feature_stats_proto = result.features.add() 103 | new_feature_stats_proto.CopyFrom(feature_stats_proto) 104 | 105 | # Get the number of examples from one of the features that 106 | # has common stats. 107 | if num_examples is None: 108 | stats_type = feature_stats_proto.WhichOneof('stats') 109 | stats_proto = None 110 | if stats_type == 'num_stats': 111 | stats_proto = feature_stats_proto.num_stats 112 | else: 113 | stats_proto = feature_stats_proto.string_stats 114 | 115 | if stats_proto.HasField('common_stats'): 116 | num_examples = (stats_proto.common_stats.num_non_missing + 117 | stats_proto.common_stats.num_missing) 118 | 119 | # Set the num_examples field. 120 | if num_examples is not None: 121 | result.num_examples = num_examples 122 | return result 123 | 124 | 125 | def _make_dataset_feature_statistics_list_proto( 126 | stats_proto 127 | ): 128 | """Constructs a DatasetFeatureStatisticsList proto. 129 | 130 | Args: 131 | stats_proto: The input DatasetFeatureStatistics proto. 132 | 133 | Returns: 134 | The DatasetFeatureStatisticsList proto containing the input stats proto. 135 | """ 136 | # Create a new DatasetFeatureStatisticsList proto. 137 | result = statistics_pb2.DatasetFeatureStatisticsList() 138 | 139 | # Add the input DatasetFeatureStatistics proto. 140 | dataset_stats_proto = result.datasets.add() 141 | dataset_stats_proto.CopyFrom(stats_proto) 142 | return result 143 | 144 | 145 | 146 | 147 | @beam.typehints.with_input_types(types.ExampleBatch) 148 | @beam.typehints.with_output_types( 149 | statistics_pb2.DatasetFeatureStatistics) 150 | class _CombineFnWrapper(beam.CombineFn): 151 | """Class to wrap a CombinerStatsGenerator as a beam.CombineFn.""" 152 | 153 | def __init__( 154 | self, 155 | generator): 156 | self._generator = generator 157 | 158 | def __reduce__(self): 159 | return _CombineFnWrapper, (self._generator,) 160 | 161 | def create_accumulator(self 162 | ): # pytype: disable=invalid-annotation 163 | return self._generator.create_accumulator() 164 | 165 | def add_input(self, accumulator, 166 | input_batch): 167 | return self._generator.add_input(accumulator, input_batch) 168 | 169 | def merge_accumulators(self, accumulators): 170 | return self._generator.merge_accumulators(accumulators) 171 | 172 | def extract_output( 173 | self, 174 | accumulator 175 | ): # pytype: disable=invalid-annotation 176 | return self._generator.extract_output(accumulator) 177 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/map_util_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_data_validation/anomalies/map_util.h" 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | #include "tensorflow/core/platform/types.h" 26 | 27 | namespace tensorflow { 28 | namespace data_validation { 29 | namespace { 30 | 31 | TEST(MapUtilTest, ContainsKeyMapKeyPresent) { 32 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 33 | EXPECT_TRUE(ContainsKey(a, "a")); 34 | } 35 | 36 | TEST(MapUtilTest, ContainsKeyMapKeyAbsent) { 37 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 38 | EXPECT_FALSE(ContainsKey(a, "c")); 39 | } 40 | 41 | TEST(MapUtilTest, ContainsKeyMapEmpty) { 42 | const std::map a = {}; 43 | EXPECT_FALSE(ContainsKey(a, "a")); 44 | } 45 | 46 | TEST(MapUtilTest, ContainsKeySetKeyPresent) { 47 | const std::set a = {"a", "b"}; 48 | EXPECT_TRUE(ContainsKey(a, "a")); 49 | } 50 | 51 | TEST(MapUtilTest, ContainsKeySetKeyAbsent) { 52 | const std::set a = {"a", "b"}; 53 | EXPECT_FALSE(ContainsKey(a, "c")); 54 | } 55 | 56 | TEST(MapUtilTest, ContainsKeySetEmpty) { 57 | const std::set a = {}; 58 | EXPECT_FALSE(ContainsKey(a, "a")); 59 | } 60 | 61 | TEST(MapUtilTest, SumValues) { 62 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 63 | double result = SumValues(a); 64 | EXPECT_NEAR(result, 4.0, 1.0e-7); 65 | } 66 | 67 | TEST(MapUtilTest, SumValuesEmpty) { 68 | const std::map a = {}; 69 | double result = SumValues(a); 70 | EXPECT_NEAR(result, 0.0, 1.0e-7); 71 | } 72 | 73 | TEST(MapUtilTest, GetValuesFromMap) { 74 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 75 | const std::vector result = GetValuesFromMap(a); 76 | EXPECT_THAT(result, testing::ElementsAre(3.0, 1.0)); 77 | } 78 | 79 | TEST(MapUtilTest, GetValuesFromMapEmpty) { 80 | const std::map a = {}; 81 | const std::vector result = GetValuesFromMap(a); 82 | EXPECT_THAT(result, testing::ElementsAre()); 83 | } 84 | 85 | TEST(MapUtilTest, GetKeysFromMap) { 86 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 87 | const std::vector result = GetKeysFromMap(a); 88 | EXPECT_THAT(result, testing::ElementsAre("a", "b")); 89 | } 90 | 91 | TEST(MapUtilTest, GetKeysFromMapEmpty) { 92 | const std::map a = {}; 93 | const std::vector result = GetKeysFromMap(a); 94 | EXPECT_THAT(result, testing::ElementsAre()); 95 | } 96 | 97 | TEST(MapUtilTest, Normalize) { 98 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 99 | const std::map result = Normalize(a); 100 | EXPECT_THAT(result, 101 | testing::ElementsAre(std::pair("a", 0.75), 102 | std::pair("b", 0.25))); 103 | } 104 | 105 | TEST(MapUtilTest, NormalizeAllZeros) { 106 | const std::map a = {{"a", 0.0}, {"b", 0.0}}; 107 | const std::map result = Normalize(a); 108 | EXPECT_THAT(result, 109 | testing::ElementsAre(std::pair("a", 0.0), 110 | std::pair("b", 0.0))); 111 | } 112 | 113 | TEST(MapUtilTest, NormalizeEmpty) { 114 | const std::map a = {}; 115 | const std::map result = Normalize(a); 116 | EXPECT_THAT(result, testing::ElementsAre()); 117 | } 118 | 119 | TEST(MapUtilTest, GetDifference) { 120 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 121 | const std::map b = {{"c", 1.0}, {"b", 1.0}}; 122 | const std::map c = GetDifference(a, b); 123 | EXPECT_THAT(c, testing::ElementsAre(std::pair("a", 3.0), 124 | std::pair("b", 0.0), 125 | std::pair("c", -1.0))); 126 | } 127 | 128 | TEST(MapUtilTest, GetDifferenceEmpty) { 129 | const std::map a = {}; 130 | const std::map b = {}; 131 | const std::map c = GetDifference(a, b); 132 | EXPECT_THAT(c, testing::ElementsAre()); 133 | } 134 | 135 | TEST(MapUtilTest, IncrementMap) { 136 | const std::map a = {{"c", 1.0}, {"b", 1.0}}; 137 | std::map b = {{"a", 3.0}, {"b", 1.0}}; 138 | IncrementMap(a, &b); 139 | EXPECT_THAT(b, testing::ElementsAre(std::pair("a", 3.0), 140 | std::pair("b", 2.0), 141 | std::pair("c", 1.0))); 142 | } 143 | 144 | TEST(MapUtilTest, IncrementMapEmpty) { 145 | const std::map a; 146 | std::map b; 147 | IncrementMap(a, &b); 148 | EXPECT_THAT(b, testing::ElementsAre()); 149 | } 150 | 151 | TEST(MapUtilTest, GetSum) { 152 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 153 | const std::map b = {{"c", 1.0}, {"b", 1.0}}; 154 | const std::map c = GetSum(a, b); 155 | EXPECT_THAT(c, testing::ElementsAre(std::pair("a", 3.0), 156 | std::pair("b", 2.0), 157 | std::pair("c", 1.0))); 158 | } 159 | 160 | TEST(MapUtilTest, GetSumEmpty) { 161 | const std::map a = {}; 162 | const std::map b = {}; 163 | const std::map c = GetSum(a, b); 164 | EXPECT_THAT(c, testing::ElementsAre()); 165 | } 166 | 167 | TEST(MapUtilTest, MapValues) { 168 | const std::map a = {{"a", 3.0}, {"b", 1.0}}; 169 | const std::map c = 170 | MapValues(a, [](double a) { return a + 1.0; }); 171 | EXPECT_THAT(c, testing::ElementsAre(std::pair("a", 4.0), 172 | std::pair("b", 2.0))); 173 | } 174 | 175 | TEST(MapUtilTest, MapValuesEmpty) { 176 | const std::map a = {}; 177 | const std::map c = 178 | MapValues(a, [](double a) { return a + 1.0; }); 179 | EXPECT_THAT(c, testing::ElementsAre()); 180 | } 181 | 182 | } // namespace 183 | } // namespace data_validation 184 | } // namespace tensorflow 185 | -------------------------------------------------------------------------------- /tensorflow_data_validation/anomalies/path.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow_data_validation/anomalies/path.h" 2 | 3 | #include 4 | 5 | #include "absl/strings/str_cat.h" 6 | #include "absl/strings/str_join.h" 7 | #include "absl/strings/str_replace.h" 8 | #include "absl/strings/str_split.h" 9 | #include "absl/strings/string_view.h" 10 | #include "re2/re2.h" 11 | #include "tensorflow/core/lib/core/errors.h" 12 | #include "tensorflow/core/lib/core/status.h" 13 | 14 | namespace tensorflow { 15 | namespace data_validation { 16 | namespace { 17 | 18 | 19 | #ifdef TFDV_GOOGLE_PLATFORM 20 | using Re2StringPiece = absl::string_view; 21 | #else 22 | using Re2StringPiece = re2::StringPiece; 23 | #endif 24 | 25 | Re2StringPiece StringViewToRe2StringPiece(absl::string_view view) { 26 | return Re2StringPiece(view.begin(), view.size()); 27 | } 28 | 29 | absl::string_view Re2StringPieceToStringView(Re2StringPiece str_piece) { 30 | return absl::string_view(str_piece.begin(), str_piece.size()); 31 | } 32 | 33 | // This matches a standard step. 34 | // Standard steps include any proto steps (extensions and regular fields). 35 | // Standard steps either begin or end with parentheses with none on the inside 36 | // or they have no parentheses, dot, or single quotes. 37 | static const LazyRE2 kStandardStep = {R"((\([^()]*\))|([^().']+))", 38 | RE2::Latin1}; 39 | 40 | // This matches a serialized step with quotes. 41 | // This begins and ends with a single quote, and all internal single quotes are 42 | // doubled. 43 | static const LazyRE2 kSerializedWithQuotes = {"'(('')|[^'])*'", RE2::Latin1}; 44 | 45 | // This matches a serialized step and the next dot. 46 | // Note that this is the union of the above two regexes followed by a dot. 47 | static const LazyRE2 kSerializedStepAndDot = { 48 | R"(((('(('')|[^'])*')|(\([^()]*\))|([^()'.]+))\.))", RE2::Latin1}; 49 | 50 | // Returns true if: 51 | // str is nonempty and has no ".", "(", ")", or "'", OR: 52 | // str starts with "(", ends with ")", and has no "(" or ")" in the interior. 53 | bool IsStandardStep(const string& str) { 54 | return RE2::FullMatch(str, *kStandardStep); 55 | } 56 | 57 | string SerializeStep(const string& str) { 58 | if (IsStandardStep(str)) { 59 | return str; 60 | } 61 | // Double any single quotes in the string, and encapsulate with single quotes. 62 | return absl::StrCat("'", absl::StrReplaceAll(str, {{"'", "''"}}), "'"); 63 | } 64 | 65 | // Deserialize a step in-place. 66 | // If the step is in the standard format, do nothing. 67 | // Otherwise, remove beginning and ending quote and replace pairs of single 68 | // quotes with single quotes. 69 | tensorflow::Status DeserializeStep(string* to_modify) { 70 | if (IsStandardStep(*to_modify)) { 71 | return Status::OK(); 72 | } 73 | 74 | // A legal serialized string here begins and ends with a single quote and 75 | // has had all interior single quotes replaced with double quotes. 76 | if (!RE2::FullMatch(*to_modify, *kSerializedWithQuotes)) { 77 | return errors::InvalidArgument("Not a valid serialized step: ", *to_modify); 78 | } 79 | // Remove the first and last quote 80 | const absl::string_view quotes_removed(to_modify->data() + 1, 81 | to_modify->size() - 2); 82 | 83 | // Replace each pair of quotes remaining with a single quote. 84 | *to_modify = absl::StrReplaceAll(quotes_removed, {{"''", "'"}}); 85 | return Status::OK(); 86 | } 87 | 88 | // Find the next step delimiter. 89 | // See absl/strings/str_split.h 90 | // If the next step is not valid in some way, 91 | // return absl::string_view(text.end(), 0) 92 | struct StepDelimiter { 93 | absl::string_view Find(absl::string_view text, size_t pos) { 94 | if (pos >= text.size()) { 95 | return absl::string_view(text.end(), 0); 96 | } 97 | Re2StringPiece solution; 98 | Re2StringPiece remaining_string = 99 | StringViewToRe2StringPiece(text.substr(pos)); 100 | // Regex captures a serialized step followed by a dot. 101 | // Note that this only captures the step if it is not the last one. Since 102 | // on no match we just have the rest of the string be a step, we are OK. 103 | // Note that solution is a view into a view into the original text argument. 104 | // Match(...,&solution, 1) will return 1 submatch (i.e. the substring of 105 | // text matching the regular expression). 106 | if (kSerializedStepAndDot->Match(remaining_string, 0, 107 | remaining_string.size(), RE2::ANCHOR_START, 108 | &solution, 1) && 109 | solution.data() != nullptr) { 110 | // solution now contains the step and the dot after the step. 111 | // Returns a string_view of text that is equal to ".". 112 | return Re2StringPieceToStringView(solution.substr(solution.size() - 1)); 113 | } 114 | return absl::string_view(text.end(), 0); 115 | } 116 | }; 117 | } // namespace 118 | 119 | Path::Path(const tensorflow::metadata::v0::Path& p) 120 | : step_(p.step().begin(), p.step().end()) {} 121 | 122 | // Part of the implementation of Compare(). 123 | bool Path::Equals(const Path& p) const { 124 | if (p.step_.size() != step_.size()) { 125 | return false; 126 | } 127 | return std::equal(step_.begin(), step_.end(), p.step_.begin()); 128 | } 129 | 130 | bool Path::Less(const Path& p) const { 131 | return std::lexicographical_compare(step_.begin(), step_.end(), 132 | p.step_.begin(), p.step_.end()); 133 | } 134 | 135 | int Path::Compare(const Path& p) const { 136 | if (Equals(p)) { 137 | return 0; 138 | } 139 | return Less(p) ? -1 : +1; 140 | } 141 | 142 | bool operator==(const Path& a, const Path& b) { return a.Compare(b) == 0; } 143 | 144 | bool operator<(const Path& a, const Path& b) { return a.Compare(b) < 0; } 145 | 146 | bool operator>(const Path& a, const Path& b) { return a.Compare(b) > 0; } 147 | 148 | bool operator>=(const Path& a, const Path& b) { return a.Compare(b) >= 0; } 149 | 150 | bool operator<=(const Path& a, const Path& b) { return a.Compare(b) <= 0; } 151 | 152 | bool operator!=(const Path& a, const Path& b) { return a.Compare(b) != 0; } 153 | 154 | string Path::Serialize() const { 155 | const string separator = "."; 156 | std::vector serialized_steps; 157 | for (const string& step : step_) { 158 | serialized_steps.push_back(SerializeStep(step)); 159 | } 160 | return absl::StrJoin(serialized_steps, separator); 161 | } 162 | 163 | tensorflow::metadata::v0::Path Path::AsProto() const { 164 | tensorflow::metadata::v0::Path path; 165 | for (const string& step : step_) { 166 | path.add_step(step); 167 | } 168 | return path; 169 | } 170 | 171 | // Deserializes a string created with Serialize(). 172 | // Note: for any path p: 173 | // p==Path::Deserialize(p.Serialize()) 174 | tensorflow::Status Path::Deserialize(absl::string_view str, Path* result) { 175 | result->step_.clear(); 176 | if (str.empty()) { 177 | return Status::OK(); 178 | } 179 | result->step_ = absl::StrSplit(str, StepDelimiter()); 180 | for (string& step : result->step_) { 181 | TF_RETURN_IF_ERROR(DeserializeStep(&step)); 182 | } 183 | return Status::OK(); 184 | } 185 | 186 | Path Path::GetChild(absl::string_view last_step) const { 187 | std::vector new_steps(step_.begin(), step_.end()); 188 | new_steps.push_back(string(last_step)); 189 | return Path(std::move(new_steps)); 190 | } 191 | 192 | } // namespace data_validation 193 | } // namespace tensorflow 194 | --------------------------------------------------------------------------------