├── .bazelignore
├── .bazelrc
├── .bazelversion
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── WORKSPACE
├── configure
├── BUILD
├── MANIFEST.in
├── requirements.in
├── requirements_lock_3_10.txt
├── requirements_lock_3_11.txt
├── requirements_lock_3_9.txt
└── setup.py
├── documentation
├── BUILD
├── _book.yaml
├── _index.yaml
├── developer_manual.md
├── directory_structure.md
├── distributed_training.md
├── image
│ ├── logo.png
│ └── ydf_logo.png
├── installation.md
├── known_issues.md
├── migration.md
├── more_examples.md
├── simple_ml_for_sheets
│ ├── image
│ │ ├── simple_ml_owl_1.png
│ │ └── simple_ml_predict_1.png
│ └── index.md
├── tensorflow_serving.md
├── text_features.md
├── tf_df_in_tf_js.md
└── tutorials
│ ├── BUILD
│ ├── advanced_colab.ipynb
│ ├── automatic_tuning_colab.ipynb
│ ├── beginner_colab.ipynb
│ ├── dtreeviz_colab.ipynb
│ ├── images
│ ├── beginner_tensorboard.png
│ └── model_composition_structure.png
│ ├── index.md
│ ├── intermediate_colab.ipynb
│ ├── model_composition_colab.ipynb
│ ├── predict_colab.ipynb
│ ├── proximities_colab.ipynb
│ ├── ranking_colab.ipynb
│ └── uplift_colab.ipynb
├── examples
├── BUILD
├── distributed_hyperparameter_optimization.py
├── distributed_training.py
├── fake_distributed_training.sh
├── hyperparameter_optimization.py
└── minimal.py
├── tensorflow_decision_forests
├── BUILD
├── __init__.py
├── component
│ ├── BUILD
│ ├── __init__.py
│ ├── builder
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ └── builder_test.py
│ ├── inspector
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── blob_sequence.py
│ │ ├── blob_sequence_test.py
│ │ ├── inspector.py
│ │ └── inspector_test.py
│ ├── model_plotter
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── model_plotter.py
│ │ ├── model_plotter_test.py
│ │ └── plotter.js
│ ├── py_tree
│ │ ├── BUILD
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── condition.py
│ │ ├── condition_test.py
│ │ ├── dataspec.py
│ │ ├── dataspec_test.py
│ │ ├── node.py
│ │ ├── node_test.py
│ │ ├── objective.py
│ │ ├── objective_test.py
│ │ ├── tree.py
│ │ ├── tree_test.py
│ │ ├── value.py
│ │ └── value_test.py
│ └── tuner
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── tuner.py
│ │ └── tuner_test.py
├── contrib
│ ├── BUILD
│ ├── README.md
│ ├── example_of_contrib
│ │ ├── BUILD
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── example_of_contrib.py
│ │ └── example_of_contrib_test.py
│ ├── scikit_learn_model_converter
│ │ ├── BUILD
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── scikit_learn_model_converter.py
│ │ └── scikit_learn_model_converter_test.py
│ └── training_preprocessing
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── training_preprocessing.py
│ │ └── training_preprocessing_test.py
├── keras
│ ├── BUILD
│ ├── __init__.py
│ ├── core.py
│ ├── core_inference.py
│ ├── keras_distributed_test.py
│ ├── keras_internal.py
│ ├── keras_internal_test.py
│ ├── keras_test.py
│ ├── keras_tuner_test.py
│ ├── test_runner.py
│ ├── wrapper
│ │ ├── BUILD
│ │ ├── wrapper.bzl
│ │ ├── wrapper.cc
│ │ ├── wrapper.h
│ │ ├── wrapper_main.cc
│ │ └── wrapper_test.cc
│ └── wrappers_pre_generated.py
├── tensorflow
│ ├── BUILD
│ ├── __init__.py
│ ├── cc_logging.py
│ ├── check_version.py
│ ├── check_version_test.py
│ ├── core.py
│ ├── core_inference.py
│ ├── core_test.py
│ ├── distribute
│ │ ├── BUILD
│ │ └── tensorflow_std_server.py
│ ├── ops
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── inference
│ │ │ ├── BUILD
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── kernel.cc
│ │ │ ├── op.cc
│ │ │ ├── op.py
│ │ │ ├── op_dynamic.py
│ │ │ ├── test_utils.py
│ │ │ ├── tf1_test.py
│ │ │ └── tf2_test.py
│ │ └── training
│ │ │ ├── BUILD
│ │ │ ├── __init__.py
│ │ │ ├── feature_on_file.cc
│ │ │ ├── feature_on_file.h
│ │ │ ├── features.h
│ │ │ ├── kernel.cc
│ │ │ ├── kernel.h
│ │ │ ├── kernel_grpc_worker.cc
│ │ │ ├── kernel_long_process.cc
│ │ │ ├── kernel_on_file.cc
│ │ │ ├── op.cc
│ │ │ ├── op.py
│ │ │ ├── op_dynamic.py
│ │ │ ├── op_grpc_worker.cc
│ │ │ └── op_test.py
│ ├── tf1_compatibility.py
│ ├── tf_logging.py
│ └── utils.bzl
├── tensorflow_decision_forests.bzl
├── test_data
│ ├── BUILD
│ ├── README.md
│ └── model
│ │ └── saved_model_adult_rf
│ │ ├── assets
│ │ ├── data_spec.pb
│ │ ├── done
│ │ ├── header.pb
│ │ ├── nodes-00000-of-00001
│ │ └── random_forest_header.pb
│ │ ├── keras_metadata.pb
│ │ ├── saved_model.pb
│ │ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
└── tools
│ └── run_e2e_tfdf_test.sh
├── third_party
├── BUILD
├── absl
│ ├── BUILD
│ └── workspace.bzl
├── absl_py
│ ├── BUILD
│ └── workspace.bzl
├── benchmark
│ ├── BUILD
│ └── workspace.bzl
├── gtest
│ ├── BUILD
│ └── workspace.bzl
├── protobuf
│ ├── BUILD
│ └── workspace.bzl
├── tensorflow
│ ├── BUILD
│ ├── tf-216.patch
│ ├── tf.patch
│ └── workspace.bzl
├── tensorflow_pypi
│ ├── BUILD
│ ├── BUILD.tpl
│ ├── README
│ └── tf_configure.bzl
└── yggdrasil_decision_forests
│ ├── BUILD
│ ├── workspace.bzl
│ └── ydf.patch
└── tools
├── BUILD
├── build_docs.py
├── build_pip_package.bat
├── build_pip_package.sh
├── start_compile_docker.sh
├── submit_pip_package.sh
├── test_bazel.bat
├── test_bazel.sh
└── tf_serving
├── README.md
├── build_tf_serving_with_tf_df.sh
└── build_tf_serving_workspace_extra.txt
/.bazelignore:
--------------------------------------------------------------------------------
1 | tensorflow_decision_forests/contrib/temporal_feature_processor
2 |
--------------------------------------------------------------------------------
/.bazelrc:
--------------------------------------------------------------------------------
1 | # Bazel configuration for Yggdrasil Decision Forests
2 | #
3 | # OPTIONS
4 | #
5 | # Linux
6 | #
7 | # linux: Linux build.
8 | #
9 | # Windows
10 | #
11 | # windows: Windows build (unsupported).
12 | #
13 | # MacOS
14 | #
15 | # macos: MacOS build.
16 | #
17 |
18 | # Common flags.
19 | common --experimental_repo_remote_exec
20 | build --announce_rc
21 | build -c opt
22 |
23 | # Allow to use tensorflow internal methods.
24 | build --nocheck_visibility
25 |
26 | # Yggdrasil uses tensorflow for IO.
27 | build --define=use_tensorflow_io=1
28 |
29 | # If 1, use the pre-compiled wrappers.
30 | build --define=use_precompiled_wrappers=0
31 |
32 | # The user can stop the training with Ctrl+C.
33 | build --define=stop_training_on_interrupt=1
34 |
35 | # Yggdrasil uses the tensorflow headers specialized for dynamic linking.
36 | build --define=tensorflow_with_header_lib=1
37 |
38 | # Nice print
39 | build:linux --copt=-fdiagnostics-color=always
40 | build --color=yes
41 |
42 | # Suppress C++ compiler warnings.
43 | build:linux --copt=-w
44 | build:linux --host_copt=-w
45 | build:macos --copt=-w
46 | build:windows --copt=/W0
47 |
48 | # Common build options.
49 | build --define=framework_shared_object=true
50 | build --enable_platform_specific_config
51 | build --define=dynamic_loaded_kernels=true
52 | build --copt=-DAUTOLOAD_DYNAMIC_KERNELS
53 | build --spawn_strategy=standalone
54 |
55 | build --strategy=Genrule=standalone
56 | build --action_env TF_NEED_CUDA="0"
57 |
58 | # Host specifics options.
59 | build:macos --features=-supports_dynamic_linker
60 | # On Intel machines, AVX2 is available
61 | build:macos_intel --config=avx_linux
62 | # Absl synchronization primitives currently do not work on MacOS.
63 | # TODO
64 | build:macos --define std_synchronization_primitives=1
65 |
66 | build:windows --config=release_cpu_windows
67 |
--------------------------------------------------------------------------------
/.bazelversion:
--------------------------------------------------------------------------------
1 | 6.5.0
2 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | Contributions for TensorFlow Decision Forests and its engine Yggdrasil Decision Forests are welcome.
4 |
5 | See [github.com/google/yggdrasil-decision-forests](https://github.com/google/yggdrasil-decision-forests) for Yggdrasil Decision Forests.
6 |
7 | ## Pull Request Checklist
8 |
9 | Before sending pull requests, make sure to review the
10 | [developer manual](documentation/developer_manual.md).
11 |
12 | ## Contributor License Agreement
13 |
14 | Contributions to this project must be accompanied by a Contributor License
15 | Agreement. You (or your employer) retain the copyright to your contribution,
16 | this simply gives us permission to use and redistribute your contributions as
17 | part of the project. Head over to to see
18 | your current agreements on file or to sign a new one.
19 |
20 | You generally only need to submit a CLA once, so if you've already submitted one
21 | (even if it was for a different project), you probably don't need to do it
22 | again.
23 |
24 | ## Code reviews
25 |
26 | All submissions, including submissions by project members, require review. We
27 | use GitHub pull requests for this purpose. Consult
28 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
29 | information on using pull requests.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | **TensorFlow Decision Forests** (**TF-DF**) is a library to train, run and
6 | interpret [decision forest](https://ydf.readthedocs.io/en/latest/intro_df.html)
7 | models (e.g., Random Forests, Gradient Boosted Trees) in TensorFlow. TF-DF
8 | supports classification, regression and ranking.
9 |
10 | **TF-DF** is powered by
11 | [Yggdrasil Decision Forest](https://github.com/google/yggdrasil-decision-forests)
12 | (**YDF**, a library to train and use decision forests in C++, JavaScript, CLI,
13 | and Go. TF-DF models are
14 | [compatible](https://ydf.readthedocs.io/en/latest/convert_model.html#convert-a-a-tensorflow-decision-forests-model-to-a-yggdrasil-model)
15 | with YDF' models, and vice versa.
16 |
17 | Tensorflow Decision Forests is available on Linux and Mac. Windows users can use
18 | the library through WSL+Linux.
19 |
20 | ## Usage example
21 |
22 | A minimal end-to-end run looks as follows:
23 |
24 | ```python
25 | import tensorflow_decision_forests as tfdf
26 | import pandas as pd
27 |
28 | # Load the dataset in a Pandas dataframe.
29 | train_df = pd.read_csv("project/train.csv")
30 | test_df = pd.read_csv("project/test.csv")
31 |
32 | # Convert the dataset into a TensorFlow dataset.
33 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label")
34 | test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="my_label")
35 |
36 | # Train the model
37 | model = tfdf.keras.RandomForestModel()
38 | model.fit(train_ds)
39 |
40 | # Look at the model.
41 | model.summary()
42 |
43 | # Evaluate the model.
44 | model.evaluate(test_ds)
45 |
46 | # Export to a TensorFlow SavedModel.
47 | # Note: the model is compatible with Yggdrasil Decision Forests.
48 | model.save("project/model")
49 | ```
50 |
51 | ## Google I/O Presentation
52 |
53 |
58 |
59 | ## Documentation & Resources
60 |
61 | The following resources are available:
62 |
63 | - [TF-DF on TensorFlow.org](https://tensorflow.org/decision_forests) (API
64 | Reference, Guides and Tutorials)
65 | - [Tutorials](https://www.tensorflow.org/decision_forests/tutorials) (on
66 | tensorflow.org)
67 | - [YDF documentation](https://ydf.readthedocs.io) (also applicable to TF-DF)
68 | - [Issue tracker](https://github.com/tensorflow/decision-forests/issues)
69 | - [Known issues](documentation/known_issues.md)
70 | - [Changelog](CHANGELOG.md)
71 | - [More examples](documentation/more_examples.md)
72 |
73 | ## Installation
74 |
75 | To install TensorFlow Decision Forests, run:
76 |
77 | ```shell
78 | pip3 install tensorflow_decision_forests --upgrade
79 | ```
80 |
81 | See the [installation](documentation/installation.md) page for more details,
82 | troubleshooting and alternative installation solutions.
83 |
84 | ## Contributing
85 |
86 | Contributions to TensorFlow Decision Forests and Yggdrasil Decision Forests are
87 | welcome. If you want to contribute, make sure to review the
88 | [developer manual](documentation/developer_manual.md) and
89 | [contribution guidelines](CONTRIBUTING.md).
90 |
91 | ## Citation
92 |
93 | If you us Tensorflow Decision Forests in a scientific publication, please cite
94 | the following paper:
95 | [Yggdrasil Decision Forests: A Fast and Extensible Decision Forests Library](https://doi.org/10.1145/3580305.3599933).
96 |
97 | **Bibtex**
98 |
99 | ```
100 | @inproceedings{GBBSP23,
101 | author = {Mathieu Guillame{-}Bert and
102 | Sebastian Bruch and
103 | Richard Stotz and
104 | Jan Pfeifer},
105 | title = {Yggdrasil Decision Forests: {A} Fast and Extensible Decision Forests
106 | Library},
107 | booktitle = {Proceedings of the 29th {ACM} {SIGKDD} Conference on Knowledge Discovery
108 | and Data Mining, {KDD} 2023, Long Beach, CA, USA, August 6-10, 2023},
109 | pages = {4068--4077},
110 | year = {2023},
111 | url = {https://doi.org/10.1145/3580305.3599933},
112 | doi = {10.1145/3580305.3599933},
113 | }
114 | ```
115 |
116 | **Raw**
117 |
118 | Yggdrasil Decision Forests: A Fast and Extensible Decision Forests Library,
119 | Guillame-Bert et al., KDD 2023: 4068-4077. doi:10.1145/3580305.3599933
120 |
121 | ## Contact
122 |
123 | You can contact the core development team at
124 | [decision-forests-contact@google.com](mailto:decision-forests-contact@google.com).
125 |
126 | ## Credits
127 |
128 | TensorFlow Decision Forests was developed by:
129 |
130 | - Mathieu Guillame-Bert (gbm AT google DOT com)
131 | - Jan Pfeifer (janpf AT google DOT com)
132 | - Richard Stotz (richardstotz AT google DOT com)
133 | - Sebastian Bruch (sebastian AT bruch DOT io)
134 | - Arvind Srinivasan (arvnd AT google DOT com)
135 |
136 | ## License
137 |
138 | [Apache License 2.0](LICENSE)
139 |
--------------------------------------------------------------------------------
/WORKSPACE:
--------------------------------------------------------------------------------
1 | workspace(name = "org_tensorflow_decision_forests")
2 |
3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
4 |
5 | # rules_java is required for Tensorflow.
6 | http_archive(
7 | name = "rules_java",
8 | sha256 = "c73336802d0b4882e40770666ad055212df4ea62cfa6edf9cb0f9d29828a0934",
9 | url = "https://github.com/bazelbuild/rules_java/releases/download/5.3.5/rules_java-5.3.5.tar.gz",
10 | )
11 |
12 | # ==========================================
13 | # Start of TensorFlow and its dependencies
14 | # ==========================================
15 |
16 | # This version of TensorFlow is injected only to make sure we use the same dependencies as TensorFlow (protobuffer, grpc, absl).
17 | # TensorFlow is not compiled.
18 |
19 | # Note: The OPs dynamic library depends on symbols specific to the version of
20 | # absl used by tensorflow.
21 | http_archive(
22 | name = "org_tensorflow",
23 | sha256 = "4691b18e8c914cdf6759b80f1b3b7f3e17be41099607ed0143134f38836d058e",
24 | strip_prefix = "tensorflow-2.19.0",
25 | urls = ["https://github.com/tensorflow/tensorflow/archive/v2.19.0.tar.gz"],
26 | )
27 |
28 | load("//tensorflow_decision_forests:tensorflow_decision_forests.bzl", "py_deps_profile")
29 |
30 | py_deps_profile(
31 | name = "release_or_nightly",
32 | deps_map = {
33 | "tensorflow": [
34 | "tf-nightly",
35 | "tf_header_lib",
36 | "libtensorflow_framework",
37 | ],
38 | "tf-keras": ["tf-keras-nightly"],
39 | },
40 | pip_repo_name = "pypi",
41 | requirements_in = "//configure:requirements.in",
42 | switch = {
43 | "IS_NIGHTLY": "nightly",
44 | },
45 | )
46 |
47 | # Initialize hermetic Python
48 | load("@org_tensorflow//third_party/py:python_init_rules.bzl", "python_init_rules")
49 |
50 | python_init_rules()
51 |
52 | load("@org_tensorflow//third_party/py:python_init_repositories.bzl", "python_init_repositories")
53 |
54 | python_init_repositories(
55 | default_python_version = "system",
56 | requirements = {
57 | "3.9": "//configure:requirements_lock_3_9.txt",
58 | "3.10": "//configure:requirements_lock_3_10.txt",
59 | "3.11": "//configure:requirements_lock_3_11.txt",
60 | },
61 | )
62 |
63 | load("@org_tensorflow//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
64 |
65 | python_init_toolchains()
66 |
67 | load("//third_party/tensorflow_pypi:tf_configure.bzl", "tf_configure")
68 |
69 | tf_configure()
70 |
71 | load("@pypi//:requirements.bzl", "install_deps")
72 |
73 | install_deps()
74 |
75 | # Inject tensorflow dependencies.
76 | # TensorFlow cannot anymore be injected from a sub-module.
77 | # Note: The order is important.
78 | load("@org_tensorflow//tensorflow:workspace3.bzl", tf1 = "workspace")
79 |
80 | tf1()
81 |
82 | load("@org_tensorflow//tensorflow:workspace2.bzl", tf2 = "workspace")
83 |
84 | tf2()
85 |
86 | load("@org_tensorflow//tensorflow:workspace1.bzl", tf3 = "workspace")
87 |
88 | tf3()
89 |
90 | load("@org_tensorflow//tensorflow:workspace0.bzl", tf4 = "workspace")
91 |
92 | tf4()
93 |
94 | load(
95 | "@local_tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
96 | "cuda_json_init_repository",
97 | )
98 |
99 | cuda_json_init_repository()
100 |
101 | load(
102 | "@cuda_redist_json//:distributions.bzl",
103 | "CUDA_REDISTRIBUTIONS",
104 | "CUDNN_REDISTRIBUTIONS",
105 | )
106 | load(
107 | "@local_tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
108 | "cuda_redist_init_repositories",
109 | "cudnn_redist_init_repository",
110 | )
111 |
112 | cuda_redist_init_repositories(
113 | cuda_redistributions = CUDA_REDISTRIBUTIONS,
114 | )
115 |
116 | cudnn_redist_init_repository(
117 | cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
118 | )
119 |
120 | load(
121 | "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
122 | "cuda_configure",
123 | )
124 |
125 | cuda_configure(name = "local_config_cuda")
126 |
127 | load(
128 | "@local_tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
129 | "nccl_redist_init_repository",
130 | )
131 |
132 | nccl_redist_init_repository()
133 |
134 | load(
135 | "@local_tsl//third_party/nccl/hermetic:nccl_configure.bzl",
136 | "nccl_configure",
137 | )
138 |
139 | nccl_configure(name = "local_config_nccl")
140 |
141 | # ========================================
142 | # End of TensorFlow and its dependencies
143 | # ========================================
144 |
145 | # Third party libraries
146 | load("//third_party/absl:workspace.bzl", absl = "deps")
147 | load("//third_party/absl_py:workspace.bzl", absl_py = "deps")
148 | load("//third_party/benchmark:workspace.bzl", benchmark = "deps")
149 | load("//third_party/gtest:workspace.bzl", gtest = "deps")
150 | load("//third_party/protobuf:workspace.bzl", protobuf = "deps")
151 |
152 | absl()
153 |
154 | absl_py()
155 |
156 | benchmark()
157 |
158 | gtest()
159 |
160 | protobuf()
161 |
162 | # Yggdrasil Decision Forests
163 | load("//third_party/yggdrasil_decision_forests:workspace.bzl", yggdrasil_decision_forests = "deps")
164 |
165 | yggdrasil_decision_forests()
166 |
167 | load("@ydf//yggdrasil_decision_forests:library.bzl", ydf_load_deps = "load_dependencies")
168 |
169 | ydf_load_deps(
170 | exclude_repo = [
171 | "absl",
172 | "protobuf",
173 | "zlib",
174 | "farmhash",
175 | "grpc",
176 | "eigen",
177 | "pybind11",
178 | "pybind11_abseil",
179 | "pybind11_protobuf",
180 | "tensorflow",
181 | ],
182 | repo_name = "@ydf",
183 | )
184 |
--------------------------------------------------------------------------------
/configure/BUILD:
--------------------------------------------------------------------------------
1 | load("@python//:defs.bzl", "compile_pip_requirements")
2 | load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
3 |
4 | licenses(["notice"])
5 |
6 | compile_pip_requirements(
7 | name = "requirements",
8 | extra_args = [
9 | "--allow-unsafe",
10 | "--build-isolation",
11 | "--rebuild",
12 | ],
13 | generate_hashes = True,
14 | requirements_in = "//configure:requirements.in",
15 | requirements_txt = REQUIREMENTS,
16 | )
17 |
18 | py_binary(
19 | name = "tensorflow_build_info",
20 | srcs = ["tensorflow_build_info.py"],
21 | deps = ["@pypi_tf_nightly//:pkg"],
22 | )
23 |
--------------------------------------------------------------------------------
/configure/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include README
3 | include tensorflow_decision_forests/component/model_plotter/plotter.js
4 | recursive-include * *.so
5 | recursive-include * *.so.[0-9]
6 | recursive-include * *.dylib
7 | recursive-include * *.dll
8 | recursive-include * grpc_worker_main
9 |
--------------------------------------------------------------------------------
/configure/requirements.in:
--------------------------------------------------------------------------------
1 | numpy
2 | pandas
3 | tensorflow==2.19.0
4 | six
5 | absl_py
6 | wheel
7 | wurlitzer
8 | tf_keras
9 | ydf>=0.11.0
10 | scikit-learn
--------------------------------------------------------------------------------
/configure/requirements_lock_3_10.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/configure/requirements_lock_3_10.txt
--------------------------------------------------------------------------------
/configure/requirements_lock_3_11.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/configure/requirements_lock_3_11.txt
--------------------------------------------------------------------------------
/configure/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Setup file for pip's build.
16 |
17 | This file is used by tools/build_pip_package.sh.
18 | """
19 |
20 | import platform
21 | import sys
22 | import setuptools
23 | from setuptools.command.install import install
24 | from setuptools.dist import Distribution
25 |
26 | _VERSION = "1.12.0"
27 |
28 | with open("README.md", "r", encoding="utf-8") as fh:
29 | long_description = fh.read()
30 |
31 | REQUIRED_PACKAGES = [
32 | "numpy",
33 | "pandas",
34 | "tensorflow==2.19.0",
35 | "six",
36 | "absl_py",
37 | "wheel",
38 | "wurlitzer",
39 | "tf_keras~=2.19",
40 | "ydf>=0.11.0",
41 | ]
42 |
43 |
44 | class InstallPlatlib(install):
45 |
46 | def finalize_options(self):
47 | install.finalize_options(self)
48 | if self.distribution.has_ext_modules():
49 | self.install_lib = self.install_platlib
50 |
51 |
52 | class BinaryDistribution(Distribution):
53 |
54 | def has_ext_modules(self):
55 | return True
56 |
57 | def is_pure(self):
58 | return False
59 |
60 |
61 | if "bdist_wheel" in sys.argv:
62 | if "--plat-name" not in sys.argv:
63 | if platform.system() == "Darwin":
64 | idx = sys.argv.index("bdist_wheel") + 1
65 | sys.argv.insert(idx, "--plat-name")
66 | if platform.processor() == "arm":
67 | sys.argv.insert(idx + 1, "macosx_12_0_arm64")
68 | elif platform.processor() == "i386":
69 | sys.argv.insert(idx + 1, "macosx_10_15_x86_64")
70 | else:
71 | raise ValueError(f"Unknown processor {platform.processor()}")
72 | else:
73 | print("Not on MacOS")
74 | else:
75 | print("PLAT-NAME Supplied")
76 | else:
77 | print("NO BDIST_WHEEL")
78 |
79 | setuptools.setup(
80 | cmdclass={
81 | "install": InstallPlatlib,
82 | },
83 | name="tensorflow_decision_forests",
84 | version=_VERSION,
85 | author="Google Inc.",
86 | author_email="decision-forests-contact@google.com",
87 | description=(
88 | "Collection of training and inference decision forest algorithms."
89 | ),
90 | long_description=long_description,
91 | long_description_content_type="text/markdown",
92 | url="https://github.com/tensorflow/decision-forests",
93 | project_urls={
94 | "Bug Tracker": "https://github.com/tensorflow/decision-forests/issues",
95 | },
96 | classifiers=[
97 | "Intended Audience :: Developers",
98 | "Intended Audience :: Education",
99 | "Intended Audience :: Science/Research",
100 | "License :: OSI Approved :: Apache Software License",
101 | "Programming Language :: Python :: 3",
102 | "Programming Language :: Python :: 3.9",
103 | "Programming Language :: Python :: 3.10",
104 | "Programming Language :: Python :: 3.11",
105 | "Programming Language :: Python :: 3.12",
106 | "Programming Language :: Python :: 3 :: Only",
107 | "Topic :: Scientific/Engineering",
108 | "Topic :: Scientific/Engineering :: Mathematics",
109 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
110 | "Topic :: Software Development",
111 | "Topic :: Software Development :: Libraries",
112 | "Topic :: Software Development :: Libraries :: Python Modules",
113 | ],
114 | distclass=BinaryDistribution,
115 | packages=setuptools.find_packages(),
116 | python_requires=">=3.9",
117 | license="Apache 2.0",
118 | keywords=(
119 | "tensorflow tensor machine learning decision forests random forest"
120 | " gradient boosted decision trees"
121 | ),
122 | install_requires=REQUIRED_PACKAGES,
123 | include_package_data=True,
124 | zip_safe=False,
125 | )
126 |
--------------------------------------------------------------------------------
/documentation/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
--------------------------------------------------------------------------------
/documentation/_book.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | upper_tabs:
16 | # Tabs left of dropdown menu
17 | - include: /_upper_tabs_left.yaml
18 | - include: /api_docs/_upper_tabs_api.yaml
19 | # Dropdown menu
20 | - name: Resources
21 | path: /resources
22 | is_default: true
23 | menu:
24 | - include: /resources/_menu_toc.yaml
25 | lower_tabs:
26 | # Subsite tabs
27 | other:
28 | - name: Guides
29 | contents:
30 | - title: Getting started
31 | path: /decision_forests/tutorials/beginner_colab
32 | - title: Migrating to YDF
33 | path: https://ydf.readthedocs.io/en/latest/tutorial/migrating_to_ydf/
34 | status: external
35 | - title: Hyper-parameters
36 | path: https://ydf.readthedocs.io/en/latest/hyperparameters/
37 | status: external
38 | - title: Known issues
39 | path: /decision_forests/known_issues
40 |
41 | - heading: Guide
42 | - title: Migrate from neural networks
43 | path: /decision_forests/migration
44 | - title: Learning to rank
45 | path: /decision_forests/tutorials/ranking_colab
46 | - title: Inspect and debug
47 | path: /decision_forests/tutorials/advanced_colab
48 | - title: Compose models
49 | path: /decision_forests/tutorials/model_composition_colab
50 | - title: Automatic tuning
51 | path: /decision_forests/tutorials/automatic_tuning_colab
52 | - title: Proximities and Prototypes
53 | path: /decision_forests/tutorials/proximities_colab
54 | - title: Uplift modeling
55 | path: /decision_forests/tutorials/uplift_colab
56 | - title: TensorFlow Serving
57 | path: /decision_forests/tensorflow_serving
58 | - title: Text and pre-trained embeddings
59 | path: /decision_forests/tutorials/intermediate_colab
60 | - title: Distributed Training
61 | path: /decision_forests/distributed_training
62 | - title: Making predictions
63 | path: /decision_forests/tutorials/predict_colab
64 | - title: Introduction
65 | path: /decision_forests/simple_ml_for_sheets/
66 |
67 | - heading: Experimental
68 | - title: Run model in JavaScript
69 | path: https://ydf.readthedocs.io/en/latest/js_serving.html
70 | status: external
71 | - title: Run model with TensorFlow.js
72 | path: /decision_forests/tf_df_in_tf_js
73 | status: external
74 |
75 | - name: API Reference
76 | skip_translation: true
77 | contents:
78 | - title: All Symbols
79 | path: /decision_forests/api_docs/python/tfdf/all_symbols
80 | - include: /decision_forests/api_docs/python/tfdf/_toc.yaml
81 |
82 | - include: /_upper_tabs_right.yaml
83 |
--------------------------------------------------------------------------------
/documentation/developer_manual.md:
--------------------------------------------------------------------------------
1 | # Developer Manual
2 |
3 | Internally, TD-DF relies on
4 | [Yggdrasil Decision Forests](https://github.com/google/yggdrasil-decision-forests)
5 | (YDF). Depending on the change, reading YDF's user and developer manual might be
6 | beneficial.
7 |
8 | The library's dependency structure is organized in layers:
9 |
10 | 1. Keras
11 | 2. TensorFlow
12 | 3. Python utility
13 | 4. Yggdrasil
14 |
15 | New logic should be implemented where relevant. When several layers are
16 | possibly relevant, the most generic layer should be favored.
17 |
18 | The directory structure of [TF-DF](directory_structure.md) and
19 | [YDF](https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/directory_structure.md)
20 | is a good start.
21 |
--------------------------------------------------------------------------------
/documentation/directory_structure.md:
--------------------------------------------------------------------------------
1 | # Directory Structure
2 |
3 | The project is organised as follows:
4 |
5 | ```
6 | ├── configure: Project configuration.
7 | ├── documentation: User and developer documentation. Contains the colabs.
8 | ├── examples: Collection of usage examples.
9 | ├── tensorflow_decision_forests: The library
10 | │ ├── component: Utilities.
11 | │ │ ├── builder: Create models "by hand".
12 | │ │ ├── inspector: Inspection of structure and meta-data of models.
13 | │ │ ├── model_plotter: Plotting of model tree structure.
14 | │ │ ├── inspector: Inspection of structure and meta-data of models.
15 | │ │ ├── py_tree: Representation of a decision tree as a python object.
16 | │ │ └── tuner: TF-DF's own hyper-parameter tuner.
17 | │ ├── contrib: Additional functionality outside the project's main scope.
18 | │ ├── keras: Keras logic. Depends on tensorflow logic.
19 | │ │ └── wrapper: Python code generator for Keras models.
20 | │ │── tensorflow: TensorFlow logic.
21 | │ │ └── ops: Custom C++ ops.
22 | │ │ ├── inference: ... for inference.
23 | │ │ └── training: ... for training.
24 | │ └── test_data: Datasets for unit tests and benchmarks.
25 | ├── third_party: Bazel configuration for dependencies.
26 | └── tools: Tools for the management of the project and code.
27 | ```
28 |
--------------------------------------------------------------------------------
/documentation/image/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/image/logo.png
--------------------------------------------------------------------------------
/documentation/image/ydf_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/image/ydf_logo.png
--------------------------------------------------------------------------------
/documentation/known_issues.md:
--------------------------------------------------------------------------------
1 | # Known Issues
2 |
3 | ## Prefer YDF for new projects
4 |
5 | [YDF](https://github.com/google/yggdrasil-decision-forests) is Google's new
6 | library to train Decision Forests.
7 |
8 | YDF extends the power of TF-DF, offering new features, a simplified API, faster
9 | training times, updated documentation, and enhanced compatibility with popular
10 | ML libraries.
11 |
12 | Some of the issues mentioned below are fixed in YDF.
13 |
14 | ## Windows Pip package is not available
15 |
16 | TensorFlow Decision Forests is not yet available as a Windows Pip package.
17 |
18 | **Workarounds:**
19 |
20 | - *Solution #1:* Install
21 | [Windows Subsystem for Linux (WSL)](https://en.wikipedia.org/wiki/Windows_Subsystem_for_Linux)
22 | on your Windows machine and follow the Linux instructions.
23 |
24 | ## Incompatibility with Keras 3
25 |
26 | Compatibility with Keras 3 is not yet implemented. Use tf_keras or a TensorFlow
27 | version before 2.16. Alternatively, use [ydf](https://pypi.org/project/ydf/).
28 |
29 | ## Untested for conda
30 |
31 | While TF-DF might work with Conda, this is not tested and we currently do not
32 | maintain packages on conda-forge.
33 |
34 | ## Incompatibility with old or nightly versions of TensorFlow
35 |
36 | TensorFlow's [ABI](https://en.wikipedia.org/wiki/Application_binary_interface)
37 | is not compatible in between releases. Because TF-DF relies on custom TensorFlow
38 | C++ ops, each version of TF-DF is tied to a specific version of TensorFlow. The
39 | last released version of TF-DF is always tied to the last released version of
40 | TensorFlow.
41 |
42 | For these reasons, the current version of TF-DF might not be compatible with
43 | older versions or with the nightly build of TensorFlow.
44 |
45 | If using incompatible versions of TF and TF-DF, you will see cryptic errors such
46 | as:
47 |
48 | ```
49 | tensorflow_decision_forests/tensorflow/ops/training/training.so: undefined symbol: _ZN10tensorflow11GetNodeAttrERKNS_9AttrSliceEN4absl14lts_2020_09_2311string_viewEPSs
50 | ```
51 |
52 | - Use the version of TF-DF that is compatible with your version of TensorFlow.
53 |
54 | ### Compatibility table
55 |
56 | The following table shows the compatibility between
57 | `tensorflow_decision_forests` and its dependencies:
58 |
59 | tensorflow_decision_forests | tensorflow
60 | --------------------------- | ---------------
61 | 1.12.0 | 2.19.0
62 | 1.11.0 | 2.18.0
63 | 1.10.0 | 2.17.0
64 | 1.9.2 | 2.16.2
65 | 1.9.1 | 2.16.1
66 | 1.9.0 | 2.16.1
67 | 1.8.0 - 1.8.1 | 2.15.0
68 | 1.6.0 - 1.7.0 | 2.14.0
69 | 1.5.0 | 2.13.0
70 | 1.3.0 - 1.4.0 | 2.12.0
71 | 1.1.0 - 1.2.0 | 2.11.0
72 | 1.0.0 - 1.0.1 | 2.10.0 - 2.10.1
73 | 0.2.6 - 0.2.7 | 2.9.1
74 | 0.2.5 | 2.9
75 | 0.2.4 | 2.8
76 | 0.2.1 - 0.2.3 | 2.7
77 | 0.1.9 - 0.2.0 | 2.6
78 | 0.1.1 - 0.1.8 | 2.5
79 | 0.1.0 | 2.4
80 |
81 | - *Solution #2:* Wrap your preprocessing function into another function that
82 | [squeezes](https://www.tensorflow.org/api_docs/python/tf/squeeze) its
83 | inputs.
84 |
85 | ## Not all models support distributed training and distribute strategies
86 |
87 | Unless specified, models are trained on a single machine and are not compatible
88 | with distribution strategies. For example the `GradientBoostedTreesModel` does
89 | not support distributed training while `DistributedGradientBoostedTreesModel`
90 | does.
91 |
92 | **Workarounds:**
93 |
94 | - Use a model that supports distribution strategies (e.g.
95 | `DistributedGradientBoostedTreesModel`), or downsample your dataset so that
96 | it fits on a single machine.
97 |
98 | ## No support for GPU / TPU.
99 |
100 | TF-DF does not supports GPU or TPU training. Compiling with AVX instructions,
101 | however, may speed up serving.
102 |
103 | ## No support for [model_to_estimator](https://www.tensorflow.org/api_docs/python/tf/keras/estimator/model_to_estimator)
104 |
105 | TF-DF does not implement the APIs required to convert a trained/untrained model
106 | to the estimator format.
107 |
108 | ## Loaded models behave differently than Python models.
109 |
110 | While abstracted by the Keras API, a model instantiated in Python (e.g., with
111 | `tfdf.keras.RandomForestModel()`) and a model loaded from disk (e.g., with
112 | `tf_keras.models.load_model()`) can behave differently. Notably, a Python
113 | instantiated model automatically applies necessary type conversions. For
114 | example, if a `float64` feature is fed to a model expecting a `float32` feature,
115 | this conversion is performed implicitly. However, such a conversion is not
116 | possible for models loaded from disk. It is therefore important that the
117 | training data and the inference data always have the exact same type.
118 |
119 | ## Tensorflow feature name sanitization
120 |
121 | Tensorflow sanitizes feature names and might, for instance, convert them to
122 | lowercase.
123 |
--------------------------------------------------------------------------------
/documentation/more_examples.md:
--------------------------------------------------------------------------------
1 | # More examples
2 |
3 | This page lists examples created by TF-DF users.
4 |
5 | - [Minimal example](https://github.com/tensorflow/decision-forests/blob/main/examples/minimal.py)
6 | (on github.org)
7 | - [Running TF-DF models on Arduino](https://github.com/achoum/ardwino-tensorflow-decision-forests)
8 | (on github.com)
9 | - [Classification with TensorFlow Decision Forests](https://keras.io/examples/structured_data/classification_with_tfdf/)
10 | (on keras.io)
11 |
--------------------------------------------------------------------------------
/documentation/simple_ml_for_sheets/image/simple_ml_owl_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/simple_ml_for_sheets/image/simple_ml_owl_1.png
--------------------------------------------------------------------------------
/documentation/simple_ml_for_sheets/image/simple_ml_predict_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/simple_ml_for_sheets/image/simple_ml_predict_1.png
--------------------------------------------------------------------------------
/documentation/tensorflow_serving.md:
--------------------------------------------------------------------------------
1 | # TensorFlow Decision Forests and TensorFlow Serving
2 |
3 | [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) (TF Serving)
4 | is a tool to run TensorFlow models online in large production settings using a
5 | RPC or REST API. TensorFlow Decision Forests (TF-DF) is supported natively by TF
6 | Serving >=2.11.
7 |
8 | TF-DF models are directly compatible with TF Serving. Yggdrasil models can be
9 | used with TF Serving after being
10 | [converted](https://ydf.readthedocs.io/en/latest/convert_model.html#convert-a-yggdrasil-model-to-a-tensorflow-decision-forests-model)
11 | first.
12 |
13 | ## Limitations
14 |
15 | TensorFlow adds a significant amount of computation overhead. For small, latency
16 | sensitive models (e.g., model inference time ~1µs), this overhead can be
17 | an order of magnitude larger than time needed by the model itself.
18 | In this case, it is recommended to run the TF-DF models with
19 | [Yggdrasil Decision Forests](https://ydf.readthedocs.io).
20 |
21 | ## Usage example
22 |
23 | The following example shows how to run a TF-DF model in TF Serving:
24 |
25 | First, [install TF Serving](https://github.com/tensorflow/serving#set-up). In
26 | this example, we will use a pre-compiled version of TF-Serving + TF-DF.
27 |
28 | ```shell
29 | # Download TF Serving
30 | wget https://github.com/tensorflow/decision-forests/releases/download/serving-1.0.1/tensorflow_model_server_linux.zip
31 | unzip tensorflow_model_server_linux.zip
32 |
33 | # Check that TF Serving works.
34 | ./tensorflow_model_server --version
35 | ```
36 |
37 | In this example, we use an already trained TF-DF model.
38 |
39 | ```shell
40 | # Get a TF-DF model
41 | git clone https://github.com/tensorflow/decision-forests.git
42 | MODEL_PATH=$(pwd)/decision-forests/tensorflow_decision_forests/test_data/model/saved_model_adult_rf
43 |
44 | echo "The TF-DF model is available at: ${MODEL_PATH}"
45 | ```
46 |
47 | **Notes:** TF-Serving requires the model's full path. This is why we use
48 | `$(pwd)`.
49 |
50 | TF-Serving supports model versioning. The model should be contained in a
51 | directory whose name is the version of the model. A model version is an integer
52 | e.g., "1". Here is a typical directory for TF-Serving.
53 |
54 | - `/path/to/model`
55 | - `1` : Version 1 of the model
56 | - `5` : Version 5 of the model
57 | - `6` : Version 6 of the model
58 |
59 | For this example, we only need to put the model in a directory called "1".
60 |
61 | ```shell
62 | mkdir -p /tmp/tf_serving_model
63 | cp -R "${MODEL_PATH}" /tmp/tf_serving_model/1
64 | ```
65 |
66 | Now, we can start TF-Sering on the model.
67 |
68 | ```shell
69 | ./tensorflow_model_server \
70 | --rest_api_port=8502 \
71 | --model_name=my_model \
72 | --model_base_path=/tmp/tf_serving_model
73 | ```
74 |
75 | Finally, you can send a request to TF Serving using the Rest API. Two formats
76 | are available: predict+instances API and predict+inputs API. Here is an example
77 | of each of them:
78 |
79 | ```shell
80 | # Predictions with the predict+instances API.
81 | curl http://localhost:8502/v1/models/my_model:predict -X POST \
82 | -d '{"instances": [{"age":39,"workclass":"State-gov","fnlwgt":77516,"education":"Bachelors","education_num":13,"marital_status":"Never-married","occupation":"Adm-clerical","relationship":"Not-in-family","race":"White","sex":"Male","capital_gain":2174,"capital_loss":0,"hours_per_week":40,"native_country":"United-States"}]}'
83 | ```
84 |
85 | ```shell
86 | # Predictions with the predict+inputs API
87 | curl http://localhost:8502/v1/models/my_model:predict -X POST \
88 | -d '{"inputs": {"age":[39],"workclass":["State-gov"],"fnlwgt":[77516],"education":["Bachelors"],"education_num":[13],"marital_status":["Never-married"],"occupation":["Adm-clerical"],"relationship":["Not-in-family"],"race":["White"],"sex":["Male"],"capital_gain":[2174],"capital_loss":[0],"hours_per_week":[40],"native_country":["United-States"]}}'
89 | ```
90 |
--------------------------------------------------------------------------------
/documentation/tf_df_in_tf_js.md:
--------------------------------------------------------------------------------
1 | # Running TensorFlow Decision Forests models with TensorFlow.js
2 |
3 | These instructions explain how to train a TF-DF model and run it on the
4 | web using TensorFlow.js.
5 |
6 | ## Detailed instructions
7 |
8 | ### Train a model in TF-DF
9 |
10 | To try out this tutorial, you first need a TF-DF model. You can use your own
11 | model or train a model with the
12 | [Beginner's tutorial](https://www.tensorflow.org/decision_forests/tutorials/beginner_colab).
13 |
14 | If you simply want to quickly train a model in Google Colab, you can use the
15 | following code snippet.
16 |
17 | ```python
18 | !pip install tensorflow_decision_forests -U -qq
19 | import tensorflow as tf
20 | import tensorflow_decision_forests as tfdf
21 | import pandas as pd
22 |
23 | # Download the dataset, load it into a pandas dataframe and convert it to TensorFlow format.
24 | !wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv
25 | dataset_df = pd.read_csv("/tmp/penguins.csv")
26 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")
27 |
28 | # Create and train the model
29 | model_1 = tfdf.keras.GradientBoostedTreesModel()
30 | model_1.fit(train_ds)
31 | ```
32 |
33 | ### Convert the model
34 |
35 | The instructions going forward assume that you have saved your TF-DF model under
36 | the path `/tmp/my_saved_model`. Run the following snippet to convert the model
37 | to TensorFlow.js.
38 |
39 | ```python
40 | !pip install tensorflow tensorflow_decision_forests 'tensorflowjs>=4.4.0'
41 | !pip install tf_keras
42 |
43 | # Prepare and load the model with TensorFlow
44 | import tensorflow as tf
45 | import tensorflowjs as tfjs
46 | from google.colab import files
47 |
48 | # Save the model in the SavedModel format
49 | tf.saved_model.save(model_1, "/tmp/my_saved_model")
50 |
51 | # Convert the SavedModel to TensorFlow.js and save as a zip file
52 | tfjs.converters.tf_saved_model_conversion_v2.convert_tf_saved_model("/tmp/my_saved_model", "./tfjs_model")
53 |
54 | # Download the converted TFJS model
55 | !zip -r tfjs_model.zip tfjs_model/
56 | files.download("tfjs_model.zip")
57 | ```
58 |
59 | When Google Colab finishes running, it downloads the converted TFJS model as a
60 | zip file.. Unzip this file before using it in the next step.
61 |
62 | An unzipped Tensorflow.js model consists of a number of files. The example
63 | model contains the following:
64 |
65 | - assets.zip
66 | - group1-shard1of1.bin
67 | - model.json
68 |
69 |
70 | ### Use the Tensorflow.js model on the web
71 |
72 | Use this template to load TFJS dependencies and run the TFDF model. Change the
73 | model path to where your model is served and modify the tensor given to
74 | executeAsync.
75 |
76 | ```html
77 |
78 |
79 |
98 | ```
99 |
100 | ## Questions?
101 |
102 | Check out the
103 | [TensorFlow Decision Forests documentation](https://www.tensorflow.org/decision_forests)
104 | and the [TensorFlow.js documentation](https://www.tensorflow.org/js/tutorials).
105 |
--------------------------------------------------------------------------------
/documentation/tutorials/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
--------------------------------------------------------------------------------
/documentation/tutorials/images/beginner_tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/tutorials/images/beginner_tensorboard.png
--------------------------------------------------------------------------------
/documentation/tutorials/images/model_composition_structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/documentation/tutorials/images/model_composition_structure.png
--------------------------------------------------------------------------------
/documentation/tutorials/index.md:
--------------------------------------------------------------------------------
1 | # TensorFlow Decision Forests tutorials
2 |
3 | The following notebooks are available:
4 |
5 | - [Beginner Colab](beginner_colab.ipynb): Learn about the basic about model
6 | training, evaluation and exportation.
7 | - [Ranking Colab](ranking_colab.ipynb): Learn about ranking with decision
8 | forests.
9 | - [Intermediate Colab](intermediate_colab.ipynb): How to consume text and
10 | combine decision forest with neural networks.
11 | - [Advanced Colab](advanced_colab.ipynb): How to inspect and create model
12 | structures directly.
13 | - [Uplifting Colab](uplift_colab.ipynb): Learn about uplift modeling with
14 | decision forests.
15 | - [Model composition Colab](model_composition_colab.ipynb): How to compose
16 | decision forests and neural networks together.
17 | - [Proximities and Prototypes with Random Forests](proximities_colab.ipynb):
18 | Measure the distance between tabular examples and use it to understand a
19 | model and its predictions.
20 | - [Automatic hyper-parameter tuning](automatic_tuning_colab.ipynb):
21 | Automatically select the best hyper-parameters for a model.
22 | - [Making predictions](predict_colab.ipynb): List of options to make
23 | predictions with a model.
24 |
--------------------------------------------------------------------------------
/examples/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | py_binary(
7 | name = "minimal",
8 | srcs = ["minimal.py"],
9 | deps = [
10 | "//tensorflow_decision_forests",
11 | "@pypi_absl_py//:pkg", # :app
12 | "@pypi_numpy//:pkg",
13 | "@pypi_pandas//:pkg",
14 | "@release_or_nightly//:tensorflow_pkg",
15 | "@release_or_nightly//:tf_keras_pkg",
16 | ],
17 | )
18 |
19 | py_binary(
20 | name = "hyperparameter_optimization",
21 | srcs = ["hyperparameter_optimization.py"],
22 | deps = [
23 | "//tensorflow_decision_forests",
24 | "@pypi_absl_py//:pkg", # :app
25 | "@pypi_numpy//:pkg",
26 | "@pypi_pandas//:pkg",
27 | "@release_or_nightly//:tensorflow_pkg",
28 | "@release_or_nightly//:tf_keras_pkg",
29 | ],
30 | )
31 |
32 | py_binary(
33 | name = "distributed_training",
34 | srcs = ["distributed_training.py"],
35 | deps = [
36 | "//tensorflow_decision_forests",
37 | "@pypi_absl_py//:pkg", # :app
38 | "@pypi_numpy//:pkg",
39 | "@pypi_pandas//:pkg",
40 | "@release_or_nightly//:tensorflow_pkg",
41 | "@release_or_nightly//:tf_keras_pkg",
42 | ],
43 | )
44 |
45 | py_binary(
46 | name = "distributed_hyperparameter_optimization",
47 | srcs = ["distributed_hyperparameter_optimization.py"],
48 | deps = [
49 | "//tensorflow_decision_forests",
50 | "@pypi_absl_py//:pkg", # :app
51 | "@release_or_nightly//:tensorflow_pkg",
52 | "@release_or_nightly//:tf_keras_pkg",
53 | ],
54 | )
55 |
--------------------------------------------------------------------------------
/examples/distributed_hyperparameter_optimization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Example of distributed hyper-parameter optimization with TF-DF.
16 |
17 | This example trains and exports a Gradient Boosted Tree model.
18 |
19 | Usage example:
20 |
21 | You need to configure TF Parameters servers. See:
22 | https://www.tensorflow.org/decision_forests/distributed_training
23 | https://www.tensorflow.org/tutorials/distribute/parameter_server_training
24 |
25 | TF_CONFIG = ...
26 | # Start the workers
27 | # ...
28 | # Run the chief
29 | python3 distributed_hyperparameter_optimization.py
30 | """
31 |
32 | from absl import app
33 | from absl import logging
34 |
35 | import tensorflow as tf
36 | import tensorflow_decision_forests as tfdf
37 |
38 |
39 | def main(argv):
40 | if len(argv) > 1:
41 | raise app.UsageError("Too many command-line arguments.")
42 |
43 | # Configure training
44 | logging.info("Configure training")
45 | cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
46 | strategy = tf.distribute.experimental.ParameterServerStrategy(
47 | cluster_resolver)
48 | with strategy.scope():
49 | tuner = tfdf.tuner.RandomSearch(
50 | # 200 trials to find the best hyper-parameters.
51 | num_trials=200,
52 | # Use the pre-defined hyper-parameter space.
53 | use_predefined_hps=True,
54 | # Each model is trained on 4 threads.
55 | trial_num_threads=4)
56 | model = tfdf.keras.GradientBoostedTreesModel(
57 | tuner=tuner,
58 | temp_directory="/cns/bh-d/home/gbm/tmp/ttl=15d/tfdf_cache_dho3",
59 | # Number of threads available on each worker.
60 | num_threads=30,
61 | )
62 |
63 | # Trains the model.
64 | logging.info("Start tuning")
65 | model.fit_on_dataset_path(
66 | train_path="/cns/is-d/home/gbm/ml_dataset_repository/others/adult/adult_train.csv",
67 | valid_path="/cns/is-d/home/gbm/ml_dataset_repository/others/adult/adult_test.csv",
68 | label_key="income",
69 | dataset_format="csv")
70 |
71 | logging.info("Trained model:")
72 | model.summary()
73 |
74 | # Access to model metrics.
75 | inspector = model.make_inspector()
76 | logging.info("Model self evaluation: %s", inspector.evaluation().to_dict())
77 |
78 | # Exports the model to disk in the SavedModel format for later re-use. This
79 | # model can be used with TensorFlow Serving and Yggdrasil Decision Forests
80 | # (https://ydf.readthedocs.io/en/latest/serving_apis.html).
81 | logging.info("Export model")
82 | model.save("/cns/bh-d/home/gbm/tmp/ttl=15d/tfdf_model_dho3")
83 |
84 |
85 | if __name__ == "__main__":
86 | app.run(main)
87 |
--------------------------------------------------------------------------------
/examples/distributed_training.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Minimal usage example of Distributed training with TF-DF.
16 |
17 | This example trains and exports a Gradient Boosted Tree model.
18 |
19 | Usage example:
20 |
21 | For this example, we need a large dataset. If you don't have such dataset
22 | available, create a synthetic dataset following the instructions in the
23 | "Synthetic dataset for usage example" below.
24 |
25 | You need to configure TF Parameters servers. See:
26 | https://www.tensorflow.org/decision_forests/distributed_training
27 | https://www.tensorflow.org/tutorials/distribute/parameter_server_training
28 |
29 | TF_CONFIG = ...
30 | # Start the workers
31 | # ...
32 | # Run the chief
33 | python3 distributed_training.py
34 |
35 | Synthetic dataset for usage example:
36 |
37 | In this example, we use a synthetic dataset containing 1M examples. This
38 | dataset is small enought that is could be used without distributed training,
39 | but this is a good example.
40 |
41 | This dataset is generated with the "synthetic_dataset" tool of YDF.
42 |
43 | Create a file "config.pbtxt" with the content:
44 | num_examples:1000000
45 | num_examples_per_shards: 100
46 | num_numerical:100
47 | num_categorical:50
48 | num_categorical_set:0
49 | num_boolean:50
50 | categorical_vocab_size:100
51 |
52 | Then run
53 |
54 | bazel run -c opt \
55 | //external/ydf/yggdrasil_decision_forests/cli/utils:synthetic_dataset -- \
56 | --alsologtostderr \
57 | --options=/config.pbtxt\
58 | --train=recordio+tfe:/train@60 \
59 | --valid=recordio+tfe:/valid@20 \
60 | --test=recordio+tfe:/test@20 \
61 | --ratio_valid=0.2 \
62 | --ratio_test=0.2
63 | """
64 |
65 | import os
66 | from absl import app
67 | from absl import logging
68 |
69 | import tensorflow as tf
70 | import tensorflow_decision_forests as tfdf
71 |
72 |
73 | def main(argv):
74 | if len(argv) > 1:
75 | raise app.UsageError("Too many command-line arguments.")
76 |
77 | # "work_directory" is used to store the temporary checkpoints as well as the
78 | # final model. "work_directory" should be accessible to both the chief and the
79 | # workers.
80 |
81 | work_directory = "/some/remote/directory"
82 |
83 | # Alternatively, You can use a local directory when testing distributed
84 | # training locally i.e. when running the workers in the same machine at the
85 | # chief. See "fake_distributed_training.sh".
86 | # work_directory = "/tmp/tfdf_model"
87 |
88 | # The dataset is provided as a set of sharded files.
89 | train_dataset_path = "/path/to/dataset/train@60"
90 | valid_dataset_path = "/path/to/dataset/valid@60"
91 | dataset_format = "recordio+tfe"
92 |
93 | # Alternatively, when testing distributed training locally, you can use a
94 | # non-sharded dataset.
95 | # train_dataset_path = "external/ydf/yggdrasil_decision_forests/test_data/dataset/adult_train.csv"
96 | # valid_dataset_path = "external/ydf/yggdrasil_decision_forests/test_data/dataset/adult_test.csv"
97 | # dataset_format = "csv"
98 |
99 | # Configure training
100 | logging.info("Configure training")
101 | cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver(
102 | rpc_layer="grpc")
103 | strategy = tf.distribute.experimental.ParameterServerStrategy(
104 | cluster_resolver)
105 | with strategy.scope():
106 | model = tfdf.keras.DistributedGradientBoostedTreesModel(
107 | # Speed-up training by discretizing numerical features.
108 | force_numerical_discretization=True,
109 | # Cache directory used to store checkpoints.
110 | temp_directory=os.path.join(work_directory, "work_dir"),
111 | # Number of threads on each worker.
112 | num_threads=30,
113 | )
114 | model.compile(metrics=["accuracy"])
115 |
116 | # Trains the model.
117 | logging.info("Start training")
118 | model.fit_on_dataset_path(
119 | train_path=train_dataset_path,
120 | valid_path=valid_dataset_path,
121 | label_key="income",
122 | dataset_format=dataset_format)
123 |
124 | logging.info("Trained model:")
125 | model.summary()
126 |
127 | # Access to model metrics.
128 | inspector = model.make_inspector()
129 | logging.info("Model self evaluation: %s", inspector.evaluation().to_dict())
130 | logging.info("Model training logs: %s", inspector.training_logs())
131 | inspector.export_to_tensorboard(os.path.join(work_directory, "tensorboard"))
132 |
133 | # Exports the model to disk in the SavedModel format for later re-use. This
134 | # model can be used with TensorFlow Serving and Yggdrasil Decision Forests
135 | # (https://ydf.readthedocs.io/en/latest/serving_apis.html).
136 | logging.info("Export model")
137 | model.save(os.path.join(work_directory, "model"))
138 |
139 |
140 | if __name__ == "__main__":
141 | app.run(main)
142 |
--------------------------------------------------------------------------------
/examples/fake_distributed_training.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2021 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
18 | # Simulate distributed training locally.
19 | #
20 |
21 | echo "Warning: This script does not stop workers. After running it, you need to"
22 | echo "stop the workers manually (e.g. using kill -9)."
23 |
24 | set -vex
25 |
26 | PYTHON=python3
27 |
28 | export TF_CONFIG='{
29 | "cluster": {
30 | "worker": [
31 | "localhost:4300",
32 | "localhost:4301"],
33 | "ps": ["localhost:4310"],
34 | "chief": ["localhost:4311"]
35 | }
36 | }'
37 |
38 | bazel build \
39 | //third_party/tensorflow_decision_forests/examples:distributed_training \
40 | //third_party/tensorflow_decision_forests/tensorflow/distribute:tensorflow_std_server_py
41 |
42 | WORKER=bazel-bin/third_party/tensorflow_decision_forests/tensorflow/distribute/tensorflow_std_server_py
43 | CHIEF=bazel-bin/third_party/tensorflow_decision_forests/examples/distributed_training
44 |
45 | # Start the workers
46 | ${WORKER} --alsologtostderr --job_name=worker --task_index=0 &
47 | ${WORKER} --alsologtostderr --job_name=worker --task_index=1 &
48 | ${WORKER} --alsologtostderr --job_name=ps --task_index=0 &
49 |
50 | # Start the chief
51 | ${CHIEF} --alsologtostderr
52 |
--------------------------------------------------------------------------------
/examples/hyperparameter_optimization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Example of automated hyper-parameter tuning with TensorFlow Decision Forests.
16 |
17 | This example trains, displays, evaluates and export a Gradient Boosted Tree
18 | model.
19 |
20 | Usage example:
21 |
22 | pip3 install tensorflow_decision_forests -U
23 | python3 hyperparameter_optimization.py
24 |
25 | Or
26 |
27 | bazel run -c opt \
28 | //tensorflow_decision_forests/examples:hyperparameter_optimization
29 | \
30 | -- --alsologtostderr
31 | """
32 |
33 | from absl import app
34 |
35 | import numpy as np
36 | import pandas as pd
37 | import tensorflow as tf
38 | import tensorflow_decision_forests as tfdf
39 | import tf_keras
40 |
41 | def main(argv):
42 | if len(argv) > 1:
43 | raise app.UsageError("Too many command-line arguments.")
44 |
45 | # Download the Adult dataset.
46 | dataset_path = tf_keras.utils.get_file(
47 | "adult.csv",
48 | "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/"
49 | "main/yggdrasil_decision_forests/test_data/dataset/adult.csv")
50 |
51 | # Load a dataset into a Pandas Dataframe.
52 | dataset_df = pd.read_csv(dataset_path) # "df" for Pandas's DataFrame.
53 |
54 | print("First the first three examples:")
55 | print(dataset_df.head(3))
56 |
57 | # Notice that the dataset contains a mix of numerical and categorical
58 | # features. TensorFlow Decision Forests handles them automatically (e.g. no
59 | # need for one-hot encoding or normalization; except for the label).
60 |
61 | # Split the dataset into a training and a testing dataset.
62 | test_indices = np.random.rand(len(dataset_df)) < 0.30
63 | test_ds_pd = dataset_df[test_indices]
64 | train_ds_pd = dataset_df[~test_indices]
65 | print(f"{len(train_ds_pd)} examples in training"
66 | f", {len(test_ds_pd)} examples for testing.")
67 |
68 | # Converts datasets from Pandas dataframe to TensorFlow dataset format.
69 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label="income")
70 | test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label="income")
71 |
72 | # Tune the model.
73 | #
74 | # The hyper-parameters to optimize are automatically set with
75 | # "use_predefined_hps=True". See
76 | # https://www.tensorflow.org/decision_forests/tutorials/automatic_tuning_colab
77 | # for an example where the hyper-parameter space is configured manually.
78 | tuner = tfdf.tuner.RandomSearch(num_trials=30, use_predefined_hps=True)
79 | model = tfdf.keras.GradientBoostedTreesModel(verbose=2, tuner=tuner)
80 | model.fit(train_ds)
81 |
82 | # Some information about the model.
83 | print(model.summary())
84 |
85 | # Evaluates the model on the test dataset.
86 | model.compile(metrics=["accuracy"])
87 | evaluation = model.evaluate(test_ds)
88 | print(f"BinaryCrossentropyloss: {evaluation[0]}")
89 | print(f"Accuracy: {evaluation[1]}")
90 |
91 | # Exports the model to disk in the SavedModel format for later re-use. This
92 | # model can be used with TensorFlow Serving and Yggdrasil Decision Forests
93 | # (https://ydf.readthedocs.io/en/latest/serving_apis.html).
94 | model.save("/tmp/my_saved_model")
95 |
96 |
97 | if __name__ == "__main__":
98 | app.run(main)
99 |
--------------------------------------------------------------------------------
/examples/minimal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Minimal usage example of TensorFlow Decision Forests.
16 |
17 | This example trains, displays, evaluates and exports a Gradient Boosted Tree
18 | model.
19 |
20 | Usage example:
21 |
22 | pip3 install tensorflow_decision_forests -U
23 | python3 minimal.py
24 |
25 | Or
26 |
27 | bazel run -c opt \
28 | //tensorflow_decision_forests/examples:minimal \
29 | -- --alsologtostderr
30 | """
31 |
32 | from absl import app
33 | import numpy as np
34 | import pandas as pd
35 | import tensorflow as tf
36 | import tensorflow_decision_forests as tfdf
37 | import tf_keras
38 |
39 |
40 | def main(argv):
41 | if len(argv) > 1:
42 | raise app.UsageError("Too many command-line arguments.")
43 |
44 | # Download the Adult dataset.
45 | dataset_path = tf_keras.utils.get_file(
46 | "adult.csv",
47 | "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/"
48 | "main/yggdrasil_decision_forests/test_data/dataset/adult.csv",
49 | )
50 |
51 | # Load a dataset into a Pandas Dataframe.
52 | dataset_df = pd.read_csv(dataset_path) # "df" for Pandas's DataFrame.
53 |
54 | print("First the first three examples:")
55 | print(dataset_df.head(3))
56 |
57 | # Notice that the dataset contains a mix of numerical and categorical
58 | # features. TensorFlow Decision Forests handles them automatically (e.g. no
59 | # need for one-hot encoding or normalization; except for the label).
60 |
61 | # Split the dataset into a training and a testing dataset.
62 | test_indices = np.random.rand(len(dataset_df)) < 0.30
63 | test_ds_pd = dataset_df[test_indices]
64 | train_ds_pd = dataset_df[~test_indices]
65 | print(
66 | f"{len(train_ds_pd)} examples in training"
67 | f", {len(test_ds_pd)} examples for testing."
68 | )
69 |
70 | # Converts datasets from Pandas dataframe to TensorFlow dataset format.
71 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label="income")
72 | test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label="income")
73 |
74 | # Trains the model.
75 | model = tfdf.keras.GradientBoostedTreesModel(verbose=2)
76 | model.fit(train_ds)
77 |
78 | # Some information about the model.
79 | print(model.summary())
80 |
81 | # Evaluates the model on the test dataset.
82 | model.compile(metrics=["accuracy"])
83 | evaluation = model.evaluate(test_ds)
84 | print(f"BinaryCrossentropyloss: {evaluation[0]}")
85 | print(f"Accuracy: {evaluation[1]}")
86 |
87 | # Exports the model to disk in the SavedModel format for later re-use. This
88 | # model can be used with TensorFlow Serving and Yggdrasil Decision Forests
89 | # (https://ydf.readthedocs.io/en/latest/serving_apis.html).
90 | model.save("/tmp/my_saved_model")
91 |
92 |
93 | if __name__ == "__main__":
94 | app.run(main)
95 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/BUILD:
--------------------------------------------------------------------------------
1 |
2 | package(
3 | default_visibility = ["//visibility:public"],
4 | licenses = ["notice"],
5 | )
6 |
7 |
8 |
9 | exports_files(["LICENSE"])
10 |
11 | # Do not depend on this rule. Instead, depend on "//tensorflow_decision_forests".
12 | py_library(
13 | name = "tensorflow_decision_forests",
14 | srcs = ["__init__.py"],
15 | visibility = [
16 | "//learning/lib/ami/simple_ml:__subpackages__",
17 | "//tensorflow_decision_forests:__subpackages__",
18 | "//tensorflow_decision_forests:__subpackages__",
19 | ],
20 | deps = [
21 | "//tensorflow_decision_forests/component/builder",
22 | "//tensorflow_decision_forests/component/inspector",
23 | "//tensorflow_decision_forests/component/model_plotter",
24 | "//tensorflow_decision_forests/component/py_tree",
25 | "//tensorflow_decision_forests/component/tuner",
26 | "//tensorflow_decision_forests/keras",
27 | "//tensorflow_decision_forests/tensorflow:check_version",
28 | ],
29 | )
30 |
31 | # If set, use the pre-compiled version of the OP wrappers. This significantly
32 | # speed-up the compilation of the project.
33 | config_setting(
34 | name = "use_precompiled_wrappers",
35 | values = {"define": "use_precompiled_wrappers=1"},
36 | )
37 |
38 | # If set, stop the training op when receiving an interrupt signal.
39 | config_setting(
40 | name = "stop_training_on_interrupt",
41 | values = {"define": "stop_training_on_interrupt=1"},
42 | )
43 |
44 | # If "disable_tf_ps_distribution_strategy" is true, the TF Parameter Server
45 | # distribution strategy is not available for distributed training.
46 | #
47 | # Distribution with TF PS is currently NOT supported for OSS TF-DF with shared
48 | # build (monolithic build works however) and TF<2.7. In this case, the GRPC
49 | # Worker Server can be used instead.
50 | config_setting(
51 | name = "disable_tf_ps_distribution_strategy",
52 | values = {"define": "tf_ps_distribution_strategy=0"},
53 | )
54 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/builder/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | # Libraries
7 | # =========
8 |
9 | py_library(
10 | name = "builder",
11 | srcs = ["builder.py"],
12 | deps = [
13 | "//tensorflow_decision_forests/component/inspector",
14 | "//tensorflow_decision_forests/component/inspector:blob_sequence",
15 | "//tensorflow_decision_forests/component/py_tree",
16 | "//tensorflow_decision_forests/keras:core_inference",
17 | "//tensorflow_decision_forests/tensorflow:core_inference",
18 | "@pypi_numpy//:pkg",
19 | "@pypi_six//:pkg",
20 | "@release_or_nightly//:tensorflow_pkg",
21 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto",
22 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto",
23 | "@ydf//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_py_proto",
24 | "@ydf//yggdrasil_decision_forests/model/random_forest:random_forest_py_proto",
25 | ],
26 | )
27 |
28 | # Tests
29 | # =====
30 |
31 | py_test(
32 | name = "builder_test",
33 | timeout = "long",
34 | srcs = ["builder_test.py"],
35 | data = ["@ydf//yggdrasil_decision_forests/test_data"],
36 | shard_count = 10,
37 | deps = [
38 | ":builder",
39 | "//tensorflow_decision_forests/component/inspector",
40 | "//tensorflow_decision_forests/component/py_tree",
41 | "//tensorflow_decision_forests/keras",
42 | "@pypi_absl_py//:pkg", # /flags
43 | "@pypi_numpy//:pkg",
44 | "@pypi_pandas//:pkg",
45 | "@release_or_nightly//:tensorflow_pkg",
46 | ],
47 | )
48 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/builder/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/inspector/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | # Libraries
7 | # =========
8 |
9 | py_library(
10 | name = "inspector",
11 | srcs = ["inspector.py"],
12 | deps = [
13 | ":blob_sequence",
14 | "//tensorflow_decision_forests/component/py_tree",
15 | "@pypi_pandas//:pkg",
16 | "@pypi_six//:pkg",
17 | "@release_or_nightly//:tensorflow_pkg",
18 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto",
19 | "@ydf//yggdrasil_decision_forests/metric:metric_py_proto",
20 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto",
21 | "@ydf//yggdrasil_decision_forests/model:hyperparameter_py_proto",
22 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto",
23 | "@ydf//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_py_proto",
24 | "@ydf//yggdrasil_decision_forests/model/multitasker:multitasker_py_proto",
25 | "@ydf//yggdrasil_decision_forests/model/random_forest:random_forest_py_proto",
26 | ],
27 | )
28 |
29 | py_library(
30 | name = "blob_sequence",
31 | srcs = ["blob_sequence.py"],
32 | deps = [
33 | "@release_or_nightly//:tensorflow_pkg",
34 | ],
35 | )
36 |
37 | # Tests
38 | # =====
39 |
40 | py_test(
41 | name = "inspector_test",
42 | timeout = "long",
43 | srcs = ["inspector_test.py"],
44 | data = ["@ydf//yggdrasil_decision_forests/test_data"],
45 | shard_count = 10,
46 | deps = [
47 | ":inspector",
48 | "//tensorflow_decision_forests/component/py_tree",
49 | "//tensorflow_decision_forests/keras",
50 | "@pypi_absl_py//:pkg", # /flags
51 | "@pypi_numpy//:pkg",
52 | "@release_or_nightly//:tensorflow_pkg",
53 | "@ydf//yggdrasil_decision_forests/metric:metric_py_proto",
54 | "@ydf//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_py_proto",
55 | ],
56 | )
57 |
58 | py_test(
59 | name = "blob_sequence_test",
60 | srcs = ["blob_sequence_test.py"],
61 | data = ["@ydf//yggdrasil_decision_forests/test_data"],
62 | deps = [
63 | ":blob_sequence",
64 | "@pypi_absl_py//:pkg", # /flags
65 | "@release_or_nightly//:tensorflow_pkg",
66 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto",
67 | ],
68 | )
69 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/inspector/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/inspector/blob_sequence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Blob Sequence reader and writer.
16 |
17 | A blob sequence is a stream (e.g. a file) containing a sequence of blob (i.e.
18 | chunk of bytes). It can be used to store sequence of serialized protos.
19 |
20 | See yggdrasil_decision_forests/utils/blob_sequence.h for more details.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from typing import Iterator, Optional
28 |
29 | import tensorflow as tf
30 |
31 |
32 | class Reader(object):
33 | """Reader of Blob Sequence files.
34 |
35 | Usage example:
36 |
37 | for blob in Reader(path):
38 | print(blob)
39 | """
40 |
41 | def __init__(self, path: str):
42 | self.file_: tf.io.gfile.Gfile = None
43 | self.path_ = None
44 |
45 | if path:
46 | self.open(path)
47 |
48 | def open(self, path: str):
49 | """Open Blob sequence file."""
50 |
51 | self.file_ = tf.io.gfile.GFile(path, "rb")
52 | self.path_ = path
53 |
54 | # Reader header.
55 | magic = self.file_.read(2)
56 | if magic != b"BS":
57 | raise ValueError(f"Invalid blob sequence file {path}")
58 | version = int.from_bytes(self.file_.read(2), byteorder="little")
59 | if version == 0:
60 | reserved = self.file_.read(4)
61 | elif version == 1:
62 | compression = int.from_bytes(self.file_.read(1), byteorder="little")
63 | if compression != 0:
64 | return ValueError(
65 | "The TF-DF inspector does not support this format of model"
66 | " (blob-sequence-v1 with compression). Use the format"
67 | " blob-sequence-v1 without compression instead."
68 | )
69 | reserved = self.file_.read(3)
70 | else:
71 | raise ValueError(
72 | f"Non supported blob sequence version {version} for file {path}. The"
73 | " model was created with a more recent vesion of YDF / TF-DF."
74 | )
75 | del reserved
76 |
77 | def close(self):
78 | self.file_.close()
79 | self.path_ = None
80 | self.file_ = None
81 |
82 | def read(self) -> Optional[bytes]:
83 | """Reads and returns the next blob."""
84 |
85 | raw_length = self.file_.read(4)
86 | if not raw_length:
87 | return None
88 | if len(raw_length) != 4:
89 | raise ValueError(f"Corrupted blob sequence {self.path_}")
90 | length = int.from_bytes(raw_length, byteorder="little")
91 | blob = self.file_.read(length)
92 | if len(blob) != length:
93 | raise ValueError(f"Truncated blob sequence {self.path_}")
94 | return blob
95 |
96 | def __enter__(self):
97 | return self
98 |
99 | def __exit__(self, exc_type, exc_val, exc_tb):
100 | self.file_.close()
101 |
102 | def __iter__(self) -> Iterator[bytes]:
103 | """Iterates overt the BS file content."""
104 |
105 | # Read blobs
106 | while True:
107 | blob = self.read()
108 | if blob is None:
109 | break
110 | yield blob
111 |
112 |
113 | class Writer(object):
114 | """Writer of Blob Sequence files.
115 |
116 | Usage example:
117 |
118 | bs = Writer(path)
119 | bs.write(b"Hello")
120 | bs.write(b"World")
121 | bs.close()
122 | """
123 |
124 | def __enter__(self):
125 | return self
126 |
127 | def __exit__(self, exc_type, exc_val, exc_tb):
128 | self.file_.close()
129 |
130 | def __init__(self, path: Optional[str] = None):
131 | self.file_: tf.io.gfile.Gfile = None
132 | self.path_ = None
133 |
134 | if path:
135 | self.open(path)
136 |
137 | def open(self, path: str):
138 | self.file_ = tf.io.gfile.GFile(path, "wb")
139 | self.path_ = path
140 |
141 | self.file_.write(b"BS")
142 | version = 0
143 | self.file_.write(version.to_bytes(2, byteorder="little"))
144 | self.file_.write(b"\0\0\0\0")
145 |
146 | def write(self, blob: bytes):
147 | self.file_.write(len(blob).to_bytes(4, byteorder="little"))
148 | self.file_.write(blob)
149 |
150 | def close(self):
151 | self.file_.close()
152 | self.path_ = None
153 | self.file_ = None
154 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/inspector/blob_sequence_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 |
21 | from absl import flags
22 | from absl import logging
23 | from absl.testing import parameterized
24 | import tensorflow as tf
25 |
26 | from tensorflow_decision_forests.component.inspector import blob_sequence
27 | from yggdrasil_decision_forests.model.decision_tree import decision_tree_pb2
28 |
29 |
30 | def data_root_path() -> str:
31 | return ""
32 |
33 |
34 | def test_data_path() -> str:
35 | return os.path.join(data_root_path(),
36 | "external/ydf/yggdrasil_decision_forests/test_data")
37 |
38 |
39 | def tmp_path() -> str:
40 | return flags.FLAGS.test_tmpdir
41 |
42 |
43 | def test_model_directory() -> str:
44 | return os.path.join(test_data_path(), "model")
45 |
46 |
47 | class BlogSequenceTest(parameterized.TestCase, tf.test.TestCase):
48 |
49 | def test_base(self):
50 | path = os.path.join(tmp_path(), "tmp.bs")
51 |
52 | writer = blob_sequence.Writer(path)
53 | writer.write(b"HELLO")
54 | writer.write(b"WORLD")
55 | writer.close()
56 |
57 | reader = blob_sequence.Reader(path)
58 | self.assertEqual([b"HELLO", b"WORLD"], list(reader))
59 |
60 | def test_node(self):
61 | path = os.path.join(test_model_directory(), "adult_binary_class_rf",
62 | "nodes-00000-of-00001")
63 | num_nodes = 0
64 | for serialized_node in blob_sequence.Reader(path):
65 | node = decision_tree_pb2.Node.FromString(serialized_node)
66 | if num_nodes <= 2:
67 | logging.info("Node: %s", node)
68 | num_nodes += 1
69 | # Matches the number of nodes in "description.txt".
70 | self.assertEqual(num_nodes, 125578)
71 |
72 |
73 | if __name__ == "__main__":
74 | tf.test.main()
75 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/model_plotter/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | # Data
7 | # ====
8 |
9 | filegroup(
10 | name = "assets",
11 | srcs = ["plotter.js"],
12 | )
13 |
14 | # Libraries
15 | # =========
16 |
17 | py_library(
18 | name = "model_plotter",
19 | srcs = ["model_plotter.py"],
20 | data = ["assets"],
21 | deps = [
22 | "//tensorflow_decision_forests/component/inspector",
23 | "//tensorflow_decision_forests/component/py_tree:condition",
24 | "//tensorflow_decision_forests/component/py_tree:node",
25 | "//tensorflow_decision_forests/component/py_tree:tree",
26 | "//tensorflow_decision_forests/component/py_tree:value",
27 | "@release_or_nightly//:tensorflow_pkg",
28 | ],
29 | )
30 |
31 | # Tests
32 | # =====
33 |
34 | py_test(
35 | name = "model_plotter_test",
36 | srcs = ["model_plotter_test.py"],
37 | deps = [
38 | ":model_plotter",
39 | "//tensorflow_decision_forests/component/py_tree:condition",
40 | "//tensorflow_decision_forests/component/py_tree:dataspec",
41 | "//tensorflow_decision_forests/component/py_tree:node",
42 | "//tensorflow_decision_forests/component/py_tree:tree",
43 | "//tensorflow_decision_forests/component/py_tree:value",
44 | "@pypi_absl_py//:pkg", # /flags
45 | "@release_or_nightly//:tensorflow_pkg",
46 | ],
47 | )
48 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/model_plotter/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/model_plotter/model_plotter_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | from absl import flags
21 | from absl import logging
22 | from absl.testing import parameterized
23 | import tensorflow as tf
24 |
25 | from tensorflow_decision_forests.component.model_plotter import model_plotter
26 | from tensorflow_decision_forests.component.py_tree import condition as condition_lib
27 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib
28 | from tensorflow_decision_forests.component.py_tree import node as node_lib
29 | from tensorflow_decision_forests.component.py_tree import tree as tree_lib
30 | from tensorflow_decision_forests.component.py_tree import value as value_lib
31 |
32 |
33 | class ModelPlotterTest(parameterized.TestCase, tf.test.TestCase):
34 |
35 | def _save_plot(self, plot):
36 | plot_path = os.path.join(self.get_temp_dir(), "plot.html")
37 | logging.info("Plot to %s", plot_path)
38 | with open(plot_path, "w") as f:
39 | f.write(plot)
40 |
41 | def test_empty_tree(self):
42 | tree = tree_lib.Tree(None)
43 | plot = model_plotter.plot_tree(tree=tree)
44 | self._save_plot(plot)
45 |
46 | def test_basic_tree(self):
47 | tree = tree_lib.Tree(
48 | node_lib.NonLeafNode(
49 | condition=condition_lib.NumericalHigherThanCondition(
50 | feature=dataspec_lib.SimpleColumnSpec(
51 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL
52 | ),
53 | threshold=1.5,
54 | missing_evaluation=False,
55 | ),
56 | pos_child=node_lib.LeafNode(
57 | value=value_lib.RegressionValue(
58 | value=5.0, num_examples=10, standard_deviation=1.0
59 | )
60 | ),
61 | neg_child=node_lib.LeafNode(
62 | value=value_lib.ProbabilityValue(
63 | probability=[0.5, 0.4, 0.1], num_examples=10
64 | )
65 | ),
66 | )
67 | )
68 | plot = model_plotter.plot_tree(
69 | tree=tree, display_options=model_plotter.DisplayOptions(node_x_size=150)
70 | )
71 | self._save_plot(plot)
72 |
73 | def test_basic_tree_with_label_classes(self):
74 | tree = tree_lib.Tree(
75 | node_lib.NonLeafNode(
76 | condition=condition_lib.NumericalHigherThanCondition(
77 | feature=dataspec_lib.SimpleColumnSpec(
78 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL
79 | ),
80 | threshold=1.5,
81 | missing_evaluation=False,
82 | ),
83 | pos_child=node_lib.LeafNode(
84 | value=value_lib.RegressionValue(
85 | value=5.0, num_examples=10, standard_deviation=1.0
86 | )
87 | ),
88 | neg_child=node_lib.LeafNode(
89 | value=value_lib.ProbabilityValue(
90 | probability=[0.5, 0.4, 0.1], num_examples=10
91 | )
92 | ),
93 | ),
94 | label_classes=["x", "y", "z"],
95 | )
96 | plot = model_plotter.plot_tree(tree=tree)
97 | self._save_plot(plot)
98 |
99 | def test_basic_tree_with_uplift(self):
100 | tree = tree_lib.Tree(
101 | node_lib.LeafNode(
102 | value=value_lib.UpliftValue(
103 | treatment_effect=[0.5, 0.4], num_examples=10
104 | )
105 | )
106 | )
107 | plot = model_plotter.plot_tree(tree=tree)
108 | self._save_plot(plot)
109 |
110 |
111 | if __name__ == "__main__":
112 | tf.test.main()
113 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/BUILD:
--------------------------------------------------------------------------------
1 | load("@rules_python//python:defs.bzl", "py_library")
2 |
3 | package(
4 | default_visibility = ["//visibility:public"],
5 | licenses = ["notice"],
6 | )
7 |
8 | # All the components
9 | # =========
10 |
11 | py_library(
12 | name = "py_tree",
13 | srcs = ["__init__.py"],
14 | deps = [
15 | ":condition",
16 | ":dataspec",
17 | ":node",
18 | ":objective",
19 | ":tree",
20 | ":value",
21 | ],
22 | )
23 |
24 | # Libraries
25 | # =========
26 |
27 | py_library(
28 | name = "tree",
29 | srcs = ["tree.py"],
30 | deps = [":node"],
31 | )
32 |
33 | py_library(
34 | name = "node",
35 | srcs = ["node.py"],
36 | deps = [
37 | ":condition",
38 | ":value",
39 | "@pypi_absl_py//:pkg", # /logging
40 | "@pypi_numpy//:pkg",
41 | "@pypi_six//:pkg",
42 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto",
43 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto",
44 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto",
45 | ],
46 | )
47 |
48 | py_library(
49 | name = "condition",
50 | srcs = ["condition.py"],
51 | deps = [
52 | ":dataspec",
53 | "@pypi_absl_py//:pkg", # /logging
54 | "@pypi_numpy//:pkg",
55 | "@pypi_six//:pkg",
56 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto",
57 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto",
58 | ],
59 | )
60 |
61 | py_library(
62 | name = "value",
63 | srcs = ["value.py"],
64 | deps = [
65 | "@pypi_numpy//:pkg",
66 | "@pypi_six//:pkg",
67 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto",
68 | ],
69 | )
70 |
71 | py_library(
72 | name = "dataspec",
73 | srcs = ["dataspec.py"],
74 | deps = [
75 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto",
76 | ],
77 | )
78 |
79 | py_library(
80 | name = "objective",
81 | srcs = ["objective.py"],
82 | deps = [
83 | "@pypi_six//:pkg",
84 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto",
85 | ],
86 | )
87 |
88 | # Tests
89 | # =====
90 |
91 | py_test(
92 | name = "tree_test",
93 | srcs = ["tree_test.py"],
94 | deps = [
95 | ":condition",
96 | ":dataspec",
97 | ":node",
98 | ":tree",
99 | ":value",
100 | "@pypi_absl_py//:pkg", # /logging
101 | "@release_or_nightly//:tensorflow_pkg",
102 | ],
103 | )
104 |
105 | py_test(
106 | name = "node_test",
107 | srcs = ["node_test.py"],
108 | deps = [
109 | ":condition",
110 | ":dataspec",
111 | ":node",
112 | ":value",
113 | "@pypi_absl_py//:pkg", # /logging
114 | "@release_or_nightly//:tensorflow_pkg",
115 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto",
116 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto",
117 | ],
118 | )
119 |
120 | py_test(
121 | name = "condition_test",
122 | srcs = ["condition_test.py"],
123 | deps = [
124 | ":condition",
125 | ":dataspec",
126 | "@pypi_absl_py//:pkg", # /logging
127 | "@release_or_nightly//:tensorflow_pkg",
128 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto",
129 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto",
130 | ],
131 | )
132 |
133 | py_test(
134 | name = "dataspec_test",
135 | srcs = ["dataspec_test.py"],
136 | deps = [
137 | ":dataspec",
138 | "@pypi_absl_py//:pkg", # /testing:parameterized
139 | "@release_or_nightly//:tensorflow_pkg",
140 | "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto",
141 | ],
142 | )
143 |
144 | py_test(
145 | name = "value_test",
146 | srcs = ["value_test.py"],
147 | deps = [
148 | ":value",
149 | "@pypi_absl_py//:pkg", # /logging
150 | "@release_or_nightly//:tensorflow_pkg",
151 | "@ydf//yggdrasil_decision_forests/model/decision_tree:decision_tree_py_proto",
152 | ],
153 | )
154 |
155 | py_test(
156 | name = "objective_test",
157 | srcs = ["objective_test.py"],
158 | deps = [
159 | ":dataspec",
160 | ":objective",
161 | "@pypi_absl_py//:pkg", # /logging
162 | "@release_or_nightly//:tensorflow_pkg",
163 | ],
164 | )
165 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/README.md:
--------------------------------------------------------------------------------
1 | This directory contains a python implementation of decision trees, that can be
2 | used to inspect and manipulate the decision forests trained with TF-DF using
3 | python.
4 |
5 | The internal and efficient decision forest representation can be converted back
6 | and forth to this python representation for inspection and even arbitrary
7 | manipulation.
8 |
9 | Note: as one can expect this python representation is more explicit but much
10 | less efficient than the Proto model (also referred as **core model**) used
11 | internally by TF-DF. Complex algorithms should preferably be implemented in C++
12 | on the Proto model.
13 |
14 | A `Tree` is composed of a single `AbstractNode` called the *root*. This node and
15 | its children recursively defines a decision tree.
16 |
17 | If a node is a `LeafNode`, it contains an `AbstractValue` defining the
18 | output/prediction/value of the node. Depending on the tree type, the
19 | `AbstractValue` can be a `ProbabilityValue` or a `RegressionValue`.
20 |
21 | If this node is `NonLeafNode`, it contains an `AbstractCondition`, two children
22 | nodes, and optionally an `AbstractValue` for model interpretation.
23 |
24 | Different implementations of `AbstractCondition` support different types of
25 | conditions e.g. `NumericalHigherThanCondition`, `CategoricalIsInCondition`.
26 |
27 | All objects can be printed using `str(...)`.
28 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Decision trees stored as python objects.
16 |
17 | To be used with the model inspector and model builder.
18 | """
19 |
20 | from tensorflow_decision_forests.component.py_tree import condition
21 | from tensorflow_decision_forests.component.py_tree import dataspec
22 | from tensorflow_decision_forests.component.py_tree import node
23 | from tensorflow_decision_forests.component.py_tree import objective
24 | from tensorflow_decision_forests.component.py_tree import tree
25 | from tensorflow_decision_forests.component.py_tree import value
26 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/condition_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from absl.testing import parameterized
20 | import tensorflow as tf
21 |
22 | from tensorflow_decision_forests.component.py_tree import condition as condition_lib
23 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib
24 | from yggdrasil_decision_forests.dataset import data_spec_pb2
25 | from yggdrasil_decision_forests.model.decision_tree import decision_tree_pb2
26 |
27 |
28 | class ConditionTest(parameterized.TestCase, tf.test.TestCase):
29 |
30 | def test_column_spec_bitmap_to_items_integer(self):
31 | column_spec = data_spec_pb2.Column()
32 | column_spec.categorical.number_of_unique_values = 10
33 | column_spec.categorical.is_already_integerized = True
34 | # b1100101101 => 32Dh
35 | self.assertEqual(
36 | condition_lib.column_spec_bitmap_to_items(column_spec, b"\x2D\x03"),
37 | [0, 2, 3, 5, 8, 9])
38 |
39 | def test_column_spec_bitmap_to_items_string(self):
40 | column_spec = data_spec_pb2.Column()
41 | column_spec.categorical.number_of_unique_values = 10
42 | for i in range(10):
43 | column_spec.categorical.items[f"item_{i}"].index = i
44 | column_spec.categorical.is_already_integerized = False
45 | # 1100101101b => 32Dh
46 | self.assertEqual(
47 | condition_lib.column_spec_bitmap_to_items(column_spec, b"\x2D\x03"),
48 | ["item_0", "item_2", "item_3", "item_5", "item_8", "item_9"])
49 |
50 | def test_core_condition_to_condition_is_missing(self):
51 | core_condition = decision_tree_pb2.NodeCondition()
52 | core_condition.na_value = False
53 | core_condition.attribute = 0
54 | core_condition.condition.na_condition.SetInParent()
55 |
56 | dataspec = data_spec_pb2.DataSpecification()
57 | column_spec = dataspec.columns.add()
58 | column_spec.name = "a"
59 | column_spec.type = dataspec_lib.ColumnType.NUMERICAL
60 |
61 | attribute = dataspec_lib.SimpleColumnSpec("a",
62 | dataspec_lib.ColumnType.NUMERICAL,
63 | 0)
64 | self.assertEqual(
65 | condition_lib.core_condition_to_condition(core_condition, dataspec),
66 | condition_lib.IsMissingInCondition(attribute))
67 |
68 |
69 | if __name__ == "__main__":
70 | tf.test.main()
71 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/dataspec_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import math
20 | from absl.testing import parameterized
21 | import tensorflow as tf
22 |
23 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib
24 | from yggdrasil_decision_forests.dataset import data_spec_pb2
25 |
26 |
27 | def toy_dataspec():
28 | dataspec = data_spec_pb2.DataSpecification()
29 |
30 | f1 = dataspec.columns.add()
31 | f1.name = "f1"
32 | f1.type = data_spec_pb2.ColumnType.NUMERICAL
33 |
34 | f2 = dataspec.columns.add()
35 | f2.name = "f2"
36 | f2.type = data_spec_pb2.ColumnType.CATEGORICAL
37 | f2.categorical.number_of_unique_values = 3
38 | f2.categorical.items[""].index = 0
39 | f2.categorical.items["x"].index = 1
40 | f2.categorical.items["y"].index = 2
41 |
42 | f3 = dataspec.columns.add()
43 | f3.name = "f3"
44 | f3.type = data_spec_pb2.ColumnType.CATEGORICAL
45 | f3.categorical.number_of_unique_values = 3
46 | f3.categorical.is_already_integerized = True
47 |
48 | f4 = dataspec.columns.add()
49 | f4.name = "f4"
50 | f4.type = data_spec_pb2.ColumnType.DISCRETIZED_NUMERICAL
51 | f4.discretized_numerical.boundaries[:] = [0, 1, 2]
52 | return dataspec
53 |
54 |
55 | class DataspecTest(parameterized.TestCase, tf.test.TestCase):
56 |
57 | def test_make_simple_column_spec(self):
58 | self.assertEqual(
59 | dataspec_lib.make_simple_column_spec(toy_dataspec(), 0),
60 | dataspec_lib.SimpleColumnSpec(
61 | name="f1", type=data_spec_pb2.ColumnType.NUMERICAL, col_idx=0))
62 |
63 | def test_categorical_value_idx_to_value(self):
64 | dataspec = toy_dataspec()
65 | self.assertEqual(
66 | dataspec_lib.categorical_value_idx_to_value(dataspec.columns[1], 1),
67 | "x")
68 |
69 | self.assertEqual(
70 | dataspec_lib.categorical_value_idx_to_value(dataspec.columns[2], 1), 1)
71 |
72 | def test_discretized_numerical_to_numerical(self):
73 | column_spec = toy_dataspec().columns[3]
74 | self.assertEqual(
75 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 0), 0 - 1)
76 | self.assertEqual(
77 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 1), 0.5)
78 | self.assertEqual(
79 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 2), 1.5)
80 | self.assertEqual(
81 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 3), 2 + 1)
82 | self.assertTrue(
83 | math.isnan(
84 | dataspec_lib.discretized_numerical_to_numerical(column_spec, 4)))
85 |
86 | def test_categorical_column_dictionary_to_list(self):
87 | dataspec = toy_dataspec()
88 | self.assertEqual(
89 | dataspec_lib.categorical_column_dictionary_to_list(dataspec.columns[1]),
90 | ["", "x", "y"])
91 |
92 | def test_column_name_to_column_idx(self):
93 | dataspec = toy_dataspec()
94 | self.assertEqual(dataspec_lib.column_name_to_column_idx("f1", dataspec), 0)
95 | self.assertEqual(dataspec_lib.column_name_to_column_idx("f2", dataspec), 1)
96 |
97 |
98 | if __name__ == "__main__":
99 | tf.test.main()
100 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/node_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from absl import logging
20 | from absl.testing import parameterized
21 | import tensorflow as tf
22 |
23 | from tensorflow_decision_forests.component.py_tree import condition as condition_lib
24 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib
25 | from tensorflow_decision_forests.component.py_tree import node as node_lib
26 | from tensorflow_decision_forests.component.py_tree import value as value_lib
27 |
28 | from yggdrasil_decision_forests.dataset import data_spec_pb2
29 | from yggdrasil_decision_forests.model.decision_tree import decision_tree_pb2
30 |
31 |
32 | class NodeTest(parameterized.TestCase, tf.test.TestCase):
33 |
34 | def test_leaf_value_integer(self):
35 | node = node_lib.LeafNode(5)
36 | self.assertEqual(node.value, 5)
37 | self.assertIsNone(node.leaf_idx)
38 |
39 | def test_leaf_idx(self):
40 | node = node_lib.LeafNode(5, 2)
41 | self.assertEqual(node.value, 5)
42 | self.assertEqual(node.leaf_idx, 2)
43 |
44 | def test_leaf_regression_value(self):
45 | node = node_lib.LeafNode(
46 | value=value_lib.RegressionValue(
47 | value=5.0, num_examples=10, standard_deviation=1.0
48 | )
49 | )
50 | core_node = decision_tree_pb2.Node(
51 | regressor=decision_tree_pb2.NodeRegressorOutput(top_value=5.0)
52 | )
53 | dist = core_node.regressor.distribution
54 | dist.count = 10.0
55 | dist.sum = 0
56 | dist.sum_squares = 10.0
57 | self.assertEqual(
58 | node_lib.node_to_core_node(node, data_spec_pb2.DataSpecification()),
59 | core_node,
60 | )
61 | logging.info("node:\n%s", node)
62 |
63 | def test_non_leaf_without_children(self):
64 | node = node_lib.NonLeafNode(
65 | condition=condition_lib.NumericalHigherThanCondition(
66 | feature=dataspec_lib.SimpleColumnSpec(
67 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL
68 | ),
69 | threshold=1.5,
70 | missing_evaluation=False,
71 | )
72 | )
73 | logging.info("node:\n%s", node)
74 |
75 | def test_non_leaf_with_children(self):
76 | node = node_lib.NonLeafNode(
77 | condition=condition_lib.NumericalHigherThanCondition(
78 | feature=dataspec_lib.SimpleColumnSpec(
79 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL
80 | ),
81 | threshold=1.5,
82 | missing_evaluation=False,
83 | ),
84 | pos_child=node_lib.LeafNode(
85 | value=value_lib.RegressionValue(
86 | value=5.0, num_examples=10, standard_deviation=1.0
87 | )
88 | ),
89 | neg_child=node_lib.LeafNode(
90 | value=value_lib.ProbabilityValue(
91 | probability=[0.5, 0.4, 0.1], num_examples=10
92 | )
93 | ),
94 | )
95 | logging.info("node:\n%s", node)
96 |
97 | def test_condition_value_and_default_categorical_contains(self):
98 | a = node_lib.ScanStructureAccumulatorContainsInt()
99 | self.assertEqual(a.get_global_imutation_and_num_unique_values(), (1, 2))
100 |
101 | a.add_not_allowed([1, 3])
102 | self.assertEqual(a.get_global_imutation_and_num_unique_values(), (2, 4))
103 |
104 | a.add_not_allowed([2])
105 | self.assertEqual(a.get_global_imutation_and_num_unique_values(), (4, 5))
106 |
107 | a.add_allowed([1, 6])
108 | self.assertEqual(a.get_global_imutation_and_num_unique_values(), (6, 7))
109 |
110 |
111 | if __name__ == "__main__":
112 | tf.test.main()
113 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/objective_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from absl import logging
20 | from absl.testing import parameterized
21 | import tensorflow as tf
22 |
23 | from tensorflow_decision_forests.component.py_tree import objective as objective_lib
24 |
25 | # pylint: disable=g-long-lambda
26 |
27 |
28 | class ObjectiveTest(parameterized.TestCase, tf.test.TestCase):
29 |
30 | def test_classification(self):
31 | objective = objective_lib.ClassificationObjective(
32 | label="label", num_classes=5)
33 | logging.info("objective: %s", objective)
34 |
35 | objective = objective_lib.ClassificationObjective(
36 | label="label", classes=["a", "b"])
37 | logging.info("objective: %s", objective)
38 |
39 | objective = objective_lib.ClassificationObjective(
40 | label="label", classes=["a", "b"])
41 | logging.info("objective: %s", objective)
42 |
43 | objective = objective_lib.ClassificationObjective(
44 | label="label", classes=["a", "b"], num_classes=2)
45 | logging.info("objective: %s", objective)
46 |
47 | def test_classification_errors(self):
48 | self.assertRaises(
49 | ValueError,
50 | lambda: objective_lib.ClassificationObjective(label="label"))
51 | self.assertRaises(
52 | ValueError,
53 | lambda: objective_lib.ClassificationObjective(label="", num_classes=5))
54 | self.assertRaises(
55 | ValueError, lambda: objective_lib.ClassificationObjective(
56 | label="label", classes=["a", "b"], num_classes=5))
57 | self.assertRaises(
58 | ValueError, lambda: objective_lib.ClassificationObjective(
59 | label="label", classes=[]))
60 |
61 | def test_regression(self):
62 | objective = objective_lib.RegressionObjective(label="label")
63 | logging.info("objective: %s", objective)
64 |
65 | def test_ranking(self):
66 | objective = objective_lib.RankingObjective(label="label", group="group")
67 | logging.info("objective: %s", objective)
68 |
69 | def test_numerical_uplift(self):
70 | objective = objective_lib.NumericalUpliftObjective(
71 | label="label", treatment="treatment"
72 | )
73 | logging.info("objective: %s", objective)
74 |
75 | def test_categorical_uplift(self):
76 | objective = objective_lib.CategoricalUpliftObjective(
77 | label="label", treatment="treatment"
78 | )
79 | logging.info("objective: %s", objective)
80 |
81 | def test_uplift_objects_are_not_equal(self):
82 | numerical_objective = objective_lib.NumericalUpliftObjective(
83 | label="label", treatment="treatment"
84 | )
85 | categorical_objective = objective_lib.CategoricalUpliftObjective(
86 | label="label", treatment="treatment"
87 | )
88 | self.assertNotEqual(numerical_objective, categorical_objective)
89 |
90 |
91 | if __name__ == "__main__":
92 | tf.test.main()
93 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A decision tree."""
16 |
17 | from typing import Optional, List
18 |
19 | from tensorflow_decision_forests.component.py_tree import node as node_lib
20 |
21 |
22 | class Tree(object):
23 | """A single decision tree."""
24 |
25 | def __init__(self,
26 | root: Optional[node_lib.AbstractNode],
27 | label_classes: Optional[List[str]] = None):
28 | self._root = root
29 | self._label_classes = label_classes
30 |
31 | @property
32 | def root(self):
33 | return self._root
34 |
35 | @property
36 | def label_classes(self):
37 | return self._label_classes
38 |
39 | def __repr__(self):
40 | """Returns an inline string representation of a tree."""
41 |
42 | root_str = repr(self._root) if self._root else "None"
43 | return f"Tree(root={root_str}, label_classes={self.label_classes})"
44 |
45 | def __str__(self):
46 | # This method target users that try to debug or interpret trees.
47 | return self.pretty()
48 |
49 | def pretty(self, max_depth: Optional[int] = 4) -> str:
50 | """Returns a readable textual representation of the tree.
51 |
52 | Unlike `repr(tree)`, `tree.pretty()` format the representation (line return,
53 | margin, hide class names) to improve readability.
54 |
55 | This representation can be changed and codes should not try to parse the
56 | output of `pretty`. To access programmatically the tree structure, use
57 | `root()`.
58 |
59 | Args:
60 | max_depth: The maximum depth of the nodes to display. Deeper nodes are
61 | skipped and replaced by "...". If not specified, prints the entire tree.
62 |
63 | Returns:
64 | A pretty-string representing the tree.
65 | """
66 |
67 | content = ""
68 | if self._root:
69 | content += self._root.pretty(
70 | prefix="", is_pos=None, depth=1, max_depth=max_depth)
71 | else:
72 | content += "No root\n"
73 | if self._label_classes is not None:
74 | content += f"Label classes: {self.label_classes}\n"
75 | return content
76 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/tree_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from absl import logging
20 | from absl.testing import parameterized
21 | import tensorflow as tf
22 |
23 | from tensorflow_decision_forests.component.py_tree import condition as condition_lib
24 | from tensorflow_decision_forests.component.py_tree import dataspec as dataspec_lib
25 | from tensorflow_decision_forests.component.py_tree import node as node_lib
26 | from tensorflow_decision_forests.component.py_tree import tree as tree_lib
27 | from tensorflow_decision_forests.component.py_tree import value as value_lib
28 |
29 |
30 | class TreeTest(parameterized.TestCase, tf.test.TestCase):
31 |
32 | def test_empty_tree(self):
33 | tree = tree_lib.Tree(None)
34 | logging.info("Tree:\n%s", tree)
35 |
36 | def test_basic_tree(self):
37 | tree = tree_lib.Tree(
38 | node_lib.NonLeafNode(
39 | condition=condition_lib.NumericalHigherThanCondition(
40 | feature=dataspec_lib.SimpleColumnSpec(
41 | name="f1", type=dataspec_lib.ColumnType.NUMERICAL),
42 | threshold=1.5,
43 | missing_evaluation=False),
44 | pos_child=node_lib.LeafNode(
45 | value=value_lib.RegressionValue(
46 | value=5.0, num_examples=10, standard_deviation=1.0)),
47 | neg_child=node_lib.LeafNode(
48 | value=value_lib.ProbabilityValue(
49 | probability=[0.5, 0.4, 0.1], num_examples=10))))
50 |
51 | tree_repr = repr(tree)
52 | logging.info("Tree repr:\n%s", tree_repr)
53 | # The "repr" is a single line that does not contain any line return.
54 | self.assertNotIn("\n", tree_repr)
55 |
56 | logging.info("Tree str:\n%s", tree)
57 |
58 | pretty = tree.pretty()
59 | logging.info("Pretty:\n%s", pretty)
60 |
61 | self.assertEqual(
62 | pretty, """(f1 >= 1.5; miss=False, score=None)
63 | ├─(pos)─ RegressionValue(value=5.0,sd=1.0,n=10)
64 | └─(neg)─ ProbabilityValue([0.5, 0.4, 0.1],n=10)
65 | """)
66 |
67 | def test_stump_with_label_classes(self):
68 | tree = tree_lib.Tree(
69 | node_lib.LeafNode(
70 | value=value_lib.ProbabilityValue(
71 | probability=[0.5, 0.4, 0.1], num_examples=10)),
72 | label_classes=["a", "b", "c"])
73 | logging.info("Tree:\n%s", tree)
74 |
75 |
76 | if __name__ == "__main__":
77 | tf.test.main()
78 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/py_tree/value_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from absl import logging
20 | from absl.testing import parameterized
21 | import tensorflow as tf
22 |
23 | from tensorflow_decision_forests.component.py_tree import value as value_lib
24 | from yggdrasil_decision_forests.model.decision_tree import decision_tree_pb2
25 |
26 |
27 | class ValueTest(parameterized.TestCase, tf.test.TestCase):
28 |
29 | def test_regression(self):
30 | value = value_lib.RegressionValue(
31 | value=5.0, num_examples=10, standard_deviation=1.0
32 | )
33 | logging.info("value:\n%s", value)
34 |
35 | def test_probability(self):
36 | value = value_lib.ProbabilityValue(
37 | probability=[0.5, 0.4, 0.1], num_examples=10
38 | )
39 | logging.info("value:\n%s", value)
40 |
41 | def test_uplift(self):
42 | value = value_lib.UpliftValue(treatment_effect=[1, 2], num_examples=10)
43 | logging.info("value:\n%s", value)
44 |
45 | def test_core_value_to_value_classifier(self):
46 | core_node = decision_tree_pb2.Node()
47 | core_node.classifier.distribution.counts[:] = [0.0, 8.0, 2.0]
48 | core_node.classifier.distribution.sum = 10.0
49 | self.assertEqual(
50 | value_lib.core_value_to_value(core_node),
51 | value_lib.ProbabilityValue(probability=[0.8, 0.2], num_examples=10),
52 | )
53 |
54 | def test_core_value_to_value_regressor(self):
55 | core_node = decision_tree_pb2.Node()
56 | core_node.regressor.top_value = 1.0
57 | core_node.regressor.distribution.sum = 10.0
58 | core_node.regressor.distribution.sum_squares = 20.0
59 | core_node.regressor.distribution.count = 10.0
60 | self.assertEqual(
61 | value_lib.core_value_to_value(core_node),
62 | value_lib.RegressionValue(
63 | value=1.0, num_examples=10, standard_deviation=1.0
64 | ),
65 | )
66 |
67 | def test_core_value_to_value_uplift(self):
68 | core_node = decision_tree_pb2.Node()
69 | core_node.uplift.treatment_effect[:] = [0.0, 8.0, 2.0]
70 | core_node.uplift.sum_weights = 10.0
71 | self.assertEqual(
72 | value_lib.core_value_to_value(core_node),
73 | value_lib.UpliftValue(
74 | treatment_effect=[0.0, 8.0, 2.0], num_examples=10.0
75 | ),
76 | )
77 |
78 |
79 | if __name__ == "__main__":
80 | tf.test.main()
81 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/tuner/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | # Libraries
7 | # =========
8 |
9 | py_library(
10 | name = "tuner",
11 | srcs = ["tuner.py"],
12 | deps = [
13 | "@ydf//yggdrasil_decision_forests/learner/hyperparameters_optimizer:hyperparameters_optimizer_py_proto",
14 | "@ydf//yggdrasil_decision_forests/learner/hyperparameters_optimizer/optimizers:random_py_proto",
15 | "@ydf//yggdrasil_decision_forests/model:hyperparameter_py_proto",
16 | ],
17 | )
18 |
19 | # Tests
20 | # =====
21 |
22 | py_test(
23 | name = "tuner_test",
24 | timeout = "long",
25 | srcs = ["tuner_test.py"],
26 | data = ["@ydf//yggdrasil_decision_forests/test_data"],
27 | shard_count = 4,
28 | deps = [
29 | ":tuner",
30 | "//tensorflow_decision_forests/keras",
31 | "@com_google_protobuf//:python_srcs",
32 | "@pypi_absl_py//:pkg", # /flags
33 | "@pypi_pandas//:pkg",
34 | "@release_or_nightly//:tensorflow_pkg",
35 | "@ydf//yggdrasil_decision_forests/learner:abstract_learner_py_proto",
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/component/tuner/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow Decision Forests Contributions
2 |
3 | The Contrib(ution) directory contains functionalities that are useful for TF-DF
4 | users but which are not central to the TF-DF library. For example, it could
5 | contain decision forest scientific paper implementations, utilities to interact
6 | with other libraries (e.g. decision forests, IO), and model agnostic tools that
7 | do not have another nice place to be.
8 |
9 | Contrib libraries may not be as high-quality up-to-date as the rest of the TF-DF
10 | core library, and updates may rely on contributors outside the TF-DF core team.
11 |
12 | **Rules**
13 |
14 | 1. There are no dependencies from the core library to the contrib libraries.
15 | 1. The contrib libraries are not loaded automatically with the TF-DF library.
16 | 1. A dependency cannot be registered in `configure/setup.py` if it is only used
17 | by a contrib library.
18 | 1. If a contrib library becomes important enough (decided by the TF-DF core
19 | team), it can be moved to the `component` directory.
20 | 1. The contrib directory is not a place to store usage examples.
21 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/example_of_contrib/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | # Libraries
7 | # =========
8 |
9 | py_library(
10 | name = "example_of_contrib",
11 | srcs = [
12 | "__init__.py",
13 | "example_of_contrib.py",
14 | ],
15 | deps = [
16 | "//tensorflow_decision_forests",
17 | ],
18 | )
19 |
20 | # Tests
21 | # =====
22 |
23 | py_test(
24 | name = "example_of_contrib_test",
25 | srcs = ["example_of_contrib_test.py"],
26 | deps = [
27 | ":example_of_contrib",
28 | "@release_or_nightly//:tensorflow_pkg",
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/example_of_contrib/README.md:
--------------------------------------------------------------------------------
1 | # Example of Contrib
2 |
3 | ## Introduction
4 |
5 | The "Example of Contrib" shows how to add a library in `contrib`.
6 |
7 | ## Usage example
8 |
9 | ```python
10 | import tensorflow_decision_forests as tfdf
11 | from tensorflow_decision_forests.contrib import example_of_contrib
12 |
13 | print(example_of_contrib.my_function())
14 | ```
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/example_of_contrib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of Contrib.
16 |
17 | Usage example:
18 |
19 | ```python
20 | import tensorflow_decision_forests as tfdf
21 | from tensorflow_decision_forests.contrib import example_of_contrib
22 |
23 | print(example_of_contrib.my_function())
24 | ```
25 | """
26 |
27 | from tensorflow_decision_forests.contrib.example_of_contrib import example_of_contrib as lib
28 |
29 | my_function = lib.my_function
30 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/example_of_contrib/example_of_contrib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import tensorflow_decision_forests as tfdf
16 |
17 |
18 | def my_function() -> str:
19 | print(f"The current version of tfdf is {tfdf.__version__}")
20 | return "Hello world"
21 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/example_of_contrib/example_of_contrib_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for example_of_contrib."""
16 |
17 | import tensorflow as tf
18 | from tensorflow_decision_forests.contrib import example_of_contrib
19 |
20 |
21 | class ExampleOfContribTest(tf.test.TestCase):
22 |
23 | def test_my_function(self):
24 | self.assertEqual(example_of_contrib.my_function(), "Hello world")
25 |
26 |
27 | if __name__ == "__main__":
28 | tf.test.main()
29 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/scikit_learn_model_converter/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | # Libraries
7 | # =========
8 |
9 | py_library(
10 | name = "scikit_learn_model_converter",
11 | srcs = [
12 | "__init__.py",
13 | "scikit_learn_model_converter.py",
14 | ],
15 | deps = [
16 | "//tensorflow_decision_forests",
17 | "@pypi_scikit_learn//:pkg",
18 | "@release_or_nightly//:tensorflow_pkg",
19 | ],
20 | )
21 |
22 | # Tests
23 | # =====
24 |
25 | py_test(
26 | name = "scikit_learn_model_converter_test",
27 | srcs = ["scikit_learn_model_converter_test.py"],
28 | deps = [
29 | ":scikit_learn_model_converter",
30 | "@pypi_absl_py//:pkg", # /testing:parameterized
31 | "@pypi_numpy//:pkg",
32 | "@pypi_scikit_learn//:pkg",
33 | "@release_or_nightly//:tensorflow_pkg",
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/scikit_learn_model_converter/README.md:
--------------------------------------------------------------------------------
1 | # Scikit-Learn Converter
2 |
3 | ## Introduction
4 |
5 | **Scikit-Learn Model Converter** converts Scikit-Learn tree-based models (e.g.
6 | Random Forests) into TensorFlow models compatible with the whole TensorFlow
7 | ecosystem (e.g. Keras composition, SavedModel format, TF-Serving).
8 |
9 | The converted model is also a TensorFlow Decision Forests (TF-DF) model
10 | compatible with all of TF-DF functionalities (e.g. plotting, c++ inference API).
11 |
12 | ## Currently supported models
13 |
14 | * `sklearn.tree.DecisionTreeClassifier`
15 | * `sklearn.tree.DecisionTreeRegressor`
16 | * `sklearn.tree.ExtraTreeClassifier`
17 | * `sklearn.tree.ExtraTreeRegressor`
18 | * `sklearn.ensemble.RandomForestClassifier`
19 | * `sklearn.ensemble.RandomForestRegressor`
20 | * `sklearn.ensemble.ExtraTreesClassifier`
21 | * `sklearn.ensemble.ExtraTreesRegressor`
22 | * `sklearn.ensemble.GradientBoostingRegressor`
23 |
24 | ## Usage example
25 |
26 | ```python
27 | from sklearn import datasets
28 | from sklearn import tree
29 | import tensorflow as tf
30 | from tensorflow_decision_forests.contrib import scikit_learn_model_converter
31 |
32 | # Train your model in scikit-learn
33 | X, y = datasets.make_classification()
34 | sklearn_model = tree.DecisionTreeClassifier().fit(X, y)
35 |
36 | # Convert to tensorflow and predict
37 | tensorflow_model = scikit_learn_model_converter.convert(sklearn_model)
38 | y_pred = tensorflow_model.predict(tf.constant(X))
39 | ```
40 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/scikit_learn_model_converter/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tools for converting scikit-learn tree-based models to TFDF models.
16 |
17 | This module converts Scikit-Learn tree-based models (e.g.
18 | Random Forests) into TensorFlow models compatible with the whole TensorFlow
19 | ecosystem (e.g. Keras composition, SavedModel format, TF-Serving).
20 |
21 | The converted model is also a TensorFlow Decision Forests (TF-DF) model
22 | compatible with all of TF-DF functionalities (e.g. plotting, c++ inference API).
23 |
24 | Example usage:
25 |
26 | ```
27 | from sklearn import datasets
28 | from sklearn import tree
29 | import tensorflow as tf
30 | from tensorflow_decision_forests.contrib import scikit_learn_model_converter
31 |
32 | # Train your model in scikit-learn
33 | X, y = datasets.make_classification()
34 | sklearn_model = tree.DecisionTreeClassifier().fit(X, y)
35 |
36 | # Convert to tensorflow and predict
37 | tensorflow_model = scikit_learn_model_converter.convert(sklearn_model)
38 | y_pred = tensorflow_model.predict(tf.constant(X))
39 | ```
40 |
41 | """
42 |
43 | from tensorflow_decision_forests.contrib.scikit_learn_model_converter import scikit_learn_model_converter as lib
44 |
45 | convert = lib.convert
46 | convert_sklearn_tree_to_tfdf_pytree = lib.convert_sklearn_tree_to_tfdf_pytree
47 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/training_preprocessing/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | # Libraries
7 | # =========
8 |
9 | py_library(
10 | name = "training_preprocessing",
11 | srcs = [
12 | "__init__.py",
13 | "training_preprocessing.py",
14 | ],
15 | deps = [
16 | "//tensorflow_decision_forests",
17 | "@release_or_nightly//:tensorflow_pkg",
18 | ],
19 | )
20 |
21 | # Tests
22 | # =====
23 |
24 | py_test(
25 | name = "training_preprocessing_test",
26 | srcs = ["training_preprocessing_test.py"],
27 | deps = [
28 | ":training_preprocessing",
29 | "//tensorflow_decision_forests",
30 | "@pypi_absl_py//:pkg", # /logging
31 | "@pypi_numpy//:pkg",
32 | "@release_or_nightly//:tensorflow_pkg",
33 | ],
34 | )
35 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/contrib/training_preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for jointly preprocessing training data, labels and sample weights."""
16 |
17 | from tensorflow_decision_forests.contrib.training_preprocessing import training_preprocessing
18 |
19 | multitask_label_presence_processing = (
20 | training_preprocessing.multitask_label_presence_processing
21 | )
22 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Decision Forest in a Keras Model.
16 |
17 | Usage example:
18 |
19 | ```python
20 | import tensorflow_decision_forests as tfdf
21 | import pandas as pd
22 |
23 | # Load the dataset in a Pandas dataframe.
24 | train_df = pd.read_csv("project/train.csv")
25 | test_df = pd.read_csv("project/test.csv")
26 |
27 | # Convert the dataset into a TensorFlow dataset.
28 | train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label")
29 | test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="my_label")
30 |
31 | # Train the model.
32 | model = tfdf.keras.RandomForestModel()
33 | model.fit(train_ds)
34 |
35 | # Evaluate the model on another dataset.
36 | model.evaluate(test_ds)
37 |
38 | # Show information about the model
39 | model.summary()
40 |
41 | # Export the model with the TF.SavedModel format.
42 | model.save("/path/to/my/model")
43 |
44 | # ...
45 |
46 | # Load a model: it loads as a generic keras model.
47 | loaded_model = tf_keras.models.load_model("/path/to/my/model")
48 | ```
49 | """
50 |
51 | from typing import Callable, List
52 |
53 | from tensorflow_decision_forests.keras import core
54 | from tensorflow_decision_forests.keras import wrappers
55 |
56 | # Utility classes
57 | CoreModel = core.CoreModel
58 | FeatureSemantic = core.FeatureSemantic
59 | Task = core.Task
60 | FeatureUsage = core.FeatureUsage
61 | AdvancedArguments = core.AdvancedArguments
62 | MultiTaskItem = core.MultiTaskItem
63 | Monotonic = core.Monotonic
64 |
65 | # Learning algorithm (called Models in Keras).
66 |
67 |
68 | class RandomForestModel(wrappers.RandomForestModel):
69 | pass
70 |
71 |
72 | class GradientBoostedTreesModel(wrappers.GradientBoostedTreesModel):
73 | pass
74 |
75 |
76 | class CartModel(wrappers.CartModel):
77 | pass
78 |
79 |
80 | class DistributedGradientBoostedTreesModel(
81 | wrappers.DistributedGradientBoostedTreesModel):
82 | pass
83 |
84 |
85 | def get_all_models() -> List[Callable[[], CoreModel]]:
86 | """Gets the lists of all the available models."""
87 | return [
88 | RandomForestModel, GradientBoostedTreesModel, CartModel,
89 | DistributedGradientBoostedTreesModel
90 | ]
91 |
92 |
93 | # Utilities
94 | pd_dataframe_to_tf_dataset = core.pd_dataframe_to_tf_dataset
95 | get_worker_idx_and_num_workers = core.get_worker_idx_and_num_workers
96 | build_default_input_model_signature = core.tf_core.build_default_input_model_signature
97 | build_default_feature_signature = core.tf_core.build_default_feature_signature
98 | set_training_logs_redirection = core.cc_logging.set_training_logs_redirection
99 | yggdrasil_model_to_keras_model = core.yggdrasil_model_to_keras_model
100 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/keras_internal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Access to Keras function with a different internal and external path."""
16 |
17 | from tf_keras.src.engine import data_adapter as _data_adapter
18 | from tf_keras.src.models import Functional
19 | from tf_keras.layers import DenseFeatures
20 | from tf_keras.src.utils.dataset_creator import DatasetCreator
21 |
22 |
23 | unpack_x_y_sample_weight = _data_adapter.unpack_x_y_sample_weight
24 | get_data_handler = _data_adapter.get_data_handler
25 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/keras_internal_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import tensorflow as tf
16 | from tensorflow_decision_forests.keras import keras_internal
17 |
18 |
19 | # Does nothing. Ensures keras_internal can be loaded.
20 |
21 | if __name__ == "__main__":
22 | tf.test.main()
23 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/test_runner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Loads and runs a SavedModel generated by :keras_test on the Adult dataset.
16 |
17 | This binary tests the model in a clean tensorflow instance (without pre-loaded
18 | functions and resources) and does not require non-canonical tf ops.
19 | """
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import enum
25 | from typing import Dict
26 |
27 | from absl import app
28 | from absl import flags
29 | from absl import logging
30 | import numpy as np
31 | import pandas as pd
32 | import tensorflow as tf
33 | import tf_keras
34 |
35 | # Inject the inference ops only.
36 | from tensorflow_decision_forests.tensorflow.ops.inference import op # pylint: disable=unused-import
37 |
38 | flags.DEFINE_string("model_path", "", "Path to the SavedModel model to load.")
39 |
40 | flags.DEFINE_string("dataset_path", "", "Path to adult dataset.")
41 |
42 | FLAGS = flags.FLAGS
43 |
44 |
45 | def main(argv):
46 | del argv
47 |
48 | logging.info("Load model")
49 | model = tf_keras.models.load_model(FLAGS.model_path)
50 |
51 | logging.info("Load dataset")
52 | pd_dataset = load_dataset()
53 | tf_dataset = dataframe_to_tf_dataset(pd_dataset, "income")
54 |
55 | logging.info("Evaluate model")
56 | evaluation = model.evaluate(tf_dataset)
57 | logging.info("evaluation: %s", evaluation)
58 |
59 | logging.info("Done (unload model)")
60 |
61 |
62 | # The following code is similar to the simple_ml_keras_test.
63 |
64 |
65 | class Semantic(enum.Enum):
66 | NUMERICAL = 1
67 | CATEGORICAL = 2
68 |
69 |
70 | def dataframe_to_tf_dataset(dataframe: pd.DataFrame,
71 | label: str) -> tf.data.Dataset:
72 |
73 | dataset = tf.data.Dataset.from_tensor_slices(
74 | (dict(dataframe.drop(label, axis=1)), dataframe[label].values))
75 | return dataset.batch(64)
76 |
77 |
78 | def load_dataset() -> pd.DataFrame:
79 | """Loads the dataset in a pandas dataframe."""
80 |
81 | dataset = pd.read_csv(flags.FLAGS.dataset_path)
82 | label = "income"
83 |
84 | def clean(ds):
85 | ds[label] = np.where(ds[label] == ">50K", 1, 0)
86 | return ds
87 |
88 | dataset = clean(dataset)
89 | return auto_dataset(dataset, label)
90 |
91 |
92 | def infer_semantic_from_dataframe(dataset: pd.DataFrame) -> Dict[str, Semantic]:
93 | """Infers the semantic of the columns in a pandas dataframe."""
94 |
95 | semantics = {}
96 | for col in dataset.columns:
97 | dtype = dataset[col].dtype
98 | if dtype in [float, np.int16, np.int32, np.int64, int]:
99 | semantics[col] = Semantic.NUMERICAL
100 | elif dtype in [str, object]:
101 | semantics[col] = Semantic.CATEGORICAL
102 | else:
103 | raise Exception(
104 | "Cannot infer semantic for column \"{}\" with dtype={}".format(
105 | col, dtype))
106 |
107 | return semantics
108 |
109 |
110 | def auto_dataset(dataset, label) -> pd.DataFrame:
111 | """Prepares a dataset object."""
112 |
113 | semantics = infer_semantic_from_dataframe(dataset)
114 |
115 | def extract_by_semantic(semantic):
116 | return [k for k, s in semantics.items() if s == semantic and k != label]
117 |
118 | categorical_features = extract_by_semantic(Semantic.CATEGORICAL)
119 | numerical_features = extract_by_semantic(Semantic.NUMERICAL)
120 |
121 | def clean(dataset):
122 | for key in categorical_features:
123 | dataset[key] = dataset[key].fillna("")
124 | return dataset[categorical_features + numerical_features + [label]]
125 |
126 | return clean(dataset)
127 |
128 |
129 | if __name__ == "__main__":
130 | app.run(main)
131 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/wrapper/BUILD:
--------------------------------------------------------------------------------
1 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
2 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_test")
3 |
4 | package(
5 | default_visibility = ["//visibility:public"],
6 | licenses = ["notice"],
7 | )
8 |
9 | exports_files(["wrapper_main.cc"])
10 |
11 | # Bzl Libraries
12 | # =============
13 |
14 | bzl_library(
15 | name = "wrapper_bzl",
16 | srcs = ["wrapper.bzl"],
17 | deps = [
18 | "//tensorflow_decision_forests/tensorflow:utils_bzl",
19 | "@org_tensorflow//tensorflow:tensorflow_bzl",
20 | ],
21 | )
22 |
23 | # Binaries
24 | # ========
25 |
26 | # Libraries
27 | # =========
28 |
29 | cc_library(
30 | name = "wrapper",
31 | srcs = ["wrapper.cc"],
32 | hdrs = ["wrapper.h"],
33 | deps = [
34 | "@com_google_absl//absl/status:statusor",
35 | "@com_google_absl//absl/strings",
36 | "@ydf//yggdrasil_decision_forests/learner:learner_library",
37 | "@ydf//yggdrasil_decision_forests/utils:status_macros",
38 | ],
39 | )
40 |
41 | # Tests
42 | # =====
43 |
44 | tf_cc_test(
45 | name = "wrapper_test",
46 | srcs = ["wrapper_test.cc"],
47 | linkstatic = True,
48 | deps = [
49 | ":wrapper",
50 | "@com_google_googletest//:gtest_main",
51 | "@ydf//yggdrasil_decision_forests/learner/gradient_boosted_trees",
52 | "@ydf//yggdrasil_decision_forests/learner/random_forest",
53 | ],
54 | )
55 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/wrapper/wrapper.bzl:
--------------------------------------------------------------------------------
1 | """Rule generation utilities."""
2 |
3 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows", "tf_binary_additional_srcs", "tf_cc_binary", "tf_copts")
4 | load("//tensorflow_decision_forests/tensorflow:utils.bzl", "rpath_linkopts_to_tensorflow")
5 |
6 | def py_wrap_yggdrasil_learners(
7 | name = None,
8 | learner_deps = []):
9 | """Creates Keras wrappers around Yggdrasil Decision Forest (YDF) learners.
10 |
11 | Creates a py_library called "{name}" and containing the file "{name}.py".
12 | This library introduces a TensorFlow Decision Forests (TFDF) Keras class
13 | wrapping for each YDF learner defined in "learner_deps". The constructor of
14 | these classes contains a argument for the learner generic hyper-parameter.
15 |
16 | For example, if "learner_deps" contains a c++ dependency that register a
17 | learner with a key equal to "RANDOM_FOREST", the wrapper will create a
18 | python class called "RandomForestModel" deriving the base TFDF model class.
19 |
20 | Args:
21 | name: Name of the rule.
22 | learner_deps: List of dependencies linking Yggdrasil Decision Forest
23 | learners.
24 | """
25 |
26 | # Absolute path to the wrapper generator directory.
27 | wrapper_package = "//tensorflow_decision_forests/keras/wrapper"
28 |
29 | # Filename of the wrapper generator source code in the user package.
30 | local_cc_main = name + "_wrapper_main.cc"
31 |
32 | # Target name of the wrapper generator binary.
33 | wrapper_name = name + "_wrapper_main"
34 |
35 | # Target name of the command running the wrapper generator.
36 | run_wrapper_name = name + "_run_wrapper"
37 |
38 | # Copy the wrapper main source code to the user package.
39 | native.genrule(
40 | name = name + "_copy_cc_main",
41 | outs = [local_cc_main],
42 | srcs = [wrapper_package + ":wrapper_main.cc"],
43 | cmd = "cp $< $@",
44 | )
45 |
46 | # Compiles the wrapper binary.
47 | # TODO: Find way to link from pypi.
48 | # Note: This rule will compile a small part of TF.
49 | tf_cc_binary(
50 | name = wrapper_name,
51 | copts = tf_copts(),
52 | linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + rpath_linkopts_to_tensorflow(wrapper_name),
53 | srcs = [":" + local_cc_main],
54 | deps = [
55 | wrapper_package + ":wrapper",
56 | ] + learner_deps,
57 | linkstatic = 1,
58 | )
59 |
60 | # Runs the wrapper binary and generate the wrapper .py source code.
61 | native.genrule(
62 | name = run_wrapper_name,
63 | srcs = [],
64 | outs = [name + ".py"],
65 | cmd = "$(location " + wrapper_name + ") > \"$@\"",
66 | tools = [":" + wrapper_name] + tf_binary_additional_srcs(),
67 | )
68 |
69 | # Python library around the generated .py source code.
70 | native.py_library(
71 | name = name,
72 | srcs = [name + ".py"],
73 | srcs_version = "PY3",
74 | deps = [
75 | "//tensorflow_decision_forests/keras:core",
76 | "//tensorflow_decision_forests/component/tuner",
77 | "@release_or_nightly//:tensorflow_pkg",
78 | "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto",
79 | "@ydf//yggdrasil_decision_forests/learner:abstract_learner_py_proto",
80 | ],
81 | data = [":" + run_wrapper_name, ":" + wrapper_name],
82 | )
83 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/wrapper/wrapper.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2021 Google LLC.
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * https://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | #ifndef TENSORFLOW_DECISION_FORESTS_KERAS_WRAPPER_WRAPPER_H_
17 | #define TENSORFLOW_DECISION_FORESTS_KERAS_WRAPPER_WRAPPER_H_
18 |
19 | #include
20 |
21 | #include "absl/status/statusor.h"
22 |
23 | namespace tensorflow {
24 | namespace decision_forests {
25 |
26 | // Creates the python code source that contains wrapping classes for all linked
27 | // learners.
28 | absl::StatusOr GenKerasPythonWrapper();
29 |
30 | // Returns Python class name from the learner key.
31 | std::string LearnerKeyToClassName(absl::string_view key);
32 |
33 | // Formats a text into python documentation.
34 | //
35 | // Do the following transformations:
36 | // - Wraps lines to "max_char_per_lines" characters (while carrying leading
37 | // space).
38 | // - Add leading spaces (leading_spaces_first_line and
39 | // leading_spaces_next_lines).
40 | // - Detect and format bullet lists.
41 | std::string FormatDocumentation(absl::string_view raw,
42 | int leading_spaces_first_line,
43 | int leading_spaces_next_lines,
44 | int max_char_per_lines = 80);
45 |
46 | // Gets the number of leading spaces of a string.
47 | int NumLeadingSpaces(absl::string_view text);
48 |
49 | } // namespace decision_forests
50 | } // namespace tensorflow
51 |
52 | #endif // TENSORFLOW_DECISION_FORESTS_KERAS_WRAPPER_WRAPPER_H_
53 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/wrapper/wrapper_main.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2021 Google LLC.
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * https://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | // Outputs python source code defining python Keras class wrapper around
17 | // Yggdrasil Decision Forest learners.
18 |
19 | #include "tensorflow_decision_forests/keras/wrapper/wrapper.h"
20 |
21 | int main(int argc, char* argv[]) {
22 | std::cout << tensorflow::decision_forests::GenKerasPythonWrapper().value();
23 | return 0;
24 | }
25 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/keras/wrapper/wrapper_test.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2021 Google LLC.
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * https://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | #include "tensorflow_decision_forests/keras/wrapper/wrapper.h"
17 |
18 | #include "gtest/gtest.h"
19 |
20 | namespace tensorflow {
21 | namespace decision_forests {
22 | namespace {
23 |
24 | TEST(KerasLearnerWrappers, LearnerKeyToClassName) {
25 | EXPECT_EQ(LearnerKeyToClassName("RANDOM_FOREST"), "RandomForestModel");
26 | }
27 |
28 | TEST(KerasLearnerWrappers, Base) {
29 | std::cout << GenKerasPythonWrapper().value() << std::endl;
30 | }
31 |
32 | TEST(KerasLearnerWrappers, FormatDocumentation) {
33 | const auto formatted = FormatDocumentation(R"(AAA AAA AAA AAA AAA.
34 | AAA AAA AAA AAA.
35 | - AAA AAA AAA AAA.
36 | - AAA AAA AAA AAA.
37 | AAA AAA AAA AAA.
38 | AAA AAA AAA AAA.)",
39 | /*leading_spaces_first_line=*/4,
40 | /*leading_spaces_next_lines=*/6,
41 | /*max_char_per_lines=*/20);
42 | EXPECT_EQ(formatted, R"( AAA AAA AAA AAA
43 | AAA.
44 | AAA AAA AAA
45 | AAA.
46 | - AAA AAA AAA
47 | AAA.
48 | - AAA AAA AAA
49 | AAA.
50 | AAA AAA AAA
51 | AAA.
52 | AAA AAA
53 | AAA AAA.
54 | )");
55 | }
56 |
57 | TEST(KerasLearnerWrappers, NumLeadingSpaces) {
58 | EXPECT_EQ(NumLeadingSpaces(""), 0);
59 | EXPECT_EQ(NumLeadingSpaces(" "), 1);
60 | EXPECT_EQ(NumLeadingSpaces(" "), 2);
61 | EXPECT_EQ(NumLeadingSpaces(" HELLO "), 2);
62 | }
63 |
64 | } // namespace
65 | } // namespace decision_forests
66 | } // namespace tensorflow
67 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/check_version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Check that version of TensorFlow is compatible with TF-DF."""
16 |
17 | import logging
18 | import tensorflow as tf
19 |
20 |
21 | def check_version(tf_df_version,
22 | compatible_tf_versions,
23 | tf_version=None,
24 | external_logic=False):
25 | """Checks the compatibility of the TF version.
26 |
27 | Prints a warning message and return False in care of likely incompatible
28 | versions.
29 | """
30 |
31 | if not external_logic:
32 | pass
33 |
34 | if tf_version is None:
35 | tf_version = tf.__version__
36 | if tf_version not in compatible_tf_versions:
37 | logging.warning(
38 | "TensorFlow Decision Forests %s is compatible with the following "
39 | "TensorFlow Versions: %s. However, TensorFlow %s was detected. "
40 | "This can cause issues with the TF API and symbols in the custom C++ "
41 | "ops. See the TF and TF-DF compatibility table at "
42 | "https://github.com/tensorflow/decision-forests/blob/main/documentation/known_issues.md#compatibility-table.",
43 | tf_df_version, compatible_tf_versions, tf_version)
44 | return False
45 | return True
46 |
47 |
48 | def info_fail_to_load_custom_op(exception, filename):
49 | logging.warning(
50 | "Failure to load the %s custom c++ tensorflow ops. "
51 | "This error is likely caused the version of TensorFlow and "
52 | "TensorFlow Decision Forests are not compatible. Full error:"
53 | "%s", filename, exception)
54 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/check_version_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import tensorflow as tf
16 |
17 | from tensorflow_decision_forests.tensorflow import check_version
18 |
19 |
20 | class CheckVersionTest(tf.test.TestCase):
21 |
22 | def test_base(self):
23 |
24 | tf_df_version = "1.2.3" # Does not matter.
25 | self.assertTrue(
26 | check_version.check_version(
27 | tf_df_version, ["2.6.0", "2.6.1"], "2.6.0", external_logic=True))
28 | self.assertFalse(
29 | check_version.check_version(
30 | tf_df_version, ["2.6.0", "2.6.1"],
31 | "2.8.0-dev20211105",
32 | external_logic=True))
33 | self.assertFalse(
34 | check_version.check_version(
35 | tf_df_version, ["2.6.0", "2.6.1"], "2.7.0-rc1",
36 | external_logic=True))
37 |
38 |
39 | if __name__ == "__main__":
40 | tf.test.main()
41 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/distribute/BUILD:
--------------------------------------------------------------------------------
1 | # Implementation of the Yggdrasil Distribute API using TensorFlow Distribution Strategies.
2 |
3 | package(
4 | default_visibility = ["//visibility:public"],
5 | licenses = ["notice"],
6 | )
7 |
8 | # Worker binaries
9 | # ===============
10 |
11 | py_binary(
12 | name = "tensorflow_std_server_py",
13 | srcs = ["tensorflow_std_server.py"],
14 | main = "tensorflow_std_server.py",
15 | deps = [
16 | "//tensorflow_decision_forests",
17 | "@pypi_absl_py//:pkg", # :app
18 | "@release_or_nightly//:tensorflow_pkg",
19 | ],
20 | )
21 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/distribute/tensorflow_std_server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Python worker for ParameterServerStrategy.
16 |
17 | This binary is a distributed training worker. When using distributed training,
18 | this worker (or the c++ worker defined in the same BUILD file) should run on
19 | worker machines. The configuration of the worker and chief is done
20 | though environement variable. See examples/distributed_training.py
21 |
22 | When possible, use the c++ worker instead.
23 | """
24 |
25 | from absl import app
26 | from absl import flags
27 | from absl import logging
28 |
29 | import tensorflow as tf
30 | import tensorflow_decision_forests as tfdf # pylint: disable=unused-import
31 |
32 | FLAGS = flags.FLAGS
33 |
34 | flags.DEFINE_string("job_name", "worker", "")
35 | flags.DEFINE_string("protocol", "grpc", "")
36 | flags.DEFINE_integer("task_index", 0, "")
37 |
38 |
39 | def main(argv):
40 | if len(argv) > 1:
41 | raise app.UsageError("Too many command-line arguments.")
42 |
43 | logging.info("Starting worker")
44 | cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
45 | server = tf.distribute.Server(
46 | cluster_resolver.cluster_spec(),
47 | protocol=FLAGS.protocol,
48 | job_name=FLAGS.job_name,
49 | task_index=FLAGS.task_index)
50 |
51 | logging.info("Worker is running")
52 | server.join()
53 |
54 | logging.info("Shutting down worker")
55 |
56 |
57 | if __name__ == "__main__":
58 | app.run(main)
59 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/inference/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/inference/op.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from tensorflow_decision_forests.tensorflow.ops.inference.op_dynamic import *
16 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/inference/op_dynamic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from tensorflow_decision_forests.tensorflow import check_version
16 | import tensorflow as tf
17 | from tensorflow.python.platform import resource_loader
18 | import sys
19 |
20 | try:
21 | ops = tf.load_op_library(resource_loader.get_path_to_datafile("inference.so"))
22 | except Exception as e:
23 | check_version.info_fail_to_load_custom_op(e, "inference.so")
24 | raise e
25 |
26 | # Importing all the symbols.
27 | module = sys.modules[__name__]
28 | for name, value in ops.__dict__.items():
29 | if "__" in name:
30 | continue
31 | setattr(module, name, value)
32 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/training/kernel.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2021 Google LLC.
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * https://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | #ifndef TENSORFLOW_DECISION_FORESTS_TENSORFLOW_OPS_TRAINING_TRAINING_H_
17 | #define TENSORFLOW_DECISION_FORESTS_TENSORFLOW_OPS_TRAINING_TRAINING_H_
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #include "absl/status/status.h"
26 | #include "absl/status/statusor.h"
27 | #include "absl/strings/string_view.h"
28 | #include "tensorflow/core/framework/op_kernel.h"
29 | #include "tensorflow/core/framework/resource_mgr.h"
30 | #include "yggdrasil_decision_forests/model/abstract_model.h"
31 |
32 | namespace tensorflow_decision_forests {
33 | namespace ops {
34 |
35 | // Status of long running operations.
36 | enum LongRunningProcessStatus {
37 | kInProgress = 0,
38 | kSuccess = 1,
39 | };
40 |
41 | // Starts a long running process. Returns the id of the process.
42 | absl::StatusOr StartLongRunningProcess(
43 | ::tensorflow::OpKernelContext* ctx, std::function&& call);
44 |
45 | // Checks the status of a long running process. If the returned status of a
46 | // process is success or failure, the status of the process should not be
47 | // queried again.
48 | absl::StatusOr GetLongRunningProcessStatus(
49 | ::tensorflow::OpKernelContext* ctx, int32_t process_id);
50 |
51 | // If TFDF_STOP_TRAINING_ON_INTERRUPT is set, model training is interrupted with
52 | // the "set_stop_training_trigger()" API when the process receives a "program
53 | // interrupt" signal (i.e. SIGINT).
54 | //
55 | // An interrupted model is a valid model. But its quality is likely inferior to
56 | // a fully trained model. See "set_stop_training_trigger" for details about the
57 | // effect of the interruption.
58 | //
59 | #ifdef TFDF_STOP_TRAINING_ON_INTERRUPT
60 |
61 | // The logic in this namespace is used to interrupt the training of a model (by
62 | // setting "stop_training_trigger=True") when receiving an interruption (e.g.
63 | // the user pressed ctrl+c).
64 | namespace interruption {
65 | // Should the current training learners be stopped?
66 | inline std::atomic stop_training;
67 |
68 | // The interruption signal handler to restore when all the learners are done
69 | // training.
70 | inline void (*previous_signal_handler)(int);
71 |
72 | // Number of learners training.
73 | inline std::atomic active_learners{0};
74 |
75 | inline void StopTrainingSignalHandler(int signal) { stop_training = true; }
76 |
77 | // Enables the interruption listener.
78 | absl::Status EnableUserInterruption();
79 |
80 | // Disable the interruption listener
81 | absl::Status DisableUserInterruption();
82 |
83 | } // namespace interruption
84 | #endif
85 |
86 | // A non-compiled Yggdrasil model.
87 | class YggdrasilModelContainer : public tensorflow::ResourceBase {
88 | public:
89 | ~YggdrasilModelContainer() override = default;
90 |
91 | std::string DebugString() const override { return "YggdrasilModelContainer"; }
92 |
93 | absl::Status LoadModel(absl::string_view model_path);
94 |
95 | tensorflow::int64 MemoryUsed() const override {
96 | return approximate_model_size_in_memory_;
97 | }
98 |
99 | std::unique_ptr*
100 | mutable_model() {
101 | return &model_;
102 | }
103 |
104 | const yggdrasil_decision_forests::model::AbstractModel& model() {
105 | return *model_;
106 | }
107 |
108 | int num_label_classes() const { return num_label_classes_; }
109 |
110 | const std::vector& output_class_representation() const {
111 | return output_class_representation_;
112 | }
113 |
114 | private:
115 | // The model.
116 | std::unique_ptr model_;
117 |
118 | // Number of output classes. This information is contained in the model, but
119 | // cached for fast access.
120 | int num_label_classes_ = -1;
121 |
122 | // String representation of the output classes.
123 | std::vector output_class_representation_;
124 |
125 | // Approximation of the model size in memory.
126 | int64_t approximate_model_size_in_memory_ = 0;
127 | };
128 |
129 | } // namespace ops
130 | } // namespace tensorflow_decision_forests
131 |
132 | #endif // THIRD_PARTY_TENSORFLOW_DECISION_FORESTS_TENSORFLOW_OPS_TRAINING_H_
133 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/training/op.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from tensorflow_decision_forests.tensorflow.ops.training.op_dynamic import *
16 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/training/op_dynamic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from tensorflow_decision_forests.tensorflow import check_version
16 | import tensorflow as tf
17 | from tensorflow.python.platform import resource_loader
18 | import sys
19 |
20 | try:
21 | ops = tf.load_op_library(resource_loader.get_path_to_datafile("training.so"))
22 | except Exception as e:
23 | check_version.info_fail_to_load_custom_op(e, "training.so")
24 | raise e
25 |
26 | # Importing all the symbols.
27 | module = sys.modules[__name__]
28 | for name, value in ops.__dict__.items():
29 | if "__" in name:
30 | continue
31 | setattr(module, name, value)
32 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/training/op_grpc_worker.cc:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2021 Google LLC.
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * https://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | #include "tensorflow/core/framework/common_shape_fns.h"
17 | #include "tensorflow/core/framework/op.h"
18 | #include "tensorflow/core/framework/shape_inference.h"
19 |
20 | namespace tensorflow {
21 |
22 | // Creates a YDF GRPC Worker server. If a server with the same key already
23 | // exists, this is a no-op. Returns the port of the server.
24 | //
25 | // Args:
26 | // "key": Key of the server. Only one server with a given key can exist in a
27 | // session.
28 | // "force_ydf_port": Port for YDF to use. The chief and the workers should
29 | // be able to communicate thought this port. If -1, an available port
30 | // is automatically selected.
31 | //
32 | // Output:
33 | // port: Port of the GRPC server. If force_ydf_port is set, returns
34 | // "force_ydf_port".
35 | REGISTER_OP("SimpleMLCreateYDFGRPCWorker")
36 | .SetIsStateful()
37 | .Attr("key: int")
38 | .Attr("force_ydf_port: int = -1")
39 | .Output("port: int32")
40 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
41 | c->set_output(0, c->Scalar());
42 | return absl::OkStatus();
43 | });
44 |
45 | // Indicates to the GRPC manager that the address of a worker have changed.
46 | // This function takes change to update the other workers for inter worker
47 | // communication.
48 | //
49 | // Args:
50 | // "key": Key of the server.
51 | // "worker_idx": Index of the worker to modify.
52 | // "new_address": New address of the worker.
53 | //
54 | // Output:
55 | // port: Port of the GRPC server.
56 | REGISTER_OP("SimpleMLUpdateGRPCWorkerAddress")
57 | .SetIsStateful()
58 | .Attr("key: int")
59 | .Input("worker_idx: int32")
60 | .Input("new_address: string")
61 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
62 | c->set_input_handle_shapes_and_types(
63 | 0, {shape_inference::ShapeAndType(c->Scalar(), DataType::DT_INT32)});
64 | c->set_input_handle_shapes_and_types(
65 | 1, {shape_inference::ShapeAndType(c->Scalar(), DataType::DT_STRING)});
66 | return absl::OkStatus();
67 | });
68 |
69 | // Stop any running YDF GRPC Worker server.
70 | //
71 | // Args:
72 | // "key": Key of the server.
73 | //
74 | REGISTER_OP("SimpleMLStopYDFGRPCWorker")
75 | .SetIsStateful()
76 | .Attr("key: int")
77 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
78 | return absl::OkStatus();
79 | });
80 |
81 | } // namespace tensorflow
82 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/ops/training/op_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for op."""
16 |
17 | import logging
18 |
19 | import tensorflow as tf
20 |
21 | from tensorflow_decision_forests.tensorflow.ops.training import op
22 |
23 |
24 | class OpTest(tf.test.TestCase):
25 |
26 | def test_grpc_workers(self):
27 | port = op.SimpleMLCreateYDFGRPCWorker(key=1)
28 | logging.info("port: %d", port)
29 |
30 | port_again = op.SimpleMLCreateYDFGRPCWorker(key=1)
31 | logging.info("port_again: %d", port_again)
32 |
33 | self.assertEqual(port, port_again)
34 |
35 | port_other_server = op.SimpleMLCreateYDFGRPCWorker(key=2)
36 | logging.info("port_other_server: %d", port_other_server)
37 |
38 | self.assertNotEqual(port, port_other_server)
39 |
40 | op.SimpleMLStopYDFGRPCWorker(key=1)
41 | op.SimpleMLStopYDFGRPCWorker(key=2)
42 |
43 |
44 | if __name__ == "__main__":
45 | tf.test.main()
46 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/tf1_compatibility.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tensorflow 1 compatibility utilities.
16 |
17 | Replaces some TF2 libraries with their TF2 counterpart where necessary.
18 | """
19 |
20 | import tensorflow as tf
21 |
22 | if hasattr(tf, '__internal__'):
23 | Trackable = tf.__internal__.tracking.Trackable
24 | AutoTrackable = tf.__internal__.tracking.AutoTrackable
25 | TrackableResource = tf.saved_model.experimental.TrackableResource
26 | no_automatic_dependency_tracking = tf.__internal__.tracking.no_automatic_dependency_tracking
27 | else:
28 | # pylint: disable=g-direct-tensorflow-import, disable=g-import-not-at-top
29 | from tensorflow.python.trackable import autotrackable
30 | from tensorflow.python.trackable import base as trackable_base
31 | from tensorflow.python.trackable import resource
32 | from tensorflow.python.trackable import base as base_tracking
33 | # pylint: enable=g-direct-tensorflow-import, g-import-not-at-top
34 |
35 | Trackable = trackable_base.Trackable
36 | AutoTrackable = autotrackable.AutoTrackable
37 | TrackableResource = resource.TrackableResource
38 | no_automatic_dependency_tracking = base_tracking.no_automatic_dependency_tracking
39 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/tf_logging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Logging primitives.
16 |
17 | Replacement of absl's logging primitives that are always visible to the user.
18 | """
19 |
20 |
21 | from typing import Any, List
22 | from absl import logging
23 |
24 |
25 | def info(msg: str, *args: List[Any]) -> None:
26 | """Print an info message visible to the user.
27 |
28 | To use instead of absl.logging.info (to be visible in colabs).
29 |
30 | Usage example:
31 | logging_info("Hello %s", "world")
32 |
33 | Args:
34 | msg: String message with replacement placeholders e.g. %s.
35 | *args: Placeholder replacement values.
36 | """
37 |
38 | print(msg % args, flush=True)
39 | logging.info(msg, *args)
40 |
41 |
42 | def warning(msg: str, *args: List[Any]) -> None:
43 | """Print a warning message visible to the user.
44 |
45 | To use instead of absl.logging.info (to be visible in colabs).
46 |
47 | Usage example:
48 | logging_warning("Hello %s", "world")
49 |
50 | Args:
51 | msg: String message with replacement placeholders e.g. %s.
52 | *args: Placeholder replacement values.
53 | """
54 |
55 | print("Warning:", msg % args, flush=True)
56 | logging.warning(msg, *args)
57 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow/utils.bzl:
--------------------------------------------------------------------------------
1 | """Utilities for the compilation of tensorflow code."""
2 |
3 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "check_deps", "tf_binary_additional_srcs", "tf_cc_shared_object", "tf_copts", "tf_custom_op_library_additional_deps")
4 |
5 | def _make_search_paths(prefix, search_level):
6 | return "-rpath,%s/%s/external/org_tensorflow/tensorflow" % (prefix, "/".join([".."] * search_level))
7 |
8 | def rpath_linkopts_to_tensorflow(name):
9 | """Create a rpath linkopts flag to include tensorflow .so's directory."""
10 |
11 | levels_to_root = native.package_name().count("/") + name.count("/")
12 | return select({
13 | "@org_tensorflow//tensorflow:macos": [
14 | "-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),),
15 | "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text",
16 | ],
17 | "@org_tensorflow//tensorflow:windows": [],
18 | "//conditions:default": [
19 | "-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root + 1),),
20 | ],
21 | })
22 |
23 | def tf_custom_op_library_external(name, srcs = [], deps = [], linkopts = [], copts = [], **kwargs):
24 | """Helper to build a dynamic library (.so) from the sources containing implementations of custom ops and kernels.
25 |
26 | Similar to "tf_custom_op_library" in tensorflow, but also work for external libraries and windows.
27 | """
28 |
29 | # Rely on the TF in "pywrap_tensorflow_import_lib".
30 | deps = deps + tf_custom_op_library_additional_deps()
31 |
32 | check_deps(
33 | name = name + "_check_deps",
34 | disallowed_deps = [
35 | "@org_tensorflow//tensorflow/core:framework",
36 | "@org_tensorflow//tensorflow/core:lib",
37 | ],
38 | deps = deps,
39 | )
40 |
41 | tf_cc_shared_object(
42 | name = name,
43 | srcs = srcs,
44 | deps = deps,
45 | framework_so = tf_binary_additional_srcs() + [
46 | # Rely on the TF in "tensorflow.dll".
47 | #"@org_tensorflow//tensorflow:tensorflow_dll_import_lib",
48 | # Rely on the TF in "tensorflow_cc.dll".
49 | #"@org_tensorflow//tensorflow:tensorflow_cc_dll_import_lib",
50 | ],
51 | copts = copts + tf_copts(is_external = True),
52 | linkopts = linkopts + select({
53 | "//conditions:default": ["-lm"],
54 | "@org_tensorflow//tensorflow:windows": [],
55 | "@org_tensorflow//tensorflow:macos": [],
56 | }),
57 | **kwargs
58 | )
59 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tensorflow_decision_forests.bzl:
--------------------------------------------------------------------------------
1 | """Helpers fo building with Hermetic Python."""
2 |
3 | def _py_deps_profile_impl(ctx):
4 | is_switch = False
5 | for var_name, var_val in ctx.attr.switch.items():
6 | is_switch = is_switch or ctx.os.environ.get(var_name, "") == var_val
7 |
8 | prefix = ctx.attr.pip_repo_name
9 | reqirements_name = ctx.attr.requirements_in.name
10 | requirements_in_substitutions = {}
11 | build_content = ['exports_files(["{}"])'.format(reqirements_name)]
12 | for k, v in ctx.attr.deps_map.items():
13 | repo_name = v[0] if is_switch else k
14 | requirements_in_substitutions[k + "\n"] = repo_name + "\n"
15 | requirements_in_substitutions[k + "\r\n"] = repo_name + "\r\n"
16 | aliased_targets = ["pkg"] + v[1:]
17 | norm_repo_name = repo_name.replace("-", "_")
18 | norm_alas_name = k.replace("-", "_")
19 | for target in aliased_targets:
20 | alias_name = "{}_{}".format(norm_alas_name, target)
21 | alias_value = "@{}_{}//:{}".format(prefix, norm_repo_name, target)
22 | build_content.append("""
23 | alias(
24 | name = "{}",
25 | actual = "{}",
26 | visibility = ["//visibility:public"]
27 | )
28 | """.format(alias_name, alias_value))
29 |
30 | ctx.file("BUILD", "".join(build_content))
31 | ctx.template(
32 | reqirements_name,
33 | ctx.attr.requirements_in,
34 | executable = False,
35 | substitutions = requirements_in_substitutions,
36 | )
37 |
38 | py_deps_profile = repository_rule(
39 | implementation = _py_deps_profile_impl,
40 | attrs = {
41 | "requirements_in": attr.label(mandatory = True),
42 | "deps_map": attr.string_list_dict(mandatory = True),
43 | "pip_repo_name": attr.string(mandatory = True),
44 | "switch": attr.string_dict(mandatory = True),
45 | },
46 | local = True,
47 | )
48 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
6 | filegroup(
7 | name = "test_data",
8 | srcs = glob(["**"]),
9 | )
10 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/README.md:
--------------------------------------------------------------------------------
1 | # Data for the Unit Tests
2 |
3 | This directory only contains small files used for the unit tests.
4 |
5 | This directory structure is organized as follows: - model: Pre-trained models. -
6 | saved_model_adult_rf: TF SavedModel containing a TF-DF GBT model trained with
7 | Keras.
8 |
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/data_spec.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/data_spec.pb
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/done:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/done
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/header.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/header.pb
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/nodes-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/nodes-00000-of-00001
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/random_forest_header.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/assets/random_forest_header.pb
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/keras_metadata.pb
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/saved_model.pb
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/decision-forests/55faf7e80b5717cbcc2a64989b31d0068feb7d4c/tensorflow_decision_forests/test_data/model/saved_model_adult_rf/variables/variables.index
--------------------------------------------------------------------------------
/tensorflow_decision_forests/tools/run_e2e_tfdf_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2021 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
18 | # Converts a non-submitted CL to a standalone Bazel project in a local
19 | # directory, compile the project and run the tests.
20 | #
21 | # Usage example:
22 | # third_party/tensorflow_decision_forests/tools/run_e2e_tfdf_test.sh
23 |
24 | set -vex
25 |
26 | LOCAL_DIR="/usr/local/google/home/${USER}/git/decision-forests"
27 |
28 | CL=$(hg exportedcl)
29 | echo "Current CL: ${CL}"
30 | echo "Make sure the CL is synced!"
31 |
32 | function export_project() {
33 | COPYBARA="/google/bin/releases/copybara/public/copybara/copybara"
34 |
35 | # Test the copy bara configuration.
36 | bazel test third_party/tensorflow_decision_forests:copybara_test
37 |
38 | echo "Export a Bazel project locally"
39 | echo "=============================="
40 |
41 | rm -fr ${LOCAL_DIR}
42 | ${COPYBARA} third_party/tensorflow_decision_forests/copy.bara.sky presubmit_piper_to_gerrit ${CL} \
43 | --dry-run --init-history --squash --force \
44 | --git-destination-path ${LOCAL_DIR} --ignore-noop
45 |
46 | /google/bin/releases/opensource/thirdparty/cross/cross ${LOCAL_DIR}
47 | }
48 |
49 | echo "Test the project"
50 | echo "================"
51 |
52 | run_all() {
53 | cd ${LOCAL_DIR}
54 |
55 | # Start the Docker
56 | sudo ./tools/start_compile_docker.sh /bin/bash
57 |
58 | # In the docker, you can now trigger the builder with the following line in
59 | # the docker:
60 | # RUN_TESTS=1 PY_VERSION=3.9 TF_VERSION=2.16.1 ./tools/test_bazel.sh
61 |
62 | # Alternatively, you can trigger the build directly with:
63 | # sudo ./tools/start_compile_docker.sh "RUN_TESTS=1 PY_VERSION=3.8 TF_VERSION=2.10.0 ./tools/test_bazel.sh && chmod -R a+xrw . && /bin/bash"
64 | }
65 |
66 | export_project
67 | run_all
68 |
--------------------------------------------------------------------------------
/third_party/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/absl/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/absl/workspace.bzl:
--------------------------------------------------------------------------------
1 | """Absl project."""
2 |
3 | def deps():
4 | # We use the absl injected by the tensorflow dependency.
5 | # Note: TensorFlow relies on possibly old versions of absl.
6 | pass
7 |
--------------------------------------------------------------------------------
/third_party/absl_py/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/absl_py/workspace.bzl:
--------------------------------------------------------------------------------
1 | """Absl project."""
2 |
3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
4 |
5 | def deps():
6 | http_archive(
7 | name = "com_google_absl_py",
8 | urls = ["https://github.com/abseil/abseil-py/archive/master.zip"],
9 | strip_prefix = "abseil-py-master",
10 | )
11 |
--------------------------------------------------------------------------------
/third_party/benchmark/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/benchmark/workspace.bzl:
--------------------------------------------------------------------------------
1 | """Benchmark project."""
2 |
3 | def deps():
4 | # We use the benchmark injected by the tensorflow dependency.
5 | pass
6 |
--------------------------------------------------------------------------------
/third_party/gtest/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/gtest/workspace.bzl:
--------------------------------------------------------------------------------
1 | """Google Test project."""
2 |
3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
4 |
5 | def deps():
6 | http_archive(
7 | name = "com_google_googletest",
8 | urls = ["https://github.com/google/googletest/archive/011959aafddcd30611003de96cfd8d7a7685c700.zip"],
9 | strip_prefix = "googletest-011959aafddcd30611003de96cfd8d7a7685c700",
10 | sha256 = "6a5d7d63cd6e0ad2a7130471105a3b83799a7a2b14ef7ec8d742b54f01a4833c",
11 | )
12 |
--------------------------------------------------------------------------------
/third_party/protobuf/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/protobuf/workspace.bzl:
--------------------------------------------------------------------------------
1 | """Protobuf project."""
2 |
3 | # load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
4 |
5 | def deps():
6 | # We use the protobuf linked in tensorflow.
7 | pass
8 |
--------------------------------------------------------------------------------
/third_party/tensorflow/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/tensorflow/tf-216.patch:
--------------------------------------------------------------------------------
1 | diff --git a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
2 | index 9261a652f9c..0557e345ab1 100644
3 | --- a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
4 | +++ b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
5 | @@ -2,7 +2,7 @@
6 |
7 | load("//tensorflow/tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure")
8 | load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure")
9 | -load("//third_party/py:python_configure.bzl", "remote_python_configure")
10 | +load("//third_party/py/non_hermetic:python_configure.bzl", "remote_python_configure")
11 |
12 | def ml2014_tf_aarch64_configs(name_container_map, env):
13 | for name, container in name_container_map.items():
14 | diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
15 | index 9f71a414bf7..57f70752323 100644
16 | --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
17 | +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
18 | @@ -1,6 +1,6 @@
19 | """Macro that creates external repositories for remote config."""
20 |
21 | -load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure")
22 | +load("//third_party/py/non_hermetic:python_configure.bzl", "local_python_configure", "remote_python_configure")
23 | load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure")
24 | load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure")
25 | load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure")
26 | diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl
27 | index 056df85ffdb..7422baf8c59 100644
28 | --- a/tensorflow/workspace2.bzl
29 | +++ b/tensorflow/workspace2.bzl
30 | @@ -37,7 +37,7 @@ load("//third_party/nasm:workspace.bzl", nasm = "repo")
31 | load("//third_party/nccl:nccl_configure.bzl", "nccl_configure")
32 | load("//third_party/opencl_headers:workspace.bzl", opencl_headers = "repo")
33 | load("//third_party/pasta:workspace.bzl", pasta = "repo")
34 | -load("//third_party/py:python_configure.bzl", "python_configure")
35 | +load("//third_party/py/non_hermetic:python_configure.bzl", "python_configure")
36 | load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo")
37 | load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo")
38 | load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
39 | diff --git a/third_party/py/non_hermetic/python_configure.bzl b/third_party/py/non_hermetic/python_configure.bzl
40 | index 89732c3e33d..4ac1c8f5c04 100644
41 | --- a/third_party/py/non_hermetic/python_configure.bzl
42 | +++ b/third_party/py/non_hermetic/python_configure.bzl
43 | @@ -203,7 +203,7 @@ def _create_local_python_repository(repository_ctx):
44 | # Resolve all labels before doing any real work. Resolving causes the
45 | # function to be restarted with all previous state being lost. This
46 | # can easily lead to a O(n^2) runtime in the number of labels.
47 | - build_tpl = repository_ctx.path(Label("//third_party/py:BUILD.tpl"))
48 | + build_tpl = repository_ctx.path(Label("//third_party/py/non_hermetic:BUILD.tpl"))
49 |
50 | python_bin = get_python_bin(repository_ctx)
51 | _check_python_bin(repository_ctx, python_bin)
52 |
--------------------------------------------------------------------------------
/third_party/tensorflow/tf.patch:
--------------------------------------------------------------------------------
1 | diff --git a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
2 | index a2bdd6a7eed..ec25c23d8d4 100644
3 | --- a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
4 | +++ b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
5 | @@ -2,7 +2,7 @@
6 |
7 | load("//tensorflow/tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure")
8 | load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure")
9 | -load("//third_party/py:python_configure.bzl", "remote_python_configure")
10 | +load("//third_party/py/non_hermetic:python_configure.bzl", "remote_python_configure")
11 |
12 | def ml2014_tf_aarch64_configs(name_container_map, env):
13 | for name, container in name_container_map.items():
14 | diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
15 | index 9f71a414bf7..57f70752323 100644
16 | --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
17 | +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
18 | @@ -1,6 +1,6 @@
19 | """Macro that creates external repositories for remote config."""
20 |
21 | -load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure")
22 | +load("//third_party/py/non_hermetic:python_configure.bzl", "local_python_configure", "remote_python_configure")
23 | load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure")
24 | load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure")
25 | load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure")
26 | diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl
27 | index 953e1d1bea6..664608592a5 100644
28 | --- a/tensorflow/workspace2.bzl
29 | +++ b/tensorflow/workspace2.bzl
30 | @@ -8,7 +8,7 @@ load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
31 | load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
32 | load("//third_party/nccl:nccl_configure.bzl", "nccl_configure")
33 | load("//third_party/git:git_configure.bzl", "git_configure")
34 | -load("//third_party/py:python_configure.bzl", "python_configure")
35 | +load("//third_party/py/non_hermetic:python_configure.bzl", "python_configure")
36 | load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure")
37 | load("//tensorflow/tools/toolchains:cpus/aarch64/aarch64_compiler_configure.bzl", "aarch64_compiler_configure")
38 | load("//tensorflow/tools/toolchains:cpus/arm/arm_compiler_configure.bzl", "arm_compiler_configure")
39 | diff --git a/third_party/py/non_hermetic/python_configure.bzl b/third_party/py/non_hermetic/python_configure.bzl
40 | index 300cbfb6c71..09d98505dd9 100644
41 | --- a/third_party/py/non_hermetic/python_configure.bzl
42 | +++ b/third_party/py/non_hermetic/python_configure.bzl
43 | @@ -206,7 +206,7 @@ def _create_local_python_repository(repository_ctx):
44 | # Resolve all labels before doing any real work. Resolving causes the
45 | # function to be restarted with all previous state being lost. This
46 | # can easily lead to a O(n^2) runtime in the number of labels.
47 | - build_tpl = repository_ctx.path(Label("//third_party/py:BUILD.tpl"))
48 | + build_tpl = repository_ctx.path(Label("//third_party/py/non_hermetic:BUILD.tpl"))
49 |
50 | python_bin = get_python_bin(repository_ctx)
51 | _check_python_bin(repository_ctx, python_bin)
52 |
--------------------------------------------------------------------------------
/third_party/tensorflow/workspace.bzl:
--------------------------------------------------------------------------------
1 | """TensorFlow project."""
2 |
3 | def deps():
4 | # TensorFlow dependencies are injected directly in the WORKSPACE.
5 | pass
6 |
--------------------------------------------------------------------------------
/third_party/tensorflow_pypi/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/tensorflow_pypi/BUILD.tpl:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//visibility:public"])
2 |
3 | cc_library(
4 | name = "tf_header_lib",
5 | hdrs = [":tf_header_include"],
6 | includes = ["include"],
7 | visibility = ["//visibility:public"],
8 | )
9 |
10 | cc_library(
11 | name = "libtensorflow_framework",
12 | srcs = [":libtensorflow_framework.so"],
13 | #data = ["lib/libtensorflow_framework.so"],
14 | visibility = ["//visibility:public"],
15 | )
16 |
17 | %{TF_HEADER_GENRULE}
18 | %{TF_SHARED_LIBRARY_GENRULE}
--------------------------------------------------------------------------------
/third_party/tensorflow_pypi/README:
--------------------------------------------------------------------------------
1 | TensorFlow injected through the .h, .so and .dll available in the TensorFlow
2 | PyPi package.
3 |
4 | This directory contains a copy of:
5 | https://github.com/tensorflow/custom-op/tree/master/tf
6 |
7 | With the following edits:
8 | "//tf:" -> "//third_party/tensorflow_pypi:"
9 |
--------------------------------------------------------------------------------
/third_party/tensorflow_pypi/tf_configure.bzl:
--------------------------------------------------------------------------------
1 | """Setup TensorFlow as external dependency"""
2 |
3 | load("@python//:defs.bzl", "interpreter")
4 | load("@python_version_repo//:py_version.bzl", "REQUIREMENTS_WITH_LOCAL_WHEELS")
5 | load("@rules_python//python:pip.bzl", "package_annotation", "pip_parse")
6 |
7 | def tf_configure():
8 | tensorflow_annotation = """
9 | cc_library(
10 | name = "tf_header_lib",
11 | hdrs = glob(["site-packages/tensorflow/include/**/*"]),
12 | strip_include_prefix="site-packages/tensorflow/include/",
13 | visibility = ["//visibility:public"],
14 | )
15 | cc_library(
16 | name = "libtensorflow_framework",
17 | srcs = select({
18 | "//conditions:default": ["site-packages/tensorflow/libtensorflow_framework.so.2"],
19 | "@bazel_tools//src/conditions:darwin":["site-packages/tensorflow/libtensorflow_framework.2.dylib"],
20 | "@bazel_tools//src/conditions:darwin_x86_64": ["site-packages/tensorflow/libtensorflow_framework.2.dylib"],
21 | }),
22 | visibility = ["//visibility:public"],
23 | )
24 | """
25 | pip_parse(
26 | name = "pypi",
27 | annotations = {
28 | "numpy": package_annotation(
29 | additive_build_content = """
30 | cc_library(
31 | name = "numpy_headers_2",
32 | hdrs = glob(["site-packages/numpy/_core/include/**/*.h"]),
33 | strip_include_prefix="site-packages/numpy/_core/include/",
34 | )
35 | cc_library(
36 | name = "numpy_headers_1",
37 | hdrs = glob(["site-packages/numpy/core/include/**/*.h"]),
38 | strip_include_prefix="site-packages/numpy/core/include/",
39 | )
40 | cc_library(
41 | name = "numpy_headers",
42 | deps = [":numpy_headers_2", ":numpy_headers_1"],
43 | )
44 | """,
45 | ),
46 | "tensorflow": package_annotation(
47 | additive_build_content = tensorflow_annotation,
48 | ),
49 | "tf-nightly": package_annotation(
50 | additive_build_content = tensorflow_annotation,
51 | ),
52 | },
53 | python_interpreter_target = interpreter,
54 | requirements_lock = REQUIREMENTS_WITH_LOCAL_WHEELS,
55 | )
56 |
--------------------------------------------------------------------------------
/third_party/yggdrasil_decision_forests/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
--------------------------------------------------------------------------------
/third_party/yggdrasil_decision_forests/workspace.bzl:
--------------------------------------------------------------------------------
1 | """Yggdrasil Decision Forests project."""
2 |
3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
4 |
5 | def deps(from_git_repo = True):
6 | if from_git_repo:
7 | YDF_VERSION = "1.11.0"
8 | YDF_SHA = "8553a7bfcb96dcdf19f4e9ce7bc5aca1a72df38bd29dfff53e9a58b317bba0c0"
9 | http_archive(
10 | name = "ydf",
11 | urls = ["https://github.com/google/yggdrasil-decision-forests/archive/refs/tags/v{version}.tar.gz".format(version = YDF_VERSION)],
12 | strip_prefix = "yggdrasil-decision-forests-{version}".format(version = YDF_VERSION),
13 | sha256 = YDF_SHA,
14 | patch_args = ["-p1"],
15 | patches = ["//third_party/yggdrasil_decision_forests:ydf.patch"],
16 | )
17 | else:
18 | # You can also clone the YDF repository manually.
19 | # Note that you need to manually apply the patch for Tensorflow >= 2.16 or nightly.
20 | native.local_repository(
21 | name = "ydf",
22 | # When downloading from Github, you might need - instead of _ as folder name
23 | path = "../yggdrasil_decision_forests",
24 | )
25 |
--------------------------------------------------------------------------------
/third_party/yggdrasil_decision_forests/ydf.patch:
--------------------------------------------------------------------------------
1 | diff --git a/yggdrasil_decision_forests/learner/decision_tree/BUILD b/yggdrasil_decision_forests/learner/decision_tree/BUILD
2 | index 9c021b04..0a8746a0 100644
3 | --- a/yggdrasil_decision_forests/learner/decision_tree/BUILD
4 | +++ b/yggdrasil_decision_forests/learner/decision_tree/BUILD
5 | @@ -63,7 +63,7 @@ cc_library_ydf(
6 | "@com_google_absl//absl/time",
7 | "@com_google_absl//absl/types:optional",
8 | "@com_google_absl//absl/types:span",
9 | - "@eigen_archive//:eigen3_internal",
10 | + "@eigen_archive//:eigen3",
11 | ] + select({
12 | "//conditions:default": [
13 | ],
14 | @@ -150,7 +150,7 @@ cc_library_ydf(
15 | "@com_google_absl//absl/time",
16 | "@com_google_absl//absl/types:optional",
17 | "@com_google_absl//absl/types:span",
18 | - "@eigen_archive//:eigen3_internal",
19 | + "@eigen_archive//:eigen3",
20 | ] + select({
21 | "//conditions:default": [
22 | ],
23 |
--------------------------------------------------------------------------------
/tools/BUILD:
--------------------------------------------------------------------------------
1 | package(
2 | default_visibility = ["//visibility:public"],
3 | licenses = ["notice"],
4 | )
5 |
--------------------------------------------------------------------------------
/tools/build_docs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Export the source code comments into the API Reference website."""
16 |
17 | import os
18 |
19 | from absl import app
20 | from absl import flags
21 |
22 | from tensorflow_docs.api_generator import generate_lib
23 | from tensorflow_docs.api_generator import public_api
24 | import tensorflow_decision_forests as tfdf
25 |
26 | FLAGS = flags.FLAGS
27 |
28 | flags.DEFINE_string("output_dir", "/tmp/tfdf_api", "Where to output the docs")
29 |
30 | flags.DEFINE_string("code_url_prefix", "", "The url prefix for links to code.")
31 |
32 | flags.DEFINE_bool("search_hints", True,
33 | "Include metadata search hints in the generated files")
34 |
35 | flags.DEFINE_string("site_path", "/decision_forests/api_docs/python",
36 | "Path prefix in the _toc.yaml")
37 |
38 |
39 | def main(argv):
40 | if len(argv) > 1:
41 | raise app.UsageError("Too many command-line arguments.")
42 |
43 | doc_generator = generate_lib.DocGenerator(
44 | root_title="TensorFlow Decision Forests",
45 | py_modules=[("tfdf", tfdf)], # ("tfdf.keras", tfdf.keras)
46 | base_dir=os.path.dirname(tfdf.__file__),
47 | code_url_prefix=FLAGS.code_url_prefix,
48 | search_hints=FLAGS.search_hints,
49 | site_path=FLAGS.site_path,
50 | callbacks=[public_api.explicit_package_contents_filter])
51 | doc_generator.build(FLAGS.output_dir)
52 |
53 |
54 | if __name__ == "__main__":
55 | app.run(main)
56 |
--------------------------------------------------------------------------------
/tools/build_pip_package.bat:
--------------------------------------------------------------------------------
1 | :: Copyright 2021 Google LLC.
2 | ::
3 | :: Licensed under the Apache License, Version 2.0 (the "License");
4 | :: you may not use this file except in compliance with the License.
5 | :: You may obtain a copy of the License at
6 | ::
7 | :: https://www.apache.org/licenses/LICENSE-2.0
8 | ::
9 | :: Unless required by applicable law or agreed to in writing, software
10 | :: distributed under the License is distributed on an "AS IS" BASIS,
11 | :: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | :: See the License for the specific language governing permissions and
13 | :: limitations under the License.
14 |
15 |
16 | set PATH=C:\Python38;c:\tools\msys64\usr\bin;%PATH%
17 | SET BASH=C:\tools\msys64\usr\bin\bash.exe
18 | DOSKEY python3=C:\Python38\python.exe $*
19 |
20 | %BASH% tools\build_pip_package.sh
21 |
--------------------------------------------------------------------------------
/tools/start_compile_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2021 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
18 | # Start a docker able to compile TF-DF.
19 | #
20 | # Usage example
21 | #
22 | # # Create the pip packages for publication
23 | #
24 | # # Download and start the docker (this script)
25 | # ./tools/start_compile_docker.sh
26 | #
27 | # # Compile and test TF-DF.
28 | # RUN_TESTS=1 PY_VERSION=3.9 TF_VERSION=2.15.0 ./tools/test_bazel.sh
29 | #
30 | # # Create a Pip package for a specific version of python.
31 | # ./tools/build_pip_package.sh python3.9
32 | #
33 | # # Install the other versions of python (the docker only has py3.8).
34 | # sudo apt-get update
35 | # sudo apt-get install python3.9 python3-pip python3.10
36 | #
37 | # # Create the Pip package for the other version of python
38 | # ./tools/build_pip_package.sh python3.9
39 | # ./tools/build_pip_package.sh python3.10
40 | #
41 | # # Make the result of the docker world accessible (in case the docker is run
42 | # # in root).
43 | # chmod -R a+xrw .
44 | #
45 | # # Exit the docker
46 | # exit
47 | #
48 | # # Publish the pip packages
49 | # ./tools/submit_pip_package.sh
50 | #
51 | # Alternative ending
52 | #
53 | # # Create a Pip package for all the compatible version of pythons using pyenv.
54 | # ./tools/build_pip_package.sh ALL_VERSIONS
55 | #
56 | # # Create a Pip package for all the compatible version of python using the
57 | # # previous build_pip_package call results (i.e. the "tmp_package" directory)
58 | # ./tools/build_pip_package.sh ALL_VERSIONS_ALREADY_ASSEMBLED
59 | #
60 | # https://hub.docker.com/r/tensorflow/build/tags?page=1
61 |
62 | # Current directory
63 | # Useful if Yggdrasil Decision Forests is available locally in a neighbor
64 | # directory.
65 | TFDF_DIRNAME=${PWD##*/}
66 |
67 | DOCKER_IMAGE=tensorflow/build:2.19-python3.9
68 | DOCKER_CONTAINER=compile_tfdf
69 |
70 | echo "Available containers:"
71 | sudo docker container ls -a --size
72 |
73 | set +e # Ignore error if the container already exist
74 | CREATE_DOCKER_FLAGS="-i -t -p 8889:8889 --network host -v ${PWD}/..:/working_dir -w /working_dir/${TFDF_DIRNAME}"
75 | sudo docker create ${CREATE_DOCKER_FLAGS} --name ${DOCKER_CONTAINER} ${DOCKER_IMAGE}
76 | sudo docker start ${DOCKER_CONTAINER}
77 | set -e
78 |
79 | # Start docker
80 | sudo docker exec -it ${DOCKER_CONTAINER} /bin/bash -c $@
81 |
--------------------------------------------------------------------------------
/tools/submit_pip_package.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2021 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
18 | python3 -m twine upload -u __token__ dist/*manylinux2014_x86_64.whl
19 | python3 -m twine upload -u __token__ dist/*macosx*.whl
20 | echo "The results is available at https://pypi.org/project/tensorflow-decision-forests"
21 |
--------------------------------------------------------------------------------
/tools/test_bazel.bat:
--------------------------------------------------------------------------------
1 | :: Copyright 2021 Google LLC.
2 | ::
3 | :: Licensed under the Apache License, Version 2.0 (the "License");
4 | :: you may not use this file except in compliance with the License.
5 | :: You may obtain a copy of the License at
6 | ::
7 | :: https://www.apache.org/licenses/LICENSE-2.0
8 | ::
9 | :: Unless required by applicable law or agreed to in writing, software
10 | :: distributed under the License is distributed on an "AS IS" BASIS,
11 | :: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | :: See the License for the specific language governing permissions and
13 | :: limitations under the License.
14 |
15 | :: Compile and runs the unit tests.
16 | set BAZEL=bazel-3.7.2-windows-x86_64.exe
17 |
18 | :: TensorFlow bazelrc is required for the distributed compilation using
19 | :: RBE i.e. a remove server (fast).
20 | :: https://raw.githubusercontent.com/tensorflow/tensorflow/master/.bazelrc
21 | SET TENSORFLOW_BAZELRC=tensorflow_bazelrc
22 |
23 | %BAZEL% --bazelrc=%TENSORFLOW_BAZELRC% build^
24 | //tensorflow_decision_forests/...:all^
25 | --config=windows^
26 | --config=rbe_win^
27 | --config=rbe_win_py38^
28 | --config=tensorflow_testing_rbe_win^
29 | --flaky_test_attempts=1
30 |
31 | :: Use --output_user_root to specify a quick-to-access location.
32 |
33 | if %errorlevel% neq 0 exit /b %errorlevel%
34 |
--------------------------------------------------------------------------------
/tools/tf_serving/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow Decision Forests and TensorFlow Serving
2 |
3 | `build_tf_serving_with_tf_df.sh` is a script that compiles TF-Serving with
4 | TF-Decision Forests. See the
5 | [TensorFlow Decision Forests and TensorFlow Serving guide](https://www.tensorflow.org/decision_forests/tensorflow_serving)
6 | for background information. Results of this script as published as
7 | `tf_serving_linux.zip` on the
8 | [TF-DF GitHub release page](https://github.com/tensorflow/decision-forests/releases).
9 |
10 | **Compilation**
11 |
12 | *Note:* Make sure Docker is installed.
13 |
14 | ```shell
15 | # In a fresh directory
16 | git clone https://github.com/tensorflow/decision-forests.git
17 | git clone https://github.com/tensorflow/serving.git
18 | decision-forests/tools/tf_serving/build_tf_serving_with_tf_df.sh
19 |
20 | # Or sudo decision-forests/tools/tf_serving/build_tf_serving_with_tf_df.sh if
21 | # Docker need to be run as sudo.
22 | ```
23 |
24 | **Usage example:**
25 |
26 | ```shell
27 | # Make sure TF-DF is installed (this is only necessary to train the model)
28 | pip3 install tensorflow_decision_forests -U
29 |
30 | # Train a Random Forest model on the adult dataset
31 | # Save the model to "/tmp/my_saved_model"
32 | python3 decision-forests/examples/minimal.py
33 |
34 | # Add a version number to the model (this is required for TF Serving)
35 | mkdir -p /tmp/my_saved_model_with_version
36 | cp -r /tmp/my_saved_model /tmp/my_saved_model_with_version/1
37 |
38 | # Compile or download TF-Serving, and set the TFSERVING variable accordingly.
39 | # TFServing binary.
40 | TFSERVING="./tensorflow_model_server"
41 |
42 | # Configure the model path and name.
43 | MODEL_PATH=/tmp/my_saved_model_with_version
44 | MODEL_NAME=my_model
45 |
46 | # Start a TF Serving server
47 | # Note: This command is blocking. You need to run it in a separate terminal (or
48 | # using &).
49 | ${TFSERVING} \
50 | --rest_api_port=8501 \
51 | --model_name=${MODEL_NAME} \
52 | --model_base_path=${MODEL_PATH}
53 |
54 | # Send requests to the model:
55 | # See https://www.tensorflow.org/tfx/serving/api_rest for the various solutions.
56 |
57 | # Predictions with the predict+instances API.
58 | curl http://localhost:8501/v1/models/${MODEL_NAME}:predict -X POST \
59 | -d '{"instances": [{"age":39,"workclass":"State-gov","fnlwgt":77516,"education":"Bachelors","education_num":13,"marital_status":"Never-married","occupation":"Adm-clerical","relationship":"Not-in-family","race":"White","sex":"Male","capital_gain":2174,"capital_loss":0,"hours_per_week":40,"native_country":"United-States"}]}'
60 |
61 | # Predictions with the predict+inputs API
62 | curl http://localhost:8501/v1/models/${MODEL_NAME}:predict -X POST \
63 | -d '{"inputs": {"age":[39],"workclass":["State-gov"],"fnlwgt":[77516],"education":["Bachelors"],"education_num":[13],"marital_status":["Never-married"],"occupation":["Adm-clerical"],"relationship":["Not-in-family"],"race":["White"],"sex":["Male"],"capital_gain":[2174],"capital_loss":[0],"hours_per_week":[40],"native_country":["United-States"]}}'
64 | ```
65 |
--------------------------------------------------------------------------------
/tools/tf_serving/build_tf_serving_with_tf_df.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2021 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
18 | # Compile TensorFlow Servo with TensorFlow Decision Forests.
19 | #
20 | # This script is equivalent as the instructions at:
21 | # https://www.tensorflow.org/decision_forests/tensorflow_serving
22 | #
23 | # This script is a standalone: It will compile the version TF-DF available
24 | # publicly on github.
25 | #
26 | # Usage example:
27 | # # In a fresh directory
28 | # git clone https://github.com/tensorflow/decision-forests.git
29 | # git clone https://github.com/tensorflow/serving.git
30 | # decision-forests/tools/tf_serving/build_tf_serving_with_tf_df.sh
31 | #
32 | set -e
33 |
34 | # Add TF-DF as a dependency.
35 |
36 | # WORKSPACE
37 | WORKSPACE_PATH="serving/WORKSPACE"
38 | if [ ! -f "${WORKSPACE_PATH}.bk" ]; then
39 | echo "Edit WORKSPACE: ${WORKSPACE_PATH}"
40 | cp ${WORKSPACE_PATH} ${WORKSPACE_PATH}.bk
41 | sed -i '3rdecision-forests/tools/tf_serving/build_tf_serving_workspace_extra.txt' ${WORKSPACE_PATH}
42 | else
43 | echo "WORKSPACE already edited: ${WORKSPACE_PATH}"
44 | fi
45 |
46 | # BUILD
47 | BUILD_PATH="serving/tensorflow_serving/model_servers/BUILD"
48 | if [ ! -f "${BUILD_PATH}.bk" ]; then
49 | echo "Edit BUILD: ${BUILD_PATH}"
50 | cp ${BUILD_PATH} ${BUILD_PATH}.bk
51 | sed -i 's|if_v2(\[\])|if_v2([\n "@org_tensorflow_decision_forests//tensorflow_decision_forests/tensorflow/ops/inference:kernel_and_op"\n])|g' serving/tensorflow_serving/model_servers/BUILD
52 | else
53 | echo "BUILD already edited: ${BUILD_PATH}"
54 | fi
55 |
56 | # Compile Servo
57 | echo "Compile Servo"
58 | (cd serving && tools/run_in_docker.sh bazel build -c opt \
59 | --copt=-mfma \
60 | --define use_tensorflow_io=1 \
61 | --copt=-mavx2 tensorflow_serving/model_servers:tensorflow_model_server )
62 |
63 | # Pack the binary in a zip. This is the zip distributed in github.
64 | BINARY=serving/bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server
65 | zip -j tensorflow_model_server_linux.zip ${BINARY} README.md
66 |
67 | # If you have a model, you can run it in TF Serving using the following command:
68 | # ${BINARY} \
69 | # --rest_api_port=8501 \
70 | # --model_name=my_model \
71 | # --model_base_path=/path/to/tfdf/model
72 |
--------------------------------------------------------------------------------
/tools/tf_serving/build_tf_serving_workspace_extra.txt:
--------------------------------------------------------------------------------
1 | # Instructions to add to the TF Serving WORKSPACE file to add support to
2 | # TensorFlow Decision Forests.
3 |
4 | # Import Yggdrasil Decision Forests.
5 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
6 | http_archive(
7 | name="ydf",
8 | urls=[
9 | "https://github.com/google/yggdrasil-decision-forests/archive/refs/tags/0.2.3.zip"],
10 | strip_prefix="yggdrasil-decision-forests-0.2.3",
11 | )
12 |
13 | # Load the YDF dependencies. However, skip the ones already injected by
14 | # TensorFlow.
15 | load("@ydf//yggdrasil_decision_forests:library.bzl",
16 | ydf_load_deps="load_dependencies")
17 | ydf_load_deps(
18 | exclude_repo=[
19 | "absl",
20 | "protobuf",
21 | "zlib",
22 | "farmhash",
23 | "gtest",
24 | "tensorflow",
25 | "grpc"
26 | ],
27 | repo_name="@ydf",
28 | )
29 |
30 | # Import TensorFlow Decision Forests.
31 | load("//tensorflow_serving:repo.bzl", "tensorflow_http_archive")
32 | http_archive(
33 | name="org_tensorflow_decision_forests",
34 | urls=[
35 | "https://github.com/tensorflow/decision-forests/archive/refs/tags/0.2.3.zip"],
36 | strip_prefix="decision-forests-0.2.3",
37 | )
38 |
--------------------------------------------------------------------------------