├── .bazelignore ├── .bazelrc ├── .bazelversion ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── WORKSPACE ├── configure ├── BUILD ├── MANIFEST.in ├── requirements.in ├── requirements_lock_3_10.txt ├── requirements_lock_3_11.txt ├── requirements_lock_3_9.txt └── setup.py ├── documentation ├── BUILD ├── _book.yaml ├── _index.yaml ├── developer_manual.md ├── directory_structure.md ├── distributed_training.md ├── image │ ├── logo.png │ └── ydf_logo.png ├── installation.md ├── known_issues.md ├── migration.md ├── more_examples.md ├── simple_ml_for_sheets │ ├── image │ │ ├── simple_ml_owl_1.png │ │ └── simple_ml_predict_1.png │ └── index.md ├── tensorflow_serving.md ├── text_features.md ├── tf_df_in_tf_js.md └── tutorials │ ├── BUILD │ ├── advanced_colab.ipynb │ ├── automatic_tuning_colab.ipynb │ ├── beginner_colab.ipynb │ ├── dtreeviz_colab.ipynb │ ├── images │ ├── beginner_tensorboard.png │ └── model_composition_structure.png │ ├── index.md │ ├── intermediate_colab.ipynb │ ├── model_composition_colab.ipynb │ ├── predict_colab.ipynb │ ├── proximities_colab.ipynb │ ├── ranking_colab.ipynb │ └── uplift_colab.ipynb ├── examples ├── BUILD ├── distributed_hyperparameter_optimization.py ├── distributed_training.py ├── fake_distributed_training.sh ├── hyperparameter_optimization.py └── minimal.py ├── tensorflow_decision_forests ├── BUILD ├── __init__.py ├── component │ ├── BUILD │ ├── __init__.py │ ├── builder │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── builder.py │ │ └── builder_test.py │ ├── inspector │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── blob_sequence.py │ │ ├── blob_sequence_test.py │ │ ├── inspector.py │ │ └── inspector_test.py │ ├── model_plotter │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── model_plotter.py │ │ ├── model_plotter_test.py │ │ └── plotter.js │ ├── py_tree │ │ ├── BUILD │ │ ├── README.md │ │ ├── __init__.py │ │ ├── condition.py │ │ ├── condition_test.py │ │ ├── dataspec.py │ │ ├── dataspec_test.py │ │ ├── node.py │ │ ├── node_test.py │ │ ├── objective.py │ │ ├── objective_test.py │ │ ├── tree.py │ │ ├── tree_test.py │ │ ├── value.py │ │ └── value_test.py │ └── tuner │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── tuner.py │ │ └── tuner_test.py ├── contrib │ ├── BUILD │ ├── README.md │ ├── example_of_contrib │ │ ├── BUILD │ │ ├── README.md │ │ ├── __init__.py │ │ ├── example_of_contrib.py │ │ └── example_of_contrib_test.py │ ├── scikit_learn_model_converter │ │ ├── BUILD │ │ ├── README.md │ │ ├── __init__.py │ │ ├── scikit_learn_model_converter.py │ │ └── scikit_learn_model_converter_test.py │ └── training_preprocessing │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── training_preprocessing.py │ │ └── training_preprocessing_test.py ├── keras │ ├── BUILD │ ├── __init__.py │ ├── core.py │ ├── core_inference.py │ ├── keras_distributed_test.py │ ├── keras_internal.py │ ├── keras_internal_test.py │ ├── keras_test.py │ ├── keras_tuner_test.py │ ├── test_runner.py │ ├── wrapper │ │ ├── BUILD │ │ ├── wrapper.bzl │ │ ├── wrapper.cc │ │ ├── wrapper.h │ │ ├── wrapper_main.cc │ │ └── wrapper_test.cc │ └── wrappers_pre_generated.py ├── tensorflow │ ├── BUILD │ ├── __init__.py │ ├── cc_logging.py │ ├── check_version.py │ ├── check_version_test.py │ ├── core.py │ ├── core_inference.py │ ├── core_test.py │ ├── distribute │ │ ├── BUILD │ │ └── tensorflow_std_server.py │ ├── ops │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── inference │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── kernel.cc │ │ │ ├── op.cc │ │ │ ├── op.py │ │ │ ├── op_dynamic.py │ │ │ ├── test_utils.py │ │ │ ├── tf1_test.py │ │ │ └── tf2_test.py │ │ └── training │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── feature_on_file.cc │ │ │ ├── feature_on_file.h │ │ │ ├── features.h │ │ │ ├── kernel.cc │ │ │ ├── kernel.h │ │ │ ├── kernel_grpc_worker.cc │ │ │ ├── kernel_long_process.cc │ │ │ ├── kernel_on_file.cc │ │ │ ├── op.cc │ │ │ ├── op.py │ │ │ ├── op_dynamic.py │ │ │ ├── op_grpc_worker.cc │ │ │ └── op_test.py │ ├── tf1_compatibility.py │ ├── tf_logging.py │ └── utils.bzl ├── tensorflow_decision_forests.bzl ├── test_data │ ├── BUILD │ ├── README.md │ └── model │ │ └── saved_model_adult_rf │ │ ├── assets │ │ ├── data_spec.pb │ │ ├── done │ │ ├── header.pb │ │ ├── nodes-00000-of-00001 │ │ └── random_forest_header.pb │ │ ├── keras_metadata.pb │ │ ├── saved_model.pb │ │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index └── tools │ └── run_e2e_tfdf_test.sh ├── third_party ├── BUILD ├── absl │ ├── BUILD │ └── workspace.bzl ├── absl_py │ ├── BUILD │ └── workspace.bzl ├── benchmark │ ├── BUILD │ └── workspace.bzl ├── gtest │ ├── BUILD │ └── workspace.bzl ├── protobuf │ ├── BUILD │ └── workspace.bzl ├── tensorflow │ ├── BUILD │ ├── tf-216.patch │ ├── tf.patch │ └── workspace.bzl ├── tensorflow_pypi │ ├── BUILD │ ├── BUILD.tpl │ ├── README │ └── tf_configure.bzl └── yggdrasil_decision_forests │ ├── BUILD │ ├── workspace.bzl │ └── ydf.patch └── tools ├── BUILD ├── build_docs.py ├── build_pip_package.bat ├── build_pip_package.sh ├── start_compile_docker.sh ├── submit_pip_package.sh ├── test_bazel.bat ├── test_bazel.sh └── tf_serving ├── README.md ├── build_tf_serving_with_tf_df.sh └── build_tf_serving_workspace_extra.txt /.bazelignore: -------------------------------------------------------------------------------- 1 | tensorflow_decision_forests/contrib/temporal_feature_processor 2 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | # Bazel configuration for Yggdrasil Decision Forests 2 | # 3 | # OPTIONS 4 | # 5 | # Linux 6 | # 7 | # linux: Linux build. 8 | # 9 | # Windows 10 | # 11 | # windows: Windows build (unsupported). 12 | # 13 | # MacOS 14 | # 15 | # macos: MacOS build. 16 | # 17 | 18 | # Common flags. 19 | common --experimental_repo_remote_exec 20 | build --announce_rc 21 | build -c opt 22 | 23 | # Allow to use tensorflow internal methods. 24 | build --nocheck_visibility 25 | 26 | # Yggdrasil uses tensorflow for IO. 27 | build --define=use_tensorflow_io=1 28 | 29 | # If 1, use the pre-compiled wrappers. 30 | build --define=use_precompiled_wrappers=0 31 | 32 | # The user can stop the training with Ctrl+C. 33 | build --define=stop_training_on_interrupt=1 34 | 35 | # Yggdrasil uses the tensorflow headers specialized for dynamic linking. 36 | build --define=tensorflow_with_header_lib=1 37 | 38 | # Nice print 39 | build:linux --copt=-fdiagnostics-color=always 40 | build --color=yes 41 | 42 | # Suppress C++ compiler warnings. 43 | build:linux --copt=-w 44 | build:linux --host_copt=-w 45 | build:macos --copt=-w 46 | build:windows --copt=/W0 47 | 48 | # Common build options. 49 | build --define=framework_shared_object=true 50 | build --enable_platform_specific_config 51 | build --define=dynamic_loaded_kernels=true 52 | build --copt=-DAUTOLOAD_DYNAMIC_KERNELS 53 | build --spawn_strategy=standalone 54 | 55 | build --strategy=Genrule=standalone 56 | build --action_env TF_NEED_CUDA="0" 57 | 58 | # Host specifics options. 59 | build:macos --features=-supports_dynamic_linker 60 | # On Intel machines, AVX2 is available 61 | build:macos_intel --config=avx_linux 62 | # Absl synchronization primitives currently do not work on MacOS. 63 | # TODO 64 | build:macos --define std_synchronization_primitives=1 65 | 66 | build:windows --config=release_cpu_windows 67 | -------------------------------------------------------------------------------- /.bazelversion: -------------------------------------------------------------------------------- 1 | 6.5.0 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | Contributions for TensorFlow Decision Forests and its engine Yggdrasil Decision Forests are welcome. 4 | 5 | See [github.com/google/yggdrasil-decision-forests](https://github.com/google/yggdrasil-decision-forests) for Yggdrasil Decision Forests. 6 | 7 | ## Pull Request Checklist 8 | 9 | Before sending pull requests, make sure to review the 10 | [developer manual](documentation/developer_manual.md). 11 | 12 | ## Contributor License Agreement 13 | 14 | Contributions to this project must be accompanied by a Contributor License 15 | Agreement. You (or your employer) retain the copyright to your contribution, 16 | this simply gives us permission to use and redistribute your contributions as 17 | part of the project. Head over to to see 18 | your current agreements on file or to sign a new one. 19 | 20 | You generally only need to submit a CLA once, so if you've already submitted one 21 | (even if it was for a different project), you probably don't need to do it 22 | again. 23 | 24 | ## Code reviews 25 | 26 | All submissions, including submissions by project members, require review. We 27 | use GitHub pull requests for this purpose. Consult 28 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 29 | information on using pull requests. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | **TensorFlow Decision Forests** (**TF-DF**) is a library to train, run and 6 | interpret [decision forest](https://ydf.readthedocs.io/en/latest/intro_df.html) 7 | models (e.g., Random Forests, Gradient Boosted Trees) in TensorFlow. TF-DF 8 | supports classification, regression and ranking. 9 | 10 | **TF-DF** is powered by 11 | [Yggdrasil Decision Forest](https://github.com/google/yggdrasil-decision-forests) 12 | (**YDF**, a library to train and use decision forests in C++, JavaScript, CLI, 13 | and Go. TF-DF models are 14 | [compatible](https://ydf.readthedocs.io/en/latest/convert_model.html#convert-a-a-tensorflow-decision-forests-model-to-a-yggdrasil-model) 15 | with YDF' models, and vice versa. 16 | 17 | Tensorflow Decision Forests is available on Linux and Mac. Windows users can use 18 | the library through WSL+Linux. 19 | 20 | ## Usage example 21 | 22 | A minimal end-to-end run looks as follows: 23 | 24 | ```python 25 | import tensorflow_decision_forests as tfdf 26 | import pandas as pd 27 | 28 | # Load the dataset in a Pandas dataframe. 29 | train_df = pd.read_csv("project/train.csv") 30 | test_df = pd.read_csv("project/test.csv") 31 | 32 | # Convert the dataset into a TensorFlow dataset. 33 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label") 34 | test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="my_label") 35 | 36 | # Train the model 37 | model = tfdf.keras.RandomForestModel() 38 | model.fit(train_ds) 39 | 40 | # Look at the model. 41 | model.summary() 42 | 43 | # Evaluate the model. 44 | model.evaluate(test_ds) 45 | 46 | # Export to a TensorFlow SavedModel. 47 | # Note: the model is compatible with Yggdrasil Decision Forests. 48 | model.save("project/model") 49 | ``` 50 | 51 | ## Google I/O Presentation 52 | 53 |
54 | 55 | 56 | 57 |
58 | 59 | ## Documentation & Resources 60 | 61 | The following resources are available: 62 | 63 | - [TF-DF on TensorFlow.org](https://tensorflow.org/decision_forests) (API 64 | Reference, Guides and Tutorials) 65 | - [Tutorials](https://www.tensorflow.org/decision_forests/tutorials) (on 66 | tensorflow.org) 67 | - [YDF documentation](https://ydf.readthedocs.io) (also applicable to TF-DF) 68 | - [Issue tracker](https://github.com/tensorflow/decision-forests/issues) 69 | - [Known issues](documentation/known_issues.md) 70 | - [Changelog](CHANGELOG.md) 71 | - [More examples](documentation/more_examples.md) 72 | 73 | ## Installation 74 | 75 | To install TensorFlow Decision Forests, run: 76 | 77 | ```shell 78 | pip3 install tensorflow_decision_forests --upgrade 79 | ``` 80 | 81 | See the [installation](documentation/installation.md) page for more details, 82 | troubleshooting and alternative installation solutions. 83 | 84 | ## Contributing 85 | 86 | Contributions to TensorFlow Decision Forests and Yggdrasil Decision Forests are 87 | welcome. If you want to contribute, make sure to review the 88 | [developer manual](documentation/developer_manual.md) and 89 | [contribution guidelines](CONTRIBUTING.md). 90 | 91 | ## Citation 92 | 93 | If you us Tensorflow Decision Forests in a scientific publication, please cite 94 | the following paper: 95 | [Yggdrasil Decision Forests: A Fast and Extensible Decision Forests Library](https://doi.org/10.1145/3580305.3599933). 96 | 97 | **Bibtex** 98 | 99 | ``` 100 | @inproceedings{GBBSP23, 101 | author = {Mathieu Guillame{-}Bert and 102 | Sebastian Bruch and 103 | Richard Stotz and 104 | Jan Pfeifer}, 105 | title = {Yggdrasil Decision Forests: {A} Fast and Extensible Decision Forests 106 | Library}, 107 | booktitle = {Proceedings of the 29th {ACM} {SIGKDD} Conference on Knowledge Discovery 108 | and Data Mining, {KDD} 2023, Long Beach, CA, USA, August 6-10, 2023}, 109 | pages = {4068--4077}, 110 | year = {2023}, 111 | url = {https://doi.org/10.1145/3580305.3599933}, 112 | doi = {10.1145/3580305.3599933}, 113 | } 114 | ``` 115 | 116 | **Raw** 117 | 118 | Yggdrasil Decision Forests: A Fast and Extensible Decision Forests Library, 119 | Guillame-Bert et al., KDD 2023: 4068-4077. doi:10.1145/3580305.3599933 120 | 121 | ## Contact 122 | 123 | You can contact the core development team at 124 | [decision-forests-contact@google.com](mailto:decision-forests-contact@google.com). 125 | 126 | ## Credits 127 | 128 | TensorFlow Decision Forests was developed by: 129 | 130 | - Mathieu Guillame-Bert (gbm AT google DOT com) 131 | - Jan Pfeifer (janpf AT google DOT com) 132 | - Richard Stotz (richardstotz AT google DOT com) 133 | - Sebastian Bruch (sebastian AT bruch DOT io) 134 | - Arvind Srinivasan (arvnd AT google DOT com) 135 | 136 | ## License 137 | 138 | [Apache License 2.0](LICENSE) 139 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "org_tensorflow_decision_forests") 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | # rules_java is required for Tensorflow. 6 | http_archive( 7 | name = "rules_java", 8 | sha256 = "c73336802d0b4882e40770666ad055212df4ea62cfa6edf9cb0f9d29828a0934", 9 | url = "https://github.com/bazelbuild/rules_java/releases/download/5.3.5/rules_java-5.3.5.tar.gz", 10 | ) 11 | 12 | # ========================================== 13 | # Start of TensorFlow and its dependencies 14 | # ========================================== 15 | 16 | # This version of TensorFlow is injected only to make sure we use the same dependencies as TensorFlow (protobuffer, grpc, absl). 17 | # TensorFlow is not compiled. 18 | 19 | # Note: The OPs dynamic library depends on symbols specific to the version of 20 | # absl used by tensorflow. 21 | http_archive( 22 | name = "org_tensorflow", 23 | sha256 = "4691b18e8c914cdf6759b80f1b3b7f3e17be41099607ed0143134f38836d058e", 24 | strip_prefix = "tensorflow-2.19.0", 25 | urls = ["https://github.com/tensorflow/tensorflow/archive/v2.19.0.tar.gz"], 26 | ) 27 | 28 | load("//tensorflow_decision_forests:tensorflow_decision_forests.bzl", "py_deps_profile") 29 | 30 | py_deps_profile( 31 | name = "release_or_nightly", 32 | deps_map = { 33 | "tensorflow": [ 34 | "tf-nightly", 35 | "tf_header_lib", 36 | "libtensorflow_framework", 37 | ], 38 | "tf-keras": ["tf-keras-nightly"], 39 | }, 40 | pip_repo_name = "pypi", 41 | requirements_in = "//configure:requirements.in", 42 | switch = { 43 | "IS_NIGHTLY": "nightly", 44 | }, 45 | ) 46 | 47 | # Initialize hermetic Python 48 | load("@org_tensorflow//third_party/py:python_init_rules.bzl", "python_init_rules") 49 | 50 | python_init_rules() 51 | 52 | load("@org_tensorflow//third_party/py:python_init_repositories.bzl", "python_init_repositories") 53 | 54 | python_init_repositories( 55 | default_python_version = "system", 56 | requirements = { 57 | "3.9": "//configure:requirements_lock_3_9.txt", 58 | "3.10": "//configure:requirements_lock_3_10.txt", 59 | "3.11": "//configure:requirements_lock_3_11.txt", 60 | }, 61 | ) 62 | 63 | load("@org_tensorflow//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") 64 | 65 | python_init_toolchains() 66 | 67 | load("//third_party/tensorflow_pypi:tf_configure.bzl", "tf_configure") 68 | 69 | tf_configure() 70 | 71 | load("@pypi//:requirements.bzl", "install_deps") 72 | 73 | install_deps() 74 | 75 | # Inject tensorflow dependencies. 76 | # TensorFlow cannot anymore be injected from a sub-module. 77 | # Note: The order is important. 78 | load("@org_tensorflow//tensorflow:workspace3.bzl", tf1 = "workspace") 79 | 80 | tf1() 81 | 82 | load("@org_tensorflow//tensorflow:workspace2.bzl", tf2 = "workspace") 83 | 84 | tf2() 85 | 86 | load("@org_tensorflow//tensorflow:workspace1.bzl", tf3 = "workspace") 87 | 88 | tf3() 89 | 90 | load("@org_tensorflow//tensorflow:workspace0.bzl", tf4 = "workspace") 91 | 92 | tf4() 93 | 94 | load( 95 | "@local_tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", 96 | "cuda_json_init_repository", 97 | ) 98 | 99 | cuda_json_init_repository() 100 | 101 | load( 102 | "@cuda_redist_json//:distributions.bzl", 103 | "CUDA_REDISTRIBUTIONS", 104 | "CUDNN_REDISTRIBUTIONS", 105 | ) 106 | load( 107 | "@local_tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", 108 | "cuda_redist_init_repositories", 109 | "cudnn_redist_init_repository", 110 | ) 111 | 112 | cuda_redist_init_repositories( 113 | cuda_redistributions = CUDA_REDISTRIBUTIONS, 114 | ) 115 | 116 | cudnn_redist_init_repository( 117 | cudnn_redistributions = CUDNN_REDISTRIBUTIONS, 118 | ) 119 | 120 | load( 121 | "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", 122 | "cuda_configure", 123 | ) 124 | 125 | cuda_configure(name = "local_config_cuda") 126 | 127 | load( 128 | "@local_tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", 129 | "nccl_redist_init_repository", 130 | ) 131 | 132 | nccl_redist_init_repository() 133 | 134 | load( 135 | "@local_tsl//third_party/nccl/hermetic:nccl_configure.bzl", 136 | "nccl_configure", 137 | ) 138 | 139 | nccl_configure(name = "local_config_nccl") 140 | 141 | # ======================================== 142 | # End of TensorFlow and its dependencies 143 | # ======================================== 144 | 145 | # Third party libraries 146 | load("//third_party/absl:workspace.bzl", absl = "deps") 147 | load("//third_party/absl_py:workspace.bzl", absl_py = "deps") 148 | load("//third_party/benchmark:workspace.bzl", benchmark = "deps") 149 | load("//third_party/gtest:workspace.bzl", gtest = "deps") 150 | load("//third_party/protobuf:workspace.bzl", protobuf = "deps") 151 | 152 | absl() 153 | 154 | absl_py() 155 | 156 | benchmark() 157 | 158 | gtest() 159 | 160 | protobuf() 161 | 162 | # Yggdrasil Decision Forests 163 | load("//third_party/yggdrasil_decision_forests:workspace.bzl", yggdrasil_decision_forests = "deps") 164 | 165 | yggdrasil_decision_forests() 166 | 167 | load("@ydf//yggdrasil_decision_forests:library.bzl", ydf_load_deps = "load_dependencies") 168 | 169 | ydf_load_deps( 170 | exclude_repo = [ 171 | "absl", 172 | "protobuf", 173 | "zlib", 174 | "farmhash", 175 | "grpc", 176 | "eigen", 177 | "pybind11", 178 | "pybind11_abseil", 179 | "pybind11_protobuf", 180 | "tensorflow", 181 | ], 182 | repo_name = "@ydf", 183 | ) 184 | -------------------------------------------------------------------------------- /configure/BUILD: -------------------------------------------------------------------------------- 1 | load("@python//:defs.bzl", "compile_pip_requirements") 2 | load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") 3 | 4 | licenses(["notice"]) 5 | 6 | compile_pip_requirements( 7 | name = "requirements", 8 | extra_args = [ 9 | "--allow-unsafe", 10 | "--build-isolation", 11 | "--rebuild", 12 | ], 13 | generate_hashes = True, 14 | requirements_in = "//configure:requirements.in", 15 | requirements_txt = REQUIREMENTS, 16 | ) 17 | 18 | py_binary( 19 | name = "tensorflow_build_info", 20 | srcs = ["tensorflow_build_info.py"], 21 | deps = ["@pypi_tf_nightly//:pkg"], 22 | ) 23 | -------------------------------------------------------------------------------- /configure/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README 3 | include tensorflow_decision_forests/component/model_plotter/plotter.js 4 | recursive-include * *.so 5 | recursive-include * *.so.[0-9] 6 | recursive-include * *.dylib 7 | recursive-include * *.dll 8 | recursive-include * grpc_worker_main 9 | -------------------------------------------------------------------------------- /configure/requirements.in: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | tensorflow==2.19.0 4 | six 5 | absl_py 6 | wheel 7 | wurlitzer 8 | tf_keras 9 | ydf>=0.11.0 10 | scikit-learn -------------------------------------------------------------------------------- /configure/requirements_lock_3_10.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/configure/requirements_lock_3_10.txt -------------------------------------------------------------------------------- /configure/requirements_lock_3_11.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/configure/requirements_lock_3_11.txt -------------------------------------------------------------------------------- /configure/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Setup file for pip's build. 16 | 17 | This file is used by tools/build_pip_package.sh. 18 | """ 19 | 20 | import platform 21 | import sys 22 | import setuptools 23 | from setuptools.command.install import install 24 | from setuptools.dist import Distribution 25 | 26 | _VERSION = "1.12.0" 27 | 28 | with open("README.md", "r", encoding="utf-8") as fh: 29 | long_description = fh.read() 30 | 31 | REQUIRED_PACKAGES = [ 32 | "numpy", 33 | "pandas", 34 | "tensorflow==2.19.0", 35 | "six", 36 | "absl_py", 37 | "wheel", 38 | "wurlitzer", 39 | "tf_keras~=2.19", 40 | "ydf>=0.11.0", 41 | ] 42 | 43 | 44 | class InstallPlatlib(install): 45 | 46 | def finalize_options(self): 47 | install.finalize_options(self) 48 | if self.distribution.has_ext_modules(): 49 | self.install_lib = self.install_platlib 50 | 51 | 52 | class BinaryDistribution(Distribution): 53 | 54 | def has_ext_modules(self): 55 | return True 56 | 57 | def is_pure(self): 58 | return False 59 | 60 | 61 | if "bdist_wheel" in sys.argv: 62 | if "--plat-name" not in sys.argv: 63 | if platform.system() == "Darwin": 64 | idx = sys.argv.index("bdist_wheel") + 1 65 | sys.argv.insert(idx, "--plat-name") 66 | if platform.processor() == "arm": 67 | sys.argv.insert(idx + 1, "macosx_12_0_arm64") 68 | elif platform.processor() == "i386": 69 | sys.argv.insert(idx + 1, "macosx_10_15_x86_64") 70 | else: 71 | raise ValueError(f"Unknown processor {platform.processor()}") 72 | else: 73 | print("Not on MacOS") 74 | else: 75 | print("PLAT-NAME Supplied") 76 | else: 77 | print("NO BDIST_WHEEL") 78 | 79 | setuptools.setup( 80 | cmdclass={ 81 | "install": InstallPlatlib, 82 | }, 83 | name="tensorflow_decision_forests", 84 | version=_VERSION, 85 | author="Google Inc.", 86 | author_email="decision-forests-contact@google.com", 87 | description=( 88 | "Collection of training and inference decision forest algorithms." 89 | ), 90 | long_description=long_description, 91 | long_description_content_type="text/markdown", 92 | url="https://github.com/tensorflow/decision-forests", 93 | project_urls={ 94 | "Bug Tracker": "https://github.com/tensorflow/decision-forests/issues", 95 | }, 96 | classifiers=[ 97 | "Intended Audience :: Developers", 98 | "Intended Audience :: Education", 99 | "Intended Audience :: Science/Research", 100 | "License :: OSI Approved :: Apache Software License", 101 | "Programming Language :: Python :: 3", 102 | "Programming Language :: Python :: 3.9", 103 | "Programming Language :: Python :: 3.10", 104 | "Programming Language :: Python :: 3.11", 105 | "Programming Language :: Python :: 3.12", 106 | "Programming Language :: Python :: 3 :: Only", 107 | "Topic :: Scientific/Engineering", 108 | "Topic :: Scientific/Engineering :: Mathematics", 109 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 110 | "Topic :: Software Development", 111 | "Topic :: Software Development :: Libraries", 112 | "Topic :: Software Development :: Libraries :: Python Modules", 113 | ], 114 | distclass=BinaryDistribution, 115 | packages=setuptools.find_packages(), 116 | python_requires=">=3.9", 117 | license="Apache 2.0", 118 | keywords=( 119 | "tensorflow tensor machine learning decision forests random forest" 120 | " gradient boosted decision trees" 121 | ), 122 | install_requires=REQUIRED_PACKAGES, 123 | include_package_data=True, 124 | zip_safe=False, 125 | ) 126 | -------------------------------------------------------------------------------- /documentation/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | -------------------------------------------------------------------------------- /documentation/_book.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | upper_tabs: 16 | # Tabs left of dropdown menu 17 | - include: /_upper_tabs_left.yaml 18 | - include: /api_docs/_upper_tabs_api.yaml 19 | # Dropdown menu 20 | - name: Resources 21 | path: /resources 22 | is_default: true 23 | menu: 24 | - include: /resources/_menu_toc.yaml 25 | lower_tabs: 26 | # Subsite tabs 27 | other: 28 | - name: Guides 29 | contents: 30 | - title: Getting started 31 | path: /decision_forests/tutorials/beginner_colab 32 | - title: Migrating to YDF 33 | path: https://ydf.readthedocs.io/en/latest/tutorial/migrating_to_ydf/ 34 | status: external 35 | - title: Hyper-parameters 36 | path: https://ydf.readthedocs.io/en/latest/hyperparameters/ 37 | status: external 38 | - title: Known issues 39 | path: /decision_forests/known_issues 40 | 41 | - heading: Guide 42 | - title: Migrate from neural networks 43 | path: /decision_forests/migration 44 | - title: Learning to rank 45 | path: /decision_forests/tutorials/ranking_colab 46 | - title: Inspect and debug 47 | path: /decision_forests/tutorials/advanced_colab 48 | - title: Compose models 49 | path: /decision_forests/tutorials/model_composition_colab 50 | - title: Automatic tuning 51 | path: /decision_forests/tutorials/automatic_tuning_colab 52 | - title: Proximities and Prototypes 53 | path: /decision_forests/tutorials/proximities_colab 54 | - title: Uplift modeling 55 | path: /decision_forests/tutorials/uplift_colab 56 | - title: TensorFlow Serving 57 | path: /decision_forests/tensorflow_serving 58 | - title: Text and pre-trained embeddings 59 | path: /decision_forests/tutorials/intermediate_colab 60 | - title: Distributed Training 61 | path: /decision_forests/distributed_training 62 | - title: Making predictions 63 | path: /decision_forests/tutorials/predict_colab 64 | - title: Introduction 65 | path: /decision_forests/simple_ml_for_sheets/ 66 | 67 | - heading: Experimental 68 | - title: Run model in JavaScript 69 | path: https://ydf.readthedocs.io/en/latest/js_serving.html 70 | status: external 71 | - title: Run model with TensorFlow.js 72 | path: /decision_forests/tf_df_in_tf_js 73 | status: external 74 | 75 | - name: API Reference 76 | skip_translation: true 77 | contents: 78 | - title: All Symbols 79 | path: /decision_forests/api_docs/python/tfdf/all_symbols 80 | - include: /decision_forests/api_docs/python/tfdf/_toc.yaml 81 | 82 | - include: /_upper_tabs_right.yaml 83 | -------------------------------------------------------------------------------- /documentation/developer_manual.md: -------------------------------------------------------------------------------- 1 | # Developer Manual 2 | 3 | Internally, TD-DF relies on 4 | [Yggdrasil Decision Forests](https://github.com/google/yggdrasil-decision-forests) 5 | (YDF). Depending on the change, reading YDF's user and developer manual might be 6 | beneficial. 7 | 8 | The library's dependency structure is organized in layers: 9 | 10 | 1. Keras 11 | 2. TensorFlow 12 | 3. Python utility 13 | 4. Yggdrasil 14 | 15 | New logic should be implemented where relevant. When several layers are 16 | possibly relevant, the most generic layer should be favored. 17 | 18 | The directory structure of [TF-DF](directory_structure.md) and 19 | [YDF](https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/directory_structure.md) 20 | is a good start. 21 | -------------------------------------------------------------------------------- /documentation/directory_structure.md: -------------------------------------------------------------------------------- 1 | # Directory Structure 2 | 3 | The project is organised as follows: 4 | 5 | ``` 6 | ├── configure: Project configuration. 7 | ├── documentation: User and developer documentation. Contains the colabs. 8 | ├── examples: Collection of usage examples. 9 | ├── tensorflow_decision_forests: The library 10 | │ ├── component: Utilities. 11 | │ │ ├── builder: Create models "by hand". 12 | │ │ ├── inspector: Inspection of structure and meta-data of models. 13 | │ │ ├── model_plotter: Plotting of model tree structure. 14 | │ │ ├── inspector: Inspection of structure and meta-data of models. 15 | │ │ ├── py_tree: Representation of a decision tree as a python object. 16 | │ │ └── tuner: TF-DF's own hyper-parameter tuner. 17 | │ ├── contrib: Additional functionality outside the project's main scope. 18 | │ ├── keras: Keras logic. Depends on tensorflow logic. 19 | │ │ └── wrapper: Python code generator for Keras models. 20 | │ │── tensorflow: TensorFlow logic. 21 | │ │ └── ops: Custom C++ ops. 22 | │ │ ├── inference: ... for inference. 23 | │ │ └── training: ... for training. 24 | │ └── test_data: Datasets for unit tests and benchmarks. 25 | ├── third_party: Bazel configuration for dependencies. 26 | └── tools: Tools for the management of the project and code. 27 | ``` 28 | -------------------------------------------------------------------------------- /documentation/image/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/image/logo.png -------------------------------------------------------------------------------- /documentation/image/ydf_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/image/ydf_logo.png -------------------------------------------------------------------------------- /documentation/known_issues.md: -------------------------------------------------------------------------------- 1 | # Known Issues 2 | 3 | ## Prefer YDF for new projects 4 | 5 | [YDF](https://github.com/google/yggdrasil-decision-forests) is Google's new 6 | library to train Decision Forests. 7 | 8 | YDF extends the power of TF-DF, offering new features, a simplified API, faster 9 | training times, updated documentation, and enhanced compatibility with popular 10 | ML libraries. 11 | 12 | Some of the issues mentioned below are fixed in YDF. 13 | 14 | ## Windows Pip package is not available 15 | 16 | TensorFlow Decision Forests is not yet available as a Windows Pip package. 17 | 18 | **Workarounds:** 19 | 20 | - *Solution #1:* Install 21 | [Windows Subsystem for Linux (WSL)](https://en.wikipedia.org/wiki/Windows_Subsystem_for_Linux) 22 | on your Windows machine and follow the Linux instructions. 23 | 24 | ## Incompatibility with Keras 3 25 | 26 | Compatibility with Keras 3 is not yet implemented. Use tf_keras or a TensorFlow 27 | version before 2.16. Alternatively, use [ydf](https://pypi.org/project/ydf/). 28 | 29 | ## Untested for conda 30 | 31 | While TF-DF might work with Conda, this is not tested and we currently do not 32 | maintain packages on conda-forge. 33 | 34 | ## Incompatibility with old or nightly versions of TensorFlow 35 | 36 | TensorFlow's [ABI](https://en.wikipedia.org/wiki/Application_binary_interface) 37 | is not compatible in between releases. Because TF-DF relies on custom TensorFlow 38 | C++ ops, each version of TF-DF is tied to a specific version of TensorFlow. The 39 | last released version of TF-DF is always tied to the last released version of 40 | TensorFlow. 41 | 42 | For these reasons, the current version of TF-DF might not be compatible with 43 | older versions or with the nightly build of TensorFlow. 44 | 45 | If using incompatible versions of TF and TF-DF, you will see cryptic errors such 46 | as: 47 | 48 | ``` 49 | tensorflow_decision_forests/tensorflow/ops/training/training.so: undefined symbol: _ZN10tensorflow11GetNodeAttrERKNS_9AttrSliceEN4absl14lts_2020_09_2311string_viewEPSs 50 | ``` 51 | 52 | - Use the version of TF-DF that is compatible with your version of TensorFlow. 53 | 54 | ### Compatibility table 55 | 56 | The following table shows the compatibility between 57 | `tensorflow_decision_forests` and its dependencies: 58 | 59 | tensorflow_decision_forests | tensorflow 60 | --------------------------- | --------------- 61 | 1.12.0 | 2.19.0 62 | 1.11.0 | 2.18.0 63 | 1.10.0 | 2.17.0 64 | 1.9.2 | 2.16.2 65 | 1.9.1 | 2.16.1 66 | 1.9.0 | 2.16.1 67 | 1.8.0 - 1.8.1 | 2.15.0 68 | 1.6.0 - 1.7.0 | 2.14.0 69 | 1.5.0 | 2.13.0 70 | 1.3.0 - 1.4.0 | 2.12.0 71 | 1.1.0 - 1.2.0 | 2.11.0 72 | 1.0.0 - 1.0.1 | 2.10.0 - 2.10.1 73 | 0.2.6 - 0.2.7 | 2.9.1 74 | 0.2.5 | 2.9 75 | 0.2.4 | 2.8 76 | 0.2.1 - 0.2.3 | 2.7 77 | 0.1.9 - 0.2.0 | 2.6 78 | 0.1.1 - 0.1.8 | 2.5 79 | 0.1.0 | 2.4 80 | 81 | - *Solution #2:* Wrap your preprocessing function into another function that 82 | [squeezes](https://www.tensorflow.org/api_docs/python/tf/squeeze) its 83 | inputs. 84 | 85 | ## Not all models support distributed training and distribute strategies 86 | 87 | Unless specified, models are trained on a single machine and are not compatible 88 | with distribution strategies. For example the `GradientBoostedTreesModel` does 89 | not support distributed training while `DistributedGradientBoostedTreesModel` 90 | does. 91 | 92 | **Workarounds:** 93 | 94 | - Use a model that supports distribution strategies (e.g. 95 | `DistributedGradientBoostedTreesModel`), or downsample your dataset so that 96 | it fits on a single machine. 97 | 98 | ## No support for GPU / TPU. 99 | 100 | TF-DF does not supports GPU or TPU training. Compiling with AVX instructions, 101 | however, may speed up serving. 102 | 103 | ## No support for [model_to_estimator](https://www.tensorflow.org/api_docs/python/tf/keras/estimator/model_to_estimator) 104 | 105 | TF-DF does not implement the APIs required to convert a trained/untrained model 106 | to the estimator format. 107 | 108 | ## Loaded models behave differently than Python models. 109 | 110 | While abstracted by the Keras API, a model instantiated in Python (e.g., with 111 | `tfdf.keras.RandomForestModel()`) and a model loaded from disk (e.g., with 112 | `tf_keras.models.load_model()`) can behave differently. Notably, a Python 113 | instantiated model automatically applies necessary type conversions. For 114 | example, if a `float64` feature is fed to a model expecting a `float32` feature, 115 | this conversion is performed implicitly. However, such a conversion is not 116 | possible for models loaded from disk. It is therefore important that the 117 | training data and the inference data always have the exact same type. 118 | 119 | ## Tensorflow feature name sanitization 120 | 121 | Tensorflow sanitizes feature names and might, for instance, convert them to 122 | lowercase. 123 | -------------------------------------------------------------------------------- /documentation/more_examples.md: -------------------------------------------------------------------------------- 1 | # More examples 2 | 3 | This page lists examples created by TF-DF users. 4 | 5 | - [Minimal example](https://github.com/tensorflow/decision-forests/blob/main/examples/minimal.py) 6 | (on github.org) 7 | - [Running TF-DF models on Arduino](https://github.com/achoum/ardwino-tensorflow-decision-forests) 8 | (on github.com) 9 | - [Classification with TensorFlow Decision Forests](https://keras.io/examples/structured_data/classification_with_tfdf/) 10 | (on keras.io) 11 | -------------------------------------------------------------------------------- /documentation/simple_ml_for_sheets/image/simple_ml_owl_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/simple_ml_for_sheets/image/simple_ml_owl_1.png -------------------------------------------------------------------------------- /documentation/simple_ml_for_sheets/image/simple_ml_predict_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/simple_ml_for_sheets/image/simple_ml_predict_1.png -------------------------------------------------------------------------------- /documentation/tensorflow_serving.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Decision Forests and TensorFlow Serving 2 | 3 | [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) (TF Serving) 4 | is a tool to run TensorFlow models online in large production settings using a 5 | RPC or REST API. TensorFlow Decision Forests (TF-DF) is supported natively by TF 6 | Serving >=2.11. 7 | 8 | TF-DF models are directly compatible with TF Serving. Yggdrasil models can be 9 | used with TF Serving after being 10 | [converted](https://ydf.readthedocs.io/en/latest/convert_model.html#convert-a-yggdrasil-model-to-a-tensorflow-decision-forests-model) 11 | first. 12 | 13 | ## Limitations 14 | 15 | TensorFlow adds a significant amount of computation overhead. For small, latency 16 | sensitive models (e.g., model inference time ~1µs), this overhead can be 17 | an order of magnitude larger than time needed by the model itself. 18 | In this case, it is recommended to run the TF-DF models with 19 | [Yggdrasil Decision Forests](https://ydf.readthedocs.io). 20 | 21 | ## Usage example 22 | 23 | The following example shows how to run a TF-DF model in TF Serving: 24 | 25 | First, [install TF Serving](https://github.com/tensorflow/serving#set-up). In 26 | this example, we will use a pre-compiled version of TF-Serving + TF-DF. 27 | 28 | ```shell 29 | # Download TF Serving 30 | wget https://github.com/tensorflow/decision-forests/releases/download/serving-1.0.1/tensorflow_model_server_linux.zip 31 | unzip tensorflow_model_server_linux.zip 32 | 33 | # Check that TF Serving works. 34 | ./tensorflow_model_server --version 35 | ``` 36 | 37 | In this example, we use an already trained TF-DF model. 38 | 39 | ```shell 40 | # Get a TF-DF model 41 | git clone https://github.com/tensorflow/decision-forests.git 42 | MODEL_PATH=$(pwd)/decision-forests/tensorflow_decision_forests/test_data/model/saved_model_adult_rf 43 | 44 | echo "The TF-DF model is available at: ${MODEL_PATH}" 45 | ``` 46 | 47 | **Notes:** TF-Serving requires the model's full path. This is why we use 48 | `$(pwd)`. 49 | 50 | TF-Serving supports model versioning. The model should be contained in a 51 | directory whose name is the version of the model. A model version is an integer 52 | e.g., "1". Here is a typical directory for TF-Serving. 53 | 54 | - `/path/to/model` 55 | - `1` : Version 1 of the model 56 | - `5` : Version 5 of the model 57 | - `6` : Version 6 of the model 58 | 59 | For this example, we only need to put the model in a directory called "1". 60 | 61 | ```shell 62 | mkdir -p /tmp/tf_serving_model 63 | cp -R "${MODEL_PATH}" /tmp/tf_serving_model/1 64 | ``` 65 | 66 | Now, we can start TF-Sering on the model. 67 | 68 | ```shell 69 | ./tensorflow_model_server \ 70 | --rest_api_port=8502 \ 71 | --model_name=my_model \ 72 | --model_base_path=/tmp/tf_serving_model 73 | ``` 74 | 75 | Finally, you can send a request to TF Serving using the Rest API. Two formats 76 | are available: predict+instances API and predict+inputs API. Here is an example 77 | of each of them: 78 | 79 | ```shell 80 | # Predictions with the predict+instances API. 81 | curl http://localhost:8502/v1/models/my_model:predict -X POST \ 82 | -d '{"instances": [{"age":39,"workclass":"State-gov","fnlwgt":77516,"education":"Bachelors","education_num":13,"marital_status":"Never-married","occupation":"Adm-clerical","relationship":"Not-in-family","race":"White","sex":"Male","capital_gain":2174,"capital_loss":0,"hours_per_week":40,"native_country":"United-States"}]}' 83 | ``` 84 | 85 | ```shell 86 | # Predictions with the predict+inputs API 87 | curl http://localhost:8502/v1/models/my_model:predict -X POST \ 88 | -d '{"inputs": {"age":[39],"workclass":["State-gov"],"fnlwgt":[77516],"education":["Bachelors"],"education_num":[13],"marital_status":["Never-married"],"occupation":["Adm-clerical"],"relationship":["Not-in-family"],"race":["White"],"sex":["Male"],"capital_gain":[2174],"capital_loss":[0],"hours_per_week":[40],"native_country":["United-States"]}}' 89 | ``` 90 | -------------------------------------------------------------------------------- /documentation/tf_df_in_tf_js.md: -------------------------------------------------------------------------------- 1 | # Running TensorFlow Decision Forests models with TensorFlow.js 2 | 3 | These instructions explain how to train a TF-DF model and run it on the 4 | web using TensorFlow.js. 5 | 6 | ## Detailed instructions 7 | 8 | ### Train a model in TF-DF 9 | 10 | To try out this tutorial, you first need a TF-DF model. You can use your own 11 | model or train a model with the 12 | [Beginner's tutorial](https://www.tensorflow.org/decision_forests/tutorials/beginner_colab). 13 | 14 | If you simply want to quickly train a model in Google Colab, you can use the 15 | following code snippet. 16 | 17 | ```python 18 | !pip install tensorflow_decision_forests -U -qq 19 | import tensorflow as tf 20 | import tensorflow_decision_forests as tfdf 21 | import pandas as pd 22 | 23 | # Download the dataset, load it into a pandas dataframe and convert it to TensorFlow format. 24 | !wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv 25 | dataset_df = pd.read_csv("/tmp/penguins.csv") 26 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species") 27 | 28 | # Create and train the model 29 | model_1 = tfdf.keras.GradientBoostedTreesModel() 30 | model_1.fit(train_ds) 31 | ``` 32 | 33 | ### Convert the model 34 | 35 | The instructions going forward assume that you have saved your TF-DF model under 36 | the path `/tmp/my_saved_model`. Run the following snippet to convert the model 37 | to TensorFlow.js. 38 | 39 | ```python 40 | !pip install tensorflow tensorflow_decision_forests 'tensorflowjs>=4.4.0' 41 | !pip install tf_keras 42 | 43 | # Prepare and load the model with TensorFlow 44 | import tensorflow as tf 45 | import tensorflowjs as tfjs 46 | from google.colab import files 47 | 48 | # Save the model in the SavedModel format 49 | tf.saved_model.save(model_1, "/tmp/my_saved_model") 50 | 51 | # Convert the SavedModel to TensorFlow.js and save as a zip file 52 | tfjs.converters.tf_saved_model_conversion_v2.convert_tf_saved_model("/tmp/my_saved_model", "./tfjs_model") 53 | 54 | # Download the converted TFJS model 55 | !zip -r tfjs_model.zip tfjs_model/ 56 | files.download("tfjs_model.zip") 57 | ``` 58 | 59 | When Google Colab finishes running, it downloads the converted TFJS model as a 60 | zip file.. Unzip this file before using it in the next step. 61 | 62 | An unzipped Tensorflow.js model consists of a number of files. The example 63 | model contains the following: 64 | 65 | - assets.zip 66 | - group1-shard1of1.bin 67 | - model.json 68 | 69 | 70 | ### Use the Tensorflow.js model on the web 71 | 72 | Use this template to load TFJS dependencies and run the TFDF model. Change the 73 | model path to where your model is served and modify the tensor given to 74 | executeAsync. 75 | 76 | ```html 77 | 78 | 79 | 98 | ``` 99 | 100 | ## Questions? 101 | 102 | Check out the 103 | [TensorFlow Decision Forests documentation](https://www.tensorflow.org/decision_forests) 104 | and the [TensorFlow.js documentation](https://www.tensorflow.org/js/tutorials). 105 | -------------------------------------------------------------------------------- /documentation/tutorials/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | -------------------------------------------------------------------------------- /documentation/tutorials/images/beginner_tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/tutorials/images/beginner_tensorboard.png -------------------------------------------------------------------------------- /documentation/tutorials/images/model_composition_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/tutorials/images/model_composition_structure.png -------------------------------------------------------------------------------- /documentation/tutorials/index.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Decision Forests tutorials 2 | 3 | The following notebooks are available: 4 | 5 | - [Beginner Colab](beginner_colab.ipynb): Learn about the basic about model 6 | training, evaluation and exportation. 7 | - [Ranking Colab](ranking_colab.ipynb): Learn about ranking with decision 8 | forests. 9 | - [Intermediate Colab](intermediate_colab.ipynb): How to consume text and 10 | combine decision forest with neural networks. 11 | - [Advanced Colab](advanced_colab.ipynb): How to inspect and create model 12 | structures directly. 13 | - [Uplifting Colab](uplift_colab.ipynb): Learn about uplift modeling with 14 | decision forests. 15 | - [Model composition Colab](model_composition_colab.ipynb): How to compose 16 | decision forests and neural networks together. 17 | - [Proximities and Prototypes with Random Forests](proximities_colab.ipynb): 18 | Measure the distance between tabular examples and use it to understand a 19 | model and its predictions. 20 | - [Automatic hyper-parameter tuning](automatic_tuning_colab.ipynb): 21 | Automatically select the best hyper-parameters for a model. 22 | - [Making predictions](predict_colab.ipynb): List of options to make 23 | predictions with a model. 24 | -------------------------------------------------------------------------------- /examples/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | py_binary( 7 | name = "minimal", 8 | srcs = ["minimal.py"], 9 | deps = [ 10 | "//tensorflow_decision_forests", 11 | "@pypi_absl_py//:pkg", # :app 12 | "@pypi_numpy//:pkg", 13 | "@pypi_pandas//:pkg", 14 | "@release_or_nightly//:tensorflow_pkg", 15 | "@release_or_nightly//:tf_keras_pkg", 16 | ], 17 | ) 18 | 19 | py_binary( 20 | name = "hyperparameter_optimization", 21 | srcs = ["hyperparameter_optimization.py"], 22 | deps = [ 23 | "//tensorflow_decision_forests", 24 | "@pypi_absl_py//:pkg", # :app 25 | "@pypi_numpy//:pkg", 26 | "@pypi_pandas//:pkg", 27 | "@release_or_nightly//:tensorflow_pkg", 28 | "@release_or_nightly//:tf_keras_pkg", 29 | ], 30 | ) 31 | 32 | py_binary( 33 | name = "distributed_training", 34 | srcs = ["distributed_training.py"], 35 | deps = [ 36 | "//tensorflow_decision_forests", 37 | "@pypi_absl_py//:pkg", # :app 38 | "@pypi_numpy//:pkg", 39 | "@pypi_pandas//:pkg", 40 | "@release_or_nightly//:tensorflow_pkg", 41 | "@release_or_nightly//:tf_keras_pkg", 42 | ], 43 | ) 44 | 45 | py_binary( 46 | name = "distributed_hyperparameter_optimization", 47 | srcs = ["distributed_hyperparameter_optimization.py"], 48 | deps = [ 49 | "//tensorflow_decision_forests", 50 | "@pypi_absl_py//:pkg", # :app 51 | "@release_or_nightly//:tensorflow_pkg", 52 | "@release_or_nightly//:tf_keras_pkg", 53 | ], 54 | ) 55 | -------------------------------------------------------------------------------- /examples/distributed_hyperparameter_optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Example of distributed hyper-parameter optimization with TF-DF. 16 | 17 | This example trains and exports a Gradient Boosted Tree model. 18 | 19 | Usage example: 20 | 21 | You need to configure TF Parameters servers. See: 22 | https://www.tensorflow.org/decision_forests/distributed_training 23 | https://www.tensorflow.org/tutorials/distribute/parameter_server_training 24 | 25 | TF_CONFIG = ... 26 | # Start the workers 27 | # ... 28 | # Run the chief 29 | python3 distributed_hyperparameter_optimization.py 30 | """ 31 | 32 | from absl import app 33 | from absl import logging 34 | 35 | import tensorflow as tf 36 | import tensorflow_decision_forests as tfdf 37 | 38 | 39 | def main(argv): 40 | if len(argv) > 1: 41 | raise app.UsageError("Too many command-line arguments.") 42 | 43 | # Configure training 44 | logging.info("Configure training") 45 | cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 46 | strategy = tf.distribute.experimental.ParameterServerStrategy( 47 | cluster_resolver) 48 | with strategy.scope(): 49 | tuner = tfdf.tuner.RandomSearch( 50 | # 200 trials to find the best hyper-parameters. 51 | num_trials=200, 52 | # Use the pre-defined hyper-parameter space. 53 | use_predefined_hps=True, 54 | # Each model is trained on 4 threads. 55 | trial_num_threads=4) 56 | model = tfdf.keras.GradientBoostedTreesModel( 57 | tuner=tuner, 58 | temp_directory="/cns/bh-d/home/gbm/tmp/ttl=15d/tfdf_cache_dho3", 59 | # Number of threads available on each worker. 60 | num_threads=30, 61 | ) 62 | 63 | # Trains the model. 64 | logging.info("Start tuning") 65 | model.fit_on_dataset_path( 66 | train_path="/cns/is-d/home/gbm/ml_dataset_repository/others/adult/adult_train.csv", 67 | valid_path="/cns/is-d/home/gbm/ml_dataset_repository/others/adult/adult_test.csv", 68 | label_key="income", 69 | dataset_format="csv") 70 | 71 | logging.info("Trained model:") 72 | model.summary() 73 | 74 | # Access to model metrics. 75 | inspector = model.make_inspector() 76 | logging.info("Model self evaluation: %s", inspector.evaluation().to_dict()) 77 | 78 | # Exports the model to disk in the SavedModel format for later re-use. This 79 | # model can be used with TensorFlow Serving and Yggdrasil Decision Forests 80 | # (https://ydf.readthedocs.io/en/latest/serving_apis.html). 81 | logging.info("Export model") 82 | model.save("/cns/bh-d/home/gbm/tmp/ttl=15d/tfdf_model_dho3") 83 | 84 | 85 | if __name__ == "__main__": 86 | app.run(main) 87 | -------------------------------------------------------------------------------- /examples/distributed_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Minimal usage example of Distributed training with TF-DF. 16 | 17 | This example trains and exports a Gradient Boosted Tree model. 18 | 19 | Usage example: 20 | 21 | For this example, we need a large dataset. If you don't have such dataset 22 | available, create a synthetic dataset following the instructions in the 23 | "Synthetic dataset for usage example" below. 24 | 25 | You need to configure TF Parameters servers. See: 26 | https://www.tensorflow.org/decision_forests/distributed_training 27 | https://www.tensorflow.org/tutorials/distribute/parameter_server_training 28 | 29 | TF_CONFIG = ... 30 | # Start the workers 31 | # ... 32 | # Run the chief 33 | python3 distributed_training.py 34 | 35 | Synthetic dataset for usage example: 36 | 37 | In this example, we use a synthetic dataset containing 1M examples. This 38 | dataset is small enought that is could be used without distributed training, 39 | but this is a good example. 40 | 41 | This dataset is generated with the "synthetic_dataset" tool of YDF. 42 | 43 | Create a file "config.pbtxt" with the content: 44 | num_examples:1000000 45 | num_examples_per_shards: 100 46 | num_numerical:100 47 | num_categorical:50 48 | num_categorical_set:0 49 | num_boolean:50 50 | categorical_vocab_size:100 51 | 52 | Then run 53 | 54 | bazel run -c opt \ 55 | //external/ydf/yggdrasil_decision_forests/cli/utils:synthetic_dataset -- \ 56 | --alsologtostderr \ 57 | --options=/config.pbtxt\ 58 | --train=recordio+tfe:/train@60 \ 59 | --valid=recordio+tfe:/valid@20 \ 60 | --test=recordio+tfe:/test@20 \ 61 | --ratio_valid=0.2 \ 62 | --ratio_test=0.2 63 | """ 64 | 65 | import os 66 | from absl import app 67 | from absl import logging 68 | 69 | import tensorflow as tf 70 | import tensorflow_decision_forests as tfdf 71 | 72 | 73 | def main(argv): 74 | if len(argv) > 1: 75 | raise app.UsageError("Too many command-line arguments.") 76 | 77 | # "work_directory" is used to store the temporary checkpoints as well as the 78 | # final model. "work_directory" should be accessible to both the chief and the 79 | # workers. 80 | 81 | work_directory = "/some/remote/directory" 82 | 83 | # Alternatively, You can use a local directory when testing distributed 84 | # training locally i.e. when running the workers in the same machine at the 85 | # chief. See "fake_distributed_training.sh". 86 | # work_directory = "/tmp/tfdf_model" 87 | 88 | # The dataset is provided as a set of sharded files. 89 | train_dataset_path = "/path/to/dataset/train@60" 90 | valid_dataset_path = "/path/to/dataset/valid@60" 91 | dataset_format = "recordio+tfe" 92 | 93 | # Alternatively, when testing distributed training locally, you can use a 94 | # non-sharded dataset. 95 | # train_dataset_path = "external/ydf/yggdrasil_decision_forests/test_data/dataset/adult_train.csv" 96 | # valid_dataset_path = "external/ydf/yggdrasil_decision_forests/test_data/dataset/adult_test.csv" 97 | # dataset_format = "csv" 98 | 99 | # Configure training 100 | logging.info("Configure training") 101 | cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver( 102 | rpc_layer="grpc") 103 | strategy = tf.distribute.experimental.ParameterServerStrategy( 104 | cluster_resolver) 105 | with strategy.scope(): 106 | model = tfdf.keras.DistributedGradientBoostedTreesModel( 107 | # Speed-up training by discretizing numerical features. 108 | force_numerical_discretization=True, 109 | # Cache directory used to store checkpoints. 110 | temp_directory=os.path.join(work_directory, "work_dir"), 111 | # Number of threads on each worker. 112 | num_threads=30, 113 | ) 114 | model.compile(metrics=["accuracy"]) 115 | 116 | # Trains the model. 117 | logging.info("Start training") 118 | model.fit_on_dataset_path( 119 | train_path=train_dataset_path, 120 | valid_path=valid_dataset_path, 121 | label_key="income", 122 | dataset_format=dataset_format) 123 | 124 | logging.info("Trained model:") 125 | model.summary() 126 | 127 | # Access to model metrics. 128 | inspector = model.make_inspector() 129 | logging.info("Model self evaluation: %s", inspector.evaluation().to_dict()) 130 | logging.info("Model training logs: %s", inspector.training_logs()) 131 | inspector.export_to_tensorboard(os.path.join(work_directory, "tensorboard")) 132 | 133 | # Exports the model to disk in the SavedModel format for later re-use. This 134 | # model can be used with TensorFlow Serving and Yggdrasil Decision Forests 135 | # (https://ydf.readthedocs.io/en/latest/serving_apis.html). 136 | logging.info("Export model") 137 | model.save(os.path.join(work_directory, "model")) 138 | 139 | 140 | if __name__ == "__main__": 141 | app.run(main) 142 | -------------------------------------------------------------------------------- /examples/fake_distributed_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | 18 | # Simulate distributed training locally. 19 | # 20 | 21 | echo "Warning: This script does not stop workers. After running it, you need to" 22 | echo "stop the workers manually (e.g. using kill -9)." 23 | 24 | set -vex 25 | 26 | PYTHON=python3 27 | 28 | export TF_CONFIG='{ 29 | "cluster": { 30 | "worker": [ 31 | "localhost:4300", 32 | "localhost:4301"], 33 | "ps": ["localhost:4310"], 34 | "chief": ["localhost:4311"] 35 | } 36 | }' 37 | 38 | bazel build \ 39 | //third_party/tensorflow_decision_forests/examples:distributed_training \ 40 | //third_party/tensorflow_decision_forests/tensorflow/distribute:tensorflow_std_server_py 41 | 42 | WORKER=bazel-bin/third_party/tensorflow_decision_forests/tensorflow/distribute/tensorflow_std_server_py 43 | CHIEF=bazel-bin/third_party/tensorflow_decision_forests/examples/distributed_training 44 | 45 | # Start the workers 46 | ${WORKER} --alsologtostderr --job_name=worker --task_index=0 & 47 | ${WORKER} --alsologtostderr --job_name=worker --task_index=1 & 48 | ${WORKER} --alsologtostderr --job_name=ps --task_index=0 & 49 | 50 | # Start the chief 51 | ${CHIEF} --alsologtostderr 52 | -------------------------------------------------------------------------------- /examples/hyperparameter_optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Example of automated hyper-parameter tuning with TensorFlow Decision Forests. 16 | 17 | This example trains, displays, evaluates and export a Gradient Boosted Tree 18 | model. 19 | 20 | Usage example: 21 | 22 | pip3 install tensorflow_decision_forests -U 23 | python3 hyperparameter_optimization.py 24 | 25 | Or 26 | 27 | bazel run -c opt \ 28 | //tensorflow_decision_forests/examples:hyperparameter_optimization 29 | \ 30 | -- --alsologtostderr 31 | """ 32 | 33 | from absl import app 34 | 35 | import numpy as np 36 | import pandas as pd 37 | import tensorflow as tf 38 | import tensorflow_decision_forests as tfdf 39 | import tf_keras 40 | 41 | def main(argv): 42 | if len(argv) > 1: 43 | raise app.UsageError("Too many command-line arguments.") 44 | 45 | # Download the Adult dataset. 46 | dataset_path = tf_keras.utils.get_file( 47 | "adult.csv", 48 | "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/" 49 | "main/yggdrasil_decision_forests/test_data/dataset/adult.csv") 50 | 51 | # Load a dataset into a Pandas Dataframe. 52 | dataset_df = pd.read_csv(dataset_path) # "df" for Pandas's DataFrame. 53 | 54 | print("First the first three examples:") 55 | print(dataset_df.head(3)) 56 | 57 | # Notice that the dataset contains a mix of numerical and categorical 58 | # features. TensorFlow Decision Forests handles them automatically (e.g. no 59 | # need for one-hot encoding or normalization; except for the label). 60 | 61 | # Split the dataset into a training and a testing dataset. 62 | test_indices = np.random.rand(len(dataset_df)) < 0.30 63 | test_ds_pd = dataset_df[test_indices] 64 | train_ds_pd = dataset_df[~test_indices] 65 | print(f"{len(train_ds_pd)} examples in training" 66 | f", {len(test_ds_pd)} examples for testing.") 67 | 68 | # Converts datasets from Pandas dataframe to TensorFlow dataset format. 69 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label="income") 70 | test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label="income") 71 | 72 | # Tune the model. 73 | # 74 | # The hyper-parameters to optimize are automatically set with 75 | # "use_predefined_hps=True". See 76 | # https://www.tensorflow.org/decision_forests/tutorials/automatic_tuning_colab 77 | # for an example where the hyper-parameter space is configured manually. 78 | tuner = tfdf.tuner.RandomSearch(num_trials=30, use_predefined_hps=True) 79 | model = tfdf.keras.GradientBoostedTreesModel(verbose=2, tuner=tuner) 80 | model.fit(train_ds) 81 | 82 | # Some information about the model. 83 | print(model.summary()) 84 | 85 | # Evaluates the model on the test dataset. 86 | model.compile(metrics=["accuracy"]) 87 | evaluation = model.evaluate(test_ds) 88 | print(f"BinaryCrossentropyloss: {evaluation[0]}") 89 | print(f"Accuracy: {evaluation[1]}") 90 | 91 | # Exports the model to disk in the SavedModel format for later re-use. This 92 | # model can be used with TensorFlow Serving and Yggdrasil Decision Forests 93 | # (https://ydf.readthedocs.io/en/latest/serving_apis.html). 94 | model.save("/tmp/my_saved_model") 95 | 96 | 97 | if __name__ == "__main__": 98 | app.run(main) 99 | -------------------------------------------------------------------------------- /examples/minimal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Minimal usage example of TensorFlow Decision Forests. 16 | 17 | This example trains, displays, evaluates and exports a Gradient Boosted Tree 18 | model. 19 | 20 | Usage example: 21 | 22 | pip3 install tensorflow_decision_forests -U 23 | python3 minimal.py 24 | 25 | Or 26 | 27 | bazel run -c opt \ 28 | //tensorflow_decision_forests/examples:minimal \ 29 | -- --alsologtostderr 30 | """ 31 | 32 | from absl import app 33 | import numpy as np 34 | import pandas as pd 35 | import tensorflow as tf 36 | import tensorflow_decision_forests as tfdf 37 | import tf_keras 38 | 39 | 40 | def main(argv): 41 | if len(argv) > 1: 42 | raise app.UsageError("Too many command-line arguments.") 43 | 44 | # Download the Adult dataset. 45 | dataset_path = tf_keras.utils.get_file( 46 | "adult.csv", 47 | "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/" 48 | "main/yggdrasil_decision_forests/test_data/dataset/adult.csv", 49 | ) 50 | 51 | # Load a dataset into a Pandas Dataframe. 52 | dataset_df = pd.read_csv(dataset_path) # "df" for Pandas's DataFrame. 53 | 54 | print("First the first three examples:") 55 | print(dataset_df.head(3)) 56 | 57 | # Notice that the dataset contains a mix of numerical and categorical 58 | # features. TensorFlow Decision Forests handles them automatically (e.g. no 59 | # need for one-hot encoding or normalization; except for the label). 60 | 61 | # Split the dataset into a training and a testing dataset. 62 | test_indices = np.random.rand(len(dataset_df)) < 0.30 63 | test_ds_pd = dataset_df[test_indices] 64 | train_ds_pd = dataset_df[~test_indices] 65 | print( 66 | f"{len(train_ds_pd)} examples in training" 67 | f", {len(test_ds_pd)} examples for testing." 68 | ) 69 | 70 | # Converts datasets from Pandas dataframe to TensorFlow dataset format. 71 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label="income") 72 | test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label="income") 73 | 74 | # Trains the model. 75 | model = tfdf.keras.GradientBoostedTreesModel(verbose=2) 76 | model.fit(train_ds) 77 | 78 | # Some information about the model. 79 | print(model.summary()) 80 | 81 | # Evaluates the model on the test dataset. 82 | model.compile(metrics=["accuracy"]) 83 | evaluation = model.evaluate(test_ds) 84 | print(f"BinaryCrossentropyloss: {evaluation[0]}") 85 | print(f"Accuracy: {evaluation[1]}") 86 | 87 | # Exports the model to disk in the SavedModel format for later re-use. This 88 | # model can be used with TensorFlow Serving and Yggdrasil Decision Forests 89 | # (https://ydf.readthedocs.io/en/latest/serving_apis.html). 90 | model.save("/tmp/my_saved_model") 91 | 92 | 93 | if __name__ == "__main__": 94 | app.run(main) 95 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/BUILD: -------------------------------------------------------------------------------- 1 | 2 | package( 3 | default_visibility = ["//visibility:public"], 4 | licenses = ["notice"], 5 | ) 6 | 7 | 8 | 9 | exports_files(["LICENSE"]) 10 | 11 | # Do not depend on this rule. Instead, depend on "//tensorflow_decision_forests". 12 | py_library( 13 | name = "tensorflow_decision_forests", 14 | srcs = ["__init__.py"], 15 | visibility = [ 16 | "//learning/lib/ami/simple_ml:__subpackages__", 17 | "//tensorflow_decision_forests:__subpackages__", 18 | "//tensorflow_decision_forests:__subpackages__", 19 | ], 20 | deps = [ 21 | "//tensorflow_decision_forests/component/builder", 22 | "//tensorflow_decision_forests/component/inspector", 23 | "//tensorflow_decision_forests/component/model_plotter", 24 | "//tensorflow_decision_forests/component/py_tree", 25 | "//tensorflow_decision_forests/component/tuner", 26 | "//tensorflow_decision_forests/keras", 27 | "//tensorflow_decision_forests/tensorflow:check_version", 28 | ], 29 | ) 30 | 31 | # If set, use the pre-compiled version of the OP wrappers. This significantly 32 | # speed-up the compilation of the project. 33 | config_setting( 34 | name = "use_precompiled_wrappers", 35 | values = {"define": "use_precompiled_wrappers=1"}, 36 | ) 37 | 38 | # If set, stop the training op when receiving an interrupt signal. 39 | config_setting( 40 | name = "stop_training_on_interrupt", 41 | values = {"define": "stop_training_on_interrupt=1"}, 42 | ) 43 | 44 | # If "disable_tf_ps_distribution_strategy" is true, the TF Parameter Server 45 | # distribution strategy is not available for distributed training. 46 | # 47 | # Distribution with TF PS is currently NOT supported for OSS TF-DF with shared 48 | # build (monolithic build works however) and TF<2.7. In this case, the GRPC 49 | # Worker Server can be used instead. 50 | config_setting( 51 | name = "disable_tf_ps_distribution_strategy", 52 | values = {"define": "tf_ps_distribution_strategy=0"}, 53 | ) 54 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/builder/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | # Libraries 7 | # ========= 8 | 9 | py_library( 10 | name = "builder", 11 | srcs = ["builder.py"], 12 | deps = [ 13 | "//tensorflow_decision_forests/component/inspector", 14 | "//tensorflow_decision_forests/component/inspector:blob_sequence", 15 | "//tensorflow_decision_forests/component/py_tree", 16 | "//tensorflow_decision_forests/keras:core_inference", 17 | "//tensorflow_decision_forests/tensorflow:core_inference", 18 | "@pypi_numpy//:pkg", 19 | "@pypi_six//:pkg", 20 | "@release_or_nightly//:tensorflow_pkg", 21 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", 22 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto", 23 | "@ydf//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_py_proto", 24 | "@ydf//yggdrasil_decision_forests/model/random_forest:random_forest_py_proto", 25 | ], 26 | ) 27 | 28 | # Tests 29 | # ===== 30 | 31 | py_test( 32 | name = "builder_test", 33 | timeout = "long", 34 | srcs = ["builder_test.py"], 35 | data = ["@ydf//yggdrasil_decision_forests/test_data"], 36 | shard_count = 10, 37 | deps = [ 38 | ":builder", 39 | "//tensorflow_decision_forests/component/inspector", 40 | "//tensorflow_decision_forests/component/py_tree", 41 | "//tensorflow_decision_forests/keras", 42 | "@pypi_absl_py//:pkg", # /flags 43 | "@pypi_numpy//:pkg", 44 | "@pypi_pandas//:pkg", 45 | "@release_or_nightly//:tensorflow_pkg", 46 | ], 47 | ) 48 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/builder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/inspector/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | # Libraries 7 | # ========= 8 | 9 | py_library( 10 | name = "inspector", 11 | srcs = ["inspector.py"], 12 | deps = [ 13 | ":blob_sequence", 14 | "//tensorflow_decision_forests/component/py_tree", 15 | "@pypi_pandas//:pkg", 16 | "@pypi_six//:pkg", 17 | "@release_or_nightly//:tensorflow_pkg", 18 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", 19 | "@ydf//yggdrasil_decision_forests/metric:metric_py_proto", 20 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto", 21 | "@ydf//yggdrasil_decision_forests/model:hyperparameter_py_proto", 22 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto", 23 | "@ydf//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_py_proto", 24 | "@ydf//yggdrasil_decision_forests/model/multitasker:multitasker_py_proto", 25 | "@ydf//yggdrasil_decision_forests/model/random_forest:random_forest_py_proto", 26 | ], 27 | ) 28 | 29 | py_library( 30 | name = "blob_sequence", 31 | srcs = ["blob_sequence.py"], 32 | deps = [ 33 | "@release_or_nightly//:tensorflow_pkg", 34 | ], 35 | ) 36 | 37 | # Tests 38 | # ===== 39 | 40 | py_test( 41 | name = "inspector_test", 42 | timeout = "long", 43 | srcs = ["inspector_test.py"], 44 | data = ["@ydf//yggdrasil_decision_forests/test_data"], 45 | shard_count = 10, 46 | deps = [ 47 | ":inspector", 48 | "//tensorflow_decision_forests/component/py_tree", 49 | "//tensorflow_decision_forests/keras", 50 | "@pypi_absl_py//:pkg", # /flags 51 | "@pypi_numpy//:pkg", 52 | "@release_or_nightly//:tensorflow_pkg", 53 | "@ydf//yggdrasil_decision_forests/metric:metric_py_proto", 54 | "@ydf//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_py_proto", 55 | ], 56 | ) 57 | 58 | py_test( 59 | name = "blob_sequence_test", 60 | srcs = ["blob_sequence_test.py"], 61 | data = ["@ydf//yggdrasil_decision_forests/test_data"], 62 | deps = [ 63 | ":blob_sequence", 64 | "@pypi_absl_py//:pkg", # /flags 65 | "@release_or_nightly//:tensorflow_pkg", 66 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto", 67 | ], 68 | ) 69 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/inspector/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/inspector/blob_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Blob Sequence reader and writer. 16 | 17 | A blob sequence is a stream (e.g. a file) containing a sequence of blob (i.e. 18 | chunk of bytes). It can be used to store sequence of serialized protos. 19 | 20 | See yggdrasil_decision_forests/utils/blob_sequence.h for more details. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from typing import Iterator, Optional 28 | 29 | import tensorflow as tf 30 | 31 | 32 | class Reader(object): 33 | """Reader of Blob Sequence files. 34 | 35 | Usage example: 36 | 37 | for blob in Reader(path): 38 | print(blob) 39 | """ 40 | 41 | def __init__(self, path: str): 42 | self.file_: tf.io.gfile.Gfile = None 43 | self.path_ = None 44 | 45 | if path: 46 | self.open(path) 47 | 48 | def open(self, path: str): 49 | """Open Blob sequence file.""" 50 | 51 | self.file_ = tf.io.gfile.GFile(path, "rb") 52 | self.path_ = path 53 | 54 | # Reader header. 55 | magic = self.file_.read(2) 56 | if magic != b"BS": 57 | raise ValueError(f"Invalid blob sequence file {path}") 58 | version = int.from_bytes(self.file_.read(2), byteorder="little") 59 | if version == 0: 60 | reserved = self.file_.read(4) 61 | elif version == 1: 62 | compression = int.from_bytes(self.file_.read(1), byteorder="little") 63 | if compression != 0: 64 | return ValueError( 65 | "The TF-DF inspector does not support this format of model" 66 | " (blob-sequence-v1 with compression). Use the format" 67 | " blob-sequence-v1 without compression instead." 68 | ) 69 | reserved = self.file_.read(3) 70 | else: 71 | raise ValueError( 72 | f"Non supported blob sequence version {version} for file {path}. The" 73 | " model was created with a more recent vesion of YDF / TF-DF." 74 | ) 75 | del reserved 76 | 77 | def close(self): 78 | self.file_.close() 79 | self.path_ = None 80 | self.file_ = None 81 | 82 | def read(self) -> Optional[bytes]: 83 | """Reads and returns the next blob.""" 84 | 85 | raw_length = self.file_.read(4) 86 | if not raw_length: 87 | return None 88 | if len(raw_length) != 4: 89 | raise ValueError(f"Corrupted blob sequence {self.path_}") 90 | length = int.from_bytes(raw_length, byteorder="little") 91 | blob = self.file_.read(length) 92 | if len(blob) != length: 93 | raise ValueError(f"Truncated blob sequence {self.path_}") 94 | return blob 95 | 96 | def __enter__(self): 97 | return self 98 | 99 | def __exit__(self, exc_type, exc_val, exc_tb): 100 | self.file_.close() 101 | 102 | def __iter__(self) -> Iterator[bytes]: 103 | """Iterates overt the BS file content.""" 104 | 105 | # Read blobs 106 | while True: 107 | blob = self.read() 108 | if blob is None: 109 | break 110 | yield blob 111 | 112 | 113 | class Writer(object): 114 | """Writer of Blob Sequence files. 115 | 116 | Usage example: 117 | 118 | bs = Writer(path) 119 | bs.write(b"Hello") 120 | bs.write(b"World") 121 | bs.close() 122 | """ 123 | 124 | def __enter__(self): 125 | return self 126 | 127 | def __exit__(self, exc_type, exc_val, exc_tb): 128 | self.file_.close() 129 | 130 | def __init__(self, path: Optional[str] = None): 131 | self.file_: tf.io.gfile.Gfile = None 132 | self.path_ = None 133 | 134 | if path: 135 | self.open(path) 136 | 137 | def open(self, path: str): 138 | self.file_ = tf.io.gfile.GFile(path, "wb") 139 | self.path_ = path 140 | 141 | self.file_.write(b"BS") 142 | version = 0 143 | self.file_.write(version.to_bytes(2, byteorder="little")) 144 | self.file_.write(b"\0\0\0\0") 145 | 146 | def write(self, blob: bytes): 147 | self.file_.write(len(blob).to_bytes(4, byteorder="little")) 148 | self.file_.write(blob) 149 | 150 | def close(self): 151 | self.file_.close() 152 | self.path_ = None 153 | self.file_ = None 154 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/inspector/blob_sequence_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | 21 | from absl import flags 22 | from absl import logging 23 | from absl.testing import parameterized 24 | import tensorflow as tf 25 | 26 | from tensorflow_decision_forests.component.inspector import blob_sequence 27 | from yggdrasil_decision_forests.model.decision_tree import decision_tree_pb2 28 | 29 | 30 | def data_root_path() -> str: 31 | return "" 32 | 33 | 34 | def test_data_path() -> str: 35 | return os.path.join(data_root_path(), 36 | "external/ydf/yggdrasil_decision_forests/test_data") 37 | 38 | 39 | def tmp_path() -> str: 40 | return flags.FLAGS.test_tmpdir 41 | 42 | 43 | def test_model_directory() -> str: 44 | return os.path.join(test_data_path(), "model") 45 | 46 | 47 | class BlogSequenceTest(parameterized.TestCase, tf.test.TestCase): 48 | 49 | def test_base(self): 50 | path = os.path.join(tmp_path(), "tmp.bs") 51 | 52 | writer = blob_sequence.Writer(path) 53 | writer.write(b"HELLO") 54 | writer.write(b"WORLD") 55 | writer.close() 56 | 57 | reader = blob_sequence.Reader(path) 58 | self.assertEqual([b"HELLO", b"WORLD"], list(reader)) 59 | 60 | def test_node(self): 61 | path = os.path.join(test_model_directory(), "adult_binary_class_rf", 62 | "nodes-00000-of-00001") 63 | num_nodes = 0 64 | for serialized_node in blob_sequence.Reader(path): 65 | node = decision_tree_pb2.Node.FromString(serialized_node) 66 | if num_nodes <= 2: 67 | logging.info("Node: %s", node) 68 | num_nodes += 1 69 | # Matches the number of nodes in "description.txt". 70 | self.assertEqual(num_nodes, 125578) 71 | 72 | 73 | if __name__ == "__main__": 74 | tf.test.main() 75 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/model_plotter/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | # Data 7 | # ==== 8 | 9 | filegroup( 10 | name = "assets", 11 | srcs = ["plotter.js"], 12 | ) 13 | 14 | # Libraries 15 | # ========= 16 | 17 | py_library( 18 | name = "model_plotter", 19 | srcs = ["model_plotter.py"], 20 | data = ["assets"], 21 | deps = [ 22 | "//tensorflow_decision_forests/component/inspector", 23 | "//tensorflow_decision_forests/component/py_tree:condition", 24 | "//tensorflow_decision_forests/component/py_tree:node", 25 | "//tensorflow_decision_forests/component/py_tree:tree", 26 | "//tensorflow_decision_forests/component/py_tree:value", 27 | "@release_or_nightly//:tensorflow_pkg", 28 | ], 29 | ) 30 | 31 | # Tests 32 | # ===== 33 | 34 | py_test( 35 | name = "model_plotter_test", 36 | srcs = ["model_plotter_test.py"], 37 | deps = [ 38 | ":model_plotter", 39 | "//tensorflow_decision_forests/component/py_tree:condition", 40 | "//tensorflow_decision_forests/component/py_tree:dataspec", 41 | "//tensorflow_decision_forests/component/py_tree:node", 42 | "//tensorflow_decision_forests/component/py_tree:tree", 43 | "//tensorflow_decision_forests/component/py_tree:value", 44 | "@pypi_absl_py//:pkg", # /flags 45 | "@release_or_nightly//:tensorflow_pkg", 46 | ], 47 | ) 48 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/model_plotter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/model_plotter/model_plotter_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | from absl import flags 21 | from absl import logging 22 | from absl.testing import parameterized 23 | import tensorflow as tf 24 | 25 | from tensorflow_decision_forests.component.model_plotter import model_plotter 26 | from tensorflow_decision_forests.component.py_tree import condition as condition_lib 27 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib 28 | from tensorflow_decision_forests.component.py_tree import node as node_lib 29 | from tensorflow_decision_forests.component.py_tree import tree as tree_lib 30 | from tensorflow_decision_forests.component.py_tree import value as value_lib 31 | 32 | 33 | class ModelPlotterTest(parameterized.TestCase, tf.test.TestCase): 34 | 35 | def _save_plot(self, plot): 36 | plot_path = os.path.join(self.get_temp_dir(), "plot.html") 37 | logging.info("Plot to %s", plot_path) 38 | with open(plot_path, "w") as f: 39 | f.write(plot) 40 | 41 | def test_empty_tree(self): 42 | tree = tree_lib.Tree(None) 43 | plot = model_plotter.plot_tree(tree=tree) 44 | self._save_plot(plot) 45 | 46 | def test_basic_tree(self): 47 | tree = tree_lib.Tree( 48 | node_lib.NonLeafNode( 49 | condition=condition_lib.NumericalHigherThanCondition( 50 | feature=dataspec_lib.SimpleColumnSpec( 51 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL 52 | ), 53 | threshold=1.5, 54 | missing_evaluation=False, 55 | ), 56 | pos_child=node_lib.LeafNode( 57 | value=value_lib.RegressionValue( 58 | value=5.0, num_examples=10, standard_deviation=1.0 59 | ) 60 | ), 61 | neg_child=node_lib.LeafNode( 62 | value=value_lib.ProbabilityValue( 63 | probability=[0.5, 0.4, 0.1], num_examples=10 64 | ) 65 | ), 66 | ) 67 | ) 68 | plot = model_plotter.plot_tree( 69 | tree=tree, display_options=model_plotter.DisplayOptions(node_x_size=150) 70 | ) 71 | self._save_plot(plot) 72 | 73 | def test_basic_tree_with_label_classes(self): 74 | tree = tree_lib.Tree( 75 | node_lib.NonLeafNode( 76 | condition=condition_lib.NumericalHigherThanCondition( 77 | feature=dataspec_lib.SimpleColumnSpec( 78 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL 79 | ), 80 | threshold=1.5, 81 | missing_evaluation=False, 82 | ), 83 | pos_child=node_lib.LeafNode( 84 | value=value_lib.RegressionValue( 85 | value=5.0, num_examples=10, standard_deviation=1.0 86 | ) 87 | ), 88 | neg_child=node_lib.LeafNode( 89 | value=value_lib.ProbabilityValue( 90 | probability=[0.5, 0.4, 0.1], num_examples=10 91 | ) 92 | ), 93 | ), 94 | label_classes=["x", "y", "z"], 95 | ) 96 | plot = model_plotter.plot_tree(tree=tree) 97 | self._save_plot(plot) 98 | 99 | def test_basic_tree_with_uplift(self): 100 | tree = tree_lib.Tree( 101 | node_lib.LeafNode( 102 | value=value_lib.UpliftValue( 103 | treatment_effect=[0.5, 0.4], num_examples=10 104 | ) 105 | ) 106 | ) 107 | plot = model_plotter.plot_tree(tree=tree) 108 | self._save_plot(plot) 109 | 110 | 111 | if __name__ == "__main__": 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_python//python:defs.bzl", "py_library") 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | licenses = ["notice"], 6 | ) 7 | 8 | # All the components 9 | # ========= 10 | 11 | py_library( 12 | name = "py_tree", 13 | srcs = ["__init__.py"], 14 | deps = [ 15 | ":condition", 16 | ":dataspec", 17 | ":node", 18 | ":objective", 19 | ":tree", 20 | ":value", 21 | ], 22 | ) 23 | 24 | # Libraries 25 | # ========= 26 | 27 | py_library( 28 | name = "tree", 29 | srcs = ["tree.py"], 30 | deps = [":node"], 31 | ) 32 | 33 | py_library( 34 | name = "node", 35 | srcs = ["node.py"], 36 | deps = [ 37 | ":condition", 38 | ":value", 39 | "@pypi_absl_py//:pkg", # /logging 40 | "@pypi_numpy//:pkg", 41 | "@pypi_six//:pkg", 42 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", 43 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto", 44 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto", 45 | ], 46 | ) 47 | 48 | py_library( 49 | name = "condition", 50 | srcs = ["condition.py"], 51 | deps = [ 52 | ":dataspec", 53 | "@pypi_absl_py//:pkg", # /logging 54 | "@pypi_numpy//:pkg", 55 | "@pypi_six//:pkg", 56 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", 57 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto", 58 | ], 59 | ) 60 | 61 | py_library( 62 | name = "value", 63 | srcs = ["value.py"], 64 | deps = [ 65 | "@pypi_numpy//:pkg", 66 | "@pypi_six//:pkg", 67 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto", 68 | ], 69 | ) 70 | 71 | py_library( 72 | name = "dataspec", 73 | srcs = ["dataspec.py"], 74 | deps = [ 75 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", 76 | ], 77 | ) 78 | 79 | py_library( 80 | name = "objective", 81 | srcs = ["objective.py"], 82 | deps = [ 83 | "@pypi_six//:pkg", 84 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto", 85 | ], 86 | ) 87 | 88 | # Tests 89 | # ===== 90 | 91 | py_test( 92 | name = "tree_test", 93 | srcs = ["tree_test.py"], 94 | deps = [ 95 | ":condition", 96 | ":dataspec", 97 | ":node", 98 | ":tree", 99 | ":value", 100 | "@pypi_absl_py//:pkg", # /logging 101 | "@release_or_nightly//:tensorflow_pkg", 102 | ], 103 | ) 104 | 105 | py_test( 106 | name = "node_test", 107 | srcs = ["node_test.py"], 108 | deps = [ 109 | ":condition", 110 | ":dataspec", 111 | ":node", 112 | ":value", 113 | "@pypi_absl_py//:pkg", # /logging 114 | "@release_or_nightly//:tensorflow_pkg", 115 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", 116 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto", 117 | ], 118 | ) 119 | 120 | py_test( 121 | name = "condition_test", 122 | srcs = ["condition_test.py"], 123 | deps = [ 124 | ":condition", 125 | ":dataspec", 126 | "@pypi_absl_py//:pkg", # /logging 127 | "@release_or_nightly//:tensorflow_pkg", 128 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", 129 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto", 130 | ], 131 | ) 132 | 133 | py_test( 134 | name = "dataspec_test", 135 | srcs = ["dataspec_test.py"], 136 | deps = [ 137 | ":dataspec", 138 | "@pypi_absl_py//:pkg", # /testing:parameterized 139 | "@release_or_nightly//:tensorflow_pkg", 140 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", 141 | ], 142 | ) 143 | 144 | py_test( 145 | name = "value_test", 146 | srcs = ["value_test.py"], 147 | deps = [ 148 | ":value", 149 | "@pypi_absl_py//:pkg", # /logging 150 | "@release_or_nightly//:tensorflow_pkg", 151 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto", 152 | ], 153 | ) 154 | 155 | py_test( 156 | name = "objective_test", 157 | srcs = ["objective_test.py"], 158 | deps = [ 159 | ":dataspec", 160 | ":objective", 161 | "@pypi_absl_py//:pkg", # /logging 162 | "@release_or_nightly//:tensorflow_pkg", 163 | ], 164 | ) 165 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/README.md: -------------------------------------------------------------------------------- 1 | This directory contains a python implementation of decision trees, that can be 2 | used to inspect and manipulate the decision forests trained with TF-DF using 3 | python. 4 | 5 | The internal and efficient decision forest representation can be converted back 6 | and forth to this python representation for inspection and even arbitrary 7 | manipulation. 8 | 9 | Note: as one can expect this python representation is more explicit but much 10 | less efficient than the Proto model (also referred as **core model**) used 11 | internally by TF-DF. Complex algorithms should preferably be implemented in C++ 12 | on the Proto model. 13 | 14 | A `Tree` is composed of a single `AbstractNode` called the *root*. This node and 15 | its children recursively defines a decision tree. 16 | 17 | If a node is a `LeafNode`, it contains an `AbstractValue` defining the 18 | output/prediction/value of the node. Depending on the tree type, the 19 | `AbstractValue` can be a `ProbabilityValue` or a `RegressionValue`. 20 | 21 | If this node is `NonLeafNode`, it contains an `AbstractCondition`, two children 22 | nodes, and optionally an `AbstractValue` for model interpretation. 23 | 24 | Different implementations of `AbstractCondition` support different types of 25 | conditions e.g. `NumericalHigherThanCondition`, `CategoricalIsInCondition`. 26 | 27 | All objects can be printed using `str(...)`. 28 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Decision trees stored as python objects. 16 | 17 | To be used with the model inspector and model builder. 18 | """ 19 | 20 | from tensorflow_decision_forests.component.py_tree import condition 21 | from tensorflow_decision_forests.component.py_tree import dataspec 22 | from tensorflow_decision_forests.component.py_tree import node 23 | from tensorflow_decision_forests.component.py_tree import objective 24 | from tensorflow_decision_forests.component.py_tree import tree 25 | from tensorflow_decision_forests.component.py_tree import value 26 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/condition_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from absl.testing import parameterized 20 | import tensorflow as tf 21 | 22 | from tensorflow_decision_forests.component.py_tree import condition as condition_lib 23 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib 24 | from yggdrasil_decision_forests.dataset import data_spec_pb2 25 | from yggdrasil_decision_forests.model.decision_tree import decision_tree_pb2 26 | 27 | 28 | class ConditionTest(parameterized.TestCase, tf.test.TestCase): 29 | 30 | def test_column_spec_bitmap_to_items_integer(self): 31 | column_spec = data_spec_pb2.Column() 32 | column_spec.categorical.number_of_unique_values = 10 33 | column_spec.categorical.is_already_integerized = True 34 | # b1100101101 => 32Dh 35 | self.assertEqual( 36 | condition_lib.column_spec_bitmap_to_items(column_spec, b"\x2D\x03"), 37 | [0, 2, 3, 5, 8, 9]) 38 | 39 | def test_column_spec_bitmap_to_items_string(self): 40 | column_spec = data_spec_pb2.Column() 41 | column_spec.categorical.number_of_unique_values = 10 42 | for i in range(10): 43 | column_spec.categorical.items[f"item_{i}"].index = i 44 | column_spec.categorical.is_already_integerized = False 45 | # 1100101101b => 32Dh 46 | self.assertEqual( 47 | condition_lib.column_spec_bitmap_to_items(column_spec, b"\x2D\x03"), 48 | ["item_0", "item_2", "item_3", "item_5", "item_8", "item_9"]) 49 | 50 | def test_core_condition_to_condition_is_missing(self): 51 | core_condition = decision_tree_pb2.NodeCondition() 52 | core_condition.na_value = False 53 | core_condition.attribute = 0 54 | core_condition.condition.na_condition.SetInParent() 55 | 56 | dataspec = data_spec_pb2.DataSpecification() 57 | column_spec = dataspec.columns.add() 58 | column_spec.name = "a" 59 | column_spec.type = dataspec_lib.ColumnType.NUMERICAL 60 | 61 | attribute = dataspec_lib.SimpleColumnSpec("a", 62 | dataspec_lib.ColumnType.NUMERICAL, 63 | 0) 64 | self.assertEqual( 65 | condition_lib.core_condition_to_condition(core_condition, dataspec), 66 | condition_lib.IsMissingInCondition(attribute)) 67 | 68 | 69 | if __name__ == "__main__": 70 | tf.test.main() 71 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/dataspec_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import math 20 | from absl.testing import parameterized 21 | import tensorflow as tf 22 | 23 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib 24 | from yggdrasil_decision_forests.dataset import data_spec_pb2 25 | 26 | 27 | def toy_dataspec(): 28 | dataspec = data_spec_pb2.DataSpecification() 29 | 30 | f1 = dataspec.columns.add() 31 | f1.name = "f1" 32 | f1.type = data_spec_pb2.ColumnType.NUMERICAL 33 | 34 | f2 = dataspec.columns.add() 35 | f2.name = "f2" 36 | f2.type = data_spec_pb2.ColumnType.CATEGORICAL 37 | f2.categorical.number_of_unique_values = 3 38 | f2.categorical.items[""].index = 0 39 | f2.categorical.items["x"].index = 1 40 | f2.categorical.items["y"].index = 2 41 | 42 | f3 = dataspec.columns.add() 43 | f3.name = "f3" 44 | f3.type = data_spec_pb2.ColumnType.CATEGORICAL 45 | f3.categorical.number_of_unique_values = 3 46 | f3.categorical.is_already_integerized = True 47 | 48 | f4 = dataspec.columns.add() 49 | f4.name = "f4" 50 | f4.type = data_spec_pb2.ColumnType.DISCRETIZED_NUMERICAL 51 | f4.discretized_numerical.boundaries[:] = [0, 1, 2] 52 | return dataspec 53 | 54 | 55 | class DataspecTest(parameterized.TestCase, tf.test.TestCase): 56 | 57 | def test_make_simple_column_spec(self): 58 | self.assertEqual( 59 | dataspec_lib.make_simple_column_spec(toy_dataspec(), 0), 60 | dataspec_lib.SimpleColumnSpec( 61 | name="f1", type=data_spec_pb2.ColumnType.NUMERICAL, col_idx=0)) 62 | 63 | def test_categorical_value_idx_to_value(self): 64 | dataspec = toy_dataspec() 65 | self.assertEqual( 66 | dataspec_lib.categorical_value_idx_to_value(dataspec.columns[1], 1), 67 | "x") 68 | 69 | self.assertEqual( 70 | dataspec_lib.categorical_value_idx_to_value(dataspec.columns[2], 1), 1) 71 | 72 | def test_discretized_numerical_to_numerical(self): 73 | column_spec = toy_dataspec().columns[3] 74 | self.assertEqual( 75 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 0), 0 - 1) 76 | self.assertEqual( 77 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 1), 0.5) 78 | self.assertEqual( 79 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 2), 1.5) 80 | self.assertEqual( 81 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 3), 2 + 1) 82 | self.assertTrue( 83 | math.isnan( 84 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 4))) 85 | 86 | def test_categorical_column_dictionary_to_list(self): 87 | dataspec = toy_dataspec() 88 | self.assertEqual( 89 | dataspec_lib.categorical_column_dictionary_to_list(dataspec.columns[1]), 90 | ["", "x", "y"]) 91 | 92 | def test_column_name_to_column_idx(self): 93 | dataspec = toy_dataspec() 94 | self.assertEqual(dataspec_lib.column_name_to_column_idx("f1", dataspec), 0) 95 | self.assertEqual(dataspec_lib.column_name_to_column_idx("f2", dataspec), 1) 96 | 97 | 98 | if __name__ == "__main__": 99 | tf.test.main() 100 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/node_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from absl import logging 20 | from absl.testing import parameterized 21 | import tensorflow as tf 22 | 23 | from tensorflow_decision_forests.component.py_tree import condition as condition_lib 24 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib 25 | from tensorflow_decision_forests.component.py_tree import node as node_lib 26 | from tensorflow_decision_forests.component.py_tree import value as value_lib 27 | 28 | from yggdrasil_decision_forests.dataset import data_spec_pb2 29 | from yggdrasil_decision_forests.model.decision_tree import decision_tree_pb2 30 | 31 | 32 | class NodeTest(parameterized.TestCase, tf.test.TestCase): 33 | 34 | def test_leaf_value_integer(self): 35 | node = node_lib.LeafNode(5) 36 | self.assertEqual(node.value, 5) 37 | self.assertIsNone(node.leaf_idx) 38 | 39 | def test_leaf_idx(self): 40 | node = node_lib.LeafNode(5, 2) 41 | self.assertEqual(node.value, 5) 42 | self.assertEqual(node.leaf_idx, 2) 43 | 44 | def test_leaf_regression_value(self): 45 | node = node_lib.LeafNode( 46 | value=value_lib.RegressionValue( 47 | value=5.0, num_examples=10, standard_deviation=1.0 48 | ) 49 | ) 50 | core_node = decision_tree_pb2.Node( 51 | regressor=decision_tree_pb2.NodeRegressorOutput(top_value=5.0) 52 | ) 53 | dist = core_node.regressor.distribution 54 | dist.count = 10.0 55 | dist.sum = 0 56 | dist.sum_squares = 10.0 57 | self.assertEqual( 58 | node_lib.node_to_core_node(node, data_spec_pb2.DataSpecification()), 59 | core_node, 60 | ) 61 | logging.info("node:\n%s", node) 62 | 63 | def test_non_leaf_without_children(self): 64 | node = node_lib.NonLeafNode( 65 | condition=condition_lib.NumericalHigherThanCondition( 66 | feature=dataspec_lib.SimpleColumnSpec( 67 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL 68 | ), 69 | threshold=1.5, 70 | missing_evaluation=False, 71 | ) 72 | ) 73 | logging.info("node:\n%s", node) 74 | 75 | def test_non_leaf_with_children(self): 76 | node = node_lib.NonLeafNode( 77 | condition=condition_lib.NumericalHigherThanCondition( 78 | feature=dataspec_lib.SimpleColumnSpec( 79 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL 80 | ), 81 | threshold=1.5, 82 | missing_evaluation=False, 83 | ), 84 | pos_child=node_lib.LeafNode( 85 | value=value_lib.RegressionValue( 86 | value=5.0, num_examples=10, standard_deviation=1.0 87 | ) 88 | ), 89 | neg_child=node_lib.LeafNode( 90 | value=value_lib.ProbabilityValue( 91 | probability=[0.5, 0.4, 0.1], num_examples=10 92 | ) 93 | ), 94 | ) 95 | logging.info("node:\n%s", node) 96 | 97 | def test_condition_value_and_default_categorical_contains(self): 98 | a = node_lib.ScanStructureAccumulatorContainsInt() 99 | self.assertEqual(a.get_global_imutation_and_num_unique_values(), (1, 2)) 100 | 101 | a.add_not_allowed([1, 3]) 102 | self.assertEqual(a.get_global_imutation_and_num_unique_values(), (2, 4)) 103 | 104 | a.add_not_allowed([2]) 105 | self.assertEqual(a.get_global_imutation_and_num_unique_values(), (4, 5)) 106 | 107 | a.add_allowed([1, 6]) 108 | self.assertEqual(a.get_global_imutation_and_num_unique_values(), (6, 7)) 109 | 110 | 111 | if __name__ == "__main__": 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/objective_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from absl import logging 20 | from absl.testing import parameterized 21 | import tensorflow as tf 22 | 23 | from tensorflow_decision_forests.component.py_tree import objective as objective_lib 24 | 25 | # pylint: disable=g-long-lambda 26 | 27 | 28 | class ObjectiveTest(parameterized.TestCase, tf.test.TestCase): 29 | 30 | def test_classification(self): 31 | objective = objective_lib.ClassificationObjective( 32 | label="label", num_classes=5) 33 | logging.info("objective: %s", objective) 34 | 35 | objective = objective_lib.ClassificationObjective( 36 | label="label", classes=["a", "b"]) 37 | logging.info("objective: %s", objective) 38 | 39 | objective = objective_lib.ClassificationObjective( 40 | label="label", classes=["a", "b"]) 41 | logging.info("objective: %s", objective) 42 | 43 | objective = objective_lib.ClassificationObjective( 44 | label="label", classes=["a", "b"], num_classes=2) 45 | logging.info("objective: %s", objective) 46 | 47 | def test_classification_errors(self): 48 | self.assertRaises( 49 | ValueError, 50 | lambda: objective_lib.ClassificationObjective(label="label")) 51 | self.assertRaises( 52 | ValueError, 53 | lambda: objective_lib.ClassificationObjective(label="", num_classes=5)) 54 | self.assertRaises( 55 | ValueError, lambda: objective_lib.ClassificationObjective( 56 | label="label", classes=["a", "b"], num_classes=5)) 57 | self.assertRaises( 58 | ValueError, lambda: objective_lib.ClassificationObjective( 59 | label="label", classes=[])) 60 | 61 | def test_regression(self): 62 | objective = objective_lib.RegressionObjective(label="label") 63 | logging.info("objective: %s", objective) 64 | 65 | def test_ranking(self): 66 | objective = objective_lib.RankingObjective(label="label", group="group") 67 | logging.info("objective: %s", objective) 68 | 69 | def test_numerical_uplift(self): 70 | objective = objective_lib.NumericalUpliftObjective( 71 | label="label", treatment="treatment" 72 | ) 73 | logging.info("objective: %s", objective) 74 | 75 | def test_categorical_uplift(self): 76 | objective = objective_lib.CategoricalUpliftObjective( 77 | label="label", treatment="treatment" 78 | ) 79 | logging.info("objective: %s", objective) 80 | 81 | def test_uplift_objects_are_not_equal(self): 82 | numerical_objective = objective_lib.NumericalUpliftObjective( 83 | label="label", treatment="treatment" 84 | ) 85 | categorical_objective = objective_lib.CategoricalUpliftObjective( 86 | label="label", treatment="treatment" 87 | ) 88 | self.assertNotEqual(numerical_objective, categorical_objective) 89 | 90 | 91 | if __name__ == "__main__": 92 | tf.test.main() 93 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A decision tree.""" 16 | 17 | from typing import Optional, List 18 | 19 | from tensorflow_decision_forests.component.py_tree import node as node_lib 20 | 21 | 22 | class Tree(object): 23 | """A single decision tree.""" 24 | 25 | def __init__(self, 26 | root: Optional[node_lib.AbstractNode], 27 | label_classes: Optional[List[str]] = None): 28 | self._root = root 29 | self._label_classes = label_classes 30 | 31 | @property 32 | def root(self): 33 | return self._root 34 | 35 | @property 36 | def label_classes(self): 37 | return self._label_classes 38 | 39 | def __repr__(self): 40 | """Returns an inline string representation of a tree.""" 41 | 42 | root_str = repr(self._root) if self._root else "None" 43 | return f"Tree(root={root_str}, label_classes={self.label_classes})" 44 | 45 | def __str__(self): 46 | # This method target users that try to debug or interpret trees. 47 | return self.pretty() 48 | 49 | def pretty(self, max_depth: Optional[int] = 4) -> str: 50 | """Returns a readable textual representation of the tree. 51 | 52 | Unlike `repr(tree)`, `tree.pretty()` format the representation (line return, 53 | margin, hide class names) to improve readability. 54 | 55 | This representation can be changed and codes should not try to parse the 56 | output of `pretty`. To access programmatically the tree structure, use 57 | `root()`. 58 | 59 | Args: 60 | max_depth: The maximum depth of the nodes to display. Deeper nodes are 61 | skipped and replaced by "...". If not specified, prints the entire tree. 62 | 63 | Returns: 64 | A pretty-string representing the tree. 65 | """ 66 | 67 | content = "" 68 | if self._root: 69 | content += self._root.pretty( 70 | prefix="", is_pos=None, depth=1, max_depth=max_depth) 71 | else: 72 | content += "No root\n" 73 | if self._label_classes is not None: 74 | content += f"Label classes: {self.label_classes}\n" 75 | return content 76 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from absl import logging 20 | from absl.testing import parameterized 21 | import tensorflow as tf 22 | 23 | from tensorflow_decision_forests.component.py_tree import condition as condition_lib 24 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib 25 | from tensorflow_decision_forests.component.py_tree import node as node_lib 26 | from tensorflow_decision_forests.component.py_tree import tree as tree_lib 27 | from tensorflow_decision_forests.component.py_tree import value as value_lib 28 | 29 | 30 | class TreeTest(parameterized.TestCase, tf.test.TestCase): 31 | 32 | def test_empty_tree(self): 33 | tree = tree_lib.Tree(None) 34 | logging.info("Tree:\n%s", tree) 35 | 36 | def test_basic_tree(self): 37 | tree = tree_lib.Tree( 38 | node_lib.NonLeafNode( 39 | condition=condition_lib.NumericalHigherThanCondition( 40 | feature=dataspec_lib.SimpleColumnSpec( 41 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL), 42 | threshold=1.5, 43 | missing_evaluation=False), 44 | pos_child=node_lib.LeafNode( 45 | value=value_lib.RegressionValue( 46 | value=5.0, num_examples=10, standard_deviation=1.0)), 47 | neg_child=node_lib.LeafNode( 48 | value=value_lib.ProbabilityValue( 49 | probability=[0.5, 0.4, 0.1], num_examples=10)))) 50 | 51 | tree_repr = repr(tree) 52 | logging.info("Tree repr:\n%s", tree_repr) 53 | # The "repr" is a single line that does not contain any line return. 54 | self.assertNotIn("\n", tree_repr) 55 | 56 | logging.info("Tree str:\n%s", tree) 57 | 58 | pretty = tree.pretty() 59 | logging.info("Pretty:\n%s", pretty) 60 | 61 | self.assertEqual( 62 | pretty, """(f1 >= 1.5; miss=False, score=None) 63 | ├─(pos)─ RegressionValue(value=5.0,sd=1.0,n=10) 64 | └─(neg)─ ProbabilityValue([0.5, 0.4, 0.1],n=10) 65 | """) 66 | 67 | def test_stump_with_label_classes(self): 68 | tree = tree_lib.Tree( 69 | node_lib.LeafNode( 70 | value=value_lib.ProbabilityValue( 71 | probability=[0.5, 0.4, 0.1], num_examples=10)), 72 | label_classes=["a", "b", "c"]) 73 | logging.info("Tree:\n%s", tree) 74 | 75 | 76 | if __name__ == "__main__": 77 | tf.test.main() 78 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/py_tree/value_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from absl import logging 20 | from absl.testing import parameterized 21 | import tensorflow as tf 22 | 23 | from tensorflow_decision_forests.component.py_tree import value as value_lib 24 | from yggdrasil_decision_forests.model.decision_tree import decision_tree_pb2 25 | 26 | 27 | class ValueTest(parameterized.TestCase, tf.test.TestCase): 28 | 29 | def test_regression(self): 30 | value = value_lib.RegressionValue( 31 | value=5.0, num_examples=10, standard_deviation=1.0 32 | ) 33 | logging.info("value:\n%s", value) 34 | 35 | def test_probability(self): 36 | value = value_lib.ProbabilityValue( 37 | probability=[0.5, 0.4, 0.1], num_examples=10 38 | ) 39 | logging.info("value:\n%s", value) 40 | 41 | def test_uplift(self): 42 | value = value_lib.UpliftValue(treatment_effect=[1, 2], num_examples=10) 43 | logging.info("value:\n%s", value) 44 | 45 | def test_core_value_to_value_classifier(self): 46 | core_node = decision_tree_pb2.Node() 47 | core_node.classifier.distribution.counts[:] = [0.0, 8.0, 2.0] 48 | core_node.classifier.distribution.sum = 10.0 49 | self.assertEqual( 50 | value_lib.core_value_to_value(core_node), 51 | value_lib.ProbabilityValue(probability=[0.8, 0.2], num_examples=10), 52 | ) 53 | 54 | def test_core_value_to_value_regressor(self): 55 | core_node = decision_tree_pb2.Node() 56 | core_node.regressor.top_value = 1.0 57 | core_node.regressor.distribution.sum = 10.0 58 | core_node.regressor.distribution.sum_squares = 20.0 59 | core_node.regressor.distribution.count = 10.0 60 | self.assertEqual( 61 | value_lib.core_value_to_value(core_node), 62 | value_lib.RegressionValue( 63 | value=1.0, num_examples=10, standard_deviation=1.0 64 | ), 65 | ) 66 | 67 | def test_core_value_to_value_uplift(self): 68 | core_node = decision_tree_pb2.Node() 69 | core_node.uplift.treatment_effect[:] = [0.0, 8.0, 2.0] 70 | core_node.uplift.sum_weights = 10.0 71 | self.assertEqual( 72 | value_lib.core_value_to_value(core_node), 73 | value_lib.UpliftValue( 74 | treatment_effect=[0.0, 8.0, 2.0], num_examples=10.0 75 | ), 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | tf.test.main() 81 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/tuner/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | # Libraries 7 | # ========= 8 | 9 | py_library( 10 | name = "tuner", 11 | srcs = ["tuner.py"], 12 | deps = [ 13 | "@ydf//yggdrasil_decision_forests/learner/hyperparameters_optimizer:hyperparameters_optimizer_py_proto", 14 | "@ydf//yggdrasil_decision_forests/learner/hyperparameters_optimizer/optimizers:random_py_proto", 15 | "@ydf//yggdrasil_decision_forests/model:hyperparameter_py_proto", 16 | ], 17 | ) 18 | 19 | # Tests 20 | # ===== 21 | 22 | py_test( 23 | name = "tuner_test", 24 | timeout = "long", 25 | srcs = ["tuner_test.py"], 26 | data = ["@ydf//yggdrasil_decision_forests/test_data"], 27 | shard_count = 4, 28 | deps = [ 29 | ":tuner", 30 | "//tensorflow_decision_forests/keras", 31 | "@com_google_protobuf//:python_srcs", 32 | "@pypi_absl_py//:pkg", # /flags 33 | "@pypi_pandas//:pkg", 34 | "@release_or_nightly//:tensorflow_pkg", 35 | "@ydf//yggdrasil_decision_forests/learner:abstract_learner_py_proto", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/component/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Decision Forests Contributions 2 | 3 | The Contrib(ution) directory contains functionalities that are useful for TF-DF 4 | users but which are not central to the TF-DF library. For example, it could 5 | contain decision forest scientific paper implementations, utilities to interact 6 | with other libraries (e.g. decision forests, IO), and model agnostic tools that 7 | do not have another nice place to be. 8 | 9 | Contrib libraries may not be as high-quality up-to-date as the rest of the TF-DF 10 | core library, and updates may rely on contributors outside the TF-DF core team. 11 | 12 | **Rules** 13 | 14 | 1. There are no dependencies from the core library to the contrib libraries. 15 | 1. The contrib libraries are not loaded automatically with the TF-DF library. 16 | 1. A dependency cannot be registered in `configure/setup.py` if it is only used 17 | by a contrib library. 18 | 1. If a contrib library becomes important enough (decided by the TF-DF core 19 | team), it can be moved to the `component` directory. 20 | 1. The contrib directory is not a place to store usage examples. 21 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/example_of_contrib/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | # Libraries 7 | # ========= 8 | 9 | py_library( 10 | name = "example_of_contrib", 11 | srcs = [ 12 | "__init__.py", 13 | "example_of_contrib.py", 14 | ], 15 | deps = [ 16 | "//tensorflow_decision_forests", 17 | ], 18 | ) 19 | 20 | # Tests 21 | # ===== 22 | 23 | py_test( 24 | name = "example_of_contrib_test", 25 | srcs = ["example_of_contrib_test.py"], 26 | deps = [ 27 | ":example_of_contrib", 28 | "@release_or_nightly//:tensorflow_pkg", 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/example_of_contrib/README.md: -------------------------------------------------------------------------------- 1 | # Example of Contrib 2 | 3 | ## Introduction 4 | 5 | The "Example of Contrib" shows how to add a library in `contrib`. 6 | 7 | ## Usage example 8 | 9 | ```python 10 | import tensorflow_decision_forests as tfdf 11 | from tensorflow_decision_forests.contrib import example_of_contrib 12 | 13 | print(example_of_contrib.my_function()) 14 | ``` 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/example_of_contrib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of Contrib. 16 | 17 | Usage example: 18 | 19 | ```python 20 | import tensorflow_decision_forests as tfdf 21 | from tensorflow_decision_forests.contrib import example_of_contrib 22 | 23 | print(example_of_contrib.my_function()) 24 | ``` 25 | """ 26 | 27 | from tensorflow_decision_forests.contrib.example_of_contrib import example_of_contrib as lib 28 | 29 | my_function = lib.my_function 30 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/example_of_contrib/example_of_contrib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tensorflow_decision_forests as tfdf 16 | 17 | 18 | def my_function() -> str: 19 | print(f"The current version of tfdf is {tfdf.__version__}") 20 | return "Hello world" 21 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/example_of_contrib/example_of_contrib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for example_of_contrib.""" 16 | 17 | import tensorflow as tf 18 | from tensorflow_decision_forests.contrib import example_of_contrib 19 | 20 | 21 | class ExampleOfContribTest(tf.test.TestCase): 22 | 23 | def test_my_function(self): 24 | self.assertEqual(example_of_contrib.my_function(), "Hello world") 25 | 26 | 27 | if __name__ == "__main__": 28 | tf.test.main() 29 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/scikit_learn_model_converter/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | # Libraries 7 | # ========= 8 | 9 | py_library( 10 | name = "scikit_learn_model_converter", 11 | srcs = [ 12 | "__init__.py", 13 | "scikit_learn_model_converter.py", 14 | ], 15 | deps = [ 16 | "//tensorflow_decision_forests", 17 | "@pypi_scikit_learn//:pkg", 18 | "@release_or_nightly//:tensorflow_pkg", 19 | ], 20 | ) 21 | 22 | # Tests 23 | # ===== 24 | 25 | py_test( 26 | name = "scikit_learn_model_converter_test", 27 | srcs = ["scikit_learn_model_converter_test.py"], 28 | deps = [ 29 | ":scikit_learn_model_converter", 30 | "@pypi_absl_py//:pkg", # /testing:parameterized 31 | "@pypi_numpy//:pkg", 32 | "@pypi_scikit_learn//:pkg", 33 | "@release_or_nightly//:tensorflow_pkg", 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/scikit_learn_model_converter/README.md: -------------------------------------------------------------------------------- 1 | # Scikit-Learn Converter 2 | 3 | ## Introduction 4 | 5 | **Scikit-Learn Model Converter** converts Scikit-Learn tree-based models (e.g. 6 | Random Forests) into TensorFlow models compatible with the whole TensorFlow 7 | ecosystem (e.g. Keras composition, SavedModel format, TF-Serving). 8 | 9 | The converted model is also a TensorFlow Decision Forests (TF-DF) model 10 | compatible with all of TF-DF functionalities (e.g. plotting, c++ inference API). 11 | 12 | ## Currently supported models 13 | 14 | * `sklearn.tree.DecisionTreeClassifier` 15 | * `sklearn.tree.DecisionTreeRegressor` 16 | * `sklearn.tree.ExtraTreeClassifier` 17 | * `sklearn.tree.ExtraTreeRegressor` 18 | * `sklearn.ensemble.RandomForestClassifier` 19 | * `sklearn.ensemble.RandomForestRegressor` 20 | * `sklearn.ensemble.ExtraTreesClassifier` 21 | * `sklearn.ensemble.ExtraTreesRegressor` 22 | * `sklearn.ensemble.GradientBoostingRegressor` 23 | 24 | ## Usage example 25 | 26 | ```python 27 | from sklearn import datasets 28 | from sklearn import tree 29 | import tensorflow as tf 30 | from tensorflow_decision_forests.contrib import scikit_learn_model_converter 31 | 32 | # Train your model in scikit-learn 33 | X, y = datasets.make_classification() 34 | sklearn_model = tree.DecisionTreeClassifier().fit(X, y) 35 | 36 | # Convert to tensorflow and predict 37 | tensorflow_model = scikit_learn_model_converter.convert(sklearn_model) 38 | y_pred = tensorflow_model.predict(tf.constant(X)) 39 | ``` 40 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/scikit_learn_model_converter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for converting scikit-learn tree-based models to TFDF models. 16 | 17 | This module converts Scikit-Learn tree-based models (e.g. 18 | Random Forests) into TensorFlow models compatible with the whole TensorFlow 19 | ecosystem (e.g. Keras composition, SavedModel format, TF-Serving). 20 | 21 | The converted model is also a TensorFlow Decision Forests (TF-DF) model 22 | compatible with all of TF-DF functionalities (e.g. plotting, c++ inference API). 23 | 24 | Example usage: 25 | 26 | ``` 27 | from sklearn import datasets 28 | from sklearn import tree 29 | import tensorflow as tf 30 | from tensorflow_decision_forests.contrib import scikit_learn_model_converter 31 | 32 | # Train your model in scikit-learn 33 | X, y = datasets.make_classification() 34 | sklearn_model = tree.DecisionTreeClassifier().fit(X, y) 35 | 36 | # Convert to tensorflow and predict 37 | tensorflow_model = scikit_learn_model_converter.convert(sklearn_model) 38 | y_pred = tensorflow_model.predict(tf.constant(X)) 39 | ``` 40 | 41 | """ 42 | 43 | from tensorflow_decision_forests.contrib.scikit_learn_model_converter import scikit_learn_model_converter as lib 44 | 45 | convert = lib.convert 46 | convert_sklearn_tree_to_tfdf_pytree = lib.convert_sklearn_tree_to_tfdf_pytree 47 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/training_preprocessing/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | # Libraries 7 | # ========= 8 | 9 | py_library( 10 | name = "training_preprocessing", 11 | srcs = [ 12 | "__init__.py", 13 | "training_preprocessing.py", 14 | ], 15 | deps = [ 16 | "//tensorflow_decision_forests", 17 | "@release_or_nightly//:tensorflow_pkg", 18 | ], 19 | ) 20 | 21 | # Tests 22 | # ===== 23 | 24 | py_test( 25 | name = "training_preprocessing_test", 26 | srcs = ["training_preprocessing_test.py"], 27 | deps = [ 28 | ":training_preprocessing", 29 | "//tensorflow_decision_forests", 30 | "@pypi_absl_py//:pkg", # /logging 31 | "@pypi_numpy//:pkg", 32 | "@release_or_nightly//:tensorflow_pkg", 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/contrib/training_preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for jointly preprocessing training data, labels and sample weights.""" 16 | 17 | from tensorflow_decision_forests.contrib.training_preprocessing import training_preprocessing 18 | 19 | multitask_label_presence_processing = ( 20 | training_preprocessing.multitask_label_presence_processing 21 | ) 22 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Decision Forest in a Keras Model. 16 | 17 | Usage example: 18 | 19 | ```python 20 | import tensorflow_decision_forests as tfdf 21 | import pandas as pd 22 | 23 | # Load the dataset in a Pandas dataframe. 24 | train_df = pd.read_csv("project/train.csv") 25 | test_df = pd.read_csv("project/test.csv") 26 | 27 | # Convert the dataset into a TensorFlow dataset. 28 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label") 29 | test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="my_label") 30 | 31 | # Train the model. 32 | model = tfdf.keras.RandomForestModel() 33 | model.fit(train_ds) 34 | 35 | # Evaluate the model on another dataset. 36 | model.evaluate(test_ds) 37 | 38 | # Show information about the model 39 | model.summary() 40 | 41 | # Export the model with the TF.SavedModel format. 42 | model.save("/path/to/my/model") 43 | 44 | # ... 45 | 46 | # Load a model: it loads as a generic keras model. 47 | loaded_model = tf_keras.models.load_model("/path/to/my/model") 48 | ``` 49 | """ 50 | 51 | from typing import Callable, List 52 | 53 | from tensorflow_decision_forests.keras import core 54 | from tensorflow_decision_forests.keras import wrappers 55 | 56 | # Utility classes 57 | CoreModel = core.CoreModel 58 | FeatureSemantic = core.FeatureSemantic 59 | Task = core.Task 60 | FeatureUsage = core.FeatureUsage 61 | AdvancedArguments = core.AdvancedArguments 62 | MultiTaskItem = core.MultiTaskItem 63 | Monotonic = core.Monotonic 64 | 65 | # Learning algorithm (called Models in Keras). 66 | 67 | 68 | class RandomForestModel(wrappers.RandomForestModel): 69 | pass 70 | 71 | 72 | class GradientBoostedTreesModel(wrappers.GradientBoostedTreesModel): 73 | pass 74 | 75 | 76 | class CartModel(wrappers.CartModel): 77 | pass 78 | 79 | 80 | class DistributedGradientBoostedTreesModel( 81 | wrappers.DistributedGradientBoostedTreesModel): 82 | pass 83 | 84 | 85 | def get_all_models() -> List[Callable[[], CoreModel]]: 86 | """Gets the lists of all the available models.""" 87 | return [ 88 | RandomForestModel, GradientBoostedTreesModel, CartModel, 89 | DistributedGradientBoostedTreesModel 90 | ] 91 | 92 | 93 | # Utilities 94 | pd_dataframe_to_tf_dataset = core.pd_dataframe_to_tf_dataset 95 | get_worker_idx_and_num_workers = core.get_worker_idx_and_num_workers 96 | build_default_input_model_signature = core.tf_core.build_default_input_model_signature 97 | build_default_feature_signature = core.tf_core.build_default_feature_signature 98 | set_training_logs_redirection = core.cc_logging.set_training_logs_redirection 99 | yggdrasil_model_to_keras_model = core.yggdrasil_model_to_keras_model 100 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/keras_internal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Access to Keras function with a different internal and external path.""" 16 | 17 | from tf_keras.src.engine import data_adapter as _data_adapter 18 | from tf_keras.src.models import Functional 19 | from tf_keras.layers import DenseFeatures 20 | from tf_keras.src.utils.dataset_creator import DatasetCreator 21 | 22 | 23 | unpack_x_y_sample_weight = _data_adapter.unpack_x_y_sample_weight 24 | get_data_handler = _data_adapter.get_data_handler 25 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/keras_internal_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tensorflow as tf 16 | from tensorflow_decision_forests.keras import keras_internal 17 | 18 | 19 | # Does nothing. Ensures keras_internal can be loaded. 20 | 21 | if __name__ == "__main__": 22 | tf.test.main() 23 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/test_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Loads and runs a SavedModel generated by :keras_test on the Adult dataset. 16 | 17 | This binary tests the model in a clean tensorflow instance (without pre-loaded 18 | functions and resources) and does not require non-canonical tf ops. 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import enum 25 | from typing import Dict 26 | 27 | from absl import app 28 | from absl import flags 29 | from absl import logging 30 | import numpy as np 31 | import pandas as pd 32 | import tensorflow as tf 33 | import tf_keras 34 | 35 | # Inject the inference ops only. 36 | from tensorflow_decision_forests.tensorflow.ops.inference import op # pylint: disable=unused-import 37 | 38 | flags.DEFINE_string("model_path", "", "Path to the SavedModel model to load.") 39 | 40 | flags.DEFINE_string("dataset_path", "", "Path to adult dataset.") 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | 45 | def main(argv): 46 | del argv 47 | 48 | logging.info("Load model") 49 | model = tf_keras.models.load_model(FLAGS.model_path) 50 | 51 | logging.info("Load dataset") 52 | pd_dataset = load_dataset() 53 | tf_dataset = dataframe_to_tf_dataset(pd_dataset, "income") 54 | 55 | logging.info("Evaluate model") 56 | evaluation = model.evaluate(tf_dataset) 57 | logging.info("evaluation: %s", evaluation) 58 | 59 | logging.info("Done (unload model)") 60 | 61 | 62 | # The following code is similar to the simple_ml_keras_test. 63 | 64 | 65 | class Semantic(enum.Enum): 66 | NUMERICAL = 1 67 | CATEGORICAL = 2 68 | 69 | 70 | def dataframe_to_tf_dataset(dataframe: pd.DataFrame, 71 | label: str) -> tf.data.Dataset: 72 | 73 | dataset = tf.data.Dataset.from_tensor_slices( 74 | (dict(dataframe.drop(label, axis=1)), dataframe[label].values)) 75 | return dataset.batch(64) 76 | 77 | 78 | def load_dataset() -> pd.DataFrame: 79 | """Loads the dataset in a pandas dataframe.""" 80 | 81 | dataset = pd.read_csv(flags.FLAGS.dataset_path) 82 | label = "income" 83 | 84 | def clean(ds): 85 | ds[label] = np.where(ds[label] == ">50K", 1, 0) 86 | return ds 87 | 88 | dataset = clean(dataset) 89 | return auto_dataset(dataset, label) 90 | 91 | 92 | def infer_semantic_from_dataframe(dataset: pd.DataFrame) -> Dict[str, Semantic]: 93 | """Infers the semantic of the columns in a pandas dataframe.""" 94 | 95 | semantics = {} 96 | for col in dataset.columns: 97 | dtype = dataset[col].dtype 98 | if dtype in [float, np.int16, np.int32, np.int64, int]: 99 | semantics[col] = Semantic.NUMERICAL 100 | elif dtype in [str, object]: 101 | semantics[col] = Semantic.CATEGORICAL 102 | else: 103 | raise Exception( 104 | "Cannot infer semantic for column \"{}\" with dtype={}".format( 105 | col, dtype)) 106 | 107 | return semantics 108 | 109 | 110 | def auto_dataset(dataset, label) -> pd.DataFrame: 111 | """Prepares a dataset object.""" 112 | 113 | semantics = infer_semantic_from_dataframe(dataset) 114 | 115 | def extract_by_semantic(semantic): 116 | return [k for k, s in semantics.items() if s == semantic and k != label] 117 | 118 | categorical_features = extract_by_semantic(Semantic.CATEGORICAL) 119 | numerical_features = extract_by_semantic(Semantic.NUMERICAL) 120 | 121 | def clean(dataset): 122 | for key in categorical_features: 123 | dataset[key] = dataset[key].fillna("") 124 | return dataset[categorical_features + numerical_features + [label]] 125 | 126 | return clean(dataset) 127 | 128 | 129 | if __name__ == "__main__": 130 | app.run(main) 131 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/wrapper/BUILD: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 2 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_test") 3 | 4 | package( 5 | default_visibility = ["//visibility:public"], 6 | licenses = ["notice"], 7 | ) 8 | 9 | exports_files(["wrapper_main.cc"]) 10 | 11 | # Bzl Libraries 12 | # ============= 13 | 14 | bzl_library( 15 | name = "wrapper_bzl", 16 | srcs = ["wrapper.bzl"], 17 | deps = [ 18 | "//tensorflow_decision_forests/tensorflow:utils_bzl", 19 | "@org_tensorflow//tensorflow:tensorflow_bzl", 20 | ], 21 | ) 22 | 23 | # Binaries 24 | # ======== 25 | 26 | # Libraries 27 | # ========= 28 | 29 | cc_library( 30 | name = "wrapper", 31 | srcs = ["wrapper.cc"], 32 | hdrs = ["wrapper.h"], 33 | deps = [ 34 | "@com_google_absl//absl/status:statusor", 35 | "@com_google_absl//absl/strings", 36 | "@ydf//yggdrasil_decision_forests/learner:learner_library", 37 | "@ydf//yggdrasil_decision_forests/utils:status_macros", 38 | ], 39 | ) 40 | 41 | # Tests 42 | # ===== 43 | 44 | tf_cc_test( 45 | name = "wrapper_test", 46 | srcs = ["wrapper_test.cc"], 47 | linkstatic = True, 48 | deps = [ 49 | ":wrapper", 50 | "@com_google_googletest//:gtest_main", 51 | "@ydf//yggdrasil_decision_forests/learner/gradient_boosted_trees", 52 | "@ydf//yggdrasil_decision_forests/learner/random_forest", 53 | ], 54 | ) 55 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/wrapper/wrapper.bzl: -------------------------------------------------------------------------------- 1 | """Rule generation utilities.""" 2 | 3 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows", "tf_binary_additional_srcs", "tf_cc_binary", "tf_copts") 4 | load("//tensorflow_decision_forests/tensorflow:utils.bzl", "rpath_linkopts_to_tensorflow") 5 | 6 | def py_wrap_yggdrasil_learners( 7 | name = None, 8 | learner_deps = []): 9 | """Creates Keras wrappers around Yggdrasil Decision Forest (YDF) learners. 10 | 11 | Creates a py_library called "{name}" and containing the file "{name}.py". 12 | This library introduces a TensorFlow Decision Forests (TFDF) Keras class 13 | wrapping for each YDF learner defined in "learner_deps". The constructor of 14 | these classes contains a argument for the learner generic hyper-parameter. 15 | 16 | For example, if "learner_deps" contains a c++ dependency that register a 17 | learner with a key equal to "RANDOM_FOREST", the wrapper will create a 18 | python class called "RandomForestModel" deriving the base TFDF model class. 19 | 20 | Args: 21 | name: Name of the rule. 22 | learner_deps: List of dependencies linking Yggdrasil Decision Forest 23 | learners. 24 | """ 25 | 26 | # Absolute path to the wrapper generator directory. 27 | wrapper_package = "//tensorflow_decision_forests/keras/wrapper" 28 | 29 | # Filename of the wrapper generator source code in the user package. 30 | local_cc_main = name + "_wrapper_main.cc" 31 | 32 | # Target name of the wrapper generator binary. 33 | wrapper_name = name + "_wrapper_main" 34 | 35 | # Target name of the command running the wrapper generator. 36 | run_wrapper_name = name + "_run_wrapper" 37 | 38 | # Copy the wrapper main source code to the user package. 39 | native.genrule( 40 | name = name + "_copy_cc_main", 41 | outs = [local_cc_main], 42 | srcs = [wrapper_package + ":wrapper_main.cc"], 43 | cmd = "cp $< $@", 44 | ) 45 | 46 | # Compiles the wrapper binary. 47 | # TODO: Find way to link from pypi. 48 | # Note: This rule will compile a small part of TF. 49 | tf_cc_binary( 50 | name = wrapper_name, 51 | copts = tf_copts(), 52 | linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + rpath_linkopts_to_tensorflow(wrapper_name), 53 | srcs = [":" + local_cc_main], 54 | deps = [ 55 | wrapper_package + ":wrapper", 56 | ] + learner_deps, 57 | linkstatic = 1, 58 | ) 59 | 60 | # Runs the wrapper binary and generate the wrapper .py source code. 61 | native.genrule( 62 | name = run_wrapper_name, 63 | srcs = [], 64 | outs = [name + ".py"], 65 | cmd = "$(location " + wrapper_name + ") > \"$@\"", 66 | tools = [":" + wrapper_name] + tf_binary_additional_srcs(), 67 | ) 68 | 69 | # Python library around the generated .py source code. 70 | native.py_library( 71 | name = name, 72 | srcs = [name + ".py"], 73 | srcs_version = "PY3", 74 | deps = [ 75 | "//tensorflow_decision_forests/keras:core", 76 | "//tensorflow_decision_forests/component/tuner", 77 | "@release_or_nightly//:tensorflow_pkg", 78 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto", 79 | "@ydf//yggdrasil_decision_forests/learner:abstract_learner_py_proto", 80 | ], 81 | data = [":" + run_wrapper_name, ":" + wrapper_name], 82 | ) 83 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/wrapper/wrapper.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Google LLC. 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * https://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #ifndef TENSORFLOW_DECISION_FORESTS_KERAS_WRAPPER_WRAPPER_H_ 17 | #define TENSORFLOW_DECISION_FORESTS_KERAS_WRAPPER_WRAPPER_H_ 18 | 19 | #include 20 | 21 | #include "absl/status/statusor.h" 22 | 23 | namespace tensorflow { 24 | namespace decision_forests { 25 | 26 | // Creates the python code source that contains wrapping classes for all linked 27 | // learners. 28 | absl::StatusOr GenKerasPythonWrapper(); 29 | 30 | // Returns Python class name from the learner key. 31 | std::string LearnerKeyToClassName(absl::string_view key); 32 | 33 | // Formats a text into python documentation. 34 | // 35 | // Do the following transformations: 36 | // - Wraps lines to "max_char_per_lines" characters (while carrying leading 37 | // space). 38 | // - Add leading spaces (leading_spaces_first_line and 39 | // leading_spaces_next_lines). 40 | // - Detect and format bullet lists. 41 | std::string FormatDocumentation(absl::string_view raw, 42 | int leading_spaces_first_line, 43 | int leading_spaces_next_lines, 44 | int max_char_per_lines = 80); 45 | 46 | // Gets the number of leading spaces of a string. 47 | int NumLeadingSpaces(absl::string_view text); 48 | 49 | } // namespace decision_forests 50 | } // namespace tensorflow 51 | 52 | #endif // TENSORFLOW_DECISION_FORESTS_KERAS_WRAPPER_WRAPPER_H_ 53 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/wrapper/wrapper_main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Google LLC. 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * https://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | // Outputs python source code defining python Keras class wrapper around 17 | // Yggdrasil Decision Forest learners. 18 | 19 | #include "tensorflow_decision_forests/keras/wrapper/wrapper.h" 20 | 21 | int main(int argc, char* argv[]) { 22 | std::cout << tensorflow::decision_forests::GenKerasPythonWrapper().value(); 23 | return 0; 24 | } 25 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/keras/wrapper/wrapper_test.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Google LLC. 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * https://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "tensorflow_decision_forests/keras/wrapper/wrapper.h" 17 | 18 | #include "gtest/gtest.h" 19 | 20 | namespace tensorflow { 21 | namespace decision_forests { 22 | namespace { 23 | 24 | TEST(KerasLearnerWrappers, LearnerKeyToClassName) { 25 | EXPECT_EQ(LearnerKeyToClassName("RANDOM_FOREST"), "RandomForestModel"); 26 | } 27 | 28 | TEST(KerasLearnerWrappers, Base) { 29 | std::cout << GenKerasPythonWrapper().value() << std::endl; 30 | } 31 | 32 | TEST(KerasLearnerWrappers, FormatDocumentation) { 33 | const auto formatted = FormatDocumentation(R"(AAA AAA AAA AAA AAA. 34 | AAA AAA AAA AAA. 35 | - AAA AAA AAA AAA. 36 | - AAA AAA AAA AAA. 37 | AAA AAA AAA AAA. 38 | AAA AAA AAA AAA.)", 39 | /*leading_spaces_first_line=*/4, 40 | /*leading_spaces_next_lines=*/6, 41 | /*max_char_per_lines=*/20); 42 | EXPECT_EQ(formatted, R"( AAA AAA AAA AAA 43 | AAA. 44 | AAA AAA AAA 45 | AAA. 46 | - AAA AAA AAA 47 | AAA. 48 | - AAA AAA AAA 49 | AAA. 50 | AAA AAA AAA 51 | AAA. 52 | AAA AAA 53 | AAA AAA. 54 | )"); 55 | } 56 | 57 | TEST(KerasLearnerWrappers, NumLeadingSpaces) { 58 | EXPECT_EQ(NumLeadingSpaces(""), 0); 59 | EXPECT_EQ(NumLeadingSpaces(" "), 1); 60 | EXPECT_EQ(NumLeadingSpaces(" "), 2); 61 | EXPECT_EQ(NumLeadingSpaces(" HELLO "), 2); 62 | } 63 | 64 | } // namespace 65 | } // namespace decision_forests 66 | } // namespace tensorflow 67 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/check_version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Check that version of TensorFlow is compatible with TF-DF.""" 16 | 17 | import logging 18 | import tensorflow as tf 19 | 20 | 21 | def check_version(tf_df_version, 22 | compatible_tf_versions, 23 | tf_version=None, 24 | external_logic=False): 25 | """Checks the compatibility of the TF version. 26 | 27 | Prints a warning message and return False in care of likely incompatible 28 | versions. 29 | """ 30 | 31 | if not external_logic: 32 | pass 33 | 34 | if tf_version is None: 35 | tf_version = tf.__version__ 36 | if tf_version not in compatible_tf_versions: 37 | logging.warning( 38 | "TensorFlow Decision Forests %s is compatible with the following " 39 | "TensorFlow Versions: %s. However, TensorFlow %s was detected. " 40 | "This can cause issues with the TF API and symbols in the custom C++ " 41 | "ops. See the TF and TF-DF compatibility table at " 42 | "https://github.com/tensorflow/decision-forests/blob/main/documentation/known_issues.md#compatibility-table.", 43 | tf_df_version, compatible_tf_versions, tf_version) 44 | return False 45 | return True 46 | 47 | 48 | def info_fail_to_load_custom_op(exception, filename): 49 | logging.warning( 50 | "Failure to load the %s custom c++ tensorflow ops. " 51 | "This error is likely caused the version of TensorFlow and " 52 | "TensorFlow Decision Forests are not compatible. Full error:" 53 | "%s", filename, exception) 54 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/check_version_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tensorflow as tf 16 | 17 | from tensorflow_decision_forests.tensorflow import check_version 18 | 19 | 20 | class CheckVersionTest(tf.test.TestCase): 21 | 22 | def test_base(self): 23 | 24 | tf_df_version = "1.2.3" # Does not matter. 25 | self.assertTrue( 26 | check_version.check_version( 27 | tf_df_version, ["2.6.0", "2.6.1"], "2.6.0", external_logic=True)) 28 | self.assertFalse( 29 | check_version.check_version( 30 | tf_df_version, ["2.6.0", "2.6.1"], 31 | "2.8.0-dev20211105", 32 | external_logic=True)) 33 | self.assertFalse( 34 | check_version.check_version( 35 | tf_df_version, ["2.6.0", "2.6.1"], "2.7.0-rc1", 36 | external_logic=True)) 37 | 38 | 39 | if __name__ == "__main__": 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/distribute/BUILD: -------------------------------------------------------------------------------- 1 | # Implementation of the Yggdrasil Distribute API using TensorFlow Distribution Strategies. 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | licenses = ["notice"], 6 | ) 7 | 8 | # Worker binaries 9 | # =============== 10 | 11 | py_binary( 12 | name = "tensorflow_std_server_py", 13 | srcs = ["tensorflow_std_server.py"], 14 | main = "tensorflow_std_server.py", 15 | deps = [ 16 | "//tensorflow_decision_forests", 17 | "@pypi_absl_py//:pkg", # :app 18 | "@release_or_nightly//:tensorflow_pkg", 19 | ], 20 | ) 21 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/distribute/tensorflow_std_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Python worker for ParameterServerStrategy. 16 | 17 | This binary is a distributed training worker. When using distributed training, 18 | this worker (or the c++ worker defined in the same BUILD file) should run on 19 | worker machines. The configuration of the worker and chief is done 20 | though environement variable. See examples/distributed_training.py 21 | 22 | When possible, use the c++ worker instead. 23 | """ 24 | 25 | from absl import app 26 | from absl import flags 27 | from absl import logging 28 | 29 | import tensorflow as tf 30 | import tensorflow_decision_forests as tfdf # pylint: disable=unused-import 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("job_name", "worker", "") 35 | flags.DEFINE_string("protocol", "grpc", "") 36 | flags.DEFINE_integer("task_index", 0, "") 37 | 38 | 39 | def main(argv): 40 | if len(argv) > 1: 41 | raise app.UsageError("Too many command-line arguments.") 42 | 43 | logging.info("Starting worker") 44 | cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 45 | server = tf.distribute.Server( 46 | cluster_resolver.cluster_spec(), 47 | protocol=FLAGS.protocol, 48 | job_name=FLAGS.job_name, 49 | task_index=FLAGS.task_index) 50 | 51 | logging.info("Worker is running") 52 | server.join() 53 | 54 | logging.info("Shutting down worker") 55 | 56 | 57 | if __name__ == "__main__": 58 | app.run(main) 59 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/inference/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/inference/op.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_decision_forests.tensorflow.ops.inference.op_dynamic import * 16 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/inference/op_dynamic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_decision_forests.tensorflow import check_version 16 | import tensorflow as tf 17 | from tensorflow.python.platform import resource_loader 18 | import sys 19 | 20 | try: 21 | ops = tf.load_op_library(resource_loader.get_path_to_datafile("inference.so")) 22 | except Exception as e: 23 | check_version.info_fail_to_load_custom_op(e, "inference.so") 24 | raise e 25 | 26 | # Importing all the symbols. 27 | module = sys.modules[__name__] 28 | for name, value in ops.__dict__.items(): 29 | if "__" in name: 30 | continue 31 | setattr(module, name, value) 32 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/training/kernel.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Google LLC. 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * https://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #ifndef TENSORFLOW_DECISION_FORESTS_TENSORFLOW_OPS_TRAINING_TRAINING_H_ 17 | #define TENSORFLOW_DECISION_FORESTS_TENSORFLOW_OPS_TRAINING_TRAINING_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "absl/status/status.h" 26 | #include "absl/status/statusor.h" 27 | #include "absl/strings/string_view.h" 28 | #include "tensorflow/core/framework/op_kernel.h" 29 | #include "tensorflow/core/framework/resource_mgr.h" 30 | #include "yggdrasil_decision_forests/model/abstract_model.h" 31 | 32 | namespace tensorflow_decision_forests { 33 | namespace ops { 34 | 35 | // Status of long running operations. 36 | enum LongRunningProcessStatus { 37 | kInProgress = 0, 38 | kSuccess = 1, 39 | }; 40 | 41 | // Starts a long running process. Returns the id of the process. 42 | absl::StatusOr StartLongRunningProcess( 43 | ::tensorflow::OpKernelContext* ctx, std::function&& call); 44 | 45 | // Checks the status of a long running process. If the returned status of a 46 | // process is success or failure, the status of the process should not be 47 | // queried again. 48 | absl::StatusOr GetLongRunningProcessStatus( 49 | ::tensorflow::OpKernelContext* ctx, int32_t process_id); 50 | 51 | // If TFDF_STOP_TRAINING_ON_INTERRUPT is set, model training is interrupted with 52 | // the "set_stop_training_trigger()" API when the process receives a "program 53 | // interrupt" signal (i.e. SIGINT). 54 | // 55 | // An interrupted model is a valid model. But its quality is likely inferior to 56 | // a fully trained model. See "set_stop_training_trigger" for details about the 57 | // effect of the interruption. 58 | // 59 | #ifdef TFDF_STOP_TRAINING_ON_INTERRUPT 60 | 61 | // The logic in this namespace is used to interrupt the training of a model (by 62 | // setting "stop_training_trigger=True") when receiving an interruption (e.g. 63 | // the user pressed ctrl+c). 64 | namespace interruption { 65 | // Should the current training learners be stopped? 66 | inline std::atomic stop_training; 67 | 68 | // The interruption signal handler to restore when all the learners are done 69 | // training. 70 | inline void (*previous_signal_handler)(int); 71 | 72 | // Number of learners training. 73 | inline std::atomic active_learners{0}; 74 | 75 | inline void StopTrainingSignalHandler(int signal) { stop_training = true; } 76 | 77 | // Enables the interruption listener. 78 | absl::Status EnableUserInterruption(); 79 | 80 | // Disable the interruption listener 81 | absl::Status DisableUserInterruption(); 82 | 83 | } // namespace interruption 84 | #endif 85 | 86 | // A non-compiled Yggdrasil model. 87 | class YggdrasilModelContainer : public tensorflow::ResourceBase { 88 | public: 89 | ~YggdrasilModelContainer() override = default; 90 | 91 | std::string DebugString() const override { return "YggdrasilModelContainer"; } 92 | 93 | absl::Status LoadModel(absl::string_view model_path); 94 | 95 | tensorflow::int64 MemoryUsed() const override { 96 | return approximate_model_size_in_memory_; 97 | } 98 | 99 | std::unique_ptr* 100 | mutable_model() { 101 | return &model_; 102 | } 103 | 104 | const yggdrasil_decision_forests::model::AbstractModel& model() { 105 | return *model_; 106 | } 107 | 108 | int num_label_classes() const { return num_label_classes_; } 109 | 110 | const std::vector& output_class_representation() const { 111 | return output_class_representation_; 112 | } 113 | 114 | private: 115 | // The model. 116 | std::unique_ptr model_; 117 | 118 | // Number of output classes. This information is contained in the model, but 119 | // cached for fast access. 120 | int num_label_classes_ = -1; 121 | 122 | // String representation of the output classes. 123 | std::vector output_class_representation_; 124 | 125 | // Approximation of the model size in memory. 126 | int64_t approximate_model_size_in_memory_ = 0; 127 | }; 128 | 129 | } // namespace ops 130 | } // namespace tensorflow_decision_forests 131 | 132 | #endif // THIRD_PARTY_TENSORFLOW_DECISION_FORESTS_TENSORFLOW_OPS_TRAINING_H_ 133 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/training/op.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_decision_forests.tensorflow.ops.training.op_dynamic import * 16 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/training/op_dynamic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_decision_forests.tensorflow import check_version 16 | import tensorflow as tf 17 | from tensorflow.python.platform import resource_loader 18 | import sys 19 | 20 | try: 21 | ops = tf.load_op_library(resource_loader.get_path_to_datafile("training.so")) 22 | except Exception as e: 23 | check_version.info_fail_to_load_custom_op(e, "training.so") 24 | raise e 25 | 26 | # Importing all the symbols. 27 | module = sys.modules[__name__] 28 | for name, value in ops.__dict__.items(): 29 | if "__" in name: 30 | continue 31 | setattr(module, name, value) 32 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/training/op_grpc_worker.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Google LLC. 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * https://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "tensorflow/core/framework/common_shape_fns.h" 17 | #include "tensorflow/core/framework/op.h" 18 | #include "tensorflow/core/framework/shape_inference.h" 19 | 20 | namespace tensorflow { 21 | 22 | // Creates a YDF GRPC Worker server. If a server with the same key already 23 | // exists, this is a no-op. Returns the port of the server. 24 | // 25 | // Args: 26 | // "key": Key of the server. Only one server with a given key can exist in a 27 | // session. 28 | // "force_ydf_port": Port for YDF to use. The chief and the workers should 29 | // be able to communicate thought this port. If -1, an available port 30 | // is automatically selected. 31 | // 32 | // Output: 33 | // port: Port of the GRPC server. If force_ydf_port is set, returns 34 | // "force_ydf_port". 35 | REGISTER_OP("SimpleMLCreateYDFGRPCWorker") 36 | .SetIsStateful() 37 | .Attr("key: int") 38 | .Attr("force_ydf_port: int = -1") 39 | .Output("port: int32") 40 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 41 | c->set_output(0, c->Scalar()); 42 | return absl::OkStatus(); 43 | }); 44 | 45 | // Indicates to the GRPC manager that the address of a worker have changed. 46 | // This function takes change to update the other workers for inter worker 47 | // communication. 48 | // 49 | // Args: 50 | // "key": Key of the server. 51 | // "worker_idx": Index of the worker to modify. 52 | // "new_address": New address of the worker. 53 | // 54 | // Output: 55 | // port: Port of the GRPC server. 56 | REGISTER_OP("SimpleMLUpdateGRPCWorkerAddress") 57 | .SetIsStateful() 58 | .Attr("key: int") 59 | .Input("worker_idx: int32") 60 | .Input("new_address: string") 61 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 62 | c->set_input_handle_shapes_and_types( 63 | 0, {shape_inference::ShapeAndType(c->Scalar(), DataType::DT_INT32)}); 64 | c->set_input_handle_shapes_and_types( 65 | 1, {shape_inference::ShapeAndType(c->Scalar(), DataType::DT_STRING)}); 66 | return absl::OkStatus(); 67 | }); 68 | 69 | // Stop any running YDF GRPC Worker server. 70 | // 71 | // Args: 72 | // "key": Key of the server. 73 | // 74 | REGISTER_OP("SimpleMLStopYDFGRPCWorker") 75 | .SetIsStateful() 76 | .Attr("key: int") 77 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 78 | return absl::OkStatus(); 79 | }); 80 | 81 | } // namespace tensorflow 82 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/ops/training/op_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for op.""" 16 | 17 | import logging 18 | 19 | import tensorflow as tf 20 | 21 | from tensorflow_decision_forests.tensorflow.ops.training import op 22 | 23 | 24 | class OpTest(tf.test.TestCase): 25 | 26 | def test_grpc_workers(self): 27 | port = op.SimpleMLCreateYDFGRPCWorker(key=1) 28 | logging.info("port: %d", port) 29 | 30 | port_again = op.SimpleMLCreateYDFGRPCWorker(key=1) 31 | logging.info("port_again: %d", port_again) 32 | 33 | self.assertEqual(port, port_again) 34 | 35 | port_other_server = op.SimpleMLCreateYDFGRPCWorker(key=2) 36 | logging.info("port_other_server: %d", port_other_server) 37 | 38 | self.assertNotEqual(port, port_other_server) 39 | 40 | op.SimpleMLStopYDFGRPCWorker(key=1) 41 | op.SimpleMLStopYDFGRPCWorker(key=2) 42 | 43 | 44 | if __name__ == "__main__": 45 | tf.test.main() 46 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/tf1_compatibility.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tensorflow 1 compatibility utilities. 16 | 17 | Replaces some TF2 libraries with their TF2 counterpart where necessary. 18 | """ 19 | 20 | import tensorflow as tf 21 | 22 | if hasattr(tf, '__internal__'): 23 | Trackable = tf.__internal__.tracking.Trackable 24 | AutoTrackable = tf.__internal__.tracking.AutoTrackable 25 | TrackableResource = tf.saved_model.experimental.TrackableResource 26 | no_automatic_dependency_tracking = tf.__internal__.tracking.no_automatic_dependency_tracking 27 | else: 28 | # pylint: disable=g-direct-tensorflow-import, disable=g-import-not-at-top 29 | from tensorflow.python.trackable import autotrackable 30 | from tensorflow.python.trackable import base as trackable_base 31 | from tensorflow.python.trackable import resource 32 | from tensorflow.python.trackable import base as base_tracking 33 | # pylint: enable=g-direct-tensorflow-import, g-import-not-at-top 34 | 35 | Trackable = trackable_base.Trackable 36 | AutoTrackable = autotrackable.AutoTrackable 37 | TrackableResource = resource.TrackableResource 38 | no_automatic_dependency_tracking = base_tracking.no_automatic_dependency_tracking 39 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/tf_logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Logging primitives. 16 | 17 | Replacement of absl's logging primitives that are always visible to the user. 18 | """ 19 | 20 | 21 | from typing import Any, List 22 | from absl import logging 23 | 24 | 25 | def info(msg: str, *args: List[Any]) -> None: 26 | """Print an info message visible to the user. 27 | 28 | To use instead of absl.logging.info (to be visible in colabs). 29 | 30 | Usage example: 31 | logging_info("Hello %s", "world") 32 | 33 | Args: 34 | msg: String message with replacement placeholders e.g. %s. 35 | *args: Placeholder replacement values. 36 | """ 37 | 38 | print(msg % args, flush=True) 39 | logging.info(msg, *args) 40 | 41 | 42 | def warning(msg: str, *args: List[Any]) -> None: 43 | """Print a warning message visible to the user. 44 | 45 | To use instead of absl.logging.info (to be visible in colabs). 46 | 47 | Usage example: 48 | logging_warning("Hello %s", "world") 49 | 50 | Args: 51 | msg: String message with replacement placeholders e.g. %s. 52 | *args: Placeholder replacement values. 53 | """ 54 | 55 | print("Warning:", msg % args, flush=True) 56 | logging.warning(msg, *args) 57 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow/utils.bzl: -------------------------------------------------------------------------------- 1 | """Utilities for the compilation of tensorflow code.""" 2 | 3 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "check_deps", "tf_binary_additional_srcs", "tf_cc_shared_object", "tf_copts", "tf_custom_op_library_additional_deps") 4 | 5 | def _make_search_paths(prefix, search_level): 6 | return "-rpath,%s/%s/external/org_tensorflow/tensorflow" % (prefix, "/".join([".."] * search_level)) 7 | 8 | def rpath_linkopts_to_tensorflow(name): 9 | """Create a rpath linkopts flag to include tensorflow .so's directory.""" 10 | 11 | levels_to_root = native.package_name().count("/") + name.count("/") 12 | return select({ 13 | "@org_tensorflow//tensorflow:macos": [ 14 | "-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),), 15 | "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text", 16 | ], 17 | "@org_tensorflow//tensorflow:windows": [], 18 | "//conditions:default": [ 19 | "-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root + 1),), 20 | ], 21 | }) 22 | 23 | def tf_custom_op_library_external(name, srcs = [], deps = [], linkopts = [], copts = [], **kwargs): 24 | """Helper to build a dynamic library (.so) from the sources containing implementations of custom ops and kernels. 25 | 26 | Similar to "tf_custom_op_library" in tensorflow, but also work for external libraries and windows. 27 | """ 28 | 29 | # Rely on the TF in "pywrap_tensorflow_import_lib". 30 | deps = deps + tf_custom_op_library_additional_deps() 31 | 32 | check_deps( 33 | name = name + "_check_deps", 34 | disallowed_deps = [ 35 | "@org_tensorflow//tensorflow/core:framework", 36 | "@org_tensorflow//tensorflow/core:lib", 37 | ], 38 | deps = deps, 39 | ) 40 | 41 | tf_cc_shared_object( 42 | name = name, 43 | srcs = srcs, 44 | deps = deps, 45 | framework_so = tf_binary_additional_srcs() + [ 46 | # Rely on the TF in "tensorflow.dll". 47 | #"@org_tensorflow//tensorflow:tensorflow_dll_import_lib", 48 | # Rely on the TF in "tensorflow_cc.dll". 49 | #"@org_tensorflow//tensorflow:tensorflow_cc_dll_import_lib", 50 | ], 51 | copts = copts + tf_copts(is_external = True), 52 | linkopts = linkopts + select({ 53 | "//conditions:default": ["-lm"], 54 | "@org_tensorflow//tensorflow:windows": [], 55 | "@org_tensorflow//tensorflow:macos": [], 56 | }), 57 | **kwargs 58 | ) 59 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/tensorflow_decision_forests.bzl: -------------------------------------------------------------------------------- 1 | """Helpers fo building with Hermetic Python.""" 2 | 3 | def _py_deps_profile_impl(ctx): 4 | is_switch = False 5 | for var_name, var_val in ctx.attr.switch.items(): 6 | is_switch = is_switch or ctx.os.environ.get(var_name, "") == var_val 7 | 8 | prefix = ctx.attr.pip_repo_name 9 | reqirements_name = ctx.attr.requirements_in.name 10 | requirements_in_substitutions = {} 11 | build_content = ['exports_files(["{}"])'.format(reqirements_name)] 12 | for k, v in ctx.attr.deps_map.items(): 13 | repo_name = v[0] if is_switch else k 14 | requirements_in_substitutions[k + "\n"] = repo_name + "\n" 15 | requirements_in_substitutions[k + "\r\n"] = repo_name + "\r\n" 16 | aliased_targets = ["pkg"] + v[1:] 17 | norm_repo_name = repo_name.replace("-", "_") 18 | norm_alas_name = k.replace("-", "_") 19 | for target in aliased_targets: 20 | alias_name = "{}_{}".format(norm_alas_name, target) 21 | alias_value = "@{}_{}//:{}".format(prefix, norm_repo_name, target) 22 | build_content.append(""" 23 | alias( 24 | name = "{}", 25 | actual = "{}", 26 | visibility = ["//visibility:public"] 27 | ) 28 | """.format(alias_name, alias_value)) 29 | 30 | ctx.file("BUILD", "".join(build_content)) 31 | ctx.template( 32 | reqirements_name, 33 | ctx.attr.requirements_in, 34 | executable = False, 35 | substitutions = requirements_in_substitutions, 36 | ) 37 | 38 | py_deps_profile = repository_rule( 39 | implementation = _py_deps_profile_impl, 40 | attrs = { 41 | "requirements_in": attr.label(mandatory = True), 42 | "deps_map": attr.string_list_dict(mandatory = True), 43 | "pip_repo_name": attr.string(mandatory = True), 44 | "switch": attr.string_dict(mandatory = True), 45 | }, 46 | local = True, 47 | ) 48 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | 6 | filegroup( 7 | name = "test_data", 8 | srcs = glob(["**"]), 9 | ) 10 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/README.md: -------------------------------------------------------------------------------- 1 | # Data for the Unit Tests 2 | 3 | This directory only contains small files used for the unit tests. 4 | 5 | This directory structure is organized as follows: - model: Pre-trained models. - 6 | saved_model_adult_rf: TF SavedModel containing a TF-DF GBT model trained with 7 | Keras. 8 | -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/data_spec.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/data_spec.pb -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/done: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/done -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/header.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/header.pb -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/nodes-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/nodes-00000-of-00001 -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/random_forest_header.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/random_forest_header.pb -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/keras_metadata.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/keras_metadata.pb -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/saved_model.pb -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /tensorflow_decision_forests/test_data/model/saved_model_adult_rf/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/variables/variables.index -------------------------------------------------------------------------------- /tensorflow_decision_forests/tools/run_e2e_tfdf_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | 18 | # Converts a non-submitted CL to a standalone Bazel project in a local 19 | # directory, compile the project and run the tests. 20 | # 21 | # Usage example: 22 | # third_party/tensorflow_decision_forests/tools/run_e2e_tfdf_test.sh 23 | 24 | set -vex 25 | 26 | LOCAL_DIR="/usr/local/google/home/${USER}/git/decision-forests" 27 | 28 | CL=$(hg exportedcl) 29 | echo "Current CL: ${CL}" 30 | echo "Make sure the CL is synced!" 31 | 32 | function export_project() { 33 | COPYBARA="/google/bin/releases/copybara/public/copybara/copybara" 34 | 35 | # Test the copy bara configuration. 36 | bazel test third_party/tensorflow_decision_forests:copybara_test 37 | 38 | echo "Export a Bazel project locally" 39 | echo "==============================" 40 | 41 | rm -fr ${LOCAL_DIR} 42 | ${COPYBARA} third_party/tensorflow_decision_forests/copy.bara.sky presubmit_piper_to_gerrit ${CL} \ 43 | --dry-run --init-history --squash --force \ 44 | --git-destination-path ${LOCAL_DIR} --ignore-noop 45 | 46 | /google/bin/releases/opensource/thirdparty/cross/cross ${LOCAL_DIR} 47 | } 48 | 49 | echo "Test the project" 50 | echo "================" 51 | 52 | run_all() { 53 | cd ${LOCAL_DIR} 54 | 55 | # Start the Docker 56 | sudo ./tools/start_compile_docker.sh /bin/bash 57 | 58 | # In the docker, you can now trigger the builder with the following line in 59 | # the docker: 60 | # RUN_TESTS=1 PY_VERSION=3.9 TF_VERSION=2.16.1 ./tools/test_bazel.sh 61 | 62 | # Alternatively, you can trigger the build directly with: 63 | # sudo ./tools/start_compile_docker.sh "RUN_TESTS=1 PY_VERSION=3.8 TF_VERSION=2.10.0 ./tools/test_bazel.sh && chmod -R a+xrw . && /bin/bash" 64 | } 65 | 66 | export_project 67 | run_all 68 | -------------------------------------------------------------------------------- /third_party/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/absl/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/absl/workspace.bzl: -------------------------------------------------------------------------------- 1 | """Absl project.""" 2 | 3 | def deps(): 4 | # We use the absl injected by the tensorflow dependency. 5 | # Note: TensorFlow relies on possibly old versions of absl. 6 | pass 7 | -------------------------------------------------------------------------------- /third_party/absl_py/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/absl_py/workspace.bzl: -------------------------------------------------------------------------------- 1 | """Absl project.""" 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | def deps(): 6 | http_archive( 7 | name = "com_google_absl_py", 8 | urls = ["https://github.com/abseil/abseil-py/archive/master.zip"], 9 | strip_prefix = "abseil-py-master", 10 | ) 11 | -------------------------------------------------------------------------------- /third_party/benchmark/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/benchmark/workspace.bzl: -------------------------------------------------------------------------------- 1 | """Benchmark project.""" 2 | 3 | def deps(): 4 | # We use the benchmark injected by the tensorflow dependency. 5 | pass 6 | -------------------------------------------------------------------------------- /third_party/gtest/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/gtest/workspace.bzl: -------------------------------------------------------------------------------- 1 | """Google Test project.""" 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | def deps(): 6 | http_archive( 7 | name = "com_google_googletest", 8 | urls = ["https://github.com/google/googletest/archive/011959aafddcd30611003de96cfd8d7a7685c700.zip"], 9 | strip_prefix = "googletest-011959aafddcd30611003de96cfd8d7a7685c700", 10 | sha256 = "6a5d7d63cd6e0ad2a7130471105a3b83799a7a2b14ef7ec8d742b54f01a4833c", 11 | ) 12 | -------------------------------------------------------------------------------- /third_party/protobuf/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/protobuf/workspace.bzl: -------------------------------------------------------------------------------- 1 | """Protobuf project.""" 2 | 3 | # load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | def deps(): 6 | # We use the protobuf linked in tensorflow. 7 | pass 8 | -------------------------------------------------------------------------------- /third_party/tensorflow/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/tensorflow/tf-216.patch: -------------------------------------------------------------------------------- 1 | diff --git a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl 2 | index 9261a652f9c..0557e345ab1 100644 3 | --- a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl 4 | +++ b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl 5 | @@ -2,7 +2,7 @@ 6 | 7 | load("//tensorflow/tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure") 8 | load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") 9 | -load("//third_party/py:python_configure.bzl", "remote_python_configure") 10 | +load("//third_party/py/non_hermetic:python_configure.bzl", "remote_python_configure") 11 | 12 | def ml2014_tf_aarch64_configs(name_container_map, env): 13 | for name, container in name_container_map.items(): 14 | diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl 15 | index 9f71a414bf7..57f70752323 100644 16 | --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl 17 | +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl 18 | @@ -1,6 +1,6 @@ 19 | """Macro that creates external repositories for remote config.""" 20 | 21 | -load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") 22 | +load("//third_party/py/non_hermetic:python_configure.bzl", "local_python_configure", "remote_python_configure") 23 | load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") 24 | load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") 25 | load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") 26 | diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl 27 | index 056df85ffdb..7422baf8c59 100644 28 | --- a/tensorflow/workspace2.bzl 29 | +++ b/tensorflow/workspace2.bzl 30 | @@ -37,7 +37,7 @@ load("//third_party/nasm:workspace.bzl", nasm = "repo") 31 | load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") 32 | load("//third_party/opencl_headers:workspace.bzl", opencl_headers = "repo") 33 | load("//third_party/pasta:workspace.bzl", pasta = "repo") 34 | -load("//third_party/py:python_configure.bzl", "python_configure") 35 | +load("//third_party/py/non_hermetic:python_configure.bzl", "python_configure") 36 | load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo") 37 | load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo") 38 | load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo") 39 | diff --git a/third_party/py/non_hermetic/python_configure.bzl b/third_party/py/non_hermetic/python_configure.bzl 40 | index 89732c3e33d..4ac1c8f5c04 100644 41 | --- a/third_party/py/non_hermetic/python_configure.bzl 42 | +++ b/third_party/py/non_hermetic/python_configure.bzl 43 | @@ -203,7 +203,7 @@ def _create_local_python_repository(repository_ctx): 44 | # Resolve all labels before doing any real work. Resolving causes the 45 | # function to be restarted with all previous state being lost. This 46 | # can easily lead to a O(n^2) runtime in the number of labels. 47 | - build_tpl = repository_ctx.path(Label("//third_party/py:BUILD.tpl")) 48 | + build_tpl = repository_ctx.path(Label("//third_party/py/non_hermetic:BUILD.tpl")) 49 | 50 | python_bin = get_python_bin(repository_ctx) 51 | _check_python_bin(repository_ctx, python_bin) 52 | -------------------------------------------------------------------------------- /third_party/tensorflow/tf.patch: -------------------------------------------------------------------------------- 1 | diff --git a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl 2 | index a2bdd6a7eed..ec25c23d8d4 100644 3 | --- a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl 4 | +++ b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl 5 | @@ -2,7 +2,7 @@ 6 | 7 | load("//tensorflow/tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure") 8 | load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") 9 | -load("//third_party/py:python_configure.bzl", "remote_python_configure") 10 | +load("//third_party/py/non_hermetic:python_configure.bzl", "remote_python_configure") 11 | 12 | def ml2014_tf_aarch64_configs(name_container_map, env): 13 | for name, container in name_container_map.items(): 14 | diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl 15 | index 9f71a414bf7..57f70752323 100644 16 | --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl 17 | +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl 18 | @@ -1,6 +1,6 @@ 19 | """Macro that creates external repositories for remote config.""" 20 | 21 | -load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") 22 | +load("//third_party/py/non_hermetic:python_configure.bzl", "local_python_configure", "remote_python_configure") 23 | load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") 24 | load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") 25 | load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") 26 | diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl 27 | index 953e1d1bea6..664608592a5 100644 28 | --- a/tensorflow/workspace2.bzl 29 | +++ b/tensorflow/workspace2.bzl 30 | @@ -8,7 +8,7 @@ load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") 31 | load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") 32 | load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") 33 | load("//third_party/git:git_configure.bzl", "git_configure") 34 | -load("//third_party/py:python_configure.bzl", "python_configure") 35 | +load("//third_party/py/non_hermetic:python_configure.bzl", "python_configure") 36 | load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure") 37 | load("//tensorflow/tools/toolchains:cpus/aarch64/aarch64_compiler_configure.bzl", "aarch64_compiler_configure") 38 | load("//tensorflow/tools/toolchains:cpus/arm/arm_compiler_configure.bzl", "arm_compiler_configure") 39 | diff --git a/third_party/py/non_hermetic/python_configure.bzl b/third_party/py/non_hermetic/python_configure.bzl 40 | index 300cbfb6c71..09d98505dd9 100644 41 | --- a/third_party/py/non_hermetic/python_configure.bzl 42 | +++ b/third_party/py/non_hermetic/python_configure.bzl 43 | @@ -206,7 +206,7 @@ def _create_local_python_repository(repository_ctx): 44 | # Resolve all labels before doing any real work. Resolving causes the 45 | # function to be restarted with all previous state being lost. This 46 | # can easily lead to a O(n^2) runtime in the number of labels. 47 | - build_tpl = repository_ctx.path(Label("//third_party/py:BUILD.tpl")) 48 | + build_tpl = repository_ctx.path(Label("//third_party/py/non_hermetic:BUILD.tpl")) 49 | 50 | python_bin = get_python_bin(repository_ctx) 51 | _check_python_bin(repository_ctx, python_bin) 52 | -------------------------------------------------------------------------------- /third_party/tensorflow/workspace.bzl: -------------------------------------------------------------------------------- 1 | """TensorFlow project.""" 2 | 3 | def deps(): 4 | # TensorFlow dependencies are injected directly in the WORKSPACE. 5 | pass 6 | -------------------------------------------------------------------------------- /third_party/tensorflow_pypi/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/tensorflow_pypi/BUILD.tpl: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "tf_header_lib", 5 | hdrs = [":tf_header_include"], 6 | includes = ["include"], 7 | visibility = ["//visibility:public"], 8 | ) 9 | 10 | cc_library( 11 | name = "libtensorflow_framework", 12 | srcs = [":libtensorflow_framework.so"], 13 | #data = ["lib/libtensorflow_framework.so"], 14 | visibility = ["//visibility:public"], 15 | ) 16 | 17 | %{TF_HEADER_GENRULE} 18 | %{TF_SHARED_LIBRARY_GENRULE} -------------------------------------------------------------------------------- /third_party/tensorflow_pypi/README: -------------------------------------------------------------------------------- 1 | TensorFlow injected through the .h, .so and .dll available in the TensorFlow 2 | PyPi package. 3 | 4 | This directory contains a copy of: 5 | https://github.com/tensorflow/custom-op/tree/master/tf 6 | 7 | With the following edits: 8 | "//tf:" -> "//third_party/tensorflow_pypi:" 9 | -------------------------------------------------------------------------------- /third_party/tensorflow_pypi/tf_configure.bzl: -------------------------------------------------------------------------------- 1 | """Setup TensorFlow as external dependency""" 2 | 3 | load("@python//:defs.bzl", "interpreter") 4 | load("@python_version_repo//:py_version.bzl", "REQUIREMENTS_WITH_LOCAL_WHEELS") 5 | load("@rules_python//python:pip.bzl", "package_annotation", "pip_parse") 6 | 7 | def tf_configure(): 8 | tensorflow_annotation = """ 9 | cc_library( 10 | name = "tf_header_lib", 11 | hdrs = glob(["site-packages/tensorflow/include/**/*"]), 12 | strip_include_prefix="site-packages/tensorflow/include/", 13 | visibility = ["//visibility:public"], 14 | ) 15 | cc_library( 16 | name = "libtensorflow_framework", 17 | srcs = select({ 18 | "//conditions:default": ["site-packages/tensorflow/libtensorflow_framework.so.2"], 19 | "@bazel_tools//src/conditions:darwin":["site-packages/tensorflow/libtensorflow_framework.2.dylib"], 20 | "@bazel_tools//src/conditions:darwin_x86_64": ["site-packages/tensorflow/libtensorflow_framework.2.dylib"], 21 | }), 22 | visibility = ["//visibility:public"], 23 | ) 24 | """ 25 | pip_parse( 26 | name = "pypi", 27 | annotations = { 28 | "numpy": package_annotation( 29 | additive_build_content = """ 30 | cc_library( 31 | name = "numpy_headers_2", 32 | hdrs = glob(["site-packages/numpy/_core/include/**/*.h"]), 33 | strip_include_prefix="site-packages/numpy/_core/include/", 34 | ) 35 | cc_library( 36 | name = "numpy_headers_1", 37 | hdrs = glob(["site-packages/numpy/core/include/**/*.h"]), 38 | strip_include_prefix="site-packages/numpy/core/include/", 39 | ) 40 | cc_library( 41 | name = "numpy_headers", 42 | deps = [":numpy_headers_2", ":numpy_headers_1"], 43 | ) 44 | """, 45 | ), 46 | "tensorflow": package_annotation( 47 | additive_build_content = tensorflow_annotation, 48 | ), 49 | "tf-nightly": package_annotation( 50 | additive_build_content = tensorflow_annotation, 51 | ), 52 | }, 53 | python_interpreter_target = interpreter, 54 | requirements_lock = REQUIREMENTS_WITH_LOCAL_WHEELS, 55 | ) 56 | -------------------------------------------------------------------------------- /third_party/yggdrasil_decision_forests/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/yggdrasil_decision_forests/workspace.bzl: -------------------------------------------------------------------------------- 1 | """Yggdrasil Decision Forests project.""" 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | def deps(from_git_repo = True): 6 | if from_git_repo: 7 | YDF_VERSION = "1.11.0" 8 | YDF_SHA = "8553a7bfcb96dcdf19f4e9ce7bc5aca1a72df38bd29dfff53e9a58b317bba0c0" 9 | http_archive( 10 | name = "ydf", 11 | urls = ["https://github.com/google/yggdrasil-decision-forests/archive/refs/tags/v{version}.tar.gz".format(version = YDF_VERSION)], 12 | strip_prefix = "yggdrasil-decision-forests-{version}".format(version = YDF_VERSION), 13 | sha256 = YDF_SHA, 14 | patch_args = ["-p1"], 15 | patches = ["//third_party/yggdrasil_decision_forests:ydf.patch"], 16 | ) 17 | else: 18 | # You can also clone the YDF repository manually. 19 | # Note that you need to manually apply the patch for Tensorflow >= 2.16 or nightly. 20 | native.local_repository( 21 | name = "ydf", 22 | # When downloading from Github, you might need - instead of _ as folder name 23 | path = "../yggdrasil_decision_forests", 24 | ) 25 | -------------------------------------------------------------------------------- /third_party/yggdrasil_decision_forests/ydf.patch: -------------------------------------------------------------------------------- 1 | diff --git a/yggdrasil_decision_forests/learner/decision_tree/BUILD b/yggdrasil_decision_forests/learner/decision_tree/BUILD 2 | index 9c021b04..0a8746a0 100644 3 | --- a/yggdrasil_decision_forests/learner/decision_tree/BUILD 4 | +++ b/yggdrasil_decision_forests/learner/decision_tree/BUILD 5 | @@ -63,7 +63,7 @@ cc_library_ydf( 6 | "@com_google_absl//absl/time", 7 | "@com_google_absl//absl/types:optional", 8 | "@com_google_absl//absl/types:span", 9 | - "@eigen_archive//:eigen3_internal", 10 | + "@eigen_archive//:eigen3", 11 | ] + select({ 12 | "//conditions:default": [ 13 | ], 14 | @@ -150,7 +150,7 @@ cc_library_ydf( 15 | "@com_google_absl//absl/time", 16 | "@com_google_absl//absl/types:optional", 17 | "@com_google_absl//absl/types:span", 18 | - "@eigen_archive//:eigen3_internal", 19 | + "@eigen_archive//:eigen3", 20 | ] + select({ 21 | "//conditions:default": [ 22 | ], 23 | -------------------------------------------------------------------------------- /tools/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], 4 | ) 5 | -------------------------------------------------------------------------------- /tools/build_docs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Export the source code comments into the API Reference website.""" 16 | 17 | import os 18 | 19 | from absl import app 20 | from absl import flags 21 | 22 | from tensorflow_docs.api_generator import generate_lib 23 | from tensorflow_docs.api_generator import public_api 24 | import tensorflow_decision_forests as tfdf 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | flags.DEFINE_string("output_dir", "/tmp/tfdf_api", "Where to output the docs") 29 | 30 | flags.DEFINE_string("code_url_prefix", "", "The url prefix for links to code.") 31 | 32 | flags.DEFINE_bool("search_hints", True, 33 | "Include metadata search hints in the generated files") 34 | 35 | flags.DEFINE_string("site_path", "/decision_forests/api_docs/python", 36 | "Path prefix in the _toc.yaml") 37 | 38 | 39 | def main(argv): 40 | if len(argv) > 1: 41 | raise app.UsageError("Too many command-line arguments.") 42 | 43 | doc_generator = generate_lib.DocGenerator( 44 | root_title="TensorFlow Decision Forests", 45 | py_modules=[("tfdf", tfdf)], # ("tfdf.keras", tfdf.keras) 46 | base_dir=os.path.dirname(tfdf.__file__), 47 | code_url_prefix=FLAGS.code_url_prefix, 48 | search_hints=FLAGS.search_hints, 49 | site_path=FLAGS.site_path, 50 | callbacks=[public_api.explicit_package_contents_filter]) 51 | doc_generator.build(FLAGS.output_dir) 52 | 53 | 54 | if __name__ == "__main__": 55 | app.run(main) 56 | -------------------------------------------------------------------------------- /tools/build_pip_package.bat: -------------------------------------------------------------------------------- 1 | :: Copyright 2021 Google LLC. 2 | :: 3 | :: Licensed under the Apache License, Version 2.0 (the "License"); 4 | :: you may not use this file except in compliance with the License. 5 | :: You may obtain a copy of the License at 6 | :: 7 | :: https://www.apache.org/licenses/LICENSE-2.0 8 | :: 9 | :: Unless required by applicable law or agreed to in writing, software 10 | :: distributed under the License is distributed on an "AS IS" BASIS, 11 | :: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | :: See the License for the specific language governing permissions and 13 | :: limitations under the License. 14 | 15 | 16 | set PATH=C:\Python38;c:\tools\msys64\usr\bin;%PATH% 17 | SET BASH=C:\tools\msys64\usr\bin\bash.exe 18 | DOSKEY python3=C:\Python38\python.exe $* 19 | 20 | %BASH% tools\build_pip_package.sh 21 | -------------------------------------------------------------------------------- /tools/start_compile_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | 18 | # Start a docker able to compile TF-DF. 19 | # 20 | # Usage example 21 | # 22 | # # Create the pip packages for publication 23 | # 24 | # # Download and start the docker (this script) 25 | # ./tools/start_compile_docker.sh 26 | # 27 | # # Compile and test TF-DF. 28 | # RUN_TESTS=1 PY_VERSION=3.9 TF_VERSION=2.15.0 ./tools/test_bazel.sh 29 | # 30 | # # Create a Pip package for a specific version of python. 31 | # ./tools/build_pip_package.sh python3.9 32 | # 33 | # # Install the other versions of python (the docker only has py3.8). 34 | # sudo apt-get update 35 | # sudo apt-get install python3.9 python3-pip python3.10 36 | # 37 | # # Create the Pip package for the other version of python 38 | # ./tools/build_pip_package.sh python3.9 39 | # ./tools/build_pip_package.sh python3.10 40 | # 41 | # # Make the result of the docker world accessible (in case the docker is run 42 | # # in root). 43 | # chmod -R a+xrw . 44 | # 45 | # # Exit the docker 46 | # exit 47 | # 48 | # # Publish the pip packages 49 | # ./tools/submit_pip_package.sh 50 | # 51 | # Alternative ending 52 | # 53 | # # Create a Pip package for all the compatible version of pythons using pyenv. 54 | # ./tools/build_pip_package.sh ALL_VERSIONS 55 | # 56 | # # Create a Pip package for all the compatible version of python using the 57 | # # previous build_pip_package call results (i.e. the "tmp_package" directory) 58 | # ./tools/build_pip_package.sh ALL_VERSIONS_ALREADY_ASSEMBLED 59 | # 60 | # https://hub.docker.com/r/tensorflow/build/tags?page=1 61 | 62 | # Current directory 63 | # Useful if Yggdrasil Decision Forests is available locally in a neighbor 64 | # directory. 65 | TFDF_DIRNAME=${PWD##*/} 66 | 67 | DOCKER_IMAGE=tensorflow/build:2.19-python3.9 68 | DOCKER_CONTAINER=compile_tfdf 69 | 70 | echo "Available containers:" 71 | sudo docker container ls -a --size 72 | 73 | set +e # Ignore error if the container already exist 74 | CREATE_DOCKER_FLAGS="-i -t -p 8889:8889 --network host -v ${PWD}/..:/working_dir -w /working_dir/${TFDF_DIRNAME}" 75 | sudo docker create ${CREATE_DOCKER_FLAGS} --name ${DOCKER_CONTAINER} ${DOCKER_IMAGE} 76 | sudo docker start ${DOCKER_CONTAINER} 77 | set -e 78 | 79 | # Start docker 80 | sudo docker exec -it ${DOCKER_CONTAINER} /bin/bash -c $@ 81 | -------------------------------------------------------------------------------- /tools/submit_pip_package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | 18 | python3 -m twine upload -u __token__ dist/*manylinux2014_x86_64.whl 19 | python3 -m twine upload -u __token__ dist/*macosx*.whl 20 | echo "The results is available at https://pypi.org/project/tensorflow-decision-forests" 21 | -------------------------------------------------------------------------------- /tools/test_bazel.bat: -------------------------------------------------------------------------------- 1 | :: Copyright 2021 Google LLC. 2 | :: 3 | :: Licensed under the Apache License, Version 2.0 (the "License"); 4 | :: you may not use this file except in compliance with the License. 5 | :: You may obtain a copy of the License at 6 | :: 7 | :: https://www.apache.org/licenses/LICENSE-2.0 8 | :: 9 | :: Unless required by applicable law or agreed to in writing, software 10 | :: distributed under the License is distributed on an "AS IS" BASIS, 11 | :: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | :: See the License for the specific language governing permissions and 13 | :: limitations under the License. 14 | 15 | :: Compile and runs the unit tests. 16 | set BAZEL=bazel-3.7.2-windows-x86_64.exe 17 | 18 | :: TensorFlow bazelrc is required for the distributed compilation using 19 | :: RBE i.e. a remove server (fast). 20 | :: https://raw.githubusercontent.com/tensorflow/tensorflow/master/.bazelrc 21 | SET TENSORFLOW_BAZELRC=tensorflow_bazelrc 22 | 23 | %BAZEL% --bazelrc=%TENSORFLOW_BAZELRC% build^ 24 | //tensorflow_decision_forests/...:all^ 25 | --config=windows^ 26 | --config=rbe_win^ 27 | --config=rbe_win_py38^ 28 | --config=tensorflow_testing_rbe_win^ 29 | --flaky_test_attempts=1 30 | 31 | :: Use --output_user_root to specify a quick-to-access location. 32 | 33 | if %errorlevel% neq 0 exit /b %errorlevel% 34 | -------------------------------------------------------------------------------- /tools/tf_serving/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Decision Forests and TensorFlow Serving 2 | 3 | `build_tf_serving_with_tf_df.sh` is a script that compiles TF-Serving with 4 | TF-Decision Forests. See the 5 | [TensorFlow Decision Forests and TensorFlow Serving guide](https://www.tensorflow.org/decision_forests/tensorflow_serving) 6 | for background information. Results of this script as published as 7 | `tf_serving_linux.zip` on the 8 | [TF-DF GitHub release page](https://github.com/tensorflow/decision-forests/releases). 9 | 10 | **Compilation** 11 | 12 | *Note:* Make sure Docker is installed. 13 | 14 | ```shell 15 | # In a fresh directory 16 | git clone https://github.com/tensorflow/decision-forests.git 17 | git clone https://github.com/tensorflow/serving.git 18 | decision-forests/tools/tf_serving/build_tf_serving_with_tf_df.sh 19 | 20 | # Or sudo decision-forests/tools/tf_serving/build_tf_serving_with_tf_df.sh if 21 | # Docker need to be run as sudo. 22 | ``` 23 | 24 | **Usage example:** 25 | 26 | ```shell 27 | # Make sure TF-DF is installed (this is only necessary to train the model) 28 | pip3 install tensorflow_decision_forests -U 29 | 30 | # Train a Random Forest model on the adult dataset 31 | # Save the model to "/tmp/my_saved_model" 32 | python3 decision-forests/examples/minimal.py 33 | 34 | # Add a version number to the model (this is required for TF Serving) 35 | mkdir -p /tmp/my_saved_model_with_version 36 | cp -r /tmp/my_saved_model /tmp/my_saved_model_with_version/1 37 | 38 | # Compile or download TF-Serving, and set the TFSERVING variable accordingly. 39 | # TFServing binary. 40 | TFSERVING="./tensorflow_model_server" 41 | 42 | # Configure the model path and name. 43 | MODEL_PATH=/tmp/my_saved_model_with_version 44 | MODEL_NAME=my_model 45 | 46 | # Start a TF Serving server 47 | # Note: This command is blocking. You need to run it in a separate terminal (or 48 | # using &). 49 | ${TFSERVING} \ 50 | --rest_api_port=8501 \ 51 | --model_name=${MODEL_NAME} \ 52 | --model_base_path=${MODEL_PATH} 53 | 54 | # Send requests to the model: 55 | # See https://www.tensorflow.org/tfx/serving/api_rest for the various solutions. 56 | 57 | # Predictions with the predict+instances API. 58 | curl http://localhost:8501/v1/models/${MODEL_NAME}:predict -X POST \ 59 | -d '{"instances": [{"age":39,"workclass":"State-gov","fnlwgt":77516,"education":"Bachelors","education_num":13,"marital_status":"Never-married","occupation":"Adm-clerical","relationship":"Not-in-family","race":"White","sex":"Male","capital_gain":2174,"capital_loss":0,"hours_per_week":40,"native_country":"United-States"}]}' 60 | 61 | # Predictions with the predict+inputs API 62 | curl http://localhost:8501/v1/models/${MODEL_NAME}:predict -X POST \ 63 | -d '{"inputs": {"age":[39],"workclass":["State-gov"],"fnlwgt":[77516],"education":["Bachelors"],"education_num":[13],"marital_status":["Never-married"],"occupation":["Adm-clerical"],"relationship":["Not-in-family"],"race":["White"],"sex":["Male"],"capital_gain":[2174],"capital_loss":[0],"hours_per_week":[40],"native_country":["United-States"]}}' 64 | ``` 65 | -------------------------------------------------------------------------------- /tools/tf_serving/build_tf_serving_with_tf_df.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | 18 | # Compile TensorFlow Servo with TensorFlow Decision Forests. 19 | # 20 | # This script is equivalent as the instructions at: 21 | # https://www.tensorflow.org/decision_forests/tensorflow_serving 22 | # 23 | # This script is a standalone: It will compile the version TF-DF available 24 | # publicly on github. 25 | # 26 | # Usage example: 27 | # # In a fresh directory 28 | # git clone https://github.com/tensorflow/decision-forests.git 29 | # git clone https://github.com/tensorflow/serving.git 30 | # decision-forests/tools/tf_serving/build_tf_serving_with_tf_df.sh 31 | # 32 | set -e 33 | 34 | # Add TF-DF as a dependency. 35 | 36 | # WORKSPACE 37 | WORKSPACE_PATH="serving/WORKSPACE" 38 | if [ ! -f "${WORKSPACE_PATH}.bk" ]; then 39 | echo "Edit WORKSPACE: ${WORKSPACE_PATH}" 40 | cp ${WORKSPACE_PATH} ${WORKSPACE_PATH}.bk 41 | sed -i '3rdecision-forests/tools/tf_serving/build_tf_serving_workspace_extra.txt' ${WORKSPACE_PATH} 42 | else 43 | echo "WORKSPACE already edited: ${WORKSPACE_PATH}" 44 | fi 45 | 46 | # BUILD 47 | BUILD_PATH="serving/tensorflow_serving/model_servers/BUILD" 48 | if [ ! -f "${BUILD_PATH}.bk" ]; then 49 | echo "Edit BUILD: ${BUILD_PATH}" 50 | cp ${BUILD_PATH} ${BUILD_PATH}.bk 51 | sed -i 's|if_v2(\[\])|if_v2([\n "@org_tensorflow_decision_forests//tensorflow_decision_forests/tensorflow/ops/inference:kernel_and_op"\n])|g' serving/tensorflow_serving/model_servers/BUILD 52 | else 53 | echo "BUILD already edited: ${BUILD_PATH}" 54 | fi 55 | 56 | # Compile Servo 57 | echo "Compile Servo" 58 | (cd serving && tools/run_in_docker.sh bazel build -c opt \ 59 | --copt=-mfma \ 60 | --define use_tensorflow_io=1 \ 61 | --copt=-mavx2 tensorflow_serving/model_servers:tensorflow_model_server ) 62 | 63 | # Pack the binary in a zip. This is the zip distributed in github. 64 | BINARY=serving/bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server 65 | zip -j tensorflow_model_server_linux.zip ${BINARY} README.md 66 | 67 | # If you have a model, you can run it in TF Serving using the following command: 68 | # ${BINARY} \ 69 | # --rest_api_port=8501 \ 70 | # --model_name=my_model \ 71 | # --model_base_path=/path/to/tfdf/model 72 | -------------------------------------------------------------------------------- /tools/tf_serving/build_tf_serving_workspace_extra.txt: -------------------------------------------------------------------------------- 1 | # Instructions to add to the TF Serving WORKSPACE file to add support to 2 | # TensorFlow Decision Forests. 3 | 4 | # Import Yggdrasil Decision Forests. 5 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 6 | http_archive( 7 | name="ydf", 8 | urls=[ 9 | "https://github.com/google/yggdrasil-decision-forests/archive/refs/tags/0.2.3.zip"], 10 | strip_prefix="yggdrasil-decision-forests-0.2.3", 11 | ) 12 | 13 | # Load the YDF dependencies. However, skip the ones already injected by 14 | # TensorFlow. 15 | load("@ydf//yggdrasil_decision_forests:library.bzl", 16 | ydf_load_deps="load_dependencies") 17 | ydf_load_deps( 18 | exclude_repo=[ 19 | "absl", 20 | "protobuf", 21 | "zlib", 22 | "farmhash", 23 | "gtest", 24 | "tensorflow", 25 | "grpc" 26 | ], 27 | repo_name="@ydf", 28 | ) 29 | 30 | # Import TensorFlow Decision Forests. 31 | load("//tensorflow_serving:repo.bzl", "tensorflow_http_archive") 32 | http_archive( 33 | name="org_tensorflow_decision_forests", 34 | urls=[ 35 | "https://github.com/tensorflow/decision-forests/archive/refs/tags/0.2.3.zip"], 36 | strip_prefix="decision-forests-0.2.3", 37 | ) 38 | --------------------------------------------------------------------------------