├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── RELEASE.md ├── docs ├── _toc.yaml ├── build_tft_beam_docs.py ├── build_tft_docs.py ├── common_transformations.md ├── get_started.ipynb ├── install.md └── tf2_support.md ├── examples ├── README.md ├── census_example.py ├── census_example_v2.py ├── census_example_v2_test.py ├── dataset_tfxio_example.py ├── dataset_tfxio_example_test.py ├── local_model_server.py ├── sentiment.md ├── sentiment_example_v2.py ├── sentiment_example_v2_test.py ├── simple_example.py ├── simple_example_test.py ├── simple_sequence_example.py ├── simple_sequence_example_test.py └── testdata │ └── sequence_example │ └── data_tfrecord.gz ├── setup.py └── tensorflow_transform ├── __init__.py ├── analyzer_nodes.py ├── analyzers.py ├── analyzers_test.py ├── annotations.proto ├── annotators.py ├── annotators_test.py ├── beam ├── __init__.py ├── analysis_graph_builder.py ├── analysis_graph_builder_test.py ├── analyzer_cache.py ├── analyzer_cache_test.py ├── analyzer_impls.py ├── analyzer_impls_test.py ├── annotators_test.py ├── beam_nodes.py ├── bucketize_integration_test.py ├── cached_impl_test.py ├── combiner_packing_util.py ├── combiner_packing_util_test.py ├── common.py ├── context.py ├── context_test.py ├── deep_copy.py ├── deep_copy_test.py ├── experimental │ ├── __init__.py │ └── analyzer_impls.py ├── impl.py ├── impl_output_record_batches_test.py ├── impl_test.py ├── test_helpers.py ├── tft_beam_io │ ├── __init__.py │ ├── beam_metadata_io.py │ ├── beam_metadata_io_test.py │ ├── test_metadata.py │ ├── transform_fn_io.py │ └── transform_fn_io_test.py ├── tft_unit.py ├── tukey_hh_params_integration_test.py ├── vocabulary_integration_test.py └── vocabulary_tfrecord_gzip_integration_test.py ├── coders ├── __init__.py ├── csv_coder.py ├── csv_coder_test.py ├── example_proto_coder.py └── example_proto_coder_test.py ├── common.py ├── common_test.py ├── common_types.py ├── experimental ├── __init__.py ├── analyzers.py ├── annotators.py └── mappers.py ├── gaussianization.py ├── gaussianization_test.py ├── graph_context.py ├── graph_tools.py ├── graph_tools_test.py ├── impl_helper.py ├── impl_helper_test.py ├── info_theory.py ├── info_theory_test.py ├── inspect_preprocessing_fn.py ├── inspect_preprocessing_fn_test.py ├── keras_lib.py ├── mappers.py ├── mappers_test.py ├── nodes.py ├── nodes_test.py ├── output_wrapper.py ├── pickle_helper.py ├── pretrained_models.py ├── pretrained_models_test.py ├── py.typed ├── py_func ├── __init__.py ├── api.py └── pyfunc_helper.py ├── saved ├── __init__.py ├── constants.py ├── saved_model_loader.py ├── saved_model_loader_test.py ├── saved_transform_io.py ├── saved_transform_io_test.py ├── saved_transform_io_v2.py └── saved_transform_io_v2_test.py ├── schema_inference.py ├── schema_inference_test.py ├── test_case.py ├── test_case_test.py ├── tf2_utils.py ├── tf2_utils_test.py ├── tf_metadata ├── __init__.py ├── dataset_metadata.py ├── dataset_metadata_test.py ├── metadata_io.py ├── metadata_io_test.py ├── schema_utils.py ├── schema_utils_legacy.py ├── schema_utils_test.py ├── schema_utils_test_cases.py └── test_common.py ├── tf_utils.py ├── tf_utils_test.py └── version.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .eggs/ 3 | *.egg/ 4 | dist/ 5 | build/ 6 | .idea/ 7 | *.pyc 8 | .tox/ 9 | py27/ 10 | target/ 11 | *coverage* 12 | *.swp 13 | AUTHORS 14 | ChangeLog 15 | .DS_Store 16 | .mypy_cache 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | ## How to become a contributor and submit your own code 4 | 5 | ### Contributor License Agreements 6 | 7 | We'd love to accept your patches! Before we can take them, we have to jump a couple of legal hurdles. 8 | 9 | Please fill out either the individual or corporate Contributor License Agreement (CLA). 10 | 11 | * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html). 12 | * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 13 | 14 | Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. 15 | 16 | ***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository. 17 | 18 | ### Contributing code 19 | 20 | If you have improvements to TensorFlow Transform, send us your pull requests! 21 | For those just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/). 22 | 23 | If you want to contribute but you're not sure where to start, take a look at the 24 | [issues with the "contributions welcome" label](https://github.com/tensorflow/transform/labels/contributions%20welcome). 25 | These are issues that we believe are particularly well suited for outside 26 | contributions, often because we probably won't get to them right now. If you 27 | decide to start on an issue, leave a comment so that other people know that 28 | you're working on it. If you want to help out, but not alone, use the issue 29 | comment thread to coordinate. 30 | -------------------------------------------------------------------------------- /docs/_toc.yaml: -------------------------------------------------------------------------------- 1 | toc: 2 | - title: "Install" 3 | path: /tfx/transform/install 4 | - title: "Get started" 5 | path: /tfx/transform/get_started 6 | - title: "Using tf.Transform with TensorFlow 2.x" 7 | path: /tfx/transform/tf2_support 8 | - title: "Common transformations" 9 | path: /tfx/transform/common_transformations 10 | -------------------------------------------------------------------------------- /docs/build_tft_beam_docs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Generate docs for `tft.beam`. 16 | 17 | This requires a local installation of `tft` and `tensoirflow_docs` 18 | 19 | ``` 20 | $ pip install tensorflow_transform git+https://github.com/tensorflow/docs 21 | ``` 22 | 23 | ``` 24 | python build_tft_beam_docs.py --output_dir=/tmp/tft_beam_api/ 25 | ``` 26 | 27 | """ 28 | from absl import app 29 | from absl import flags 30 | 31 | from tensorflow_docs.api_generator import doc_controls 32 | from tensorflow_docs.api_generator import generate_lib 33 | from tensorflow_docs.api_generator import public_api 34 | import tensorflow_transform.beam as tft_beam 35 | 36 | 37 | flags.DEFINE_string('output_dir', '/tmp/tft_beam_api/', 38 | 'The path to output the files to') 39 | 40 | flags.DEFINE_string( 41 | 'code_url_prefix', 42 | 'https://github.com/tensorflow/transform/tree/master/tensorflow_transform', 43 | 'The url prefix for links to code.') 44 | 45 | flags.DEFINE_bool('search_hints', True, 46 | 'Include metadata search hints in the generated files') 47 | 48 | flags.DEFINE_string('site_path', 'tfx/transform/api_docs/python', 49 | 'Path prefix in the _toc.yaml') 50 | 51 | FLAGS = flags.FLAGS 52 | 53 | 54 | def main(args): 55 | if args[1:]: 56 | raise ValueError('Unrecognized Command line args', args[1:]) 57 | 58 | doc_controls.do_not_generate_docs(tft_beam.analyzer_impls) 59 | 60 | doc_generator = generate_lib.DocGenerator( 61 | root_title='TFT-Beam', 62 | py_modules=[('tft_beam', tft_beam)], 63 | code_url_prefix=FLAGS.code_url_prefix + '/beam', 64 | search_hints=FLAGS.search_hints, 65 | site_path=FLAGS.site_path, 66 | callbacks=[ 67 | public_api.explicit_package_contents_filter, 68 | public_api.local_definitions_filter 69 | ]) 70 | 71 | doc_generator.build(FLAGS.output_dir) 72 | 73 | 74 | if __name__ == '__main__': 75 | app.run(main) 76 | -------------------------------------------------------------------------------- /docs/build_tft_docs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Generate docs for `tft`. 16 | 17 | This requires a local installation of `tft` and `tensoirflow_docs` 18 | 19 | ``` 20 | $ pip install tensorflow_transform git+https://github.com/tensorflow/docs 21 | ``` 22 | 23 | ``` 24 | python build_tft_docs.py --output_dir=/tmp/tft-api 25 | ``` 26 | 27 | """ 28 | from absl import app 29 | from absl import flags 30 | 31 | from tensorflow_docs.api_generator import generate_lib 32 | from tensorflow_docs.api_generator import public_api 33 | import tensorflow_transform as transform 34 | 35 | 36 | flags.DEFINE_string('output_dir', '/tmp/tft_api/', 37 | 'The path to output the files to') 38 | 39 | flags.DEFINE_string( 40 | 'code_url_prefix', 41 | 'https://github.com/tensorflow/transform/tree/master/tensorflow_transform', 42 | 'The url prefix for links to code.') 43 | 44 | flags.DEFINE_bool('search_hints', True, 45 | 'Include metadata search hints in the generated files') 46 | 47 | flags.DEFINE_string('site_path', 'tfx/transform/api_docs/python', 48 | 'Path prefix in the _toc.yaml') 49 | 50 | FLAGS = flags.FLAGS 51 | 52 | 53 | def main(args): 54 | if args[1:]: 55 | raise ValueError('Unrecognized Command line args', args[1:]) 56 | 57 | doc_generator = generate_lib.DocGenerator( 58 | root_title='TF-Transform', 59 | py_modules=[('tft', transform)], 60 | code_url_prefix=FLAGS.code_url_prefix, 61 | search_hints=FLAGS.search_hints, 62 | site_path=FLAGS.site_path, 63 | callbacks=[public_api.explicit_package_contents_filter]) 64 | 65 | doc_generator.build(FLAGS.output_dir) 66 | 67 | 68 | if __name__ == '__main__': 69 | app.run(main) 70 | -------------------------------------------------------------------------------- /docs/common_transformations.md: -------------------------------------------------------------------------------- 1 | # Common Transformations 2 | 3 | [TOC] 4 | 5 | In this document we describe how to do common transformations with tf.transform. 6 | 7 | We assume you have already constructed the beam pipeline along the lines of the 8 | examples, and only describe what needs to be added to `preprocessing_fn` and 9 | possibly model. 10 | 11 | ## Using String/Categorical data 12 | 13 | The following `preprocessing_fn` will compute a vocabulary over the values of 14 | feature `x` with tokens in descending frequency order, convert feature `x` 15 | values to their index in the vocabulary, and finally perform a one-hot encoding 16 | for the output. 17 | 18 | This is common for example in use cases where the label feature is a categorical 19 | string. 20 | The resulting one-hot encoding is ready for training. 21 | 22 | Note: this example produces `x_out` as a potentially large dense tensor. This is 23 | fine as long as the transformed data doesn't get materialized, and this is the 24 | format expected in training. Otherwise, a more efficient representation would be 25 | a `tf.SparseTensor`, in which case only a single index and value (1) is used to 26 | represent each instance. 27 | 28 | ```python 29 | def preprocessing_fn(inputs): 30 | integerized = tft.compute_and_apply_vocabulary( 31 | inputs['x'], 32 | num_oov_buckets=1, 33 | vocab_filename='x_vocab') 34 | one_hot_encoded = tf.one_hot( 35 | integerized, 36 | depth=tf.cast(tft.experimental.get_vocabulary_size_by_name('x_vocab') + 1, 37 | tf.int32), 38 | on_value=1.0, 39 | off_value=0.0) 40 | return { 41 | 'x_out': one_hot_encoded, 42 | } 43 | ``` 44 | 45 | ## Mean imputation for missing data 46 | 47 | In this example, feature `x` is an optional feature, represented as a 48 | `tf.SparseTensor` in the `preprocessing_fn`. In order to convert it to a dense 49 | tensor, we compute its mean, and set the mean to be the default value when it 50 | is missing from an instance. 51 | 52 | The resulting dense tensor will have the shape `[None, 1]`, `None` represents 53 | the batch dimension, and for the second dimension it will be the number of 54 | values that `x` can have per instance. In this case it's 1. 55 | 56 | ```python 57 | def preprocessing_fn(inputs): 58 | return { 59 | 'x_out': tft.sparse_tensor_to_dense_with_shape( 60 | inputs['x'], default_value=tft.mean(x), shape=[None, 1]) 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /docs/tf2_support.md: -------------------------------------------------------------------------------- 1 | # Using tf.Transform with TensorFlow 2.x 2 | 3 | Starting with the `0.30` release of `tf.Transform`, the default behavior is to 4 | export a TF 2.x SavedModel unless TF 2.x behaviors are explicitly disabled. This 5 | page provides a guide for using `tf.Transform` to export the transform graph as 6 | a TensorFlow 2.x SavedModel. 7 | 8 | ## New in tf.Transform with TF 2.x 9 | 10 | #### Loading Keras models within the `preprocessing_fn` 11 | 12 | Please use the `tft.make_and_track_object` API to load Keras models as shown in 13 | the example below. 14 | 15 | ```python 16 | def preprocessing_fn(inputs): 17 | keras_model = tft.make_and_track_object(lambda: tf.keras.models.load_model(...), name='_unique_name') 18 | ... 19 | return {'keras_model_output': keras_model(inputs[...])} 20 | ``` 21 | 22 | ### Using TF 2.x tf.hub modules 23 | 24 | TF 2.x hub modules work in `tf.Transform` only when the `preprocessing_fn` is 25 | traced and exported as a TF 2.x SavedModel (this is the default behavior 26 | starting with `tensorflow_transform 0.30`). Please use the 27 | `tft.make_and_track_object` API to load `tf.hub` modules as shown in the example 28 | below. 29 | 30 | ```python 31 | def preprocessing_fn(inputs): 32 | hub_module = tft.make_and_track_object(lambda: hub.load(...)) 33 | ... 34 | return {'hub_module_output': hub_module(inputs[...])} 35 | ``` 36 | 37 | ## Potential migration issues 38 | 39 | If migrating an existing `tf.Transform` pipeline from TF 1.x to TF 2.x, the 40 | following issues may be encountered: 41 | 42 | ### RuntimeError: The order of analyzers in your `preprocessing_fn` appears to be non-deterministic. 43 | 44 | In TF 2.x, the `preprocessing_fn` provided by the user is traced several times. 45 | If the order in which TFT analyzers are encountered changes with each trace, 46 | this error will be raised. This can be fixed by removing any non-determinism in 47 | the order in which TFT analyzers are invoked. 48 | 49 | ### Output of `transform_raw_features` does not contain expected feature. 50 | 51 | Example exceptions: 52 | 53 | ```shell 54 | KeyError: \ 55 | ``` 56 | 57 | or 58 | 59 | ```shell 60 | \ not found in features dictionary. 61 | ``` 62 | 63 | [`TFTransformOutput.transform_raw_features`](https://www.tensorflow.org/tfx/transform/api_docs/python/tft/TFTransformOutput#transform_raw_features) 64 | ignores the `drop_unused_features` parameter and behaves as if it were True. 65 | Please update any usages of the output dictionary from this API to check if the 66 | key you are attempting to retrieve exists in it. 67 | 68 | ### tf.estimator.BaselineClassifier sees Table not initialized error. 69 | 70 | Example exception: 71 | 72 | ```shell 73 | tensorflow.python.framework.errors_impl.FailedPreconditionError: Table not initialized. 74 | ``` 75 | 76 | Support for Trainer with Estimator based executor is best-effort. While other 77 | estimators work, we have seen issues with table initialization in the 78 | BaselineClassifier. Please 79 | [disable TF 2.x in `tf.Transform`](https://www.tensorflow.org/tfx/transform/tf2_support#retaining_the_legacy_tftransform_behavior). 80 | 81 | ## Known issues / Features not yet supported 82 | 83 | ### Outputting vocabularies in TFRecord format is not yet supported. 84 | 85 | `tfrecord_gzip` is not yet supported as a valid value for the `file_format` 86 | parameter in `tft.vocabulary` (and other vocabulary APIs). 87 | 88 | ## Retaining the legacy tf.Transform behavior 89 | 90 | If your `tf.Transform` pipeline should not run with TF 2.x, you can retain the 91 | legacy behavior in one of the following ways: 92 | 93 | * Disable TF2 in `tf.Transform` by calling 94 | `tf.compat.v1.disable_v2_behavior()` 95 | * Passing `force_tf_compat_v1=True` to `tft_beam.Context` if using 96 | `tf.Transform` as a standalone library or to the Transform component in TFX. 97 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # TensorFlow Transform Examples 4 | 5 | ## Simple example 6 | 7 | There's a minimal TFX example available in the [GitHub repo](./simple_example.py). 8 | 9 | ## Census income example 10 | 11 | The *Census income* dataset is provided by the 12 | [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Census+Income). 13 | This dataset contains both categorical and numeric data. See 14 | [Get started with TensorFlow Transform](https://www.tensorflow.org/tfx/transform/get_started) 15 | for details. 16 | 17 | ## Sentiment analysis example 18 | 19 | Similar to the *Census income* example, but requires more extensive Apache Beam 20 | processing before `tf.Transform` is invoked. See the 21 | [sentiment analysis](./sentiment.md) for more information. 22 | -------------------------------------------------------------------------------- /examples/census_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Example using census data from UCI repository.""" 15 | 16 | # pylint: disable=g-bad-import-order 17 | import os 18 | import pprint 19 | import tempfile 20 | 21 | import tensorflow as tf 22 | from tensorflow import estimator as tf_estimator 23 | import tensorflow_transform as tft 24 | import census_example_common as common 25 | 26 | # Functions for training 27 | 28 | 29 | def _make_inputs_dense(transformed_features): 30 | return { 31 | k: tf.sparse.to_dense(v) if isinstance(v, tf.SparseTensor) else v 32 | for k, v in transformed_features.items() 33 | } 34 | # pylint: disable=g-deprecated-tf-checker 35 | 36 | 37 | def _make_training_input_fn(tf_transform_output, transformed_examples, 38 | batch_size): 39 | """Creates an input function reading from transformed data. 40 | 41 | Args: 42 | tf_transform_output: Wrapper around output of tf.Transform. 43 | transformed_examples: Base filename of examples. 44 | batch_size: Batch size. 45 | 46 | Returns: 47 | The input function for training or eval. 48 | """ 49 | def input_fn(): 50 | """Input function for training and eval.""" 51 | dataset = tf.data.experimental.make_batched_features_dataset( 52 | file_pattern=transformed_examples, 53 | batch_size=batch_size, 54 | features=tf_transform_output.transformed_feature_spec(), 55 | reader=tf.data.TFRecordDataset, 56 | shuffle=True) 57 | 58 | transformed_features = _make_inputs_dense( 59 | tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() 60 | ) 61 | 62 | # Extract features and label from the transformed tensors. 63 | # TODO(b/30367437): make transformed_labels a dict. 64 | transformed_labels = tf.where( 65 | tf.equal(transformed_features.pop(common.LABEL_KEY), 1)) 66 | 67 | return transformed_features, transformed_labels[:, 1] 68 | 69 | return input_fn 70 | 71 | 72 | def _make_serving_input_fn(tf_transform_output): 73 | """Creates an input function reading from raw data. 74 | 75 | Args: 76 | tf_transform_output: Wrapper around output of tf.Transform. 77 | 78 | Returns: 79 | The serving input function. 80 | """ 81 | raw_feature_spec = common.RAW_DATA_FEATURE_SPEC.copy() 82 | # Remove label since it is not available during serving. 83 | raw_feature_spec.pop(common.LABEL_KEY) 84 | 85 | def serving_input_fn(): 86 | """Input function for serving.""" 87 | # Get raw features by generating the basic serving input_fn and calling it. 88 | # Here we generate an input_fn that expects a parsed Example proto to be fed 89 | # to the model at serving time. See also 90 | # tf.estimator.export.build_raw_serving_input_receiver_fn. 91 | raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( 92 | raw_feature_spec, default_batch_size=None) 93 | serving_input_receiver = raw_input_fn() 94 | 95 | # Apply the transform function that was used to generate the materialized 96 | # data. 97 | raw_features = serving_input_receiver.features 98 | transformed_features = _make_inputs_dense( 99 | tf_transform_output.transform_raw_features(raw_features) 100 | ) 101 | 102 | return tf_estimator.export.ServingInputReceiver( 103 | transformed_features, serving_input_receiver.receiver_tensors) 104 | 105 | return serving_input_fn 106 | 107 | 108 | def get_feature_columns(tf_transform_output): 109 | """Returns the FeatureColumns for the model. 110 | 111 | Args: 112 | tf_transform_output: A `TFTransformOutput` object. 113 | 114 | Returns: 115 | A list of FeatureColumns. 116 | """ 117 | feature_spec = tf_transform_output.transformed_feature_spec() 118 | # Wrap scalars as real valued columns. 119 | def get_shape(spec): 120 | if isinstance(spec, tf.io.SparseFeature): 121 | return spec.size 122 | return spec.shape 123 | 124 | return [ 125 | tf.feature_column.numeric_column(key, shape=get_shape(feature_spec[key])) 126 | for key in (common.NUMERIC_FEATURE_KEYS + common.CATEGORICAL_FEATURE_KEYS) 127 | ] 128 | 129 | 130 | def train_and_evaluate(working_dir, 131 | num_train_instances=common.NUM_TRAIN_INSTANCES, 132 | num_test_instances=common.NUM_TEST_INSTANCES): 133 | """Train the model on training data and evaluate on test data. 134 | 135 | Args: 136 | working_dir: Directory to read transformed data and metadata from and to 137 | write exported model to. 138 | num_train_instances: Number of instances in train set 139 | num_test_instances: Number of instances in test set 140 | 141 | Returns: 142 | The results from the estimator's 'evaluate' method 143 | """ 144 | tf_transform_output = tft.TFTransformOutput(working_dir) 145 | 146 | run_config = tf_estimator.RunConfig() 147 | 148 | estimator = tf_estimator.LinearClassifier( 149 | feature_columns=get_feature_columns(tf_transform_output), 150 | config=run_config, 151 | loss_reduction=tf.losses.Reduction.SUM) 152 | 153 | # Fit the model using the default optimizer. 154 | train_input_fn = _make_training_input_fn( 155 | tf_transform_output, 156 | os.path.join(working_dir, common.TRANSFORMED_TRAIN_DATA_FILEBASE + '*'), 157 | batch_size=common.TRAIN_BATCH_SIZE) 158 | estimator.train( 159 | input_fn=train_input_fn, 160 | max_steps=common.TRAIN_NUM_EPOCHS * num_train_instances / 161 | common.TRAIN_BATCH_SIZE) 162 | 163 | # Evaluate model on test dataset. 164 | eval_input_fn = _make_training_input_fn( 165 | tf_transform_output, 166 | os.path.join(working_dir, common.TRANSFORMED_TEST_DATA_FILEBASE + '*'), 167 | batch_size=1) 168 | 169 | # Export the model. 170 | serving_input_fn = _make_serving_input_fn(tf_transform_output) 171 | exported_model_dir = os.path.join(working_dir, common.EXPORTED_MODEL_DIR) 172 | estimator.export_saved_model(exported_model_dir, serving_input_fn) 173 | 174 | return estimator.evaluate(input_fn=eval_input_fn, steps=num_test_instances) 175 | 176 | 177 | def main(): 178 | args = common.get_args() 179 | if args.working_dir: 180 | working_dir = args.working_dir 181 | else: 182 | working_dir = tempfile.mkdtemp(dir=args.input_data_dir) 183 | 184 | train_data_file = os.path.join(args.input_data_dir, 'adult.data') 185 | test_data_file = os.path.join(args.input_data_dir, 'adult.test') 186 | 187 | common.transform_data(train_data_file, test_data_file, working_dir) 188 | 189 | results = train_and_evaluate(working_dir) 190 | 191 | pprint.pprint(results) 192 | 193 | if __name__ == '__main__': 194 | main() 195 | -------------------------------------------------------------------------------- /examples/census_example_v2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for census_example_v2.""" 15 | 16 | import os 17 | import shutil 18 | from packaging import version 19 | 20 | import tensorflow.compat.v2 as tf 21 | import census_example_v2 22 | from tensorflow_transform import test_case as tft_test_case 23 | import local_model_server 24 | from tensorflow_transform.keras_lib import tf_keras 25 | 26 | from google.protobuf import text_format 27 | 28 | # Use first row of test data set, which has high probability on label 1 (which 29 | # corresponds to '<=50K'). 30 | _PREDICT_TF_EXAMPLE_TEXT_PB = """ 31 | features { 32 | feature { 33 | key: "age" 34 | value { float_list: { value: 25 } } 35 | } 36 | feature { 37 | key: "workclass" 38 | value { bytes_list: { value: "Private" } } 39 | } 40 | feature { 41 | key: "education" 42 | value { bytes_list: { value: "11th" } } 43 | } 44 | feature { 45 | key: "education-num" 46 | value { float_list: { value: 7 } } 47 | } 48 | feature { 49 | key: "marital-status" 50 | value { bytes_list: { value: "Never-married" } } 51 | } 52 | feature { 53 | key: "occupation" 54 | value { bytes_list: { value: "Machine-op-inspct" } } 55 | } 56 | feature { 57 | key: "relationship" 58 | value { bytes_list: { value: "Own-child" } } 59 | } 60 | feature { 61 | key: "race" 62 | value { bytes_list: { value: "Black" } } 63 | } 64 | feature { 65 | key: "sex" 66 | value { bytes_list: { value: "Male" } } 67 | } 68 | feature { 69 | key: "capital-gain" 70 | value { float_list: { value: 0 } } 71 | } 72 | feature { 73 | key: "capital-loss" 74 | value { float_list: { value: 0 } } 75 | } 76 | feature { 77 | key: "hours-per-week" 78 | value { float_list: { value: 40 } } 79 | } 80 | feature { 81 | key: "native-country" 82 | value { bytes_list: { value: "United-States" } } 83 | } 84 | } 85 | """ 86 | 87 | _MODEL_NAME = 'my_model' 88 | 89 | _CLASSIFICATION_REQUEST_TEXT_PB = """model_spec { name: "%s" } 90 | input { 91 | example_list { 92 | examples { 93 | %s 94 | } 95 | } 96 | }""" % (_MODEL_NAME, _PREDICT_TF_EXAMPLE_TEXT_PB) 97 | 98 | 99 | class CensusExampleV2Test(tft_test_case.TransformTestCase): 100 | 101 | def setUp(self): 102 | super().setUp() 103 | if tft_test_case.is_external_environment() and version.parse( 104 | tf.version.VERSION 105 | ) < version.parse('2.3'): 106 | raise tft_test_case.SkipTest('This test requires TF version >= 2.3') 107 | 108 | def _get_data_dir(self): 109 | return os.path.join(os.path.dirname(__file__), 'testdata/census') 110 | 111 | def _get_working_dir(self): 112 | return os.path.join( 113 | os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), 114 | self._testMethodName) 115 | 116 | def _should_saved_model_load_work(self): 117 | return version.parse(tf.__version__) >= version.parse('2.2') 118 | 119 | @tft_test_case.named_parameters([ 120 | dict( 121 | testcase_name='_read_raw_data_for_training', 122 | read_raw_data_for_training=True), 123 | dict( 124 | testcase_name='_read_transformed_data_for_training', 125 | read_raw_data_for_training=False), 126 | ]) 127 | def testCensusExampleAccuracy(self, read_raw_data_for_training): 128 | 129 | if not self._should_saved_model_load_work(): 130 | self.skipTest('The generated SavedModel cannot be read with TF<2.2') 131 | raw_data_dir = self._get_data_dir() 132 | working_dir = self._get_working_dir() 133 | 134 | train_data_file = os.path.join(raw_data_dir, 'adult.data') 135 | test_data_file = os.path.join(raw_data_dir, 'adult.test') 136 | 137 | census_example_v2.transform_data( 138 | train_data_file, test_data_file, working_dir 139 | ) 140 | 141 | if read_raw_data_for_training: 142 | raw_train_and_eval_patterns = (train_data_file, test_data_file) 143 | transformed_train_and_eval_patterns = None 144 | else: 145 | train_pattern = os.path.join( 146 | working_dir, census_example_v2.TRANSFORMED_TRAIN_DATA_FILEBASE + '*' 147 | ) 148 | eval_pattern = os.path.join( 149 | working_dir, census_example_v2.TRANSFORMED_TEST_DATA_FILEBASE + '*' 150 | ) 151 | raw_train_and_eval_patterns = None 152 | transformed_train_and_eval_patterns = (train_pattern, eval_pattern) 153 | output_dir = os.path.join(working_dir, census_example_v2.EXPORTED_MODEL_DIR) 154 | results = census_example_v2.train_and_evaluate( 155 | raw_train_and_eval_patterns, 156 | transformed_train_and_eval_patterns, 157 | output_dir, 158 | working_dir, 159 | num_train_instances=1000, 160 | num_test_instances=1000) 161 | self.assertGreaterEqual(results[1], 0.7) 162 | 163 | # Removing the tf.Transform output directory in order to show that the 164 | # exported model is hermetic. 165 | shutil.rmtree(os.path.join(working_dir, 'transform_fn')) 166 | 167 | model_path = os.path.join(working_dir, census_example_v2.EXPORTED_MODEL_DIR) 168 | 169 | actual_model_path = os.path.join(model_path, '1') 170 | tf_keras.backend.clear_session() 171 | model = tf_keras.models.load_model(actual_model_path) 172 | model.summary() 173 | 174 | example = text_format.Parse(_PREDICT_TF_EXAMPLE_TEXT_PB, tf.train.Example()) 175 | prediction = model.signatures['serving_default']( 176 | tf.constant([example.SerializeToString()], tf.string)) 177 | self.assertAllEqual([['0', '1']], prediction['classes']) 178 | self.assertAllClose([[0, 1]], prediction['scores'], atol=0.01) 179 | 180 | # This is required in order to support the classify API for this Keras 181 | # model. 182 | updater = tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater( 183 | actual_model_path) 184 | updater.replace_method_name( 185 | signature_key='serving_default', 186 | method_name='tensorflow/serving/classify', 187 | tags=['serve']) 188 | updater.save() 189 | 190 | if local_model_server.local_model_server_supported(): 191 | with local_model_server.start_server(_MODEL_NAME, model_path) as address: 192 | ascii_classification_request = _CLASSIFICATION_REQUEST_TEXT_PB 193 | results = local_model_server.make_classification_request( 194 | address, ascii_classification_request) 195 | self.assertLen(results, 1) 196 | self.assertLen(results[0].classes, 2) 197 | self.assertEqual(results[0].classes[0].label, '0') 198 | self.assertLess(results[0].classes[0].score, 0.01) 199 | self.assertEqual(results[0].classes[1].label, '1') 200 | self.assertGreater(results[0].classes[1].score, 0.99) 201 | 202 | def test_main_runs(self): 203 | census_example_v2.main( 204 | self._get_data_dir(), 205 | self._get_working_dir(), 206 | read_raw_data_for_training=False, 207 | num_train_instances=10, 208 | num_test_instances=10) 209 | 210 | def test_main_runs_raw_data(self): 211 | census_example_v2.main( 212 | self._get_data_dir(), 213 | self._get_working_dir(), 214 | read_raw_data_for_training=True, 215 | num_train_instances=10, 216 | num_test_instances=10) 217 | 218 | 219 | if __name__ == '__main__': 220 | tf.test.main() 221 | -------------------------------------------------------------------------------- /examples/dataset_tfxio_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Simple Example of DatasetTFXIO usage.""" 15 | 16 | import pprint 17 | import tempfile 18 | 19 | from absl import app 20 | import apache_beam as beam 21 | import tensorflow as tf 22 | import tensorflow_transform as tft 23 | import tensorflow_transform.beam.impl as tft_beam 24 | from tfx_bsl.tfxio import dataset_tfxio 25 | 26 | 27 | def _print_record_batch(data): 28 | pprint.pprint(data.to_pydict()) 29 | 30 | 31 | def _preprocessing_fn(inputs): 32 | return { 33 | 'x_centered': tf.cast(inputs['feature0'], tf.float32) - tft.mean( 34 | inputs['feature0'] 35 | ), 36 | 'x_scaled': tft.scale_by_min_max(inputs['feature0']), 37 | } 38 | 39 | 40 | def _make_tfxio() -> dataset_tfxio.DatasetTFXIO: 41 | """Make DatasetTFXIO.""" 42 | num_elements = 9 43 | batch_size = 2 44 | dataset = tf.data.Dataset.range(num_elements).batch(batch_size) 45 | 46 | return dataset_tfxio.DatasetTFXIO(dataset=dataset) 47 | 48 | 49 | def main(args): 50 | del args 51 | 52 | input_tfxio = _make_tfxio() 53 | 54 | # User-Defined Processing Pipeline 55 | with beam.Pipeline() as pipeline: 56 | with tft_beam.Context(temp_dir=tempfile.mkdtemp()): 57 | raw_dataset = ( 58 | pipeline | 'ReadRecordBatch' >> input_tfxio.BeamSource(batch_size=5), 59 | input_tfxio.TensorAdapterConfig(), 60 | ) 61 | (transformed_data, _), _ = ( 62 | raw_dataset 63 | | 'AnalyzeAndTransform' 64 | >> tft_beam.AnalyzeAndTransformDataset( 65 | _preprocessing_fn, output_record_batches=True 66 | ) 67 | ) 68 | transformed_data = transformed_data | 'ExtractRecordBatch' >> beam.Keys() 69 | _ = transformed_data | 'PrintTransformedData' >> beam.Map( 70 | _print_record_batch 71 | ) 72 | 73 | 74 | if __name__ == '__main__': 75 | app.run(main) 76 | -------------------------------------------------------------------------------- /examples/dataset_tfxio_example_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dataset_tfxio.""" 15 | 16 | import tensorflow as tf 17 | import dataset_tfxio_example 18 | from tensorflow_transform.beam import tft_unit 19 | 20 | 21 | _EXPECTED_TRANSFORMED_OUTPUT = [ 22 | {'x_scaled': 0.0, 'x_centered': -4.0}, 23 | {'x_scaled': 0.125, 'x_centered': -3.0}, 24 | {'x_scaled': 0.25, 'x_centered': -2.0}, 25 | {'x_scaled': 0.375, 'x_centered': -1.0}, 26 | {'x_scaled': 0.5, 'x_centered': 0.0}, 27 | {'x_scaled': 0.625, 'x_centered': 1.0}, 28 | {'x_scaled': 0.75, 'x_centered': 2.0}, 29 | {'x_scaled': 0.875, 'x_centered': 3.0}, 30 | {'x_scaled': 1.0, 'x_centered': 4.0}, 31 | ] 32 | 33 | 34 | class SimpleMainTest(tf.test.TestCase): 35 | 36 | def testMainDoesNotCrash(self): 37 | tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') 38 | dataset_tfxio_example.main('') 39 | 40 | 41 | class SimpleProcessingTest(tft_unit.TransformTestCase): 42 | 43 | # Asserts equal for each element. (Does not check batchwise.) 44 | def test_preprocessing_fn(self): 45 | tfxio = dataset_tfxio_example._make_tfxio() 46 | self.assertAnalyzeAndTransformResults( 47 | tfxio.BeamSource(), 48 | tfxio.TensorAdapterConfig(), 49 | dataset_tfxio_example._preprocessing_fn, 50 | _EXPECTED_TRANSFORMED_OUTPUT, 51 | ) 52 | 53 | 54 | if __name__ == '__main__': 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /examples/local_model_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Local model server for testing.""" 15 | 16 | import contextlib 17 | 18 | 19 | def local_model_server_supported(): 20 | return False 21 | 22 | 23 | @contextlib.contextmanager 24 | def start_server(model_name, model_path): 25 | del model_name # unused 26 | del model_path # unused 27 | raise NotImplementedError 28 | 29 | 30 | # TODO(KesterTong): Change the input of make_classification_request to not be a 31 | # string. This will require adding a test-only dependency on 32 | # tensorflow_serving.apis. 33 | def make_classification_request(address, ascii_classification_request): 34 | """Makes a classify request to a local server.""" 35 | del address # unused 36 | del ascii_classification_request # unused 37 | raise NotImplementedError 38 | -------------------------------------------------------------------------------- /examples/sentiment.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Sentiment Analysis 4 | 5 | [`sentiment_example.py`](./sentiment_example.py) 6 | uses the [Large Movie Review Dataset](http://ai.stanford.edu/~amaas/data/sentiment/) 7 | and contains 50,000 movie reviews equally split into train and test sets. To run 8 | this example, download and unzip the data set to a directory. Pass this 9 | directory as an argument to `sentiment_example.py`. The script creates a 10 | temporary sub-directory to add the preprocessed data. 11 | 12 | This example is similar to the 13 | [Census income example](https://github.com/tensorflow/transform/blob/master/docs/get_started.md) but 14 | requires more extensive Apache Beam processing before invoking `tf.Transform`. 15 | Here, the data must be read from multiple files across separate directories for 16 | positive and negative examples. Then, the correct labels are attached to the 17 | dataset and shuffled. 18 | 19 | Since the input data uses separate files for each review (with separate 20 | directories for positive and negative reviews), this example first reads in 21 | the original data and transcodes it into `tf.Example`s in `TFRecords`. Then 22 | we use a pre-canned [TFXIO](https://www.tensorflow.org/tfx/tfx_bsl/api_docs/python/tfx_bsl/public/tfxio) to read those `tf.Example`s into what TFT accepts. 23 | 24 | The `tf.Transform` preprocessing is more complex. Unlike the Census income 25 | example, the data in this example uses a single feature for the full text of a 26 | movie review. This is split into sentences using the `tf.string_split` 27 | function. The `tf.string_split` function takes a rank 1 tensor and converts it 28 | to a rank 2 `SparseTensor` that contains the individual tokens. Then, using 29 | `tft.compute_and_apply_vocabulary`, this `SparseTensor` is converted to a 30 | `SparseTensor` of `int64`s with the same shape. 31 | 32 | During the training and evaluation phase, the `SparseTensor` that represents 33 | the review text (tokenized and integerized) is used as the input to a 34 | bag-of-words model. In particular, the tensor is passed to a `CategoryEncoding`. 35 | However, instead of a vector with a length of `1` (per instance), there's a 36 | vector with the length of the number of tokens. In this circumstance, the vector 37 | of integerized tokens is interpreted as a bag-of-words. 38 | The vector is then used to create embeddings which can be passed to a Keras 39 | based DNN for training. 40 | -------------------------------------------------------------------------------- /examples/sentiment_example_v2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for sentiment_example_v2.""" 15 | 16 | import os 17 | import shutil 18 | 19 | import tensorflow as tf 20 | import tensorflow_transform as tft 21 | import sentiment_example_v2 22 | from tensorflow_transform import test_case 23 | import local_model_server 24 | 25 | 26 | class SentimentExampleTest(test_case.TransformTestCase): 27 | 28 | def testSentimentExampleAccuracy(self): 29 | raw_data_dir = os.path.join(os.path.dirname(__file__), 'testdata/sentiment') 30 | working_dir = self.get_temp_dir() 31 | 32 | # Copy data from raw data directory to `working_dir` 33 | try: 34 | for filename in ['test_shuffled-00000-of-00001', 35 | 'train_shuffled-00000-of-00001']: 36 | shutil.copy(os.path.join(raw_data_dir, filename), working_dir) 37 | except FileNotFoundError: 38 | # We only use a small sample of the data for testing purposes. 39 | train_neg_filepattern = os.path.join(raw_data_dir, 'train/neg/10000*') 40 | train_pos_filepattern = os.path.join(raw_data_dir, 'train/pos/10000*') 41 | test_neg_filepattern = os.path.join(raw_data_dir, 'test/neg/10000*') 42 | test_pos_filepattern = os.path.join(raw_data_dir, 'test/pos/10000*') 43 | 44 | # Writes the shuffled data under working_dir in TFRecord format. 45 | sentiment_example_v2.read_and_shuffle_data( 46 | train_neg_filepattern, 47 | train_pos_filepattern, 48 | test_neg_filepattern, 49 | test_pos_filepattern, 50 | working_dir, 51 | ) 52 | 53 | sentiment_example_v2.transform_data(working_dir) 54 | # TODO: b/323209255 - Remove this if clause once TF pulls the latest keras 55 | # nightly version. 56 | if not test_case.is_external_environment(): 57 | model_path = os.path.join( 58 | working_dir, sentiment_example_v2.EXPORTED_MODEL_DIR 59 | ) 60 | results = sentiment_example_v2.train_and_evaluate( 61 | working_dir, 62 | model_path, 63 | num_train_instances=1000, 64 | num_test_instances=1000, 65 | ) 66 | if not test_case.is_external_environment(): 67 | # Assert expected accuracy. 68 | self.assertGreaterEqual(results[1], 0.7) 69 | 70 | # Delete temp directory and transform_fn directory. This ensures that the 71 | # test of serving the model below will only pass if the SavedModel saved 72 | # to sentiment_example_v2.EXPORTED_MODEL_DIR is hermetic, i.e does not 73 | # contain references to tft_temp and transform_fn. 74 | shutil.rmtree( 75 | os.path.join(working_dir, sentiment_example_v2.TRANSFORM_TEMP_DIR) 76 | ) 77 | shutil.rmtree( 78 | os.path.join(working_dir, tft.TFTransformOutput.TRANSFORM_FN_DIR)) 79 | 80 | if local_model_server.local_model_server_supported(): 81 | model_name = 'my_model' 82 | with local_model_server.start_server(model_name, model_path) as address: 83 | # Use made up data chosen to give high probability of negative 84 | # sentiment. 85 | ascii_classification_request = """model_spec { name: "my_model" } 86 | input { 87 | example_list { 88 | examples { 89 | features { 90 | feature { 91 | key: "review" 92 | value: { 93 | bytes_list { 94 | value: "errible terrible terrible terrible terrible terrible terrible." 95 | } 96 | } 97 | } 98 | } 99 | } 100 | } 101 | }""" 102 | results = local_model_server.make_classification_request( 103 | address, ascii_classification_request) 104 | self.assertLen(results, 1) 105 | self.assertLen(results[0].classes, 2) 106 | self.assertEqual(results[0].classes[0].label, '0') 107 | self.assertGreater(results[0].classes[0].score, 0.8) 108 | self.assertEqual(results[0].classes[1].label, '1') 109 | self.assertLess(results[0].classes[1].score, 0.2) 110 | 111 | 112 | if __name__ == '__main__': 113 | tf.test.main() 114 | -------------------------------------------------------------------------------- /examples/simple_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Simple Example of tf.Transform usage.""" 15 | 16 | import pprint 17 | import tempfile 18 | 19 | import tensorflow as tf 20 | import tensorflow_transform as tft 21 | import tensorflow_transform.beam as tft_beam 22 | 23 | _RAW_DATA_METADATA = tft.DatasetMetadata.from_feature_spec({ 24 | 's': tf.io.FixedLenFeature([], tf.string), 25 | 'y': tf.io.FixedLenFeature([], tf.float32), 26 | 'x': tf.io.FixedLenFeature([], tf.float32), 27 | }) 28 | 29 | _RAW_DATA = [{ 30 | 'x': 1, 31 | 'y': 1, 32 | 's': 'hello' 33 | }, { 34 | 'x': 2, 35 | 'y': 2, 36 | 's': 'world' 37 | }, { 38 | 'x': 3, 39 | 'y': 3, 40 | 's': 'hello' 41 | }] 42 | 43 | 44 | def _preprocessing_fn(inputs): 45 | """Preprocess input columns into transformed columns.""" 46 | x = inputs['x'] 47 | y = inputs['y'] 48 | s = inputs['s'] 49 | x_centered = x - tft.mean(x) 50 | y_normalized = tft.scale_to_0_1(y) 51 | s_integerized = tft.compute_and_apply_vocabulary(s) 52 | x_centered_times_y_normalized = (x_centered * y_normalized) 53 | return { 54 | 'x_centered': x_centered, 55 | 'y_normalized': y_normalized, 56 | 'x_centered_times_y_normalized': x_centered_times_y_normalized, 57 | 's_integerized': s_integerized 58 | } 59 | 60 | 61 | def main(): 62 | 63 | with tft_beam.Context(temp_dir=tempfile.mkdtemp()): 64 | transformed_dataset, transform_fn = ( # pylint: disable=unused-variable 65 | (_RAW_DATA, _RAW_DATA_METADATA) 66 | | tft_beam.AnalyzeAndTransformDataset(_preprocessing_fn)) 67 | 68 | transformed_data, transformed_metadata = transformed_dataset # pylint: disable=unused-variable 69 | 70 | pprint.pprint(transformed_data) 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /examples/simple_example_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for simple_example.""" 15 | 16 | import tensorflow as tf 17 | import simple_example 18 | from tensorflow_transform.beam import tft_unit 19 | 20 | 21 | _EXPECTED_TRANSFORMED_OUTPUT = [ 22 | { 23 | 'x_centered': 1.0, 24 | 'y_normalized': 1.0, 25 | 'x_centered_times_y_normalized': 1.0, 26 | 's_integerized': 0, 27 | }, 28 | { 29 | 'x_centered': 0.0, 30 | 'y_normalized': 0.5, 31 | 'x_centered_times_y_normalized': 0.0, 32 | 's_integerized': 1, 33 | }, 34 | { 35 | 'x_centered': -1.0, 36 | 'y_normalized': 0.0, 37 | 'x_centered_times_y_normalized': -0.0, 38 | 's_integerized': 0, 39 | }, 40 | ] 41 | 42 | 43 | class SimpleExampleTest(tft_unit.TransformTestCase): 44 | 45 | def test_preprocessing_fn(self): 46 | self.assertAnalyzeAndTransformResults(simple_example._RAW_DATA, 47 | simple_example._RAW_DATA_METADATA, 48 | simple_example._preprocessing_fn, 49 | _EXPECTED_TRANSFORMED_OUTPUT) 50 | 51 | 52 | class SimpleMainTest(tf.test.TestCase): 53 | 54 | def testMainDoesNotCrash(self): 55 | simple_example.main() 56 | 57 | 58 | if __name__ == '__main__': 59 | tf.test.main() 60 | -------------------------------------------------------------------------------- /examples/simple_sequence_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Example of reading SequenceExample in tf.Transform.""" 15 | 16 | import os 17 | import tempfile 18 | 19 | from absl import logging 20 | import apache_beam as beam 21 | import tensorflow_transform as tft 22 | import tensorflow_transform.beam as tft_beam 23 | from tfx_bsl.public import tfxio 24 | 25 | from tensorflow_metadata.proto.v0 import schema_pb2 26 | from google.protobuf import text_format 27 | 28 | _TRANSFORM_TEMP_DIR = 'tft_temp' 29 | _SCHEMA = text_format.Parse( 30 | """ 31 | feature { 32 | name: "int_feature" 33 | type: INT 34 | value_count { 35 | min: 1 36 | max: 1 37 | } 38 | } 39 | feature { 40 | name: "float_feature" 41 | type: FLOAT 42 | value_count { 43 | min: 4 44 | max: 4 45 | } 46 | } 47 | feature { 48 | name: "##SEQUENCE##" 49 | type: STRUCT 50 | struct_domain { 51 | feature { 52 | name: "int_feature" 53 | type: INT 54 | value_count { 55 | min: 0 56 | max: 2 57 | } 58 | } 59 | feature { 60 | name: "string_feature" 61 | type: BYTES 62 | value_count { 63 | min: 0 64 | max: 2 65 | } 66 | } 67 | } 68 | } 69 | tensor_representation_group { 70 | key: "" 71 | value { 72 | tensor_representation { 73 | key: "int_feature" 74 | value { varlen_sparse_tensor { column_name: "int_feature" } } 75 | } 76 | tensor_representation { 77 | key: "float_feature" 78 | value { varlen_sparse_tensor { column_name: "float_feature" } } 79 | } 80 | tensor_representation { 81 | key: "seq_string_feature" 82 | value { ragged_tensor { 83 | feature_path { step: "##SEQUENCE##" step: "string_feature" } } } 84 | } 85 | tensor_representation { 86 | key: "seq_int_feature" 87 | value { ragged_tensor { 88 | feature_path { step: "##SEQUENCE##" step: "int_feature" } } } 89 | } 90 | } 91 | } 92 | """, schema_pb2.Schema()) 93 | 94 | _TELEMETRY_DESCRIPTORS = ['TFT', 'SequenceExample'] 95 | 96 | 97 | def _print_record_batch(data): 98 | logging.info(data.to_pydict()) 99 | 100 | 101 | def _make_tfxio(schema): 102 | """Creates TFXIO for SequenceExample. 103 | 104 | Args: 105 | schema: A TFMD Schema describing the dataset. 106 | 107 | Returns: 108 | TFSequenceExampleRecord TFXIO Instance. 109 | 110 | The data_tfrecord.gz file holds Serialized SequenceExample as below: 111 | context { 112 | feature { key: "int_feature" value { int64_list { value: [0] } } } 113 | feature { 114 | key: "float_feature" 115 | value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } 116 | } 117 | } 118 | feature_lists { 119 | feature_list { 120 | key: "int_feature" 121 | value { 122 | feature { int64_list { value: [1, 2] } } 123 | feature { int64_list { value: [3, 4] } } 124 | } 125 | } 126 | feature_list { 127 | key: "string_feature" 128 | value { 129 | feature { bytes_list { value: ["Hello", "World"] } } 130 | feature { bytes_list { value: [] } } 131 | } 132 | } 133 | } 134 | """ 135 | sequence_example_file = os.path.join( 136 | os.path.dirname(__file__), 'testdata/sequence_example/data_tfrecord.gz') 137 | return tfxio.TFSequenceExampleRecord( 138 | sequence_example_file, 139 | schema=schema, 140 | telemetry_descriptors=_TELEMETRY_DESCRIPTORS) 141 | 142 | 143 | def _preprocessing_fn(inputs): 144 | """Preprocess input columns into transformed columns. 145 | 146 | Args: 147 | inputs: Input Tensors. 148 | 149 | Returns: 150 | Dictionary of respective transformed inputs 151 | 152 | Example: 153 | `int_features`: tft.scale_to_0_1(...) 154 | Input: [[[0]], [[1]], [[2]]] 155 | Output: [[[0]], [[0.5]], [[1]]] 156 | 157 | `float_features`: tft.scale_to_0_1(.., elementwise = True) 158 | Input: [ 159 | [[1.0, 2.0, 3.0, 4.0]], 160 | [[2.0, 3.0, 4.0, 5.0]], 161 | [[3.0, 4.0, 0.0, 0.0]] 162 | ] 163 | Output: [ 164 | [[0.0, 0.0, 0.75, 0.8]], 165 | [[0.5, 0.5, 1.0, 1.0]], 166 | [[1.0, 1.0, 0.0, 0.0]] 167 | ] 168 | 169 | `seq_int_feature`: tft.scale_by_min_max(...) 170 | Input: [ 171 | [ [1, 2], [3, 4]], 172 | [ [5, 6], [7, 8]], 173 | [[9, 10], [11, 12]] 174 | ] 175 | Output: [ 176 | [[ 0.0, 0.0909], [0.1818, 0.2727]], 177 | [[0.3636, 0.4545], [0.5454, 0.6363]], 178 | [[0.7272, 0.8181], [0.9090, 1.0]] 179 | ] 180 | 181 | `seq_string_feature`: tft.compute_and_apply_vocabulary(...) 182 | Input: [ 183 | [[ b'Hello', b'World'], []], 184 | [[ b'foo', b'bar'], []], 185 | [[b'tensor', b'flow'], []] 186 | ] 187 | Output: [ 188 | [[[5, 4], []]], 189 | [[[1, 3], []]], 190 | [[[0, 2], []]] 191 | ] 192 | """ 193 | return { 194 | 'transformed_seq_int_feature': 195 | tft.scale_by_min_max(inputs['seq_int_feature']), 196 | 'transformed_seq_string_feature': 197 | tft.compute_and_apply_vocabulary(inputs['seq_string_feature']), 198 | 'transformed_float_feature': 199 | tft.scale_to_0_1(inputs['float_feature'], elementwise=True), 200 | 'transformed_int_feature': 201 | tft.scale_to_0_1(inputs['int_feature']), 202 | } 203 | 204 | 205 | def _transform_data(sequence_example_tfxio): 206 | """Transform the data and output transformed values. 207 | 208 | Args: 209 | sequence_example_tfxio: tfxio.TFSequenceExampleRecord Object 210 | """ 211 | 212 | with beam.Pipeline() as pipeline: 213 | with tft_beam.Context( 214 | temp_dir=os.path.join(tempfile.mkdtemp(), _TRANSFORM_TEMP_DIR)): 215 | 216 | raw_data = pipeline | 'ReadAndDecode' >> sequence_example_tfxio.BeamSource( 217 | ) 218 | _ = raw_data | 'PrintInputData' >> beam.Map(_print_record_batch) 219 | 220 | (transformed_data, 221 | _), _ = ((raw_data, sequence_example_tfxio.TensorAdapterConfig()) 222 | | 'AnalyzeAndTransform' >> tft_beam.AnalyzeAndTransformDataset( 223 | _preprocessing_fn, output_record_batches=True)) 224 | 225 | # Drop empty pass-through features dictionary that is not relevant 226 | # for this example. 227 | transformed_data = transformed_data | 'ExtractRecordBatch' >> beam.Keys() 228 | _ = transformed_data | 'PrintTransformedData' >> beam.Map( 229 | _print_record_batch) 230 | 231 | 232 | def main(): 233 | _transform_data(_make_tfxio(_SCHEMA)) 234 | 235 | 236 | if __name__ == '__main__': 237 | main() 238 | -------------------------------------------------------------------------------- /examples/simple_sequence_example_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for simple_example.""" 15 | 16 | import tensorflow as tf 17 | import simple_sequence_example 18 | from tensorflow_transform.beam import tft_unit 19 | 20 | _EXPECTED_TRANSFORMED_OUTPUT = [{ 21 | 'transformed_seq_int_feature$ragged_values': [ 22 | 0.0, 0.09090909, 0.18181818, 0.27272727 23 | ], 24 | 'transformed_seq_int_feature$row_lengths_1': [2, 2], 25 | 'transformed_seq_string_feature$ragged_values': [5, 4], 26 | 'transformed_seq_string_feature$row_lengths_1': [2, 0], 27 | 'transformed_float_feature': [0.0, 0.0, 0.75, 0.8], 28 | 'transformed_int_feature': [0], 29 | }, { 30 | 'transformed_seq_int_feature$ragged_values': [ 31 | 0.36363636, 0.45454545, 0.54545454, 0.63636363 32 | ], 33 | 'transformed_seq_int_feature$row_lengths_1': [2, 2], 34 | 'transformed_seq_string_feature$ragged_values': [1, 3], 35 | 'transformed_seq_string_feature$row_lengths_1': [2, 0], 36 | 'transformed_float_feature': [0.5, 0.5, 1.0, 1.0], 37 | 'transformed_int_feature': [0.5], 38 | }, { 39 | 'transformed_seq_int_feature$ragged_values': [ 40 | 0.72727272, 0.81818181, 0.90909090, 1.0 41 | ], 42 | 'transformed_seq_int_feature$row_lengths_1': [2, 2], 43 | 'transformed_seq_string_feature$ragged_values': [0, 2], 44 | 'transformed_seq_string_feature$row_lengths_1': [2, 0], 45 | 'transformed_float_feature': [1.0, 1.0, 0.0, 0.0], 46 | 'transformed_int_feature': [1], 47 | }] 48 | 49 | 50 | class SimpleMainTest(tf.test.TestCase): 51 | 52 | def testMainDoesNotCrash(self): 53 | tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') 54 | simple_sequence_example.main() 55 | 56 | 57 | class SimpleSequenceExampleTest(tft_unit.TransformTestCase): 58 | 59 | def testPreprocessingFn(self): 60 | tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') 61 | tfxio = simple_sequence_example._make_tfxio(simple_sequence_example._SCHEMA) 62 | self.assertAnalyzeAndTransformResults( 63 | tfxio.BeamSource(), 64 | tfxio.TensorAdapterConfig(), 65 | simple_sequence_example._preprocessing_fn, 66 | output_record_batches=True, 67 | expected_data=_EXPECTED_TRANSFORMED_OUTPUT) 68 | 69 | 70 | if __name__ == '__main__': 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /examples/testdata/sequence_example/data_tfrecord.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/transform/1e8c81a54b94841d55be5fefd4aa860ea49cc389/examples/testdata/sequence_example/data_tfrecord.gz -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Package Setup script for tf.Transform.""" 15 | import os 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | 21 | def select_constraint(default, nightly=None, git_master=None): 22 | """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" 23 | selector = os.environ.get('TFX_DEPENDENCY_SELECTOR') 24 | if selector == 'UNCONSTRAINED': 25 | return '' 26 | elif selector == 'NIGHTLY' and nightly is not None: 27 | return nightly 28 | elif selector == 'GIT_MASTER' and git_master is not None: 29 | return git_master 30 | else: 31 | return default 32 | 33 | 34 | # Get version from version module. 35 | with open('tensorflow_transform/version.py') as fp: 36 | globals_dict = {} 37 | exec(fp.read(), globals_dict) # pylint: disable=exec-used 38 | __version__ = globals_dict['__version__'] 39 | 40 | 41 | def _make_required_install_packages(): 42 | # Make sure to sync the versions of common dependencies (absl-py, numpy, and 43 | # protobuf) with TF and pyarrow version with tfx-bsl. 44 | return [ 45 | 'absl-py>=0.9,<2.0.0', 46 | 'apache-beam[gcp]>=2.53,<3;python_version>="3.11"', 47 | 'apache-beam[gcp]>=2.47,<3;python_version<"3.11"', 48 | 'numpy>=1.22.0', 49 | 'protobuf>=4.25.2,<6;python_version>="3.11"', 50 | 'protobuf>=3.20.3,<5;python_version<"3.11"', 51 | 'pyarrow>=10,<11', 52 | 'pydot>=1.2,<2', 53 | 'tensorflow>=2.17,<2.18', 54 | 'tensorflow-metadata' 55 | + select_constraint( 56 | default='>=1.16.1,<1.17.0', 57 | nightly='>=1.17.0.dev', 58 | git_master='@git+https://github.com/tensorflow/metadata@master', 59 | ), 60 | 'tf_keras>=2', 61 | 'tfx-bsl' 62 | + select_constraint( 63 | default='>=1.16.1,<1.17.0', 64 | nightly='>=1.17.0.dev', 65 | git_master='@git+https://github.com/tensorflow/tfx-bsl@master', 66 | ), 67 | ] 68 | 69 | 70 | # Get the long description from the README file. 71 | with open('README.md') as fp: 72 | _LONG_DESCRIPTION = fp.read() 73 | 74 | setup( 75 | name='tensorflow-transform', 76 | version=__version__, 77 | author='Google Inc.', 78 | author_email='tensorflow-extended-dev@googlegroups.com', 79 | license='Apache 2.0', 80 | classifiers=[ 81 | 'Development Status :: 5 - Production/Stable', 82 | 'Intended Audience :: Developers', 83 | 'Intended Audience :: Education', 84 | 'Intended Audience :: Science/Research', 85 | 'License :: OSI Approved :: Apache Software License', 86 | 'Operating System :: OS Independent', 87 | 'Programming Language :: Python', 88 | 'Programming Language :: Python :: 3', 89 | 'Programming Language :: Python :: 3.9', 90 | 'Programming Language :: Python :: 3.10', 91 | 'Programming Language :: Python :: 3.11', 92 | 'Programming Language :: Python :: 3 :: Only', 93 | 'Topic :: Scientific/Engineering', 94 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 95 | 'Topic :: Scientific/Engineering :: Mathematics', 96 | 'Topic :: Software Development', 97 | 'Topic :: Software Development :: Libraries', 98 | 'Topic :: Software Development :: Libraries :: Python Modules', 99 | ], 100 | namespace_packages=[], 101 | install_requires=_make_required_install_packages(), 102 | python_requires='>=3.9,<4', 103 | packages=find_packages(), 104 | include_package_data=True, 105 | package_data={'tensorflow_transform': ['py.typed']}, 106 | description='A library for data preprocessing with TensorFlow', 107 | long_description=_LONG_DESCRIPTION, 108 | long_description_content_type='text/markdown', 109 | keywords='tensorflow transform tfx', 110 | url='https://www.tensorflow.org/tfx/transform/get_started', 111 | download_url='https://github.com/tensorflow/transform/tags', 112 | requires=[]) 113 | -------------------------------------------------------------------------------- /tensorflow_transform/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Init module for TF.Transform.""" 15 | 16 | # pylint: disable=wildcard-import 17 | from tensorflow_transform import coders 18 | from tensorflow_transform import experimental 19 | from tensorflow_transform.analyzers import * 20 | from tensorflow_transform.annotators import * 21 | from tensorflow_transform.inspect_preprocessing_fn import * 22 | from tensorflow_transform.mappers import * 23 | from tensorflow_transform.output_wrapper import TFTransformOutput 24 | from tensorflow_transform.output_wrapper import TransformFeaturesLayer 25 | from tensorflow_transform.py_func.api import apply_pyfunc 26 | from tensorflow_transform.tf_metadata.dataset_metadata import DatasetMetadata 27 | # pylint: enable=wildcard-import 28 | 29 | # Import version string. 30 | from tensorflow_transform.version import __version__ 31 | 32 | # TF 2.6 split support for filesystems such as Amazon S3 out to the 33 | # `tensorflow_io` package. Hence, this import is needed wherever we touch the 34 | # filesystem. 35 | try: 36 | import tensorflow_io as _ # pytype: disable=import-error # pylint: disable=g-import-not-at-top 37 | except ModuleNotFoundError: 38 | pass 39 | 40 | try: 41 | from tensorflow_transform import google # pytype: disable=import-error # pylint: disable=g-import-not-at-top 42 | except ImportError: 43 | pass 44 | -------------------------------------------------------------------------------- /tensorflow_transform/annotations.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | // Annotations that can be applied to the output schema by analyzers or mappers. 4 | package third_party.py.tensorflow_transform.annotations; 5 | 6 | // Represents the bucket boundaries that were used to Bucketize a feature. 7 | message BucketBoundaries { 8 | repeated float boundaries = 1; 9 | } 10 | 11 | // Represents metadata about the computed vocabulary. 12 | message VocabularyMetadata { 13 | optional string file_name = 1; 14 | // The original size of the vocabulary, prior to any filtering (e.g. 15 | // filtering to top_k). 16 | optional int64 unfiltered_vocabulary_size = 2; 17 | // The filtered size of the vocabulary. (e.g. after filtering to top_k). 18 | optional int64 filtered_vocabulary_size = 3; 19 | } 20 | -------------------------------------------------------------------------------- /tensorflow_transform/annotators_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.annotators.""" 15 | 16 | import tensorflow as tf 17 | from tensorflow_transform import annotators 18 | from tensorflow_transform import test_case 19 | 20 | 21 | class AnnotatorsTest(test_case.TransformTestCase): 22 | 23 | @test_case.named_parameters( 24 | dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True), 25 | dict(testcase_name='tf2', use_tf_compat_v1=False)) 26 | def test_annotate_asset(self, use_tf_compat_v1): 27 | if not use_tf_compat_v1: 28 | test_case.skip_if_not_tf2('Tensorflow 2.x required') 29 | 30 | def foo(): 31 | annotators.annotate_asset('scope/my_key', 'scope/my_value') 32 | annotators.annotate_asset('my_key2', 'should_be_replaced') 33 | annotators.annotate_asset('my_key2', 'my_value2') 34 | 35 | if use_tf_compat_v1: 36 | with tf.Graph().as_default() as graph: 37 | foo() 38 | else: 39 | graph = tf.function(foo).get_concrete_function().graph 40 | 41 | self.assertDictEqual( 42 | annotators.get_asset_annotations(graph), { 43 | 'my_key': 'my_value', 44 | 'my_key2': 'my_value2' 45 | }) 46 | 47 | annotators.clear_asset_annotations(graph) 48 | self.assertDictEqual(annotators.get_asset_annotations(graph), {}) 49 | 50 | def test_object_tracker(self): 51 | test_case.skip_if_not_tf2('Tensorflow 2.x required') 52 | 53 | trackable_object = tf.__internal__.tracking.Trackable() 54 | 55 | @tf.function 56 | def preprocessing_fn(): 57 | _ = annotators.make_and_track_object(lambda: trackable_object) 58 | return 1 59 | 60 | object_tracker = annotators.ObjectTracker() 61 | with annotators.object_tracker_scope(object_tracker): 62 | _ = preprocessing_fn() 63 | 64 | self.assertLen(object_tracker.trackable_objects, 1) 65 | self.assertEqual(trackable_object, object_tracker.trackable_objects[0]) 66 | 67 | 68 | if __name__ == '__main__': 69 | test_case.main() 70 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module level imports for tensorflow_transform.beam.""" 15 | 16 | # pylint: disable=wildcard-import 17 | # The doc-generator's `explicit_package_contents_filter` requires that 18 | # sub-modules you want documented are explicitly imported. 19 | # Also: analyzer_impls registers implementation of analyzers. 20 | from tensorflow_transform.beam import analyzer_cache 21 | from tensorflow_transform.beam import analyzer_impls 22 | from tensorflow_transform.beam import experimental 23 | from tensorflow_transform.beam.context import Context 24 | from tensorflow_transform.beam.impl import AnalyzeAndTransformDataset 25 | from tensorflow_transform.beam.impl import AnalyzeDataset 26 | from tensorflow_transform.beam.impl import AnalyzeDatasetWithCache 27 | from tensorflow_transform.beam.impl import EncodeTransformedDataset 28 | from tensorflow_transform.beam.impl import TransformDataset 29 | from tensorflow_transform.beam.tft_beam_io import * 30 | 31 | # pylint: enable=wildcard-import 32 | 33 | # TF 2.6 split support for filesystems such as Amazon S3 out to the 34 | # `tensorflow_io` package. Hence, this import is needed wherever we touch the 35 | # filesystem. 36 | try: 37 | import tensorflow_io as _ # pytype: disable=import-error # pylint: disable=g-import-not-at-top 38 | except ModuleNotFoundError: 39 | pass 40 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/analyzer_impls_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.beam.analyzer_impls.""" 15 | 16 | import apache_beam as beam 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_transform.beam import analyzer_impls 21 | from tensorflow_transform.beam import tft_unit 22 | 23 | 24 | class AnalyzerImplsTest(tft_unit.TransformTestCase): 25 | 26 | def testSplitInputsByKey(self): 27 | inputs = [ 28 | np.array(['my_key', 'my_other_key']), 29 | np.array([[1, 2], [3, 4]]), 30 | np.array([5, 6]) 31 | ] 32 | split_inputs = list(analyzer_impls._split_inputs_by_key(inputs)) 33 | self.assertEqual(len(split_inputs), 2) 34 | 35 | self.assertEqual(len(split_inputs[0]), 2) 36 | self.assertEqual(split_inputs[0][0], 'my_key') 37 | self.assertEqual(len(split_inputs[0][1]), 2) 38 | self.assertAllEqual(split_inputs[0][1][0], np.array([1, 2])) 39 | self.assertAllEqual(split_inputs[0][1][1], np.array(5)) 40 | 41 | self.assertEqual(len(split_inputs[1]), 2) 42 | self.assertEqual(split_inputs[1][0], 'my_other_key') 43 | self.assertEqual(len(split_inputs[1][1]), 2) 44 | self.assertAllEqual(split_inputs[1][1][0], np.array([3, 4])) 45 | self.assertAllEqual(split_inputs[1][1][1], np.array(6)) 46 | 47 | def testMergeOutputsByKey(self): 48 | outputs = [ 49 | ('my_key', [np.array(20), np.array([21, 22])]), 50 | ('my_other_key', [np.array(23), np.array([24, 25])]) 51 | ] 52 | outputs_pcoll = [outputs] 53 | merged_outputs_pcolls = tuple(outputs_pcoll | beam.FlatMap( 54 | analyzer_impls._merge_outputs_by_key, 55 | outputs_dtype=[tf.int64, tf.int64]).with_outputs('key', '0', '1')) 56 | self.assertAllEqual(merged_outputs_pcolls[0][0], 57 | np.array(['my_key', 'my_other_key'])) 58 | self.assertAllEqual(merged_outputs_pcolls[1][0], 59 | np.array([20, 23])) 60 | self.assertAllEqual(merged_outputs_pcolls[2][0], 61 | np.array([[21, 22], [24, 25]])) 62 | 63 | def testMergeOutputsByKeyEmptyInput(self): 64 | outputs = [] 65 | outputs_pcoll = [outputs] 66 | merged_outputs_pcolls = tuple(outputs_pcoll | beam.FlatMap( 67 | analyzer_impls._merge_outputs_by_key, 68 | outputs_dtype=[tf.float32, tf.float32]).with_outputs('key', '0', '1')) 69 | self.assertAllEqual(merged_outputs_pcolls[0][0], 70 | np.array([])) 71 | self.assertAllEqual(merged_outputs_pcolls[1][0], np.array([])) 72 | self.assertAllEqual(merged_outputs_pcolls[2][0], np.array([])) 73 | 74 | @tft_unit.named_parameters( 75 | dict( 76 | testcase_name='Increasing', 77 | input_boundaries=np.array([[1, 1.00000001], [1, 2]]), 78 | expected_boundaries=np.array([[1, 1.00000001], [1, 2]])), 79 | dict( 80 | testcase_name='Repeating', 81 | input_boundaries=np.array([[1, 1, 1], [4, 4, 4]]), 82 | expected_boundaries=np.array([[1, 1.000001, 1.000002], 83 | [4, 4.000001, 4.000002]])), 84 | dict( 85 | testcase_name='NonIncreasing', 86 | input_boundaries=np.array([[3, 5.1, 5.1], [4.01, 4.01, 4.2]]), 87 | expected_boundaries=np.array([[3, 5.1, 5.1000021], 88 | [4.01, 4.01000019, 4.20000019]]), 89 | atol=1e-6), 90 | ) 91 | def testMakeStrictlyIncreasingBoundariesRows(self, 92 | input_boundaries, 93 | expected_boundaries, 94 | atol=None): 95 | result = analyzer_impls._make_strictly_increasing_boundaries_rows( 96 | input_boundaries) 97 | if atol is None: 98 | self.assertAllEqual(result, expected_boundaries) 99 | else: 100 | self.assertAllClose(result, expected_boundaries, atol=atol) 101 | 102 | @tft_unit.named_parameters( 103 | dict( 104 | testcase_name='Simple', 105 | input_boundaries=np.array([[0, 1, 2], [0, 1, 2]]), 106 | expected_boundaries=np.array([0, 0.5, 1, 1.5, 2]), 107 | expected_scales=np.array([0.5, 0.5]), 108 | expected_shifts=np.array([0, 1]), 109 | expected_num_buckets=np.array(4)), 110 | dict( 111 | testcase_name='Complex', 112 | input_boundaries=np.array([[0, 1, 2, 3], [3, 3, 3, 3], [2, 4, 6, 8]]), 113 | expected_boundaries=np.array([ 114 | 0, 0.33333333, 0.66666667, 1, 1.33333333, 1.66666667, 2, 115 | 2.33333333, 2.66666667, 3 116 | ]), 117 | expected_scales=np.array([0.333333333, 333333.333, 0.166666667]), 118 | expected_shifts=np.array([0, -999999, 1.66666667]), 119 | expected_num_buckets=np.array(5)), 120 | dict( 121 | testcase_name='SingleBoundary', 122 | input_boundaries=np.array([[1], [2]]), 123 | expected_boundaries=np.array([0]), 124 | expected_scales=np.array([1., 1.]), 125 | expected_shifts=np.array([-1, -1]), 126 | expected_num_buckets=np.array(2)), 127 | ) 128 | def testJoinBoundarieRows(self, input_boundaries, expected_boundaries, 129 | expected_scales, expected_shifts, 130 | expected_num_buckets): 131 | boundaries, scales, shifts, num_buckets = ( 132 | analyzer_impls._join_boundary_rows(input_boundaries)) 133 | self.assertAllClose(boundaries, expected_boundaries) 134 | self.assertAllClose(scales, expected_scales) 135 | self.assertAllClose(shifts, expected_shifts) 136 | self.assertAllEqual(num_buckets, expected_num_buckets) 137 | 138 | 139 | if __name__ == '__main__': 140 | tft_unit.main() 141 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/beam_nodes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Nodes that define the Beam execution graph. 15 | 16 | `OperationNode`s are objects that define a graph of operations similar to the 17 | graph of `beam.PTransform`s. 18 | `tensorflow_transform.beam.analysis_graph_builder.build` converts the graph 19 | defined the user's preprocessing_fn into a new graph of `OperationNode`s which 20 | in turn is in turn implemented as a graph of `beam.PTransform`s by 21 | `tensorflow_transform.beam.common.ConstructBeamPipelineVisitor`. A registration 22 | system is used to register implementations of each individual `OperationDef`. 23 | The `OperationDef`s defined by the user in their preprocessing_fn are all 24 | subclasses of `AnayzerDef` (except `TensorSource`, which gets converted to 25 | `ExtractFromDict` in `tensorflow_transform.beam.analysis_graph_builder.build`). 26 | The subclasses of `AnalyzerDef` are defined in 27 | `tensorflow_transform.analyzer_nodes` and are implemented in 28 | `tensorflow_transform.beam.analyzer_impls`. 29 | 30 | This module contains the nodes that are created by 31 | `tensorflow_transform.beam.analysis_graph_builder.build`. These nodes define 32 | the parts of the beam graph that run a TensorFlow graph in a `beam.ParDo`, 33 | extract `PCollections` containing tuples of tensors required by analyzers, 34 | run the analyzers, and then create a new (deferred) TensorFlow graph where 35 | the results of analyzers are replaced by constant tensors. This happens in a 36 | number of phases, since an analyzer might depend on a tensor that in turn 37 | depends on the result of another analyzer. 38 | 39 | The `OperationDef` subclasses defined here are implemented in 40 | `tensorflow_transform.beam.impl`. 41 | """ 42 | 43 | import tensorflow as tf 44 | from tensorflow_transform import nodes 45 | # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` 46 | # once the Spark issue is resolved. 47 | from tfx_bsl.types import tfx_namedtuple 48 | 49 | 50 | class CreateTensorBinding( 51 | tfx_namedtuple.namedtuple( 52 | 'CreateTensorBinding', 53 | ['tensor_name', 'dtype_enum', 'is_asset_filepath', 'label']), 54 | nodes.OperationDef): 55 | """An operation that represents creating a tensor binding from a value. 56 | 57 | This `OperationDef` represents a `beam.PTransform` that applies a ParDo 58 | (where the input PCollection is assumed to contain a single element), which 59 | combines the single element with the a tensor name and `is_asset_filepath` 60 | to create a tensor binding. 61 | 62 | Attributes: 63 | tensor_name: The name of the tensor that the given value should replace as a 64 | constant tensor. 65 | dtype_enum: The Dtype of the tensor as a TF `types_pb2.DataType`. 66 | is_asset_filepath: If true, then the replaced value will be added to the 67 | ASSET_FILEPATHS collection if exporting a TF1 Graph. 68 | label: A unique label for this operation. 69 | """ 70 | __slots__ = () 71 | 72 | 73 | class CreateSavedModel( 74 | tfx_namedtuple.namedtuple( 75 | 'CreateSavedModel', 76 | ['table_initializers', 'output_signature', 'label']), 77 | nodes.OperationDef): 78 | """An operation that represents creating a SavedModel with bound values. 79 | 80 | This operation represents creating a SavedModel. Its output is a 81 | PCollection containing a single element which is the directory containing the 82 | `SavedModel`. The inputs are a PCollection of tensor bindings. A tensor 83 | binding is the specification of a tensor and a value that it should be 84 | replaced with in the graph. 85 | 86 | This allows us to create a `SavedModel` in a deferred manner, which depends on 87 | deferred values (the tensor bindings) which were not known when the Beam graph 88 | was constructed. 89 | 90 | 91 | Attributes: 92 | table_initializers: A list of table initializer ops that should be run as 93 | part of this SavedModel. 94 | output_signature: The output signature of this `SavedModel`, as a dictionary 95 | whose keys are feature names and values are `Tensor`s or 96 | `SparseTensor`s. 97 | label: A unique label for this operation. 98 | """ 99 | __slots__ = () 100 | 101 | def _get_tensor_type_name(self, tensor): 102 | if isinstance(tensor, tf.Tensor): 103 | return 'Tensor' 104 | elif isinstance(tensor, tf.SparseTensor): 105 | return 'SparseTensor' 106 | elif isinstance(tensor, tf.RaggedTensor): 107 | return 'RaggedTensor' 108 | raise ValueError('Got a {}, expected a Tensor or SparseTensor'.format( 109 | type(tensor))) 110 | 111 | def get_field_str(self, field_name): 112 | # Overriding the str representation of table initializers since it may be 113 | # different for various versions of TF. 114 | if field_name == 'table_initializers': 115 | return '{}'.format(len(self.table_initializers)) 116 | elif field_name == 'output_signature': 117 | copied = self.output_signature.copy() 118 | for key in copied: 119 | value = self.output_signature[key] 120 | copied[key] = '{}'.format( 121 | self._get_tensor_type_name(value), value.shape.as_list(), 122 | value.dtype) 123 | return str(copied) 124 | return super().get_field_str(field_name) 125 | 126 | 127 | class ExtractInputForSavedModel( 128 | tfx_namedtuple.namedtuple('ExtractInputForSavedModel', 129 | ['dataset_key', 'label']), nodes.OperationDef): 130 | """An operation that forwards the requested dataset in PCollection form. 131 | 132 | The resulting PCollection is either the dataset corresponding to 133 | `dataset_key`, or a flattened PCollection if `dataset_key` is not specified. 134 | 135 | Attributes: 136 | dataset_key: (Optional) dataset key str. 137 | label: A unique label for this operation. 138 | """ 139 | __slots__ = () 140 | 141 | 142 | class ApplySavedModel( 143 | tfx_namedtuple.namedtuple('ApplySavedModel', ['phase', 'label']), 144 | nodes.OperationDef): 145 | """An operation that represents applying a SavedModel as a `beam.ParDo`. 146 | 147 | This operation represents applying a `SavedModel`, which is the input to this 148 | operation, to the input values. The inputs values are not an input to this 149 | operation, but are provided to the implementation by 150 | `tensorflow_transform.beam.common.ConstructBeamPipelineVisitor.ExtraArgs`. 151 | 152 | The input should be a PCollection containing a single element which is the 153 | directory containing the SavedModel to be run. 154 | 155 | Attributes: 156 | phase: An integer which is the phase that this operation is run as part of. 157 | label: A unique label for this operation. 158 | """ 159 | __slots__ = () 160 | 161 | @property 162 | def is_partitionable(self): 163 | return True 164 | 165 | 166 | class ExtractFromDict( 167 | tfx_namedtuple.namedtuple('ExtractFromDict', ['keys', 'label']), 168 | nodes.OperationDef): 169 | """An operation that represents extracting values from a dictionary. 170 | 171 | This operation represents a `beam.ParDo` that is applied to a PCollection 172 | whose elements are assumed to be a dictionary of values. For each element of 173 | the PCollection, this corresponding element of the output PCollection is a 174 | tuple of values, one for each key. 175 | 176 | Attributes: 177 | keys: The keys whose values should be extracted from each element of the 178 | input PCollection. keys should either be a tuple or a string. 179 | label: A unique label for this operation. 180 | """ 181 | __slots__ = () 182 | 183 | @property 184 | def is_partitionable(self): 185 | return True 186 | 187 | 188 | class Flatten( 189 | tfx_namedtuple.namedtuple('Flatten', ['label']), nodes.OperationDef): 190 | __slots__ = () 191 | 192 | @property 193 | def is_partitionable(self): 194 | return True 195 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/context_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.beam.context.""" 15 | 16 | import os 17 | 18 | import tensorflow_transform.beam as tft_beam 19 | from tensorflow_transform.beam import tft_unit 20 | 21 | 22 | class ContextTest(tft_unit.TransformTestCase): 23 | 24 | def testNestedContextCreateBaseTempDir(self): 25 | 26 | level_1_dir = self.get_temp_dir() 27 | with tft_beam.Context(temp_dir=level_1_dir): 28 | self.assertEqual( 29 | os.path.join(level_1_dir, tft_beam.Context._TEMP_SUBDIR), 30 | tft_beam.Context.create_base_temp_dir()) 31 | level_2_dir = self.get_temp_dir() 32 | with tft_beam.Context(temp_dir=level_2_dir): 33 | self.assertEqual( 34 | os.path.join(level_2_dir, tft_beam.Context._TEMP_SUBDIR), 35 | tft_beam.Context.create_base_temp_dir()) 36 | self.assertEqual( 37 | os.path.join(level_1_dir, tft_beam.Context._TEMP_SUBDIR), 38 | tft_beam.Context.create_base_temp_dir()) 39 | with self.assertRaises(ValueError): 40 | tft_beam.Context.create_base_temp_dir() 41 | 42 | 43 | if __name__ == '__main__': 44 | tft_unit.main() 45 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module level imports for tensorflow_transform.beam.experimental.""" 15 | 16 | from tensorflow_transform.beam.experimental.analyzer_impls import * 17 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/experimental/analyzer_impls.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Beam implementations of experimental tf.Transform canonical analyzers.""" 15 | import apache_beam as beam 16 | 17 | 18 | class PTransformAnalyzer(beam.PTransform): 19 | """A PTransform analyzer's base class which provides a temp dir if needed.""" 20 | 21 | def __init__(self): 22 | self._base_temp_dir = None 23 | 24 | @property 25 | def base_temp_dir(self): 26 | return self._base_temp_dir 27 | 28 | @base_temp_dir.setter 29 | def base_temp_dir(self, val): 30 | self._base_temp_dir = val 31 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/test_helpers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Copyright 2017 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Helpers for tensorflow_transform.beam tests.""" 17 | 18 | 19 | def make_test_beam_pipeline_kwargs(): 20 | # This is kwargs for apache_beam.Pipeline's __init__, using the default runner 21 | # here. 22 | return {} 23 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/tft_beam_io/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module level imports for tensorflow_transform.beam.tft_beam_io.""" 15 | 16 | from tensorflow_transform.beam.tft_beam_io.beam_metadata_io import WriteMetadata 17 | from tensorflow_transform.beam.tft_beam_io.transform_fn_io import ReadTransformFn 18 | from tensorflow_transform.beam.tft_beam_io.transform_fn_io import WriteTransformFn 19 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/tft_beam_io/beam_metadata_io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Transforms to read/write metadata from disk. 15 | 16 | A write/read cycle will render all metadata deferred, but in general users 17 | should avoid doing this anyway and pass around live metadata objects. 18 | """ 19 | 20 | import json 21 | import os 22 | 23 | import apache_beam as beam 24 | import tensorflow as tf 25 | from tensorflow_transform import output_wrapper 26 | from tensorflow_transform.beam import common 27 | from tensorflow_transform.tf_metadata import metadata_io 28 | # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` 29 | # once the Spark issue is resolved. 30 | from tfx_bsl.types import tfx_namedtuple 31 | 32 | 33 | class BeamDatasetMetadata( 34 | tfx_namedtuple.namedtuple( 35 | 'BeamDatasetMetadata', 36 | ['dataset_metadata', 'deferred_metadata', 'asset_map'])): 37 | """A class like DatasetMetadata also holding `PCollection`s and an asset_map. 38 | 39 | `deferred_metadata` is a PCollection containing a single DatasetMetadata. 40 | `asset_map` is a Dictionary mapping asset keys to filenames. 41 | """ 42 | 43 | @property 44 | def schema(self): 45 | return self.dataset_metadata.schema 46 | 47 | 48 | class WriteMetadata(beam.PTransform): 49 | """A PTransform to write Metadata to disk. 50 | 51 | Input can either be a DatasetMetadata or a tuple of properties. 52 | 53 | Depending on the optional `write_to_unique_subdirectory`, writes the given 54 | metadata to either `path` or a new unique subdirectory under `path`. 55 | 56 | Returns a singleton with the path to which the metadata was written. 57 | """ 58 | 59 | # NOTE: The pipeline metadata is required by PTransform given that all the 60 | # inputs may be non-deferred. 61 | def __init__(self, path, pipeline, write_to_unique_subdirectory=False): 62 | """Init method. 63 | 64 | Args: 65 | path: A str, the default path that the metadata should be written to. 66 | pipeline: A beam Pipeline. 67 | write_to_unique_subdirectory: (Optional) A bool indicating whether to 68 | write the metadata out to `path` or a unique subdirectory under `path`. 69 | """ 70 | super().__init__() 71 | self._path = path 72 | self._write_to_unique_subdirectory = write_to_unique_subdirectory 73 | self.pipeline = pipeline 74 | 75 | def _extract_input_pvalues(self, metadata): 76 | pvalues = [] 77 | if isinstance(metadata, BeamDatasetMetadata): 78 | pvalues.append(metadata.deferred_metadata) 79 | return metadata, pvalues 80 | 81 | def expand(self, metadata): 82 | if hasattr(metadata, 'deferred_metadata'): 83 | metadata_pcoll = metadata.deferred_metadata 84 | else: 85 | metadata_pcoll = self.pipeline | beam.Create([metadata]) 86 | 87 | asset_map = getattr(metadata, 'asset_map', {}) 88 | 89 | def write_metadata_output(metadata): 90 | output_path = self._path 91 | if self._write_to_unique_subdirectory: 92 | output_path = common.get_unique_temp_path(self._path) 93 | metadata_io.write_metadata(metadata, output_path) 94 | if asset_map: 95 | with tf.io.gfile.GFile( 96 | os.path.join(output_path, 97 | output_wrapper.TFTransformOutput.ASSET_MAP), 'w') as f: 98 | f.write(json.dumps(asset_map)) 99 | return output_path 100 | 101 | return metadata_pcoll | 'WriteMetadata' >> beam.Map(write_metadata_output) 102 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/tft_beam_io/beam_metadata_io_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for beam_metadata_io.""" 15 | 16 | import json 17 | import os 18 | 19 | import apache_beam as beam 20 | import tensorflow as tf 21 | from tensorflow_transform import output_wrapper 22 | from tensorflow_transform.beam.tft_beam_io import beam_metadata_io 23 | from tensorflow_transform.beam import tft_unit 24 | from tensorflow_transform.beam.tft_beam_io import test_metadata 25 | import tensorflow_transform.test_case as tft_test_case 26 | from tensorflow_transform.tf_metadata import metadata_io 27 | 28 | mock = tf.compat.v1.test.mock 29 | 30 | 31 | class BeamMetadataIoTest(tft_unit.TransformTestCase): 32 | 33 | def testWriteMetadataNonDeferred(self): 34 | # Write metadata to disk using WriteMetadata PTransform. 35 | with beam.Pipeline() as pipeline: 36 | path = self.get_temp_dir() 37 | _ = (test_metadata.COMPLETE_METADATA 38 | | beam_metadata_io.WriteMetadata(path, pipeline)) 39 | 40 | # Load from disk and check that it is as expected. 41 | metadata = metadata_io.read_metadata(path) 42 | self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) 43 | 44 | def testWriteMetadataDeferred(self): 45 | # Write metadata to disk using WriteMetadata PTransform, combining 46 | # incomplete metadata with (deferred) complete metadata. 47 | expected_asset_map = {'key': 'value'} 48 | with beam.Pipeline() as pipeline: 49 | path = self.get_temp_dir() 50 | deferred_metadata = pipeline | 'CreateDeferredMetadata' >> beam.Create( 51 | [test_metadata.COMPLETE_METADATA]) 52 | metadata = beam_metadata_io.BeamDatasetMetadata( 53 | test_metadata.INCOMPLETE_METADATA, deferred_metadata, 54 | expected_asset_map) 55 | _ = metadata | beam_metadata_io.WriteMetadata(path, pipeline) 56 | 57 | # Load from disk and check that it is as expected. 58 | metadata = metadata_io.read_metadata(path) 59 | self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) 60 | 61 | with tf.io.gfile.GFile( 62 | os.path.join(path, output_wrapper.TFTransformOutput.ASSET_MAP)) as f: 63 | asset_map = json.loads(f.read()) 64 | self.assertDictEqual(asset_map, expected_asset_map) 65 | 66 | def testWriteMetadataIsRetryable(self): 67 | tft_test_case.skip_if_external_environment( 68 | 'Retries are currently not available on this environment.') 69 | original_write_metadata = beam_metadata_io.metadata_io.write_metadata 70 | write_metadata_called_list = [] 71 | 72 | def mock_write_metadata(metadata, path): 73 | """Mocks metadata_io.write_metadata to fail the first time it is called by this test, thus forcing a retry which should succeed.""" 74 | if not write_metadata_called_list: 75 | write_metadata_called_list.append(True) 76 | original_write_metadata(metadata, path) 77 | raise ArithmeticError('Some error') 78 | return original_write_metadata(metadata, path) 79 | 80 | # Write metadata to disk using WriteMetadata PTransform. 81 | with mock.patch( 82 | 'tensorflow_transform.tf_metadata.metadata_io.write_metadata', 83 | mock_write_metadata): 84 | with self._makeTestPipeline() as pipeline: 85 | path = self.get_temp_dir() 86 | _ = ( 87 | test_metadata.COMPLETE_METADATA 88 | | beam_metadata_io.WriteMetadata(path, pipeline)) 89 | 90 | # Load from disk and check that it is as expected. 91 | metadata = metadata_io.read_metadata(path) 92 | self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) 93 | 94 | 95 | if __name__ == '__main__': 96 | tf.test.main() 97 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/tft_beam_io/test_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test metadata for tft_beam_io tests.""" 15 | 16 | import tensorflow as tf 17 | from tensorflow_transform.tf_metadata import dataset_metadata 18 | 19 | from tensorflow_metadata.proto.v0 import schema_pb2 20 | 21 | _FEATURE_SPEC = { 22 | 'fixed_column': tf.io.FixedLenFeature([3], tf.string), 23 | 'list_columm': tf.io.VarLenFeature(tf.int64), 24 | } 25 | 26 | COMPLETE_METADATA = dataset_metadata.DatasetMetadata.from_feature_spec( 27 | _FEATURE_SPEC, domains={'list_columm': schema_pb2.IntDomain(min=-1, max=5)}) 28 | 29 | INCOMPLETE_METADATA = dataset_metadata.DatasetMetadata.from_feature_spec( 30 | _FEATURE_SPEC, 31 | # Values will be overridden by those in COMPLETE_METADATA 32 | domains={'list_columm': schema_pb2.IntDomain(min=0, max=0)}) 33 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/tft_beam_io/transform_fn_io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Transforms to read/write transform functions from disk.""" 15 | 16 | import os 17 | 18 | import apache_beam as beam 19 | import tensorflow_transform as tft 20 | from tensorflow_transform import impl_helper 21 | from tensorflow_transform.beam import common 22 | from tensorflow_transform.beam.tft_beam_io import beam_metadata_io 23 | from tensorflow_transform.tf_metadata import metadata_io 24 | 25 | # Users should avoid these aliases, they are provided for backwards 26 | # compatibility only. 27 | TRANSFORMED_METADATA_DIR = tft.TFTransformOutput.TRANSFORMED_METADATA_DIR 28 | TRANSFORM_FN_DIR = tft.TFTransformOutput.TRANSFORM_FN_DIR 29 | 30 | 31 | def _copy_tree_to_unique_temp_dir(source, base_temp_dir_path): 32 | """Copies from source to a unique sub directory under base_temp_dir_path.""" 33 | destination = common.get_unique_temp_path(base_temp_dir_path) 34 | _copy_tree(source, destination) 35 | return destination 36 | 37 | 38 | def _copy_tree(source, destination): 39 | """Recursively copies source to destination.""" 40 | # TODO(b/35363519): Perhaps use Beam IO eventually (which also already 41 | # supports recursive copy)? 42 | import tensorflow as tf # pylint: disable=g-import-not-at-top 43 | 44 | if tf.io.gfile.isdir(source): 45 | source_dir_name = os.path.basename(os.path.normpath(source)) 46 | if source_dir_name == impl_helper.METADATA_DIR_NAME: 47 | return 48 | 49 | tf.io.gfile.makedirs(destination) 50 | for filename in tf.io.gfile.listdir(source): 51 | _copy_tree( 52 | os.path.join(source, filename), os.path.join(destination, filename)) 53 | else: 54 | tf.io.gfile.copy(source, destination) 55 | 56 | 57 | class WriteTransformFn(beam.PTransform): 58 | """Writes a TransformFn to disk. 59 | 60 | The internal structure is a directory containing two subdirectories. The 61 | first is 'transformed_metadata' and contains metadata of the transformed data. 62 | The second is 'transform_fn' and contains a SavedModel representing the 63 | transformed data. 64 | """ 65 | 66 | def __init__(self, path): 67 | super().__init__() 68 | self._path = path 69 | 70 | def _extract_input_pvalues(self, transform_fn): 71 | saved_model_dir, metadata = transform_fn 72 | pvalues = [saved_model_dir] 73 | if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): 74 | pvalues.append(metadata.deferred_metadata) 75 | return transform_fn, pvalues 76 | 77 | def expand(self, transform_fn): 78 | saved_model_dir, metadata = transform_fn 79 | pipeline = saved_model_dir.pipeline 80 | 81 | # Using a temp dir within `path` ensures that the source and dstination 82 | # paths for the rename below are in the same file system. 83 | base_temp_dir = os.path.join(self._path, 'transform_tmp') 84 | temp_metadata_path = ( 85 | metadata 86 | | 'WriteMetadataToTemp' >> beam_metadata_io.WriteMetadata( 87 | base_temp_dir, pipeline, write_to_unique_subdirectory=True)) 88 | 89 | temp_transform_fn_path = ( 90 | saved_model_dir 91 | | 'WriteTransformFnToTemp' >> beam.Map(_copy_tree_to_unique_temp_dir, 92 | base_temp_dir)) 93 | 94 | metadata_path = os.path.join(self._path, 95 | tft.TFTransformOutput.TRANSFORMED_METADATA_DIR) 96 | transform_fn_path = os.path.join(self._path, 97 | tft.TFTransformOutput.TRANSFORM_FN_DIR) 98 | 99 | def publish_outputs(unused_element, metadata_source_path, 100 | transform_fn_source_path): 101 | import tensorflow as tf # pylint: disable=g-import-not-at-top 102 | if not tf.io.gfile.exists(self._path): 103 | tf.io.gfile.makedirs(self._path) 104 | 105 | if tf.io.gfile.exists(metadata_path): 106 | tf.io.gfile.rmtree(metadata_path) 107 | tf.io.gfile.rename(metadata_source_path, metadata_path, overwrite=True) 108 | 109 | if tf.io.gfile.exists(transform_fn_path): 110 | tf.io.gfile.rmtree(transform_fn_path) 111 | tf.io.gfile.rename( 112 | transform_fn_source_path, transform_fn_path, overwrite=True) 113 | 114 | # TODO(b/211615643): Remove the exists check once importing TFIO in S3 115 | # addresses NotFoundError. 116 | if tf.io.gfile.exists(base_temp_dir): 117 | tf.io.gfile.rmtree(base_temp_dir) 118 | 119 | # TODO(KesterTong): Move this "must follows" logic into a tfx_bsl helper 120 | # function or into Beam. 121 | return ( 122 | pipeline 123 | | 'CreateSole' >> beam.Create([None]) 124 | | 'PublishMetadataAndTransformFn' >> beam.Map( 125 | publish_outputs, 126 | metadata_source_path=beam.pvalue.AsSingleton(temp_metadata_path), 127 | transform_fn_source_path=beam.pvalue.AsSingleton( 128 | temp_transform_fn_path))) 129 | 130 | 131 | class ReadTransformFn(beam.PTransform): 132 | """Reads a TransformFn written by WriteTransformFn.""" 133 | 134 | def __init__(self, path): 135 | super().__init__() 136 | self._path = path 137 | 138 | def expand(self, pvalue): 139 | transform_fn_path = os.path.join(self._path, 140 | tft.TFTransformOutput.TRANSFORM_FN_DIR) 141 | saved_model_dir_pcoll = ( 142 | pvalue.pipeline 143 | | 'CreateTransformFnPath' >> beam.Create([transform_fn_path])) 144 | 145 | metadata = metadata_io.read_metadata( 146 | os.path.join(self._path, 147 | tft.TFTransformOutput.TRANSFORMED_METADATA_DIR)) 148 | 149 | return saved_model_dir_pcoll, metadata 150 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/tft_beam_io/transform_fn_io_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for transform_fn_io.""" 15 | 16 | import os 17 | 18 | import apache_beam as beam 19 | from apache_beam.testing import util as beam_test_util 20 | import tensorflow as tf 21 | import tensorflow_transform as tft 22 | from tensorflow_transform.beam.tft_beam_io import beam_metadata_io 23 | from tensorflow_transform.beam.tft_beam_io import transform_fn_io 24 | from tensorflow_transform.beam import tft_unit 25 | from tensorflow_transform.beam.tft_beam_io import test_metadata 26 | from tensorflow_transform.tf_metadata import metadata_io 27 | 28 | from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import 29 | 30 | mock = tf.compat.v1.test.mock 31 | # TODO(varshaan): Remove global variable and use a class attribute. 32 | _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED = False 33 | 34 | 35 | class TransformFnIoTest(tft_unit.TransformTestCase): 36 | 37 | def testReadTransformFn(self): 38 | path = self.get_temp_dir() 39 | # NOTE: we don't need to create or write to the transform_fn directory since 40 | # ReadTransformFn never inspects this directory. 41 | transform_fn_dir = os.path.join( 42 | path, tft.TFTransformOutput.TRANSFORM_FN_DIR) 43 | transformed_metadata_dir = os.path.join( 44 | path, tft.TFTransformOutput.TRANSFORMED_METADATA_DIR) 45 | metadata_io.write_metadata(test_metadata.COMPLETE_METADATA, 46 | transformed_metadata_dir) 47 | 48 | with beam.Pipeline() as pipeline: 49 | saved_model_dir_pcoll, metadata = ( 50 | pipeline | transform_fn_io.ReadTransformFn(path)) 51 | beam_test_util.assert_that( 52 | saved_model_dir_pcoll, 53 | beam_test_util.equal_to([transform_fn_dir]), 54 | label='AssertSavedModelDir') 55 | # NOTE: metadata is currently read in a non-deferred manner. 56 | self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) 57 | 58 | def testWriteTransformFn(self): 59 | transform_output_dir = os.path.join(self.get_temp_dir(), 'output') 60 | 61 | with beam.Pipeline() as pipeline: 62 | # Create an empty directory for the source saved model dir. 63 | saved_model_dir = os.path.join(self.get_temp_dir(), 'source') 64 | file_io.recursive_create_dir(saved_model_dir) 65 | saved_model_dir_pcoll = ( 66 | pipeline | 'CreateSavedModelDir' >> beam.Create([saved_model_dir])) 67 | # Combine test metadata with a dict of PCollections resolving futures. 68 | deferred_metadata = pipeline | 'CreateDeferredMetadata' >> beam.Create( 69 | [test_metadata.COMPLETE_METADATA]) 70 | metadata = beam_metadata_io.BeamDatasetMetadata( 71 | test_metadata.INCOMPLETE_METADATA, deferred_metadata, {}) 72 | 73 | _ = ((saved_model_dir_pcoll, metadata) 74 | | transform_fn_io.WriteTransformFn(transform_output_dir)) 75 | 76 | # Test reading with TFTransformOutput 77 | tf_transform_output = tft.TFTransformOutput(transform_output_dir) 78 | metadata = tf_transform_output.transformed_metadata 79 | self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) 80 | 81 | transform_fn_dir = tf_transform_output.transform_savedmodel_dir 82 | self.assertTrue(file_io.file_exists(transform_fn_dir)) 83 | self.assertTrue(file_io.is_directory(transform_fn_dir)) 84 | 85 | def testWriteTransformFnIsIdempotent(self): 86 | transform_output_dir = os.path.join(self.get_temp_dir(), 'output') 87 | 88 | def mock_write_metadata_expand(unused_self, unused_metadata): 89 | raise ArithmeticError('Some error') 90 | 91 | with beam.Pipeline() as pipeline: 92 | # Create an empty directory for the source saved model dir. 93 | saved_model_dir = os.path.join(self.get_temp_dir(), 'source') 94 | saved_model_dir_pcoll = ( 95 | pipeline | 'CreateSavedModelDir' >> beam.Create([saved_model_dir])) 96 | 97 | with mock.patch.object(transform_fn_io.beam_metadata_io.WriteMetadata, 98 | 'expand', mock_write_metadata_expand): 99 | with self.assertRaisesRegex(ArithmeticError, 'Some error'): 100 | _ = ((saved_model_dir_pcoll, object()) 101 | | transform_fn_io.WriteTransformFn(transform_output_dir)) 102 | 103 | self.assertFalse(file_io.file_exists(transform_output_dir)) 104 | 105 | def testWriteTransformFnIsRetryable(self): 106 | tft.test_case.skip_if_external_environment( 107 | 'Retries are currently not available on this environment.') 108 | original_copy_tree_to_unique_temp_dir = ( 109 | transform_fn_io._copy_tree_to_unique_temp_dir) 110 | 111 | def mock_copy_tree_to_unique_temp_dir(source, base_temp_dir_path): 112 | """Mocks transform_fn_io._copy_tree to fail the first time it is called by this test, thus forcing a retry which should succeed.""" 113 | global _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED 114 | if not _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED: 115 | _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED = True 116 | original_copy_tree_to_unique_temp_dir(source, base_temp_dir_path) 117 | raise ArithmeticError('Some error') 118 | return original_copy_tree_to_unique_temp_dir(source, base_temp_dir_path) 119 | 120 | with self._makeTestPipeline() as pipeline: 121 | transform_output_dir = os.path.join(self.get_temp_dir(), 'output') 122 | # Create an empty directory for the source saved model dir. 123 | saved_model_dir = os.path.join(self.get_temp_dir(), 'source') 124 | file_io.recursive_create_dir(saved_model_dir) 125 | saved_model_path = os.path.join(saved_model_dir, 'saved_model') 126 | with file_io.FileIO(saved_model_path, mode='w') as f: 127 | f.write('some content') 128 | saved_model_dir_pcoll = ( 129 | pipeline | 'CreateSavedModelDir' >> beam.Create([saved_model_dir])) 130 | # Combine test metadata with a dict of PCollections resolving futures. 131 | deferred_metadata = pipeline | 'CreateDeferredMetadata' >> beam.Create( 132 | [test_metadata.COMPLETE_METADATA]) 133 | metadata = beam_metadata_io.BeamDatasetMetadata( 134 | test_metadata.INCOMPLETE_METADATA, deferred_metadata, {}) 135 | with mock.patch.object(transform_fn_io, '_copy_tree_to_unique_temp_dir', 136 | mock_copy_tree_to_unique_temp_dir): 137 | _ = ((saved_model_dir_pcoll, metadata) 138 | | transform_fn_io.WriteTransformFn(transform_output_dir)) 139 | 140 | # Test reading with TFTransformOutput 141 | tf_transform_output = tft.TFTransformOutput(transform_output_dir) 142 | metadata = tf_transform_output.transformed_metadata 143 | self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) 144 | 145 | transform_fn_dir = tf_transform_output.transform_savedmodel_dir 146 | self.assertTrue(file_io.file_exists(transform_fn_dir)) 147 | self.assertTrue(file_io.is_directory(transform_fn_dir)) 148 | # Check temp directory created by failed run was cleaned up. 149 | self.assertEqual(2, len(file_io.list_directory(transform_output_dir))) 150 | 151 | 152 | if __name__ == '__main__': 153 | tf.test.main() 154 | -------------------------------------------------------------------------------- /tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2020 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for tfrecord_gzip tft.vocabulary and tft.compute_and_apply_vocabulary.""" 16 | 17 | from tensorflow_transform.beam import vocabulary_integration_test 18 | from tensorflow_transform.beam import tft_unit 19 | 20 | 21 | class TFRecordVocabularyIntegrationTest( 22 | vocabulary_integration_test.VocabularyIntegrationTest): 23 | 24 | def _VocabFormat(self): 25 | return 'tfrecord_gzip' 26 | 27 | 28 | if __name__ == '__main__': 29 | tft_unit.main() 30 | -------------------------------------------------------------------------------- /tensorflow_transform/coders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module level imports for tensorflow_transform.coders.""" 15 | 16 | from tensorflow_transform.coders.csv_coder import CsvCoder 17 | from tensorflow_transform.coders.example_proto_coder import ExampleProtoCoder 18 | -------------------------------------------------------------------------------- /tensorflow_transform/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Constants and types shared by tf.Transform package.""" 15 | 16 | import collections 17 | import contextlib 18 | import functools 19 | from typing import Any, Callable, Generator 20 | 21 | import tensorflow as tf 22 | 23 | from tensorflow.python.util import tf_decorator # pylint: disable=g-direct-tensorflow-import 24 | 25 | ANALYZER_COLLECTION = 'tft_analyzer_use' 26 | MAPPER_COLLECTION = 'tft_mapper_use' 27 | 28 | ANNOTATION_PREFIX_URL = 'type.googleapis.com' 29 | 30 | # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds. 31 | try: 32 | from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top, unused-import 33 | IS_ANNOTATIONS_PB_AVAILABLE = True 34 | except ImportError: 35 | IS_ANNOTATIONS_PB_AVAILABLE = False 36 | 37 | _in_logging_context = False 38 | 39 | 40 | @contextlib.contextmanager 41 | def logging_context() -> Generator[None, None, None]: 42 | global _in_logging_context 43 | _in_logging_context = True 44 | try: 45 | yield 46 | finally: 47 | _in_logging_context = False 48 | 49 | 50 | def log_api_use( 51 | collection_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: 52 | """Creates a decorator that logs function calls in the tensorflow graph.""" 53 | 54 | def decorator(fn): 55 | """Logs function calls in a tensorflow graph collection.""" 56 | 57 | @functools.wraps(fn) 58 | def wrapped_fn(*args, **kwargs): 59 | if not _in_logging_context: 60 | with logging_context(): 61 | graph = tf.compat.v1.get_default_graph() 62 | # Collection is a list that contains a single Counter of {name: count} 63 | # Note: We aggregate counts of function calls instead having one 64 | # collection item per call, since TFT users can use an arbitrarily 65 | # large number of analyzers and mappers and we don't want the graph 66 | # to get too big. 67 | # TODO(rachelim): Make this collection serializable so it can be added 68 | # to the SavedModel. 69 | collection = graph.get_collection_ref(collection_name) 70 | if not collection: 71 | collection.append(collections.Counter()) 72 | collection[0][fn.__name__] += 1 73 | return fn(*args, **kwargs) 74 | else: 75 | return fn(*args, **kwargs) 76 | 77 | # We use tf_decorator here so that TF can correctly introspect into 78 | # functions for docstring generation. 79 | return tf_decorator.make_decorator(fn, wrapped_fn) 80 | 81 | return decorator 82 | -------------------------------------------------------------------------------- /tensorflow_transform/common_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.common.""" 15 | 16 | import tensorflow as tf 17 | from tensorflow_transform import common 18 | from tensorflow_transform import test_case 19 | 20 | 21 | class CommonTest(test_case.TransformTestCase): 22 | 23 | def testLogAPIUse(self): 24 | 25 | @common.log_api_use("test_collection") 26 | def fn0(): 27 | return None 28 | 29 | @common.log_api_use("test_collection") 30 | def fn1(): 31 | return None 32 | 33 | @common.log_api_use("another_collection") 34 | def fn2(): 35 | return None 36 | 37 | with tf.compat.v1.Graph().as_default() as graph: 38 | fn0() 39 | fn1() 40 | fn2() 41 | fn0() 42 | fn0() 43 | 44 | self.assertAllEqual([{"fn0": 3, "fn1": 1}], 45 | graph.get_collection("test_collection")) 46 | self.assertAllEqual([{"fn2": 1}], 47 | graph.get_collection("another_collection")) 48 | 49 | def testLogAPIUseWithNestedFunction(self): 50 | """Tests that API call is not logged when called from another logged API.""" 51 | 52 | @common.log_api_use("test_collection") 53 | def fn0(): 54 | fn1() 55 | return fn2() 56 | 57 | @common.log_api_use("test_collection") 58 | def fn1(): 59 | return None 60 | 61 | @common.log_api_use("another_collection") 62 | def fn2(): 63 | return None 64 | 65 | with tf.compat.v1.Graph().as_default() as graph: 66 | fn0() 67 | 68 | self.assertEqual([{"fn0": 1}], graph.get_collection("test_collection")) 69 | self.assertAllEqual([], graph.get_collection("another_collection")) 70 | 71 | 72 | if __name__ == "__main__": 73 | test_case.main() 74 | -------------------------------------------------------------------------------- /tensorflow_transform/common_types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common types in tf.transform.""" 15 | 16 | from typing import Any, Dict, Iterable, List, TypeVar, Union, Optional 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from typing_extensions import Literal 21 | 22 | from tensorflow_metadata.proto.v0 import schema_pb2 23 | 24 | # Demonstrational per-row data formats. 25 | PrimitiveType = Union[str, bytes, float, int] 26 | InstanceValueType = Optional[ 27 | Union[np.ndarray, np.generic, PrimitiveType, List[Any]] 28 | ] 29 | InstanceDictType = Dict[str, InstanceValueType] 30 | 31 | # TODO(b/185719271): Define BucketBoundariesType at module level of mappers.py. 32 | BucketBoundariesType = Union[tf.Tensor, Iterable[Union[int, float]]] 33 | 34 | FeatureSpecType = Union[tf.io.FixedLenFeature, tf.io.VarLenFeature, 35 | tf.io.SparseFeature, tf.io.RaggedFeature] 36 | 37 | DomainType = Union[schema_pb2.IntDomain, schema_pb2.FloatDomain, 38 | schema_pb2.StringDomain] 39 | TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor] 40 | ConsistentTensorType = TypeVar( # pylint: disable=invalid-name 41 | 'ConsistentTensorType', tf.Tensor, tf.SparseTensor, tf.RaggedTensor) 42 | SparseTensorValueType = Union[tf.SparseTensor, tf.compat.v1.SparseTensorValue] 43 | RaggedTensorValueType = Union[tf.RaggedTensor, 44 | tf.compat.v1.ragged.RaggedTensorValue] 45 | TensorValueType = Union[tf.Tensor, np.ndarray, SparseTensorValueType, 46 | RaggedTensorValueType] 47 | TemporaryAnalyzerOutputType = Union[tf.Tensor, tf.saved_model.Asset] 48 | VocabularyFileFormatType = Literal['text', 'tfrecord_gzip'] 49 | -------------------------------------------------------------------------------- /tensorflow_transform/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module level imports for tensorflow_transform.experimental.""" 15 | 16 | from tensorflow_transform.experimental.analyzers import * 17 | from tensorflow_transform.experimental.annotators import * 18 | from tensorflow_transform.experimental.mappers import * 19 | -------------------------------------------------------------------------------- /tensorflow_transform/experimental/annotators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Experimental APIs to get annotations.""" 15 | 16 | from typing import Sequence, Union 17 | 18 | import tensorflow as tf 19 | from tensorflow_transform import annotators 20 | from tensorflow_transform import schema_inference 21 | 22 | from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import 23 | 24 | 25 | __all__ = [ 26 | 'get_vocabulary_size_by_name', 27 | 'annotate_sparse_output_shape', 28 | 'annotate_true_sparse_output', 29 | ] 30 | 31 | 32 | def get_vocabulary_size_by_name(vocab_filename: str) -> tf.Tensor: 33 | # pyformat: disable 34 | """Gets the size of a vocabulary created using `tft.vocabulary`. 35 | 36 | This is the number of keys in the output `vocab_filename` and does not include 37 | number of OOV buckets. 38 | 39 | Args: 40 | vocab_filename: The name of the vocabulary file whose size is to be 41 | retrieved. 42 | 43 | Example: 44 | 45 | >>> def preprocessing_fn(inputs): 46 | ... num_oov_buckets = 1 47 | ... x_int = tft.compute_and_apply_vocabulary( 48 | ... inputs['x'], vocab_filename='my_vocab', 49 | ... num_oov_buckets=num_oov_buckets) 50 | ... depth = ( 51 | ... tft.experimental.get_vocabulary_size_by_name('my_vocab') + 52 | ... num_oov_buckets) 53 | ... x_encoded = tf.one_hot( 54 | ... x_int, depth=tf.cast(depth, tf.int32), dtype=tf.int64) 55 | ... return {'x_encoded': x_encoded} 56 | >>> raw_data = [dict(x='foo'), dict(x='foo'), dict(x='bar')] 57 | >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.string)) 58 | >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) 59 | >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): 60 | ... transformed_dataset, transform_fn = ( 61 | ... (raw_data, raw_data_metadata) 62 | ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) 63 | >>> transformed_data, transformed_metadata = transformed_dataset 64 | >>> transformed_data 65 | [{'x_encoded': array([1, 0, 0])}, {'x_encoded': array([1, 0, 0])}, 66 | {'x_encoded': array([0, 1, 0])}] 67 | 68 | Returns: 69 | An integer tensor containing the size of the requested vocabulary. 70 | 71 | Raises: 72 | ValueError: if no vocabulary size found for the given `vocab_filename`. 73 | 74 | """ 75 | # pyformat: enable 76 | vocabulary_sizes_coll = ops.get_default_graph().get_collection( 77 | annotators.VOCABULARY_SIZE_BY_NAME_COLLECTION) 78 | 79 | result = dict(vocabulary_sizes_coll).get(vocab_filename, None) 80 | 81 | if result is None: 82 | raise ValueError( 83 | f'Vocabulary size not found for {vocab_filename}. If this vocabulary ' 84 | 'was created using `tft.vocabulary`, this should be the same as the ' 85 | '`vocab_filename` argument passed to it.') 86 | 87 | return result 88 | 89 | 90 | def annotate_sparse_output_shape( 91 | tensor: tf.SparseTensor, shape: Union[Sequence[int], tf.Tensor]): 92 | """Annotates a sparse output to have a given dense_shape. 93 | 94 | Args: 95 | tensor: An `SparseTensor` to be annotated. 96 | shape: A dense_shape to annotate `tensor` with. Note that this shape does 97 | not include batch_size. 98 | """ 99 | if not isinstance(shape, tf.Tensor): 100 | if (tensor.shape.rank > 1 and tensor.shape.rank - 1 != len(shape)) or ( 101 | tensor.shape.rank == 1 and len(shape) != 1): 102 | raise ValueError( 103 | f'Annotated shape {shape} was expected to have rank' 104 | f' {tensor.shape.rank - 1}') 105 | if not all(a is None or a <= b for a, b in zip(tensor.shape[1:], shape)): 106 | raise ValueError( 107 | f'Shape {shape} cannot contain annotated tensor {tensor}') 108 | shape = tf.convert_to_tensor(shape, dtype=tf.int64) 109 | elif shape.shape.rank > 1 or ( 110 | shape.shape.rank == 1 and shape.shape[0] != tensor.shape.rank - 1): 111 | raise ValueError( 112 | f'Annotation shape has rank {shape.shape.rank} but expected to have' 113 | f' rank {tensor.shape.rank - 1}') 114 | if shape.shape.rank < 1: 115 | shape = tf.expand_dims(shape, -1) 116 | # There's currently no way to override SparseTensor.dense_shape directly, 117 | # unless composing and returning a new SparseTensor. 118 | tensor._dense_shape = tf.concat( # pylint: disable=protected-access 119 | [tf.expand_dims(tensor.dense_shape[0], -1), tf.cast(shape, tf.int64)], 120 | axis=0) 121 | schema_inference.annotate_sparse_output_shape(tensor, shape) 122 | 123 | 124 | def annotate_true_sparse_output(tensor: tf.SparseTensor): 125 | """Annotates a sparse output to be truely sparse and not varlen.""" 126 | schema_inference.annotate_true_sparse_output(tensor) 127 | -------------------------------------------------------------------------------- /tensorflow_transform/graph_context.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Context manager for TF Graph when it is being traced.""" 15 | 16 | import os 17 | import threading 18 | from typing import Any, Dict, Optional 19 | 20 | import tensorflow as tf 21 | # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` 22 | # once the Spark issue is resolved. 23 | from tfx_bsl.types import tfx_namedtuple 24 | 25 | 26 | class TFGraphContext: 27 | """A context manager to pass global state to a TF graph when it is traced. 28 | 29 | All the attributes in this context are kept on a thread local state. 30 | 31 | Attributes: 32 | module_to_export: A tf.Module object that can be exported to a SavedModel 33 | and will be used to track objects created within this TF graph. 34 | temp_dir: The base path of the directory to write out any temporary files 35 | in this context block. If None, the TF graph in this context will be 36 | traced with placeholders for asset filepaths and is not serializable to a 37 | SavedModel. 38 | evaluated_replacements: A subset of placeholders/temporary asset files in 39 | `analyzer_nodes.TENSOR_REPLACEMENTS` that have been evaluated in 40 | previous TFT phases. 41 | 42 | Note that the temp dir should be accessible to worker jobs, e.g. if running 43 | with the Cloud Dataflow runner, the temp dir should be on GCS and should have 44 | permissions that allow both launcher and workers to access it. 45 | """ 46 | 47 | class _State( 48 | tfx_namedtuple.namedtuple('_State', [ 49 | 'module_to_export', 50 | 'temp_dir', 51 | 'evaluated_replacements', 52 | ])): 53 | """A named tuple storing state passed to this context manager.""" 54 | 55 | @classmethod 56 | def make_empty(cls): 57 | """Return `_State` object with all fields set to `None`.""" 58 | return cls(*(None,) * len(cls._fields)) 59 | 60 | _TEMP_SUBDIR = 'analyzer_temporary_assets' 61 | 62 | _thread_local = threading.local() 63 | 64 | def __init__(self, 65 | module_to_export: tf.Module, 66 | temp_dir: Optional[str] = None, 67 | evaluated_replacements: Optional[Dict[str, Any]] = None): 68 | self._module_to_export = module_to_export 69 | self._temp_dir = temp_dir 70 | self._evaluated_replacements = evaluated_replacements 71 | 72 | def __enter__(self): 73 | assert getattr(self._thread_local, 'current_state', None) is None 74 | self._thread_local.current_state = self._State( 75 | module_to_export=self._module_to_export, 76 | temp_dir=self._temp_dir, 77 | evaluated_replacements=self._evaluated_replacements) 78 | 79 | def __exit__(self, *exn_info): 80 | self._thread_local.current_state = None 81 | 82 | @property 83 | def module_to_export(self): 84 | return self._module_to_export 85 | 86 | @classmethod 87 | def _get_current_state(cls) -> 'TFGraphContext._State': 88 | if hasattr(cls._thread_local, 'current_state'): 89 | return cls._thread_local.current_state 90 | return cls._State.make_empty() 91 | 92 | @classmethod 93 | def get_or_create_temp_dir(cls) -> Optional[str]: 94 | """Generate a temporary location.""" 95 | current_state = cls._get_current_state() 96 | if current_state.temp_dir is None: 97 | return None 98 | if not current_state.temp_dir: 99 | raise ValueError('A temp dir was requested, but empty temp_dir was set. ' 100 | 'Use the TFGraphContext context manager.') 101 | result = os.path.join(current_state.temp_dir, cls._TEMP_SUBDIR) 102 | tf.io.gfile.makedirs(result) 103 | return result 104 | 105 | @classmethod 106 | def get_evaluated_replacements(cls) -> Optional[Dict[str, Any]]: 107 | """Retrieves the value of evaluated_replacements if set. 108 | 109 | None otherwise. 110 | 111 | Returns: 112 | A dictionary from graph tensor names to evaluated values for these 113 | tensors. The keys are a subset of placeholders/temporary asset files in 114 | `analyzer_nodes.TENSOR_REPLACEMENTS` that have been evaluated in 115 | previous TFT phases. 116 | """ 117 | return cls._get_current_state().evaluated_replacements 118 | 119 | @classmethod 120 | def get_module_to_export(cls) -> Optional[tf.Module]: 121 | """Retrieves the value of module_to_export. 122 | 123 | None if called outside a TFGraphContext scope. 124 | 125 | Returns: 126 | A tf.Module object 127 | """ 128 | return cls._get_current_state().module_to_export 129 | -------------------------------------------------------------------------------- /tensorflow_transform/info_theory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for information-theoretic preprocessing algorithms.""" 15 | 16 | import math 17 | 18 | # math.log2 was added in Python 3.3 19 | log2 = getattr(math, 'log2', lambda x: math.log(x, 2)) 20 | 21 | 22 | # TODO(b/157302701): Evaluate optimizations or approximations for this function, 23 | # in particular the _hypergeometric_pmf. 24 | def calculate_partial_expected_mutual_information(n, x_i, y_j): 25 | """Calculates the partial expected mutual information (EMI) of two variables. 26 | 27 | EMI reflects the MI expected by chance, and is used to compute adjusted 28 | mutual information. See www.wikipedia.org/wiki/Adjusted_mutual_information. 29 | 30 | The EMI for two variables x and y, is the sum of the expected mutual info 31 | for each value of x with each value of y. This function computes the EMI 32 | for a single value of each variable (x_i, y_j) and is thus considered a 33 | partial EMI calculation. 34 | 35 | Specifically: 36 | EMI(x, y) = sum_{n_ij = max(0, x_i + y_j - n) to min(x_i, y_j)} ( 37 | n_ij / n * log2((n * n_ij / (x_i * y_j)) 38 | * ((x_i! * y_j! * (n - x_i)! * (n - y_j)!) / 39 | (n! * n_ij! * (x_i - n_ij)! * (y_j - n_ij)! * (n - x_i - y_j + n_ij)!))) 40 | where n_ij is the joint count of x taking on value i and y taking on 41 | value j, x_i is the count for x taking on value i, y_j is the count for y 42 | taking on value j, and n represents total count. 43 | 44 | Args: 45 | n: The sum of weights for all values. 46 | x_i: The sum of weights for the first variable taking on value i 47 | y_j: The sum of weights for the second variable taking on value j 48 | 49 | Returns: 50 | Calculated expected mutual information for x_i, y_j. 51 | """ 52 | if x_i == 0 or y_j == 0: 53 | return 0 54 | coefficient = (-log2(x_i) - log2(y_j) + log2(n)) 55 | sum_probability = 0.0 56 | partial_result = 0.0 57 | for n_j, p_j in _hypergeometric_pmf(n, x_i, y_j): 58 | if n_j != 0: 59 | partial_result += n_j * (coefficient + log2(n_j)) * p_j 60 | sum_probability += p_j 61 | # The values of p_j should sum to 1, but given approximate calculations for 62 | # log2(x) and exp2(x) with large x, the full pmf might not sum to exactly 1. 63 | # We correct for this by dividing by the sum of the probabilities. 64 | return partial_result / sum_probability 65 | 66 | 67 | def calculate_partial_mutual_information(n_ij, x_i, y_j, n): 68 | """Calculates Mutual Information for x=i, y=j from sample counts. 69 | 70 | The standard formulation of mutual information is: 71 | MI(X,Y) = Sum_i,j {p_ij * log2(p_ij / p_i * p_j)} 72 | We are operating over counts (p_ij = n_ij / n), so this is transformed into 73 | MI(X,Y) = Sum_i,j {n_ij * (log2(n_ij) + log2(n) - log2(x_i) - log2(y_j))} / n 74 | This function returns the argument to the summation, the mutual information 75 | for a particular pair of values x_i, y_j (the caller is expected to divide 76 | the summation by n to compute the final mutual information result). 77 | 78 | Args: 79 | n_ij: The co-occurrence of x=i and y=j 80 | x_i: The frequency of x=i. 81 | y_j: The frequency of y=j. 82 | n: The total # observations 83 | 84 | Returns: 85 | Mutual information for the cell x=i, y=j. 86 | """ 87 | if n_ij == 0 or x_i == 0 or y_j == 0: 88 | return 0 89 | return n_ij * ((log2(n_ij) + log2(n)) - 90 | (log2(x_i) + log2(y_j))) 91 | 92 | 93 | def _hypergeometric_pmf(n, x_i, y_j): 94 | """Probablity for expectation computation under hypergeometric distribution. 95 | 96 | Args: 97 | n: The sum of weights for all values. 98 | x_i: The sum of weights for the first variable taking on value i 99 | y_j: The sum of weights for the second variable taking on value j 100 | 101 | Yields: 102 | The probability p_j at point n_j in the hypergeometric distribution. 103 | """ 104 | start = int(round(max(0, x_i + y_j - n))) 105 | end = int(round(min(x_i, y_j))) 106 | # Use log factorial to preserve calculation precision. 107 | # Note: because the factorials are expensive to compute, we compute the 108 | # denominator incrementally, at the cost of some readability. 109 | numerator = ( 110 | _logfactorial(x_i) + _logfactorial(y_j) + _logfactorial(n - x_i) + 111 | _logfactorial(n - y_j)) 112 | denominator = ( 113 | _logfactorial(n) + _logfactorial(start) + _logfactorial(x_i - start) + 114 | _logfactorial(y_j - start) + _logfactorial(n - x_i - y_j + start)) 115 | for n_j in range(start, end + 1): 116 | p_j = math.exp(numerator - denominator) 117 | if n_j != end: 118 | denominator += ( 119 | math.log(n_j + 1) - math.log(x_i - n_j) - math.log(y_j - n_j) + 120 | math.log(n - x_i - y_j + n_j + 1)) 121 | yield n_j, p_j 122 | 123 | 124 | def _logfactorial(n): 125 | """Calculate natural logarithm of n!.""" 126 | return math.lgamma(n + 1) 127 | -------------------------------------------------------------------------------- /tensorflow_transform/info_theory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.info_theory.""" 15 | 16 | from tensorflow_transform import info_theory 17 | from tensorflow_transform import test_case 18 | 19 | 20 | import unittest 21 | 22 | 23 | EPSILON = 1e-4 24 | 25 | 26 | def _make_hypergeometric_pmf_sum_up_to_one_parameters(): 27 | start = 1000 28 | end = 10000 29 | range_length = end - start 30 | num_chunks = 15 31 | assert range_length % num_chunks == 0 32 | chunk_size = int(range_length / num_chunks) 33 | sub_ranges = [(x, x + chunk_size) for x in range(start, end, chunk_size)] 34 | return [ # pylint: disable=g-complex-comprehension 35 | dict( 36 | testcase_name='{}_to_{}'.format(a, b), 37 | test_range=range(a, b), 38 | n=end, 39 | y_j=start) for a, b in sub_ranges 40 | ] 41 | 42 | 43 | class InfoTheoryTest(test_case.TransformTestCase): 44 | 45 | def testHypergeometricPmf(self): 46 | expected_results = [(0, 0.75), (1, 0.25)] 47 | results = list(info_theory._hypergeometric_pmf(4, 1, 1)) 48 | for expected_result, result in zip(expected_results, results): 49 | self.assertEqual(expected_result[0], result[0]) 50 | self.assertNear(expected_result[1], result[1], EPSILON) 51 | 52 | def testHypergeometricPmf_LargeN(self): 53 | expected_results = [(0, 0.9508937), (1, 0.0482198), (2, 0.0008794), 54 | (3, 7.1e-06), (4, 2.5e-08), (5, 0.0)] 55 | results = list(info_theory._hypergeometric_pmf(1000, 5, 10)) 56 | for expected_result, result in zip(expected_results, results): 57 | self.assertEqual(expected_result[0], result[0]) 58 | self.assertNear(expected_result[1], result[1], EPSILON) 59 | 60 | @test_case.named_parameters( 61 | *_make_hypergeometric_pmf_sum_up_to_one_parameters()) 62 | def test_hypergeometric_pmf_sum_up_to_one(self, test_range, n, y_j): 63 | for x in test_range: 64 | probs = [prob for _, prob in info_theory._hypergeometric_pmf(n, x, y_j)] 65 | sum_prob = sum(probs) 66 | self.assertNear(sum_prob, 1.0, EPSILON) 67 | 68 | @test_case.named_parameters( 69 | dict( 70 | testcase_name='all_co_occur', 71 | n=10, 72 | x_i=10, 73 | y_j=10, 74 | expected=0, 75 | ), 76 | dict( 77 | testcase_name='2_co_occur_no_observations', 78 | n=10, 79 | x_i=0, 80 | y_j=0, 81 | expected=0, 82 | ), 83 | dict( 84 | testcase_name='2_values_appear_half_the_time', 85 | n=10, 86 | x_i=5, 87 | y_j=5, 88 | expected=0.215411, 89 | ), 90 | dict( 91 | testcase_name='2_values_differing_frequencies', 92 | n=10, 93 | x_i=2, 94 | y_j=4, 95 | expected=0.524209, 96 | ), 97 | ) 98 | def test_calculate_partial_expected_mutual_information( 99 | self, n, x_i, y_j, expected): 100 | self.assertNear( 101 | info_theory.calculate_partial_expected_mutual_information(n, x_i, y_j), 102 | expected, EPSILON) 103 | 104 | @test_case.named_parameters( 105 | dict( 106 | testcase_name='strongly_positive_mi', 107 | cell_count=2, 108 | row_count=10, 109 | col_count=2, 110 | total_count=14, 111 | expected_mi=0.970854), 112 | dict( 113 | testcase_name='weakly_positive_mi', 114 | cell_count=4, 115 | row_count=15, 116 | col_count=6, 117 | total_count=25, 118 | expected_mi=0.608012), 119 | dict( 120 | testcase_name='strongly_negative_mi', 121 | cell_count=2, 122 | row_count=10, 123 | col_count=6, 124 | total_count=25, 125 | expected_mi=-0.526069), 126 | dict( 127 | testcase_name='weakly_negative_mi', 128 | cell_count=3, 129 | row_count=31, 130 | col_count=4, 131 | total_count=41, 132 | expected_mi=-0.0350454), 133 | dict( 134 | testcase_name='zero_mi', 135 | cell_count=4, 136 | row_count=8, 137 | col_count=8, 138 | total_count=16, 139 | expected_mi=0), 140 | dict( 141 | testcase_name='invalid_input_zero_cell_count', 142 | cell_count=4, 143 | row_count=0, 144 | col_count=8, 145 | total_count=8, 146 | expected_mi=0), 147 | dict( 148 | testcase_name='invalid_input_zero_row_count', 149 | cell_count=4, 150 | row_count=0, 151 | col_count=8, 152 | total_count=8, 153 | expected_mi=0), 154 | dict( 155 | testcase_name='invalid_input_zero_col_count', 156 | cell_count=4, 157 | row_count=8, 158 | col_count=0, 159 | total_count=8, 160 | expected_mi=0), 161 | ) 162 | def test_mutual_information(self, cell_count, row_count, col_count, 163 | total_count, expected_mi): 164 | per_cell_mi = info_theory.calculate_partial_mutual_information( 165 | cell_count, row_count, col_count, total_count) 166 | self.assertNear(per_cell_mi, expected_mi, EPSILON) 167 | 168 | 169 | if __name__ == '__main__': 170 | unittest.main() 171 | -------------------------------------------------------------------------------- /tensorflow_transform/inspect_preprocessing_fn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for inspecting users' preprocessing_fns.""" 15 | 16 | import itertools 17 | from typing import Callable, List, Mapping, Union 18 | 19 | import tensorflow as tf 20 | from tensorflow_transform import analyzer_nodes 21 | from tensorflow_transform import common_types 22 | from tensorflow_transform import graph_tools 23 | from tensorflow_transform import impl_helper 24 | from tensorflow_transform import nodes 25 | from tensorflow_transform import tf2_utils 26 | 27 | 28 | def get_analyze_input_columns( 29 | preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], 30 | Mapping[str, common_types.TensorType]], 31 | specs: Mapping[str, Union[common_types.FeatureSpecType, tf.TypeSpec]], 32 | force_tf_compat_v1: bool = False) -> List[str]: 33 | """Return columns that are required inputs of `AnalyzeDataset`. 34 | 35 | Args: 36 | preprocessing_fn: A tf.transform preprocessing_fn. 37 | specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is 38 | True, this can also be feature specifications. 39 | force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode. 40 | Defaults to `False`. 41 | 42 | Returns: 43 | A list of columns that are required inputs of analyzers. 44 | """ 45 | use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) 46 | if not use_tf_compat_v1: 47 | assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs 48 | graph, structured_inputs, structured_outputs = ( 49 | impl_helper.trace_preprocessing_function( 50 | preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1)) 51 | 52 | tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) 53 | visitor = graph_tools.SourcedTensorsVisitor() 54 | for tensor_sink in tensor_sinks: 55 | nodes.Traverser(visitor).visit_value_node(tensor_sink.future) 56 | 57 | if use_tf_compat_v1: 58 | control_dependency_ops = [] 59 | else: 60 | # If traced in TF2 as a tf.function, inputs that end up in control 61 | # dependencies are required for the function to execute. Return such inputs 62 | # as required inputs of analyzers as well. 63 | _, control_dependency_ops = ( 64 | tf2_utils.strip_and_get_tensors_and_control_dependencies( 65 | tf.nest.flatten(structured_outputs, expand_composites=True))) 66 | 67 | output_tensors = list( 68 | itertools.chain(visitor.sourced_tensors, control_dependency_ops)) 69 | analyze_input_tensors = graph_tools.get_dependent_inputs( 70 | graph, structured_inputs, output_tensors) 71 | return list(analyze_input_tensors.keys()) 72 | 73 | 74 | def get_transform_input_columns( 75 | preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], 76 | Mapping[str, common_types.TensorType]], 77 | specs: Mapping[str, Union[common_types.FeatureSpecType, tf.TypeSpec]], 78 | force_tf_compat_v1: bool = False) -> List[str]: 79 | """Return columns that are required inputs of `TransformDataset`. 80 | 81 | Args: 82 | preprocessing_fn: A tf.transform preprocessing_fn. 83 | specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is 84 | True, this can also be feature specifications. 85 | force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode. 86 | Defaults to `False`. 87 | 88 | Returns: 89 | A list of columns that are required inputs of the transform `tf.Graph` 90 | defined by `preprocessing_fn`. 91 | """ 92 | use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) 93 | if not use_tf_compat_v1: 94 | assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs 95 | graph, structured_inputs, structured_outputs = ( 96 | impl_helper.trace_preprocessing_function( 97 | preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1)) 98 | 99 | transform_input_tensors = graph_tools.get_dependent_inputs( 100 | graph, structured_inputs, structured_outputs) 101 | return list(transform_input_tensors.keys()) 102 | -------------------------------------------------------------------------------- /tensorflow_transform/inspect_preprocessing_fn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for inspect_preprocessing_fn.""" 15 | 16 | import tensorflow as tf 17 | from tensorflow_transform import analyzers 18 | from tensorflow_transform import inspect_preprocessing_fn 19 | from tensorflow_transform import mappers 20 | from tensorflow_transform import test_case 21 | 22 | _FEATURE_SPEC = { 23 | 'x': tf.io.FixedLenFeature([], tf.float32), 24 | 'y': tf.io.VarLenFeature(tf.int64), 25 | 's': tf.io.FixedLenFeature([], tf.string), 26 | } 27 | 28 | _TYPE_SPEC = { 29 | 'x': tf.TensorSpec([None], tf.float32), 30 | 'y': tf.SparseTensorSpec(shape=[None, None], dtype=tf.int64), 31 | 's': tf.TensorSpec([None], tf.string), 32 | } 33 | 34 | 35 | def _identity_preprocessing_fn(inputs): 36 | return inputs.copy() 37 | 38 | 39 | def _side_affect_preprocessing_fn(inputs): 40 | _ = analyzers.vocabulary(inputs['s']) 41 | return {} 42 | 43 | 44 | def _non_identity_ops_preprocessing_fn(inputs): 45 | outputs = inputs.copy() 46 | outputs['new_feature'] = tf.constant(1) 47 | return outputs 48 | 49 | 50 | def _renaming_preprocessing_fn(inputs): 51 | return {'id_{}'.format(key): value for key, value in inputs.items()} 52 | 53 | 54 | @tf.function 55 | def _plus_one(x): 56 | return x + 1 57 | 58 | 59 | def _one_phase_preprocessing_fn(inputs): 60 | x_plus_one = _plus_one(inputs['x']) 61 | subtracted = tf.sparse.add( 62 | tf.cast(inputs['y'], tf.float32), -analyzers.mean(x_plus_one)) 63 | _ = analyzers.vocabulary(inputs['s']) 64 | return {'subtracted': subtracted} 65 | 66 | 67 | def _two_phases_preprocessing_fn(inputs): 68 | x = inputs['x'] 69 | x_mean = analyzers.mean(x) 70 | x_square_deviations = tf.square(x - x_mean) 71 | x_var = analyzers.mean(x_square_deviations + analyzers.mean(inputs['y'])) 72 | x_normalized = (x - x_mean) / tf.sqrt(x_var) 73 | return { 74 | 'x_normalized': x_normalized, 75 | 's_id': mappers.compute_and_apply_vocabulary(inputs['s']) 76 | } 77 | 78 | 79 | def _preprocessing_fn_with_control_dependency(inputs): 80 | with tf.init_scope(): 81 | initializer = tf.lookup.KeyValueTensorInitializer(['foo', 'bar'], [0, 1]) 82 | table = tf.lookup.StaticHashTable(initializer, default_value=-1) 83 | # The table created here will add an automatic control dependency. 84 | s_int = table.lookup(inputs['s']) + 1 85 | 86 | # Perform some TF Ops to ensure x is part of the graph of dependencies for the 87 | # outputs. 88 | x_abs = tf.math.abs(inputs['x']) 89 | y_centered = ( 90 | tf.sparse.add( 91 | tf.cast(inputs['y'], tf.float32), -analyzers.mean(inputs['y']))) 92 | return {'s_int': s_int, 'x_abs': x_abs, 'y_centered': y_centered} 93 | 94 | 95 | class InspectPreprocessingFnTest(test_case.TransformTestCase): 96 | 97 | @test_case.named_parameters( 98 | *test_case.cross_named_parameters([ 99 | dict( 100 | testcase_name='identity', 101 | preprocessing_fn=_identity_preprocessing_fn, 102 | expected_analyze_input_columns=[], 103 | expected_transform_input_columns=['x', 'y', 's']), 104 | dict( 105 | testcase_name='side_affect', 106 | preprocessing_fn=_side_affect_preprocessing_fn, 107 | expected_analyze_input_columns=['s'], 108 | expected_transform_input_columns=[]), 109 | dict( 110 | testcase_name='non_identity_ops', 111 | preprocessing_fn=_non_identity_ops_preprocessing_fn, 112 | expected_analyze_input_columns=[], 113 | expected_transform_input_columns=['x', 'y', 's']), 114 | dict( 115 | testcase_name='feature_renaming', 116 | preprocessing_fn=_renaming_preprocessing_fn, 117 | expected_analyze_input_columns=[], 118 | expected_transform_input_columns=['x', 'y', 's']), 119 | dict( 120 | testcase_name='one_phase', 121 | preprocessing_fn=_one_phase_preprocessing_fn, 122 | expected_analyze_input_columns=['x', 's'], 123 | expected_transform_input_columns=['y']), 124 | dict( 125 | testcase_name='two_phases', 126 | preprocessing_fn=_two_phases_preprocessing_fn, 127 | expected_analyze_input_columns=['x', 'y', 's'], 128 | expected_transform_input_columns=['x', 's']) 129 | ], [ 130 | dict(testcase_name='tf_compat_v1', force_tf_compat_v1=True), 131 | dict(testcase_name='tf2', force_tf_compat_v1=False) 132 | ]), 133 | *test_case.cross_named_parameters([ 134 | dict( 135 | testcase_name='control_dependencies', 136 | preprocessing_fn=_preprocessing_fn_with_control_dependency, 137 | expected_transform_input_columns=['x', 'y', 's']) 138 | ], [ 139 | dict( 140 | testcase_name='tf_compat_v1', 141 | force_tf_compat_v1=True, 142 | expected_analyze_input_columns=['y']), 143 | dict( 144 | testcase_name='tf2', 145 | force_tf_compat_v1=False, 146 | expected_analyze_input_columns=['s', 'y']) 147 | ])) 148 | def test_column_inference(self, preprocessing_fn, 149 | expected_analyze_input_columns, 150 | expected_transform_input_columns, 151 | force_tf_compat_v1): 152 | if not force_tf_compat_v1: 153 | test_case.skip_if_not_tf2('Tensorflow 2.x required') 154 | specs = _TYPE_SPEC 155 | else: 156 | specs = _FEATURE_SPEC 157 | 158 | analyze_input_columns = ( 159 | inspect_preprocessing_fn.get_analyze_input_columns( 160 | preprocessing_fn, specs, force_tf_compat_v1)) 161 | transform_input_columns = ( 162 | inspect_preprocessing_fn.get_transform_input_columns( 163 | preprocessing_fn, specs, force_tf_compat_v1)) 164 | self.assertCountEqual(analyze_input_columns, expected_analyze_input_columns) 165 | self.assertCountEqual(transform_input_columns, 166 | expected_transform_input_columns) 167 | 168 | 169 | if __name__ == '__main__': 170 | test_case.main() 171 | -------------------------------------------------------------------------------- /tensorflow_transform/keras_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Imports keras 2.""" 15 | import os 16 | 17 | from absl import logging 18 | import tensorflow as tf 19 | 20 | if 'TF_USE_LEGACY_KERAS' not in os.environ: 21 | # Make sure we are using Keras 2. 22 | os.environ['TF_USE_LEGACY_KERAS'] = '1' 23 | elif os.environ['TF_USE_LEGACY_KERAS'] not in ('true', 'True', '1'): 24 | logging.warning( 25 | 'TF_USE_LEGACY_KERAS is set to %s, which will not use Keras 2. Tensorflow' 26 | ' Transform is only compatible with Keras 2. Please set' 27 | ' TF_USE_LEGACY_KERAS=1.', 28 | os.environ['TF_USE_LEGACY_KERAS'], 29 | ) 30 | 31 | version_fn = getattr(tf.keras, 'version', None) 32 | if version_fn and version_fn().startswith('3.'): 33 | # `tf.keras` points to `keras 3`, so use `tf_keras` package 34 | try: 35 | import tf_keras # pylint: disable=g-import-not-at-top,unused-import 36 | except ImportError: 37 | raise ImportError( # pylint: disable=raise-missing-from 38 | 'Keras 2 requires the `tf_keras` package.' 39 | 'Please install it with `pip install tf_keras`.' 40 | ) from None 41 | else: 42 | tf_keras = tf.keras # Keras 2 43 | -------------------------------------------------------------------------------- /tensorflow_transform/pickle_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Functions to fix pickling of certain objects (see b/121323638).""" 15 | 16 | import copyreg 17 | import tensorflow as tf 18 | from tensorflow_transform import common 19 | from tensorflow_metadata.proto.v0 import schema_pb2 20 | from tensorflow_metadata.proto.v0 import statistics_pb2 21 | 22 | if common.IS_ANNOTATIONS_PB_AVAILABLE: 23 | from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top 24 | 25 | _ANNOTATION_CLASSES = [ 26 | annotations_pb2.VocabularyMetadata, annotations_pb2.BucketBoundaries 27 | ] if common.IS_ANNOTATIONS_PB_AVAILABLE else [] 28 | 29 | _PROTO_CLASSES = [ 30 | tf.compat.v1.ConfigProto, 31 | schema_pb2.Schema, 32 | schema_pb2.TensorRepresentation, 33 | statistics_pb2.DatasetFeatureStatistics, 34 | ] + _ANNOTATION_CLASSES 35 | 36 | 37 | _PROTO_CLS_BY_NAME = {proto_cls.DESCRIPTOR.name: proto_cls 38 | for proto_cls in _PROTO_CLASSES} 39 | 40 | 41 | def _pickle_proto(proto): 42 | return _unpickle_proto, (proto.DESCRIPTOR.name, proto.SerializeToString()) 43 | 44 | 45 | def _unpickle_proto(name, serialized_proto): 46 | return _PROTO_CLS_BY_NAME[name].FromString(serialized_proto) 47 | 48 | 49 | def _pickle_tensor_spec(tensor_spec): 50 | return _unpickle_tensor_spec, (tensor_spec.shape.as_list(), 51 | tensor_spec.dtype.as_numpy_dtype) 52 | 53 | 54 | def _unpickle_tensor_spec(shape, numpy_dtype): 55 | return tf.TensorSpec(shape, tf.as_dtype(numpy_dtype)) 56 | 57 | 58 | def fix_internal_object_pickling(): 59 | """Fix pickling issues (see b/121323638).""" 60 | for proto_cls in _PROTO_CLASSES: 61 | copyreg.pickle(proto_cls, _pickle_proto) 62 | 63 | copyreg.pickle(tf.TensorSpec, _pickle_tensor_spec) 64 | -------------------------------------------------------------------------------- /tensorflow_transform/pretrained_models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.pretrained_models.""" 15 | 16 | import os 17 | 18 | import tensorflow as tf 19 | from tensorflow_transform import pretrained_models 20 | 21 | 22 | class PretrainedModelsTest(tf.test.TestCase): 23 | 24 | def save_model_with_single_input(self, export_dir): 25 | builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) 26 | with tf.compat.v1.Graph().as_default() as graph: 27 | with self.test_session(graph=graph) as sess: 28 | input1 = tf.compat.v1.placeholder( 29 | dtype=tf.int32, shape=[5], name='myinput') 30 | initializer = tf.compat.v1.initializers.constant([1, 2, 3, 4, 5]) 31 | with tf.compat.v1.variable_scope( 32 | 'Model', reuse=None, initializer=initializer): 33 | v1 = tf.compat.v1.get_variable('v1', [5], dtype=tf.int32) 34 | output1 = tf.add(v1, input1, name='myadd') 35 | inputs = {'single_input': input1} 36 | outputs = {'single_output': output1} 37 | signature_def_map = { 38 | 'my_signature_single_input': 39 | tf.compat.v1.saved_model.signature_def_utils 40 | .predict_signature_def(inputs, outputs) 41 | } 42 | sess.run(tf.compat.v1.global_variables_initializer()) 43 | builder.add_meta_graph_and_variables( 44 | sess, [tf.saved_model.SERVING], signature_def_map=signature_def_map) 45 | builder.save(False) 46 | 47 | def save_model_with_multi_inputs(self, export_dir): 48 | builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) 49 | with tf.compat.v1.Graph().as_default() as graph: 50 | with self.test_session(graph=graph) as sess: 51 | input1 = tf.compat.v1.placeholder( 52 | dtype=tf.int32, shape=[5], name='myinput1') 53 | input2 = tf.compat.v1.placeholder( 54 | dtype=tf.int32, shape=[5], name='myinput2') 55 | input3 = tf.compat.v1.placeholder( 56 | dtype=tf.int32, shape=[5], name='myinput3') 57 | initializer = tf.compat.v1.initializers.constant([1, 2, 3, 4, 5]) 58 | with tf.compat.v1.variable_scope( 59 | 'Model', reuse=None, initializer=initializer): 60 | v1 = tf.compat.v1.get_variable('v1', [5], dtype=tf.int32) 61 | o1 = tf.add(v1, input1, name='myadd1') 62 | o2 = tf.add(o1, input2, name='myadd2') 63 | output1 = tf.add(o2, input3, name='myadd3') 64 | inputs = {'input_name1': input1, 'input_name2': input2, 65 | 'input_name3': input3} 66 | outputs = {'single_output': output1} 67 | signature_def_map = { 68 | 'my_signature_multi_input': 69 | tf.compat.v1.saved_model.signature_def_utils 70 | .predict_signature_def(inputs, outputs) 71 | } 72 | sess.run(tf.compat.v1.global_variables_initializer()) 73 | builder.add_meta_graph_and_variables( 74 | sess, [tf.saved_model.SERVING], signature_def_map=signature_def_map) 75 | builder.save(False) 76 | 77 | def make_tensor_fn_two_inputs(self): 78 | def tensor_fn(input1, input2): 79 | initializer = tf.compat.v1.initializers.constant([1, 2, 3]) 80 | with tf.compat.v1.variable_scope( 81 | 'Model', reuse=None, initializer=initializer): 82 | v1 = tf.compat.v1.get_variable('v1', [3], dtype=tf.int64) 83 | o1 = tf.add(v1, input1, name='myadda1') 84 | o = tf.subtract(o1, input2, name='myadda2') 85 | return o 86 | return tensor_fn 87 | 88 | def save_checkpoint_with_two_inputs(self, checkpoint_path): 89 | test_tensor_fn = self.make_tensor_fn_two_inputs() 90 | with tf.compat.v1.Graph().as_default() as graph: 91 | with self.test_session(graph=graph) as sess: 92 | input1 = tf.compat.v1.placeholder( 93 | dtype=tf.int64, shape=[3], name='myinputa') 94 | input2 = tf.compat.v1.placeholder( 95 | dtype=tf.int64, shape=[3], name='myinputb') 96 | test_tensor_fn(input1, input2) 97 | saver = tf.compat.v1.train.Saver() 98 | sess.run(tf.compat.v1.global_variables_initializer()) 99 | saver.save(sess, checkpoint_path) 100 | 101 | def testApplySavedModelSingleInput(self): 102 | export_dir = os.path.join(self.get_temp_dir(), 'single_input') 103 | self.save_model_with_single_input(export_dir) 104 | with tf.compat.v1.Graph().as_default() as graph: 105 | with self.test_session(graph=graph) as sess: 106 | input_tensor = tf.compat.v1.placeholder( 107 | dtype=tf.int32, shape=[5], name='input_tensor') 108 | output_tensor = pretrained_models.apply_saved_model( 109 | export_dir, input_tensor, [tf.saved_model.SERVING]) 110 | feed_dict = {input_tensor: [2, 2, 2, 2, 2]} 111 | output_value = sess.run(output_tensor, feed_dict=feed_dict) 112 | self.assertAllEqual(output_value, [3, 4, 5, 6, 7]) 113 | 114 | def testApplySavedModelMultiInputs(self): 115 | export_dir = os.path.join(self.get_temp_dir(), 'multi_inputs') 116 | self.save_model_with_multi_inputs(export_dir) 117 | with tf.compat.v1.Graph().as_default() as graph: 118 | with self.test_session(graph=graph) as sess: 119 | input_tensor_1 = tf.compat.v1.placeholder( 120 | dtype=tf.int32, shape=[5], name='input_tensor_1') 121 | input_tensor_2 = tf.compat.v1.placeholder( 122 | dtype=tf.int32, shape=[5], name='input_tensor_2') 123 | input_tensor_3 = tf.compat.v1.placeholder( 124 | dtype=tf.int32, shape=[5], name='input_tensor_3') 125 | inputs = { 126 | 'input_name1': input_tensor_1, 127 | 'input_name2': input_tensor_2, 128 | 'input_name3': input_tensor_3 129 | } 130 | output_tensor = pretrained_models.apply_saved_model( 131 | export_dir, 132 | inputs, [tf.saved_model.SERVING], 133 | signature_name='my_signature_multi_input') 134 | feed_dict = {input_tensor_1: [2, 3, 4, 5, 6], 135 | input_tensor_2: [1, 1, 1, 1, 1], 136 | input_tensor_3: [1, 1, 1, 1, -1]} 137 | output_value = sess.run(output_tensor, feed_dict=feed_dict) 138 | self.assertAllEqual(output_value, [5, 7, 9, 11, 11]) 139 | 140 | def testApplyFunctionWithCheckpointTwoInputs(self): 141 | checkpoint = os.path.join(self.get_temp_dir(), 'checkpoint_two') 142 | self.save_checkpoint_with_two_inputs(checkpoint) 143 | with tf.compat.v1.Graph().as_default() as graph: 144 | with self.test_session(graph=graph) as sess: 145 | input1 = tf.compat.v1.placeholder( 146 | dtype=tf.int64, shape=[3], name='input1') 147 | input2 = tf.compat.v1.placeholder( 148 | dtype=tf.int64, shape=[3], name='input2') 149 | output_tensor = pretrained_models.apply_function_with_checkpoint( 150 | self.make_tensor_fn_two_inputs(), [input1, input2], checkpoint) 151 | feed_dict = {input1: [1, 2, 3], input2: [3, 2, 1]} 152 | output_value = sess.run(output_tensor, feed_dict=feed_dict) 153 | # [1, 2, 3] + [1, 2, 3] - [3, 2, 1] = [-1, 2, 5] 154 | self.assertAllEqual(output_value, [-1, 2, 5]) 155 | 156 | 157 | if __name__ == '__main__': 158 | tf.test.main() 159 | -------------------------------------------------------------------------------- /tensorflow_transform/py.typed: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. -------------------------------------------------------------------------------- /tensorflow_transform/py_func/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module level imports for tensorflow_transform.py_func.""" 15 | 16 | from tensorflow_transform.py_func.api import apply_pyfunc 17 | from tensorflow_transform.py_func.pyfunc_helper import register_pyfuncs_from_saved_transform 18 | -------------------------------------------------------------------------------- /tensorflow_transform/py_func/api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Public API for using py_funcs in TFTransform.""" 15 | 16 | from tensorflow_transform.py_func import pyfunc_helper 17 | 18 | 19 | # TODO(b/178867088): Figure out the TF2 compatibility plan for this API. 20 | def apply_pyfunc(func, Tout, stateful=True, name=None, *args): # pylint: disable=invalid-name 21 | """Applies a python function to some `Tensor`s. 22 | 23 | Applies a python function to some `Tensor`s given by the argument list. The 24 | number of arguments should match the number of inputs to the function. 25 | 26 | This function is for using inside a preprocessing_fn. It is a wrapper around 27 | `tf.py_func`. A function added this way can run in Transform, and during 28 | training when the graph is imported using the `transform_raw_features` method 29 | of the `TFTransformOutput` class. However if the resulting training graph is 30 | serialized and deserialized, then the `tf.py_func` op will not work and will 31 | cause an error. This means that TensorFlow Serving will not be able to serve 32 | this graph. 33 | 34 | The underlying reason for this limited support is that `tf.py_func` ops were 35 | not designed to be serialized since they contain a reference to arbitrary 36 | Python functions. This function pickles those functions and including them in 37 | the graph, and `transform_raw_features` similarly unpickles the functions. 38 | But unpickling requires a Python environment, so there it's not possible to 39 | provide support in non-Python languages for loading such ops. Therefore 40 | loading these ops in libraries such as TensorFlow Serving is not supported. 41 | 42 | Note: This API can only be used when TF2 is disabled or 43 | `tft_beam.Context.force_tf_compat_v1=True`. 44 | 45 | Args: 46 | func: A Python function, which accepts a list of NumPy `ndarray` objects 47 | having element types that match the corresponding `tf.Tensor` objects 48 | in `*args`, and returns a list of `ndarray` objects (or a single 49 | `ndarray`) having element types that match the corresponding values 50 | in `Tout`. 51 | Tout: A list or tuple of tensorflow data types or a single tensorflow data 52 | type if there is only one, indicating what `func` returns. 53 | stateful: (Boolean.) If True, the function should be considered stateful. 54 | If a function is stateless, when given the same input it will return the 55 | same output and have no observable side effects. Optimizations such as 56 | common subexpression elimination are only performed on stateless 57 | operations. 58 | name: A name for the operation (optional). 59 | *args: The list of `Tensor`s to apply the arguments to. 60 | Returns: 61 | A `Tensor` representing the application of the function. 62 | """ 63 | return pyfunc_helper.insert_pyfunc(func, Tout, stateful, name, *args) 64 | -------------------------------------------------------------------------------- /tensorflow_transform/py_func/pyfunc_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utility functions to use py_funcs in tf.transform.""" 15 | 16 | import dill 17 | import tensorflow as tf 18 | from tfx_bsl import beam as tfx_bsl_beam 19 | # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` 20 | # once the Spark issue is resolved. 21 | from tfx_bsl.types import tfx_namedtuple 22 | 23 | # pylint: disable=g-direct-tensorflow-import 24 | from tensorflow.core.framework import attr_value_pb2 25 | from tensorflow.python.framework import ops 26 | # pylint: enable=g-direct-tensorflow-import 27 | 28 | _PYFUNC_COLLECTION_KEY = 'pyfuncs' 29 | 30 | tfx_bsl_beam.fix_code_type_pickling() 31 | 32 | 33 | class _PyFuncDef(tfx_namedtuple.namedtuple('_PyFuncDef', ['token', 'func'])): 34 | """An internal wrapper around tuple(token, func). 35 | 36 | `token` can be either a single token (if the py_func returns a tensor), or a 37 | list of tokens (if the py_func returns a list of tensors). 38 | 39 | The main purpose of this class is to provides the two methods: 40 | `from_proto` and `to_proto` that enable storing tuple objects in the graph's 41 | collections as proto objects. 42 | """ 43 | __slots__ = () 44 | 45 | @staticmethod 46 | def from_proto(attr_value, import_scope=None): 47 | del import_scope # Unused 48 | return dill.loads(attr_value.s) 49 | 50 | @staticmethod 51 | def from_proto_string(proto_str, import_scope=None): 52 | del import_scope # Unused 53 | attr_value = attr_value_pb2.AttrValue() 54 | attr_value.ParseFromString(proto_str) 55 | return _PyFuncDef.from_proto(attr_value) 56 | 57 | def to_proto(self, export_scope=None): 58 | del export_scope # Unused 59 | result = attr_value_pb2.AttrValue() 60 | result.s = dill.dumps(self) 61 | return result 62 | 63 | # Register the pyfuncs collection to use `AttrValue` proto type. 64 | # The proto object stored in the graph collection will contain the pickled value 65 | # of a `_PyFuncDef` object as a string in its `s` field. 66 | # Note that `AttrValue` is used here only as a convenient placeholder for a 67 | # string, and does not represent the actual attributes of an `op` as in the 68 | # usual case. 69 | ops.register_proto_function(_PYFUNC_COLLECTION_KEY, 70 | proto_type=attr_value_pb2.AttrValue, 71 | to_proto=_PyFuncDef.to_proto, 72 | from_proto=_PyFuncDef.from_proto) 73 | 74 | 75 | def insert_pyfunc(func, Tout, stateful, name, *args): # pylint: disable=invalid-name 76 | """Calls tf.py_func and inserts the `func` in the internal registry.""" 77 | result = tf.compat.v1.py_func( 78 | func, inp=list(args), Tout=Tout, stateful=stateful, name=name) 79 | # A py_func can either return a tensor or a list. Since we care only about the 80 | # op, it doesn't matter which result we take. 81 | if isinstance(result, list): 82 | first_result = result[0] if result else None 83 | else: 84 | first_result = result 85 | if first_result is None: 86 | raise ValueError('func must return a tensor or list of tensors') 87 | token = first_result.op.node_def.attr['token'].s 88 | tf.compat.v1.add_to_collection(_PYFUNC_COLLECTION_KEY, 89 | _PyFuncDef(token, func)) 90 | return result 91 | 92 | 93 | def register_pyfuncs_from_saved_transform(graph, meta_graph, loaded_in_tf2): 94 | """Registers `py_func`s in the MetaGraphDef. 95 | 96 | Takes the picked `py_func`s stored in the MetaGraphDef and adds them to the 97 | graph. Registered `py_func`s are referred to internally by the token 98 | attribute of the `py_func` op. We first create some arbitrary ops which 99 | are not used, but which result in the pickled functions stored in the 100 | MetaGraphDef being registered. We then take the tokens of these newly 101 | registered functions, and remap the tokens in the MetaGraphDef to contain 102 | the new tokens for each function (this remapping is required since we cannot 103 | specify what token should be used to register a function). 104 | 105 | Args: 106 | graph: The tf.Graph into which the meta_graph_def will be imported. 107 | meta_graph: The MetaGraphDef containing the `py_func`s. All the `py_func` 108 | ops in the graph will be modified in-place to have their token point to 109 | the newly regsitered function. 110 | loaded_in_tf2: A boolean indicating whether the saved transform is being 111 | re-loaded in TF1 or TF2. 112 | 113 | Returns: 114 | Modified graph_def if pyfuncs were found, else None. 115 | 116 | Raises: 117 | ValueError if an unregistered pyfunc is encountered in `graph`. 118 | """ 119 | if _PYFUNC_COLLECTION_KEY not in meta_graph.collection_def: 120 | return None 121 | 122 | # TODO(b/35929054) to enable it in TF itself. Once supported, 123 | # we should refactor this code to remove extra work for pickling and 124 | # re-registering of the py_funcs. 125 | pyfuncs_collection = meta_graph.collection_def[_PYFUNC_COLLECTION_KEY] 126 | 127 | new_tokens_by_old_token = {} 128 | with graph.as_default(): 129 | for func_def_str in pyfuncs_collection.bytes_list.value: 130 | func_def = _PyFuncDef.from_proto_string(func_def_str) 131 | # Re-insert the original python function into the default graph. 132 | # The operation itself in the graph does not matter (hence the dummy 133 | # values for name, Tout, and stateful). This is done only to reinsert 134 | # the function body in the internal TF's function registry. 135 | # TODO(b/123241062): We should even remove this op from the graph if 136 | # possible. 137 | func_temp_name = func_def.token + b'_temp' 138 | output_tensor = insert_pyfunc( 139 | func_def.func, tf.float32, False, func_temp_name) 140 | # Store the token associated with the function associated with the call 141 | # to tf.py_func. 142 | token = output_tensor.op.get_attr('token') 143 | new_tokens_by_old_token[func_def.token] = token 144 | 145 | if loaded_in_tf2: 146 | graph_def = graph.as_graph_def() 147 | # Since we are updating the GraphDef of the graph in whose context pyfuncs 148 | # were re-inserted, new tokens will also be present. 149 | expected_tokens_in_graph_def = ( 150 | list(new_tokens_by_old_token.keys()) + 151 | list(new_tokens_by_old_token.values())) 152 | else: 153 | graph_def = meta_graph.graph_def 154 | expected_tokens_in_graph_def = new_tokens_by_old_token.keys() 155 | # Swap the old token stored for the function with the new one, if there are 156 | # any tokens to change. 157 | if new_tokens_by_old_token: 158 | for node in graph_def.node: 159 | if node.op == 'PyFunc' or node.op == 'PyFuncStateless': 160 | token = node.attr['token'] 161 | new_token = new_tokens_by_old_token.get(token.s, None) 162 | if new_token is not None: 163 | token.s = new_token 164 | else: 165 | if token.s not in expected_tokens_in_graph_def: 166 | raise ValueError(f'Function: {node.name} was not registered') 167 | return graph_def 168 | -------------------------------------------------------------------------------- /tensorflow_transform/saved/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module level imports for tensorflow_transform.saved.""" 15 | -------------------------------------------------------------------------------- /tensorflow_transform/saved/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Constants for tf.Transform SavedModels.""" 15 | 16 | # TODO(b/123243166) eventually migrate this constant to tag_constants.TRANSFORM. 17 | TRANSFORM_TAG = 'transform' 18 | 19 | TRANSFORM_SIGNATURE = 'transform_signature' 20 | -------------------------------------------------------------------------------- /tensorflow_transform/saved/saved_model_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utility functions to build input_fns for use with tf.Learn.""" 15 | 16 | from tensorflow_transform.saved import constants 17 | from tensorflow.python.saved_model import loader_impl # pylint: disable=g-direct-tensorflow-import 18 | 19 | 20 | def parse_saved_model(saved_model_dir): 21 | return loader_impl.parse_saved_model(saved_model_dir) 22 | 23 | 24 | def _choose_meta_graph_def_internal(saved_model, tags): 25 | """Find a MetaGraphDef within the SavedModel with exactly matching tags. 26 | 27 | Args: 28 | saved_model: A `SavedModel` protocol buffer. 29 | tags: Set of string tags to identify the required MetaGraphDef. These should 30 | correspond to the tags used when saving the variables using the 31 | SavedModel `save()` API. 32 | Returns: 33 | The chosen `MetaGraphDef` protocol buffer. This can be used to further 34 | extract signature-defs, collection-defs, etc. If tags cannot be found, 35 | returns None. 36 | """ 37 | result = None 38 | for meta_graph_def in saved_model.meta_graphs: 39 | if set(meta_graph_def.meta_info_def.tags) == set(tags): 40 | result = meta_graph_def 41 | break 42 | 43 | return result 44 | 45 | 46 | def choose_meta_graph_def(saved_model): 47 | """Find a MetaGraphDef in the SavedModel with tag `constants.TRANSFORM_TAG`. 48 | 49 | Args: 50 | saved_model: A `SavedModel` protocol buffer. 51 | 52 | Returns: 53 | The chosen `MetaGraphDef` protocol buffer. This can be used to further 54 | extract signature-defs, collection-defs, etc. If tags cannot be found, 55 | returns None. 56 | """ 57 | return _choose_meta_graph_def_internal(saved_model, [constants.TRANSFORM_TAG]) 58 | 59 | 60 | def choose_meta_graph_def_and_raise(saved_model): 61 | """Find a MetaGraphDef in the SavedModel with tag `constants.TRANSFORM_TAG`. 62 | 63 | Args: 64 | saved_model: A `SavedModel` protocol buffer. 65 | 66 | Returns: 67 | The chosen `MetaGraphDef` protocol buffer. This can be used to further 68 | extract signature-defs, collection-defs, etc. 69 | 70 | Raises: 71 | RuntimeError: MetaGraphDef associated with the tags cannot be found. 72 | """ 73 | result = choose_meta_graph_def(saved_model) 74 | 75 | if result is None: 76 | raise RuntimeError( 77 | 'MetaGraphDef associated with tags {} could not be found in SavedModel' 78 | .format(constants.TRANSFORM_TAG)) 79 | 80 | return result 81 | 82 | 83 | def get_asset_tensors(saved_model_dir, meta_graph_def_to_load): 84 | return loader_impl.get_asset_tensors(saved_model_dir, meta_graph_def_to_load) 85 | -------------------------------------------------------------------------------- /tensorflow_transform/saved/saved_model_loader_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for saved_model_loader.""" 15 | 16 | import os 17 | import tempfile 18 | 19 | import tensorflow as tf 20 | 21 | from tensorflow_transform.saved import saved_transform_io 22 | 23 | import unittest 24 | 25 | 26 | def _create_test_saved_model_dir(): 27 | export_path = os.path.join(tempfile.mkdtemp(), 'export') 28 | 29 | with tf.compat.v1.Graph().as_default(): 30 | with tf.compat.v1.Session().as_default() as session: 31 | input_float = tf.compat.v1.placeholder(tf.float32, shape=[1]) 32 | output = (input_float - 2.0) / 5.0 33 | inputs = {'x': input_float} 34 | outputs = {'x_scaled': output} 35 | saved_transform_io.write_saved_transform_from_session( 36 | session, inputs, outputs, export_path) 37 | 38 | return export_path 39 | 40 | 41 | class SavedModelLoaderTest(unittest.TestCase): 42 | 43 | @classmethod 44 | def setUpClass(cls): 45 | cls._test_saved_model_dir = _create_test_saved_model_dir() 46 | 47 | # This class has no tests at the moment. 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tensorflow_transform/test_case_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.test_case.""" 15 | 16 | import re 17 | 18 | from tensorflow_transform import test_case 19 | 20 | import unittest 21 | 22 | 23 | class TftUnitTest(test_case.TransformTestCase): 24 | 25 | def testCrossNamedParameters(self): 26 | test_cases_1 = [ 27 | {'testcase_name': 'a_1_b_1', 'a': 1, 'b': 1}, 28 | {'testcase_name': 'a_3_b_3', 'a': 3, 'b': 3}, 29 | ] 30 | test_cases_2 = [ 31 | {'testcase_name': 'c_2', 'c': 2}, 32 | {'testcase_name': 'c_4', 'c': 4}, 33 | ] 34 | expected_cross = [ 35 | {'testcase_name': 'a_1_b_1_c_2', 'a': 1, 'b': 1, 'c': 2}, 36 | {'testcase_name': 'a_1_b_1_c_4', 'a': 1, 'b': 1, 'c': 4}, 37 | {'testcase_name': 'a_3_b_3_c_2', 'a': 3, 'b': 3, 'c': 2}, 38 | {'testcase_name': 'a_3_b_3_c_4', 'a': 3, 'b': 3, 'c': 4}, 39 | ] 40 | self.assertEqual( 41 | test_case.cross_named_parameters(test_cases_1, test_cases_2), 42 | expected_cross) 43 | 44 | def testCrossParameters(self): 45 | test_cases_1 = [('a', 1), ('b', 2)] 46 | test_cases_2 = [(True,), (False,)] 47 | expected_cross = [ 48 | ('a', 1, True), ('b', 2, True), 49 | ('a', 1, False), ('b', 2, False), 50 | ] 51 | self.assertCountEqual( 52 | test_case.cross_parameters(test_cases_1, test_cases_2), expected_cross) 53 | 54 | def testAssertDataCloseOrEqual(self): 55 | self.assertDataCloseOrEqual([{'a': 'first', 56 | 'b': 1.0, 57 | 'c': 5, 58 | 'd': ('second', 2.0)}, 59 | {'e': 2, 60 | 'f': 3}], 61 | [{'a': 'first', 62 | 'b': 1.0000001, 63 | 'c': 5, 64 | 'd': ('second', 2.0000001)}, 65 | {'e': 2, 66 | 'f': 3}]) 67 | with self.assertRaisesRegex(AssertionError, r'len\(.*\) != len\(\[\]\)'): 68 | self.assertDataCloseOrEqual([{'a': 1}], []) 69 | with self.assertRaisesRegex( 70 | AssertionError, 71 | re.compile('Element counts were not equal.*: Row 0', re.DOTALL), 72 | ): 73 | self.assertDataCloseOrEqual([{'a': 1}], [{'b': 1}]) 74 | with self.assertRaisesRegex( 75 | AssertionError, 76 | re.compile('Not equal to tolerance.*: Row 0, key a', re.DOTALL), 77 | ): 78 | self.assertDataCloseOrEqual([{'a': 1}], [{'a': 2}]) 79 | 80 | @test_case.parameters((1, 'a'), (2, 'b')) 81 | def testSampleParametrizedTestMethod(self, my_arg, my_other_arg): 82 | self.assertIn((my_arg, my_other_arg), {(1, 'a'), (2, 'b')}) 83 | 84 | 85 | if __name__ == '__main__': 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /tensorflow_transform/tf2_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.tf2_utils.""" 15 | 16 | import itertools 17 | import tensorflow as tf 18 | from tensorflow_transform import tf2_utils 19 | from tensorflow_transform import test_case 20 | 21 | _TEST_BATCH_SIZES = [1, 10] 22 | _TEST_DTYPES = [ 23 | tf.int16, 24 | tf.int32, 25 | tf.int64, 26 | tf.float32, 27 | tf.float64, 28 | tf.string, 29 | ] 30 | 31 | _TEST_TENSORS_TYPES = [ 32 | (lambda dtype: tf.TensorSpec([None], dtype=dtype), tf.Tensor, []), 33 | (lambda dtype: tf.TensorSpec([None, 2], dtype=dtype), tf.Tensor, [2]), 34 | (lambda dtype: tf.RaggedTensorSpec([None, None], dtype=dtype), 35 | tf.RaggedTensor, [None]), 36 | ( 37 | lambda dtype: tf.RaggedTensorSpec( # pylint: disable=g-long-lambda 38 | [None, None, 2], 39 | dtype=dtype, 40 | ragged_rank=1), 41 | tf.RaggedTensor, 42 | [None, 2]), 43 | ] 44 | 45 | 46 | class TF2UtilsTest(test_case.TransformTestCase): 47 | 48 | def test_strip_and_get_tensors_and_control_dependencies(self): 49 | 50 | @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int64)]) 51 | def func(x): 52 | with tf.init_scope(): 53 | initializer_1 = tf.lookup.KeyValueTensorInitializer( 54 | [0, 1, 2], ['a', 'b', 'c'], 55 | key_dtype=tf.int64, 56 | value_dtype=tf.string) 57 | table_1 = tf.lookup.StaticHashTable(initializer_1, default_value='NAN') 58 | size = table_1.size() 59 | initializer_2 = tf.lookup.KeyValueTensorInitializer( 60 | ['a', 'b', 'c'], [-1, 0, 1], 61 | key_dtype=tf.string, 62 | value_dtype=tf.int64) 63 | table_2 = tf.lookup.StaticHashTable(initializer_2, default_value=-777) 64 | y = table_1.lookup(x) 65 | _ = table_2.lookup(y) 66 | z = x + size 67 | return {'x': x, 'z': z} 68 | 69 | concrete_function = func.get_concrete_function() 70 | flat_outputs = tf.nest.flatten( 71 | concrete_function.structured_outputs, expand_composites=True) 72 | expected_flat_outputs = [t.op.inputs[0] for t in flat_outputs] 73 | expected_control_dependencies = itertools.chain( 74 | *[t.op.control_inputs for t in flat_outputs]) 75 | new_flat_outputs, control_dependencies = ( 76 | tf2_utils.strip_and_get_tensors_and_control_dependencies(flat_outputs)) 77 | self.assertEqual(new_flat_outputs, expected_flat_outputs) 78 | self.assertEqual(control_dependencies, set(expected_control_dependencies)) 79 | 80 | @test_case.parameters(*test_case.cross_parameters( 81 | [(x,) for x in _TEST_BATCH_SIZES], 82 | [(x,) for x in _TEST_DTYPES], 83 | _TEST_TENSORS_TYPES, 84 | )) 85 | def test_supply_missing_tensor_inputs(self, batch_size, dtype, 86 | type_spec_getter, tensor_type, 87 | inner_shape): 88 | test_case.skip_if_not_tf2('Tensorflow 2.x required.') 89 | 90 | @tf.function(input_signature=[{ 91 | 'x_1': tf.TensorSpec([None], dtype=tf.int32), 92 | 'x_2': type_spec_getter(dtype), 93 | }]) 94 | def foo(inputs): 95 | return inputs 96 | 97 | conc_fn = foo.get_concrete_function() 98 | # structured_input_signature is a tuple of (args, kwargs). [0][0] retrieves 99 | # the structure of the first arg, which for `foo` is `inputs`. 100 | structured_inputs = tf.nest.pack_sequence_as( 101 | conc_fn.structured_input_signature[0][0], 102 | conc_fn.inputs, 103 | expand_composites=True) 104 | missing_keys = ['x_2'] 105 | result = tf2_utils.supply_missing_inputs(structured_inputs, batch_size, 106 | missing_keys) 107 | 108 | self.assertCountEqual(missing_keys, result.keys()) 109 | self.assertIsInstance(result['x_2'], tensor_type) 110 | self.assertEqual(result['x_2'].shape.as_list(), [batch_size] + inner_shape) 111 | self.assertEqual(result['x_2'].dtype, dtype) 112 | 113 | 114 | if __name__ == '__main__': 115 | test_case.main() 116 | -------------------------------------------------------------------------------- /tensorflow_transform/tf_metadata/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tensorflow_transform/tf_metadata/dataset_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """In-memory representation of all metadata associated with a dataset.""" 15 | 16 | from typing import Mapping, Optional, Type, TypeVar 17 | 18 | from tensorflow_transform import common_types 19 | from tensorflow_transform.tf_metadata import schema_utils 20 | from tensorflow_metadata.proto.v0 import schema_pb2 21 | 22 | _DatasetMetadataType = TypeVar('_DatasetMetadataType', bound='DatasetMetadata') 23 | 24 | 25 | class DatasetMetadata: 26 | """Metadata about a dataset used for the "instance dict" format. 27 | 28 | Caution: The "instance dict" format used with `DatasetMetadata` is much less 29 | efficient than TFXIO. For any serious workloads you should use TFXIO with a 30 | `tfxio.TensorAdapterConfig` instance as the metadata. Refer to 31 | [Get started with TF-Transform](https://www.tensorflow.org/tfx/transform/get_started#data_formats_and_schema) 32 | for more details. 33 | 34 | This is an in-memory representation that may be serialized and deserialized to 35 | and from a variety of disk representations. 36 | """ 37 | 38 | def __init__(self, schema: schema_pb2.Schema): 39 | self._schema = schema 40 | self._output_record_batches = True 41 | 42 | @classmethod 43 | def from_feature_spec( 44 | cls: Type[_DatasetMetadataType], 45 | feature_spec: Mapping[str, common_types.FeatureSpecType], 46 | domains: Optional[Mapping[str, common_types.DomainType]] = None 47 | ) -> _DatasetMetadataType: 48 | """Creates a DatasetMetadata from a TF feature spec dict.""" 49 | return cls(schema_utils.schema_from_feature_spec(feature_spec, domains)) 50 | 51 | @property 52 | def schema(self) -> schema_pb2.Schema: 53 | return self._schema 54 | 55 | def __eq__(self, other): 56 | if isinstance(other, self.__class__): 57 | return self.schema == other.schema 58 | return NotImplemented 59 | 60 | def __ne__(self, other): 61 | return not self == other 62 | 63 | def __repr__(self): 64 | return self.__dict__.__repr__() 65 | -------------------------------------------------------------------------------- /tensorflow_transform/tf_metadata/dataset_metadata_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dataset_metadata.""" 15 | 16 | from tensorflow_transform.tf_metadata import test_common 17 | from tensorflow_transform.tf_metadata import dataset_metadata 18 | import unittest 19 | 20 | 21 | class DatasetSchemaTest(unittest.TestCase): 22 | 23 | def test_sanity(self): 24 | metadata = dataset_metadata.DatasetMetadata.from_feature_spec( 25 | test_common.test_feature_spec) 26 | self.assertEqual(metadata.schema, test_common.get_test_schema()) 27 | 28 | 29 | if __name__ == '__main__': 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /tensorflow_transform/tf_metadata/metadata_io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities to read and write metadata in standardized versioned formats.""" 15 | 16 | import json 17 | import os 18 | 19 | 20 | import tensorflow as tf 21 | from tensorflow_transform.tf_metadata import dataset_metadata 22 | from tensorflow_transform.tf_metadata import schema_utils 23 | 24 | from google.protobuf import text_format 25 | from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import 26 | from tensorflow_metadata.proto.v0 import schema_pb2 27 | 28 | 29 | def read_metadata(path): 30 | """Load metadata in JSON format from a path into a new DatasetMetadata.""" 31 | schema_file = os.path.join(path, 'schema.pbtxt') 32 | legacy_schema_file = os.path.join(path, 'v1-json', 'schema.json') 33 | if file_io.file_exists(schema_file): 34 | text_proto = file_io.FileIO(schema_file, 'r').read() 35 | schema_proto = text_format.Parse(text_proto, schema_pb2.Schema(), 36 | allow_unknown_extension=True) 37 | elif file_io.file_exists(legacy_schema_file): 38 | schema_json = file_io.FileIO(legacy_schema_file, 'r').read() 39 | schema_proto = _parse_schema_json(schema_json) 40 | else: 41 | raise IOError( 42 | 'Schema file {} does not exist and neither did legacy format file ' 43 | '{}'.format(schema_file, legacy_schema_file)) 44 | return dataset_metadata.DatasetMetadata(schema_proto) 45 | 46 | 47 | def _parse_schema_json(schema_json): 48 | """Translate a JSON schema into a Schema proto.""" 49 | schema_dict = json.loads(schema_json) 50 | feature_spec = { 51 | feature_dict['name']: _column_schema_from_json(feature_dict) 52 | for feature_dict in schema_dict.get('feature', []) 53 | } 54 | domains = { 55 | feature_dict['name']: _domain_from_json(feature_dict['domain']) 56 | for feature_dict in schema_dict.get('feature', []) 57 | } 58 | return schema_utils.schema_from_feature_spec(feature_spec, domains) 59 | 60 | 61 | def _column_schema_from_json(feature_dict): 62 | """Translate a JSON feature dict into a feature spec.""" 63 | dtype = _dtype_from_json(feature_dict['domain']) 64 | tf_options = feature_dict['parsingOptions']['tfOptions'] 65 | if tf_options.get('fixedLenFeature') is not None: 66 | default_value = None 67 | try: 68 | # int() is needed because protobuf JSON encodes int64 as string 69 | default_value = _convert_scalar_or_list( 70 | int, tf_options['fixedLenFeature']['intDefaultValue']) 71 | except KeyError: 72 | try: 73 | default_value = tf_options['fixedLenFeature']['stringDefaultValue'] 74 | except KeyError: 75 | try: 76 | default_value = tf_options['fixedLenFeature']['floatDefaultValue'] 77 | except KeyError: 78 | pass 79 | axes = feature_dict['fixedShape'].get('axis', []) 80 | shape = [int(axis['size']) for axis in axes] 81 | return tf.io.FixedLenFeature(shape, dtype, default_value) 82 | elif tf_options.get('varLenFeature') is not None: 83 | return tf.io.VarLenFeature(dtype) 84 | else: 85 | raise ValueError('Could not interpret tfOptions: {}'.format(tf_options)) 86 | 87 | 88 | def _domain_from_json(domain): 89 | """Translate a JSON domain dict into an IntDomain or None.""" 90 | if domain.get('ints') is not None: 91 | def maybe_to_int(s): 92 | return int(s) if s is not None else None 93 | return schema_pb2.IntDomain( 94 | min=maybe_to_int(domain['ints'].get('min')), 95 | max=maybe_to_int(domain['ints'].get('max')), 96 | is_categorical=domain['ints'].get('isCategorical')) 97 | return None 98 | 99 | 100 | def _dtype_from_json(domain): 101 | """Translate a JSON domain dict into a tf.DType.""" 102 | if domain.get('ints') is not None: 103 | return tf.int64 104 | if domain.get('floats') is not None: 105 | return tf.float32 106 | if domain.get('strings') is not None: 107 | return tf.string 108 | raise ValueError('Unknown domain: {}'.format(domain)) 109 | 110 | 111 | def write_metadata(metadata, path): 112 | """Write metadata to given path, in JSON format. 113 | 114 | Args: 115 | metadata: A `DatasetMetadata` to write. 116 | path: a path to a directory where metadata should be written. 117 | """ 118 | if not file_io.file_exists(path): 119 | file_io.recursive_create_dir(path) 120 | schema_file = os.path.join(path, 'schema.pbtxt') 121 | ascii_proto = text_format.MessageToString(metadata.schema) 122 | file_io.atomic_write_string_to_file(schema_file, ascii_proto, overwrite=True) 123 | 124 | 125 | def _convert_scalar_or_list(fn, scalar_or_list): 126 | if isinstance(scalar_or_list, list): 127 | return list(map(fn, scalar_or_list)) 128 | else: 129 | return fn(scalar_or_list) 130 | -------------------------------------------------------------------------------- /tensorflow_transform/tf_metadata/metadata_io_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dataset_metadata. 15 | """ 16 | 17 | import os 18 | import tempfile 19 | 20 | from tensorflow_transform.tf_metadata import test_common 21 | from tensorflow_transform.tf_metadata import dataset_metadata 22 | from tensorflow_transform.tf_metadata import metadata_io 23 | import unittest 24 | 25 | from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import 26 | 27 | 28 | _SCHEMA_WITH_INVALID_KEYS = """ 29 | { 30 | "feature": [{ 31 | "name": "my_key", 32 | "fixedShape": { 33 | "axis": [] 34 | }, 35 | "type": "INT", 36 | "domain": { 37 | "ints": {} 38 | }, 39 | "parsingOptions": { 40 | "tfOptions": { 41 | "fixedLenFeature": {} 42 | } 43 | } 44 | }], 45 | "sparseFeature": [{ 46 | "name": "my_key", 47 | "indexFeature": [], 48 | "valueFeature": [{ 49 | "name": "value_key", 50 | "type": "INT", 51 | "domain": { 52 | "ints": {} 53 | } 54 | }] 55 | }] 56 | } 57 | """ 58 | 59 | 60 | class SchemaIOv1JsonTest(unittest.TestCase): 61 | 62 | def _write_schema_to_disk(self, basedir, schema_string): 63 | version_basedir = os.path.join(basedir, 'v1-json') 64 | 65 | # Write a proto by hand to disk 66 | file_io.recursive_create_dir(version_basedir) 67 | file_io.write_string_to_file(os.path.join(version_basedir, 'schema.json'), 68 | schema_string) 69 | 70 | def test_read_with_invalid_keys(self): 71 | # TODO(b/123241798): use TEST_TMPDIR 72 | basedir = tempfile.mkdtemp() 73 | self._write_schema_to_disk(basedir, _SCHEMA_WITH_INVALID_KEYS) 74 | 75 | def test_read_features_default_axis(self): 76 | # TODO(b/123241798): use TEST_TMPDIR 77 | basedir = tempfile.mkdtemp() 78 | schema_no_sparse_features = """ 79 | { 80 | "feature": [{ 81 | "name": "my_key", 82 | "fixedShape": {}, 83 | "type": "INT", 84 | "domain": { 85 | "ints": {} 86 | }, 87 | "parsingOptions": { 88 | "tfOptions": { 89 | "fixedLenFeature": {} 90 | } 91 | } 92 | }] 93 | } 94 | """ 95 | self._write_schema_to_disk(basedir, schema_no_sparse_features) 96 | _ = metadata_io.read_metadata(basedir) 97 | 98 | def test_read_features(self): 99 | # TODO(b/123241798): use TEST_TMPDIR 100 | basedir = tempfile.mkdtemp() 101 | schema_no_sparse_features = """ 102 | { 103 | "feature": [{ 104 | "name": "my_key", 105 | "fixedShape": { 106 | "axis": [{ 107 | "size": 2 108 | }] 109 | }, 110 | "type": "INT", 111 | "domain": { 112 | "ints": {} 113 | }, 114 | "parsingOptions": { 115 | "tfOptions": { 116 | "fixedLenFeature": {} 117 | } 118 | } 119 | }] 120 | } 121 | """ 122 | self._write_schema_to_disk(basedir, schema_no_sparse_features) 123 | _ = metadata_io.read_metadata(basedir) 124 | 125 | def test_write_and_read(self): 126 | # TODO(b/123241798): use TEST_TMPDIR 127 | basedir = tempfile.mkdtemp() 128 | original = dataset_metadata.DatasetMetadata( 129 | schema=test_common.get_test_schema()) 130 | 131 | metadata_io.write_metadata(original, basedir) 132 | reloaded = metadata_io.read_metadata(basedir) 133 | 134 | self.assertEqual(original, reloaded) 135 | 136 | 137 | if __name__ == '__main__': 138 | unittest.main() 139 | -------------------------------------------------------------------------------- /tensorflow_transform/tf_metadata/schema_utils_legacy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Stubs for handling legacy fields of the Schema proto.""" 15 | 16 | 17 | def should_set_generate_legacy_feature_spec(feature_spec): 18 | del feature_spec # unused 19 | return False 20 | 21 | 22 | def set_generate_legacy_feature_spec(schema_proto, value): 23 | del schema_proto # unused 24 | if value: 25 | raise NotImplementedError( 26 | 'The generate_legacy_feature_spec is a legacy field that is not part ' 27 | 'of the OSS tf.Transform codebase') 28 | -------------------------------------------------------------------------------- /tensorflow_transform/tf_metadata/schema_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_transform.tf_metadata.schema_utils.""" 15 | 16 | from absl.testing import parameterized 17 | from tensorflow_transform.tf_metadata import schema_utils_legacy 18 | from tensorflow_transform.tf_metadata import schema_utils_test_cases 19 | from tensorflow_transform.tf_metadata import schema_utils 20 | 21 | from google.protobuf import text_format 22 | import unittest 23 | from tensorflow_metadata.proto.v0 import schema_pb2 24 | 25 | 26 | class SchemaUtilsTest(parameterized.TestCase): 27 | 28 | @parameterized.named_parameters( 29 | *schema_utils_test_cases.EQUIVALENT_FEATURE_SPEC_AND_SCHEMAS) 30 | def test_schema_from_feature_spec( 31 | self, ascii_proto, feature_spec, domains=None, 32 | generate_legacy_feature_spec=False): 33 | expected_schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) 34 | schema_utils_legacy.set_generate_legacy_feature_spec( 35 | expected_schema_proto, generate_legacy_feature_spec) 36 | result = schema_utils.schema_from_feature_spec(feature_spec, domains) 37 | self.assertEqual(result, expected_schema_proto) 38 | 39 | @parameterized.named_parameters( 40 | *(schema_utils_test_cases.EQUIVALENT_FEATURE_SPEC_AND_SCHEMAS + 41 | schema_utils_test_cases.NON_ROUNDTRIP_SCHEMAS)) 42 | def test_schema_as_feature_spec( 43 | self, ascii_proto, feature_spec, domains=None, 44 | generate_legacy_feature_spec=False): 45 | schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) 46 | schema_utils_legacy.set_generate_legacy_feature_spec( 47 | schema_proto, generate_legacy_feature_spec) 48 | result = schema_utils.schema_as_feature_spec(schema_proto) 49 | self.assertEqual( 50 | result, 51 | schema_utils.SchemaAsFeatureSpecResult(feature_spec, domains or {}), 52 | ) 53 | 54 | @parameterized.named_parameters( 55 | *schema_utils_test_cases.INVALID_SCHEMA_PROTOS) 56 | def test_schema_as_feature_spec_fails( 57 | self, ascii_proto, error_msg, error_class=ValueError, 58 | generate_legacy_feature_spec=False): 59 | schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) 60 | schema_utils_legacy.set_generate_legacy_feature_spec( 61 | schema_proto, generate_legacy_feature_spec) 62 | with self.assertRaisesRegex(error_class, error_msg): 63 | schema_utils.schema_as_feature_spec(schema_proto) 64 | 65 | @parameterized.named_parameters( 66 | *schema_utils_test_cases.INVALID_FEATURE_SPECS) 67 | def test_schema_from_feature_spec_fails( 68 | self, feature_spec, error_msg, domain=None, error_class=ValueError): 69 | with self.assertRaisesRegex(error_class, error_msg): 70 | schema_utils.schema_from_feature_spec(feature_spec, domain) 71 | 72 | @parameterized.named_parameters( 73 | *schema_utils_test_cases.RAGGED_VALUE_FEATURES_AND_TENSOR_REPRESENTATIONS) 74 | def test_pop_ragged_source_columns(self, name, tensor_representation, 75 | feature_by_name, expected_value_feature, 76 | truncated_feature_by_name): 77 | value_feature = schema_utils.pop_ragged_source_columns( 78 | name, tensor_representation, feature_by_name) 79 | self.assertEqual(value_feature, expected_value_feature) 80 | self.assertEqual(feature_by_name, truncated_feature_by_name) 81 | 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /tensorflow_transform/tf_metadata/test_common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common data and utilities for tf_metadata tests.""" 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_transform.tf_metadata import schema_utils 19 | 20 | 21 | test_feature_spec = { 22 | # FixedLenFeatures 23 | 'fixed_categorical_int_with_range': 24 | tf.io.FixedLenFeature(shape=[], dtype=tf.int64), 25 | 'fixed_int': 26 | tf.io.FixedLenFeature(shape=[5], dtype=tf.int64), 27 | 'fixed_float': 28 | tf.io.FixedLenFeature(shape=[5], dtype=tf.float32), 29 | 'fixed_string': 30 | tf.io.FixedLenFeature(shape=[5], dtype=tf.string), 31 | 32 | # VarLenFeatures 33 | 'var_int': 34 | tf.io.VarLenFeature(dtype=tf.int64), 35 | 'var_float': 36 | tf.io.VarLenFeature(dtype=tf.float32), 37 | 'var_string': 38 | tf.io.VarLenFeature(dtype=tf.string), 39 | } 40 | 41 | 42 | def get_test_schema(): 43 | return schema_utils.schema_from_feature_spec(test_feature_spec) 44 | -------------------------------------------------------------------------------- /tensorflow_transform/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Contains the version string of TF.Transform.""" 15 | 16 | # Note that setup.py uses this version. 17 | __version__ = '1.17.0.dev' 18 | --------------------------------------------------------------------------------