├── .gitmodules ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── WORKSPACE ├── docs ├── _book.yaml ├── _index.yaml ├── build_docs.py ├── images │ ├── 2D_shape_constraints_picture_color.png │ ├── 2d_lattice.png │ ├── favicon.ico │ ├── flexible_fit.png │ ├── linear_fit.png │ ├── model_comparison.png │ ├── monotonic_fit.png │ ├── pwl_calibration_distance.png │ ├── pwl_calibration_price.png │ ├── regularized_fit.png │ └── tensorflow_lattice.png ├── install.md ├── overview.md └── tutorials │ ├── aggregate_function_models.ipynb │ ├── keras_layers.ipynb │ ├── premade_models.ipynb │ ├── shape_constraints.ipynb │ └── shape_constraints_for_ethics.ipynb ├── examples ├── BUILD ├── keras_functional_uci_heart.py └── keras_sequential_uci_heart.py ├── setup.py └── tensorflow_lattice ├── BUILD ├── __init__.py ├── layers └── __init__.py └── python ├── BUILD ├── __init__.py ├── aggregation_layer.py ├── aggregation_test.py ├── categorical_calibration_layer.py ├── categorical_calibration_lib.py ├── categorical_calibration_test.py ├── cdf_layer.py ├── cdf_test.py ├── conditional_cdf.py ├── conditional_cdf_test.py ├── conditional_pwl_calibration.py ├── conditional_pwl_calibration_test.py ├── configs.py ├── configs_test.py ├── internal_utils.py ├── internal_utils_test.py ├── kronecker_factored_lattice_layer.py ├── kronecker_factored_lattice_lib.py ├── kronecker_factored_lattice_test.py ├── lattice_layer.py ├── lattice_lib.py ├── lattice_test.py ├── linear_layer.py ├── linear_lib.py ├── linear_test.py ├── model_info.py ├── parallel_combination_layer.py ├── parallel_combination_test.py ├── premade.py ├── premade_lib.py ├── premade_test.py ├── pwl_calibration_layer.py ├── pwl_calibration_lib.py ├── pwl_calibration_test.py ├── rtl_layer.py ├── rtl_lib.py ├── rtl_test.py ├── test_utils.py ├── utils.py └── utils_test.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tensorflow"] 2 | path = tensorflow 3 | url = https://github.com/tensorflow/tensorflow.git 4 | branch = r1.3 5 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of TensorFlow Lattice authors for copyright purposes. 2 | # Names should be added to this file as: 3 | # Name or Organization 4 | # The email address is not required for organizations. 5 | Google Inc. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 15 | # How to Contribute 16 | 17 | We'd love to accept your patches and contributions to this project. There are 18 | just a few small guidelines you need to follow. 19 | 20 | ## Contributor License Agreement 21 | 22 | Contributions to this project must be accompanied by a Contributor License 23 | Agreement. You (or your employer) retain the copyright to your contribution, 24 | this simply gives us permission to use and redistribute your contributions as 25 | part of the project. Head over to to see 26 | your current agreements on file or to sign a new one. 27 | 28 | You generally only need to submit a CLA once, so if you've already submitted one 29 | (even if it was for a different project), you probably don't need to do it 30 | again. 31 | 32 | ## Code reviews 33 | 34 | All submissions, including submissions by project members, require review. We 35 | use GitHub pull requests for this purpose. Consult 36 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 37 | information on using pull requests. 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 15 | # TensorFlow Lattice 16 | 17 | TensorFlow Lattice is a library that implements constrained and interpretable 18 | lattice based models. It is an implementation of 19 | [Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html) 20 | in [TensorFlow](https://www.tensorflow.org). 21 | 22 | The library enables you to inject domain knowledge into 23 | the learning process through common-sense or policy-driven shape constraints. 24 | This is done using a collection of Keras layers that can satisfy constraints 25 | such as monotonicity, convexity and pairwise trust: 26 | 27 | * PWLCalibration: piecewise linear calibration of signals. 28 | * CategoricalCalibration: mapping of categorical inputs into real values. 29 | * Lattice: interpolated look-up table implementation. 30 | * Linear: linear function with monotonicity and norm constraints. 31 | 32 | The library also provides easy to setup canned estimators for common use cases: 33 | 34 | * Calibrated Linear 35 | * Calibrated Lattice 36 | * Random Tiny Lattices (RTL) 37 | * Crystals 38 | 39 | With TF Lattice you can use domain knowledge to better extrapolate to the parts 40 | of the input space not covered by the training dataset. This helps avoid 41 | unexpected model behaviour when the serving distribution is different from the 42 | training distribution. 43 | 44 |
45 | 46 |
47 | 48 | You can install our prebuilt pip package using 49 | 50 | ```bash 51 | pip install tensorflow-lattice 52 | ``` 53 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Lattice Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy of 5 | # the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations under 13 | # the License. 14 | # ============================================================================== 15 | 16 | workspace(name = "tensorflow_lattice") 17 | -------------------------------------------------------------------------------- /docs/_book.yaml: -------------------------------------------------------------------------------- 1 | upper_tabs: 2 | # Tabs left of dropdown menu 3 | - include: /_upper_tabs_left.yaml 4 | - include: /api_docs/_upper_tabs_api.yaml 5 | # Dropdown menu 6 | - name: Resources 7 | path: /resources 8 | is_default: true 9 | menu: 10 | - include: /resources/_menu_toc.yaml 11 | lower_tabs: 12 | # Subsite tabs 13 | other: 14 | - name: Guide & Tutorials 15 | contents: 16 | - title: Overview 17 | path: /lattice/overview 18 | - title: Install 19 | path: /lattice/install 20 | - heading: Tutorials 21 | - title: Shape Constraints 22 | path: /lattice/tutorials/shape_constraints 23 | - title: Ethical Constraints for ML Fairness 24 | path: /lattice/tutorials/shape_constraints_for_ethics 25 | - title: Keras Layers and Custom Models 26 | path: /lattice/tutorials/keras_layers 27 | - title: Keras Premade Models 28 | path: /lattice/tutorials/premade_models 29 | - title: Aggregate Function Models 30 | path: /lattice/tutorials/aggregate_function_models 31 | 32 | - name: API 33 | skip_translation: true 34 | contents: 35 | - title: All Symbols 36 | path: /lattice/api_docs/python/tfl/all_symbols 37 | - include: /lattice/api_docs/python/tfl/_toc.yaml 38 | 39 | - include: /_upper_tabs_right.yaml 40 | -------------------------------------------------------------------------------- /docs/_index.yaml: -------------------------------------------------------------------------------- 1 | book_path: /lattice/_book.yaml 2 | project_path: /lattice/_project.yaml 3 | description: A library for training constrained and interpretable lattice based models. Inject 4 | domain knowledge into the learning process through constraints on Keras layers. 5 | landing_page: 6 | custom_css_path: /site-assets/css/style.css 7 | rows: 8 | - heading: Flexible, controlled and interpretable ML with lattice based models 9 | items: 10 | - classname: devsite-landing-row-50 11 | description: > 12 |

TensorFlow Lattice is a library that implements constrained and interpretable lattice 13 | based models. The library enables you to inject domain knowledge into the learning process 14 | through common-sense or policy-driven 15 | shape constraints. This is done using a 16 | collection of Keras layers that can satisfy 17 | constraints such as monotonicity, convexity and how features interact. The library also 18 | provides easy to setup premade models.

19 |

With TF Lattice you can use domain knowledge to better extrapolate to the parts of the 20 | input space not covered by the training dataset. This helps avoid unexpected model behaviour 21 | when the serving distribution is different from the training distribution.

22 |
23 | 24 |
25 | 26 | code_block: | 27 |
28 |         import numpy as np
29 |         import tensorflow as tf
30 |         import tensorflow_lattice as tfl
31 | 
32 |         model = tf.keras.models.Sequential()
33 |         model.add(
34 |             tfl.layers.ParallelCombination([
35 |                 # Monotonic piece-wise linear calibration with bounded output
36 |                 tfl.layers.PWLCalibration(
37 |                     monotonicity='increasing',
38 |                     input_keypoints=np.linspace(1., 5., num=20),
39 |                     output_min=0.0,
40 |                     output_max=1.0),
41 |                 # Diminishing returns
42 |                 tfl.layers.PWLCalibration(
43 |                     monotonicity='increasing',
44 |                     convexity='concave',
45 |                     input_keypoints=np.linspace(0., 200., num=20),
46 |                     output_min=0.0,
47 |                     output_max=2.0),
48 |                 # Partially monotonic categorical calibration: calib(0) <= calib(1)
49 |                 tfl.layers.CategoricalCalibration(
50 |                     num_buckets=4,
51 |                     output_min=0.0,
52 |                     output_max=1.0,
53 |                     monotonicities=[(0, 1)]),
54 |             ]))
55 |         model.add(
56 |             tfl.layers.Lattice(
57 |                 lattice_sizes=[2, 3, 2],
58 |                 monotonicities=['increasing', 'increasing', 'increasing'],
59 |                 # Trust: model is more responsive to input 0 if input 1 increases
60 |                 edgeworth_trusts=(0, 1, 'positive')))
61 |         model.compile(...)
62 |         
63 | 64 | - classname: devsite-landing-row-cards 65 | items: 66 | - heading: "TensorFlow Lattice: Flexible, controlled and interpretable ML" 67 | image_path: /resources/images/tf-logo-card-16x9.png 68 | path: https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html 69 | buttons: 70 | - label: "Read on the TensorFlow blog" 71 | path: https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html 72 | - heading: "TensorFlow Lattice: Control your ML with monotonicity" 73 | youtube_id: ABBnNjbjv2Q 74 | buttons: 75 | - label: Watch the video 76 | path: https://www.youtube.com/watch?v=ABBnNjbjv2Q 77 | - heading: "TF Lattice on GitHub" 78 | image_path: /resources/images/github-card-16x9.png 79 | path: https://github.com/tensorflow/lattice 80 | buttons: 81 | - label: "View on GitHub" 82 | path: https://github.com/tensorflow/lattice 83 | -------------------------------------------------------------------------------- /docs/build_docs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Generate docs API for TF Lattice. 15 | 16 | Example run: 17 | 18 | ``` 19 | python build_docs.py --output_dir=/path/to/output 20 | ``` 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import os 28 | import sys 29 | 30 | from absl import app 31 | from absl import flags 32 | 33 | from tensorflow_docs.api_generator import generate_lib 34 | from tensorflow_docs.api_generator import public_api 35 | 36 | import tensorflow_lattice as tfl 37 | 38 | flags.DEFINE_string('output_dir', '/tmp/tfl_api/', 39 | 'The path to output the files to') 40 | 41 | flags.DEFINE_string( 42 | 'code_url_prefix', 43 | 'https://github.com/tensorflow/lattice/blob/master/tensorflow_lattice', 44 | 'The url prefix for links to code.') 45 | 46 | flags.DEFINE_bool('search_hints', True, 47 | 'Include metadata search hints in the generated files') 48 | 49 | flags.DEFINE_string('site_path', 'lattice/api_docs/python', 50 | 'Path prefix in the _toc.yaml') 51 | 52 | FLAGS = flags.FLAGS 53 | 54 | 55 | def local_definitions_filter(path, parent, children): 56 | """Filters local imports, except for the tfl.layers module.""" 57 | if path == ('tfl', 'layers'): 58 | return children 59 | return public_api.local_definitions_filter(path, parent, children) 60 | 61 | 62 | def main(_): 63 | private_map = { 64 | 'tfl': ['python'], 65 | 'tfl.aggregation_layer': ['Aggregation'], 66 | 'tfl.categorical_calibration_layer': ['CategoricalCalibration'], 67 | 'tfl.cdf_layer': ['CDF'], 68 | 'tfl.kronecker_factored_lattice_layer': ['KroneckerFactoredLattice'], 69 | 'tfl.lattice_layer': ['Lattice'], 70 | 'tfl.linear_layer': ['Linear'], 71 | 'tfl.pwl_calibration_layer': ['PWLCalibration'], 72 | 'tfl.parallel_combination_layer': ['ParallelCombination'], 73 | 'tfl.rtl_layer': ['RTL'], 74 | } 75 | doc_generator = generate_lib.DocGenerator( 76 | root_title='TensorFlow Lattice 2.0', 77 | py_modules=[('tfl', tfl)], 78 | base_dir=os.path.dirname(tfl.__file__), 79 | code_url_prefix=FLAGS.code_url_prefix, 80 | search_hints=FLAGS.search_hints, 81 | site_path=FLAGS.site_path, 82 | private_map=private_map, 83 | callbacks=[local_definitions_filter]) 84 | 85 | sys.exit(doc_generator.build(output_dir=FLAGS.output_dir)) 86 | 87 | 88 | if __name__ == '__main__': 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /docs/images/2D_shape_constraints_picture_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/2D_shape_constraints_picture_color.png -------------------------------------------------------------------------------- /docs/images/2d_lattice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/2d_lattice.png -------------------------------------------------------------------------------- /docs/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/favicon.ico -------------------------------------------------------------------------------- /docs/images/flexible_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/flexible_fit.png -------------------------------------------------------------------------------- /docs/images/linear_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/linear_fit.png -------------------------------------------------------------------------------- /docs/images/model_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/model_comparison.png -------------------------------------------------------------------------------- /docs/images/monotonic_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/monotonic_fit.png -------------------------------------------------------------------------------- /docs/images/pwl_calibration_distance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/pwl_calibration_distance.png -------------------------------------------------------------------------------- /docs/images/pwl_calibration_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/pwl_calibration_price.png -------------------------------------------------------------------------------- /docs/images/regularized_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/regularized_fit.png -------------------------------------------------------------------------------- /docs/images/tensorflow_lattice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/lattice/f52258331345031d179e24b8c8a93bffdd2d7597/docs/images/tensorflow_lattice.png -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Install TensorFlow Lattice 2 | 3 | There are several ways to set up your environment to use TensorFlow Lattice 4 | (TFL). 5 | 6 | * The easiest way to learn and use TFL requires no installation: run the any 7 | of the tutorials (e.g. 8 | [premade models](tutorials/premade_models.ipynb)). 9 | * To use TFL on a local machine, install the `tensorflow-lattice` pip package. 10 | * If you have a unique machine configuration, you can build the package from 11 | source. 12 | 13 | ## Install TensorFlow Lattice using pip 14 | 15 | Install using pip. 16 | 17 | ```shell 18 | pip install --upgrade tensorflow-lattice 19 | ``` 20 | 21 | Note that you will need to have `tf_keras` package installed as well. 22 | 23 | ## Build from source 24 | 25 | Clone the github repo: 26 | 27 | ```shell 28 | git clone https://github.com/tensorflow/lattice.git 29 | ``` 30 | 31 | Build pip package from source: 32 | 33 | ```shell 34 | python setup.py sdist bdist_wheel --universal --release 35 | ``` 36 | 37 | Install the package: 38 | 39 | ```shell 40 | pip install --user --upgrade /path/to/pkg.whl 41 | ``` 42 | -------------------------------------------------------------------------------- /docs/overview.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Lattice (TFL) 2 | 3 | TensorFlow Lattice is a library that implements flexible, controlled and 4 | interpretable lattice based models. The library enables you to inject domain 5 | knowledge into the learning process through common-sense or policy-driven 6 | [shape constraints](tutorials/shape_constraints.ipynb). This is done using a 7 | collection of [Keras layers](tutorials/keras_layers.ipynb) that can satisfy 8 | constraints such as monotonicity, convexity and pairwise trust. The library also 9 | provides easy to setup [premade models](tutorials/premade_models.ipynb). 10 | 11 | ## Concepts 12 | 13 | This section is a simplified version of the description in 14 | [Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html) 15 | , JMLR 2016. 16 | 17 | ### Lattices 18 | 19 | A *lattice* is an interpolated look-up table that can approximate arbitrary 20 | input-output relationships in your data. It overlaps a regular grid onto your 21 | input space and learns values for the output in the vertices of the grid. For a 22 | test point $x$, $f(x)$ is linearly interpolated from the lattice values 23 | surrounding $x$. 24 | 25 | 26 | 27 | The simple example above is a function with 2 input features and 4 parameters: 28 | $\theta=[0, 0.2, 0.4, 1]$, which are the function's values at the corners of the 29 | input space; the rest of the function is interpolated from these parameters. 30 | 31 | The function $f(x)$ can capture non-linear interactions between features. You 32 | can think of the lattice parameters as the height of poles set in the ground on 33 | a regular grid, and the resulting function is like cloth pulled tight against 34 | the four poles. 35 | 36 | With $D$ features and 2 vertices along each dimension, a regular lattice will 37 | have $2^D$ parameters. To fit a more flexible function, you can specify a 38 | finer-grained lattice over the feature space with more vertices along each 39 | dimension. Lattice regression functions are continuous and piecewise infinitely 40 | differentiable. 41 | 42 | ### Calibration 43 | 44 | Let's say the preceding sample lattice represents a learned *user happiness* 45 | with a suggested local coffee shop calculated using features: 46 | 47 | * coffee price, in range 0 to 20 dollars 48 | * distance to the user, in range 0 to 30 kilometers 49 | 50 | We want our model to learn user happiness with a local coffee shop suggestion. 51 | TensorFlow Lattice models can use *piecewise linear functions* (with 52 | `tfl.layers.PWLCalibration`) to calibrate and normalize the input features to 53 | the range accepted by the lattice: 0.0 to 1.0 in the example lattice above. The 54 | following show examples such calibrations functions with 10 keypoints: 55 | 56 |

57 | 58 | 59 |

60 | 61 | It is often a good idea to use the quantiles of the features as input keypoints. 62 | TensorFlow Lattice [premade models](tutorials/premade_models.ipynb) can 63 | automatically set the input keypoints to the feature quantiles. 64 | 65 | For categorical features, TensorFlow Lattice provides categorical calibration 66 | (with `tfl.layers.CategoricalCalibration`) with similar output bounding to feed 67 | into a lattice. 68 | 69 | ### Ensembles 70 | 71 | The number of parameters of a lattice layer increases exponentially with the 72 | number of input features, hence not scaling well to very high dimensions. To 73 | overcome this limitation, TensorFlow Lattice offers ensembles of lattices that 74 | combine (average) several *tiny* lattices, which enables the model to grow 75 | linearly in the number of features. 76 | 77 | The library provides two variations of these ensembles: 78 | 79 | * **Random Tiny Lattices** (RTL): Each submodel uses a random subset of 80 | features (with replacement). 81 | 82 | * **Crystals** : The Crystals algorithm first trains a *prefitting* model that 83 | estimates pairwise feature interactions. It then arranges the final ensemble 84 | such that features with more non-linear interactions are in the same 85 | lattices. 86 | 87 | ## Why TensorFlow Lattice ? 88 | 89 | You can find a brief introduction to TensorFlow Lattice in this 90 | [TF Blog post](https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html). 91 | 92 | ### Interpretability 93 | 94 | Since the parameters of each layer are the output of that layer, it is easy to 95 | analyze, understand and debug each part of the model. 96 | 97 | ### Accurate and Flexible Models 98 | 99 | Using fine-grained lattices, you can get *arbitrarily complex* functions with a 100 | single lattice layer. Using multiple layers of calibrators and lattices often 101 | work nicely in practice and can match or outperform DNN models of similar sizes. 102 | 103 | ### Common-Sense Shape Constraints 104 | 105 | Real world training data may not sufficiently represent the run-time data. 106 | Flexible ML solutions such as DNNs or forests often act unexpectedly and even 107 | wildly in parts of the input space not covered by the training data. This 108 | behaviour is especially problematic when policy or fairness constraints can be 109 | violated. 110 | 111 | 112 | 113 | Even though common forms of regularization can result in more sensible 114 | extrapolation, standard regularizers cannot guarantee reasonable model behaviour 115 | across the entire input space, especially with high-dimensional inputs. 116 | Switching to simpler models with more controlled and predictable behaviour can 117 | come at a severe cost to the model accuracy. 118 | 119 | TF Lattice makes it possible to keep using flexible models, but provides several 120 | options to inject domain knowledge into the learning process through 121 | semantically meaningful common-sense or policy-driven 122 | [shape constraints](tutorials/shape_constraints.ipynb): 123 | 124 | * **Monotonicity**: You can specify that the output should only 125 | increase/decrease with respect to an input. In our example, you may want to 126 | specify that increased distance to a coffee shop should only decrease the 127 | predicted user preference. 128 | 129 |

130 | 131 | 132 | 133 | 134 |

135 | 136 | * **Convexity/Concavity**: You can specify that the function shape can be 137 | convex or concave. Mixed with monotonicity, this can force the function to 138 | represent diminishing returns with respect to a given feature. 139 | 140 | * **Unimodality**: You can specify that the function should have a unique peak 141 | or unique valley. This lets you represent functions that have a *sweet spot* 142 | with respect to a feature. 143 | 144 | * **Pairwise trust**: This constraint works on a pair of features and suggests 145 | that one input feature semantically reflects trust in another feature. For 146 | example, higher number of reviews makes you more confident in the average 147 | star rating of a restaurant. The model will be more sensitive with respect 148 | to the star rating (i.e. will have a larger slope with respect to the 149 | rating) when the number of reviews is higher. 150 | 151 | ### Controlled Flexibility with Regularizers 152 | 153 | In addition to shape constraints, TensorFlow lattice provides a number of 154 | regularizers to control the flexibility and smoothness of the function for each 155 | layer. 156 | 157 | * **Laplacian Regularizer**: Outputs of the lattice/calibration 158 | vertices/keypoints are regularized towards the values of their respective 159 | neighbors. This results in a *flatter* function. 160 | 161 | * **Hessian Regularizer**: This penalizes the first derivative of the PWL 162 | calibration layer to make the function *more linear*. 163 | 164 | * **Wrinkle Regularizer**: This penalizes the second derivative of the PWL 165 | calibration layer to avoid sudden changes in the curvature. It makes the 166 | function smoother. 167 | 168 | * **Torsion Regularizer**: Outputs of the lattice will be regularized towards 169 | preventing torsion among the features. In other words, the model will be 170 | regularized towards independence between the contributions of the features. 171 | 172 | ### Mix and match with other Keras layers 173 | 174 | You can use TF Lattice layers in combination with other Keras layers to 175 | construct partially constrained or regularized models. For example, lattice or 176 | PWL calibration layers can be used at the last layer of deeper networks that 177 | include embeddings or other Keras layers. 178 | 179 | ## Papers 180 | 181 | * [Deontological Ethics By Monotonicity Shape Constraints](https://arxiv.org/abs/2001.11990), 182 | Serena Wang, Maya Gupta, International Conference on Artificial Intelligence 183 | and Statistics (AISTATS), 2020 184 | * [Shape Constraints for Set Functions](http://proceedings.mlr.press/v97/cotter19a.html), 185 | Andrew Cotter, Maya Gupta, H. Jiang, Erez Louidor, Jim Muller, Taman 186 | Narayan, Serena Wang, Tao Zhu. International Conference on Machine Learning 187 | (ICML), 2019 188 | * [Diminishing Returns Shape Constraints for Interpretability and 189 | Regularization](https://papers.nips.cc/paper/7916-diminishing-returns-shape-constraints-for-interpretability-and-regularization), 190 | Maya Gupta, Dara Bahri, Andrew Cotter, Kevin Canini, Advances in Neural 191 | Information Processing Systems (NeurIPS), 2018 192 | * [Deep Lattice Networks and Partial Monotonic Functions](https://research.google.com/pubs/pub46327.html), 193 | Seungil You, Kevin Canini, David Ding, Jan Pfeifer, Maya R. Gupta, Advances 194 | in Neural Information Processing Systems (NeurIPS), 2017 195 | * [Fast and Flexible Monotonic Functions with Ensembles of Lattices](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices), 196 | Mahdi Milani Fard, Kevin Canini, Andrew Cotter, Jan Pfeifer, Maya Gupta, 197 | Advances in Neural Information Processing Systems (NeurIPS), 2016 198 | * [Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html), 199 | Maya Gupta, Andrew Cotter, Jan Pfeifer, Konstantin Voevodski, Kevin Canini, 200 | Alexander Mangylov, Wojciech Moczydlowski, Alexander van Esbroeck, Journal 201 | of Machine Learning Research (JMLR), 2016 202 | * [Optimized Regression for Efficient Function Evaluation](http://ieeexplore.ieee.org/document/6203580/), 203 | Eric Garcia, Raman Arora, Maya R. Gupta, IEEE Transactions on Image 204 | Processing, 2012 205 | * [Lattice Regression](https://papers.nips.cc/paper/3694-lattice-regression), 206 | Eric Garcia, Maya Gupta, Advances in Neural Information Processing Systems 207 | (NeurIPS), 2009 208 | 209 | ## Tutorials and API docs 210 | 211 | For common model architectures, you can use 212 | [Keras premade models](tutorials/premade_models.ipynb). You can also create 213 | custom models using [TF Lattice Keras layers](tutorials/keras_layers.ipynb) or 214 | mix and match with other Keras layers. Check out the 215 | [full API docs](https://www.tensorflow.org/lattice/api_docs/python/tfl) for 216 | details. 217 | -------------------------------------------------------------------------------- /examples/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Lattice Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | load("//third_party/bazel_rules/rules_python/python:py_binary.bzl", "py_binary") 17 | 18 | licenses(["notice"]) 19 | 20 | package( 21 | default_visibility = [ 22 | "//tensorflow_lattice:__subpackages__", 23 | ], 24 | ) 25 | 26 | py_binary( 27 | name = "keras_sequential_uci_heart", 28 | srcs = ["keras_sequential_uci_heart.py"], 29 | python_version = "PY3", 30 | deps = [ 31 | # tensorflow dep, 32 | "//tensorflow_lattice", 33 | ], 34 | ) 35 | 36 | py_binary( 37 | name = "keras_functional_uci_heart", 38 | srcs = ["keras_functional_uci_heart.py"], 39 | python_version = "PY3", 40 | deps = [ 41 | # tensorflow dep, 42 | "//tensorflow_lattice", 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /examples/keras_sequential_uci_heart.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example usage of TFL within Keras models. 16 | 17 | This example builds and trains a calibrated lattice model for the UCI heart 18 | dataset. 19 | 20 | "Calibrated lattice" is a commonly used architecture for datasets where number 21 | of input features does not exceed ~15. 22 | 23 | "Calibrated lattice" assumes every feature being transformed by PWLCalibration 24 | or CategoricalCalibration layers before nonlineary fusing result of calibration 25 | within a lattice layer. 26 | 27 | Generally when you manually combine TFL layers you should keep track of: 28 | 1) Ensuring that inputs to TFL layers are within expected range. 29 | - Input range for PWLCalibration layer is defined by smallest and largest of 30 | provided keypoints. 31 | - Input range for Lattice layer is [0.0, lattice_sizes[d] - 1.0] for any 32 | dimension d. 33 | TFL layers can constraint their output to be within desired range. Feeding 34 | output of other layers into TFL layers you might want to ensure that something 35 | like sigmoid is used to constraint their output range. 36 | 2) Properly configure monotonicity. If your calibration layer is monotonic then 37 | corresponding dimension of lattice layer should also be monotonic. 38 | 39 | This example creates a Sequential Keras model and only uses TFL layers. For an 40 | example of functional model construction that also use embedding layers see 41 | keras_functional_uci_heart.py. 42 | 43 | In order to see how better generalization can be achieved with a properly 44 | constrained PWLCalibration layer compared to a vanila embedding layer, compare 45 | training and validation losses of this model with one defined in 46 | keras_functional_uci_heart.py 47 | 48 | 49 | Note that the specifics of layer configurations are for demonstration purposes 50 | and might not result in optimal performance. 51 | 52 | Example usage: 53 | keras_sequential_uci_heart 54 | """ 55 | 56 | from __future__ import absolute_import 57 | from __future__ import division 58 | from __future__ import print_function 59 | 60 | from absl import app 61 | from absl import flags 62 | 63 | import numpy as np 64 | import pandas as pd 65 | 66 | import tensorflow as tf 67 | import tensorflow_lattice as tfl 68 | # pylint: disable=g-import-not-at-top 69 | # Use Keras 2. 70 | version_fn = getattr(tf.keras, 'version', None) 71 | if version_fn and version_fn().startswith('3.'): 72 | import tf_keras as keras 73 | else: 74 | keras = tf.keras 75 | 76 | FLAGS = flags.FLAGS 77 | flags.DEFINE_integer('num_epochs', 200, 'Number of training epoch.') 78 | 79 | 80 | def main(_): 81 | # UCI Statlog (Heart) dataset. 82 | csv_file = keras.utils.get_file( 83 | 'heart.csv', 84 | 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv', 85 | ) 86 | training_data_df = pd.read_csv(csv_file).sample( 87 | frac=1.0, random_state=41).reset_index(drop=True) 88 | 89 | # Feature columns. 90 | # 0 age 91 | # 1 sex 92 | # 2 cp chest pain type (4 values) 93 | # 3 trestbps resting blood pressure 94 | # 4 chol serum cholestoral in mg/dl 95 | # 5 fbs fasting blood sugar > 120 mg/dl 96 | # 6 restecg resting electrocardiographic results (values 0,1,2) 97 | # 7 thalach maximum heart rate achieved 98 | # 8 exang exercise induced angina 99 | # 9 oldpeak ST depression induced by exercise relative to rest 100 | # 10 slope the slope of the peak exercise ST segment 101 | # 11 ca number of major vessels (0-3) colored by flourosopy 102 | # 12 thal 3 = normal; 6 = fixed defect; 7 = reversable defect 103 | 104 | # Example slice of training data: 105 | # age sex cp trestbps chol fbs restecg thalach exang oldpeak 106 | # 0 63 1 1 145 233 1 2 150 0 2.3 107 | # 1 67 1 4 160 286 0 2 108 1 1.5 108 | # 2 67 1 4 120 229 0 2 129 1 2.6 109 | # 3 37 1 3 130 250 0 0 187 0 3.5 110 | # 4 41 0 2 130 204 0 2 172 0 1.4 111 | # 5 56 1 2 120 236 0 0 178 0 0.8 112 | # 6 62 0 4 140 268 0 2 160 0 3.6 113 | # 7 57 0 4 120 354 0 0 163 1 0.6 114 | # 8 63 1 4 130 254 0 2 147 0 1.4 115 | # 9 53 1 4 140 203 1 2 155 1 3.1 116 | 117 | # Lattice sizes per dimension for Lattice layer. 118 | # Lattice layer expects input[i] to be within [0, lattice_sizes[i] - 1.0], so 119 | # we need to define lattice sizes ahead of calibration layers so we can 120 | # properly specify output range of calibration layers. 121 | lattice_sizes = [3, 2, 2, 2, 2, 2, 2] 122 | 123 | # Use ParallelCombination helper layer to group togehter calibration layers 124 | # which have to be executed in parallel in order to be able to use Sequential 125 | # model. Alternatively you can use functional API. 126 | combined_calibrators = tfl.layers.ParallelCombination() 127 | 128 | # Configure calibration layers for every feature: 129 | 130 | # ############### age ############### 131 | 132 | calibrator = tfl.layers.PWLCalibration( 133 | # Every PWLCalibration layer must have keypoints of piecewise linear 134 | # function specified. Easiest way to specify them is to uniformly cover 135 | # entire input range by using numpy.linspace(). 136 | input_keypoints=np.linspace(training_data_df['age'].min(), 137 | training_data_df['age'].max(), 138 | num=5), 139 | # You need to ensure that input keypoints have same dtype as layer input. 140 | # You can do it by setting dtype here or by providing keypoints in such 141 | # format which will be converted to desired tf.dtype by default. 142 | dtype=tf.float32, 143 | # Output range must correspond to expected lattice input range. 144 | output_min=0.0, 145 | output_max=lattice_sizes[0] - 1.0, 146 | monotonicity='increasing') 147 | combined_calibrators.append(calibrator) 148 | 149 | # ############### sex ############### 150 | 151 | # For boolean features simply specify CategoricalCalibration layer with 2 152 | # buckets. 153 | calibrator = tfl.layers.CategoricalCalibration( 154 | num_buckets=2, 155 | output_min=0.0, 156 | output_max=lattice_sizes[1] - 1.0, 157 | # Initializes all outputs to (output_min + output_max) / 2.0. 158 | kernel_initializer='constant') 159 | combined_calibrators.append(calibrator) 160 | 161 | # ############### cp ############### 162 | 163 | calibrator = tfl.layers.PWLCalibration( 164 | # Here instead of specifying dtype of layer we convert keypoints into 165 | # np.float32. 166 | input_keypoints=np.linspace(1, 4, num=4, dtype=np.float32), 167 | output_min=0.0, 168 | output_max=lattice_sizes[2] - 1.0, 169 | monotonicity='increasing', 170 | # You can specify TFL regularizers as tuple ('regularizer name', l1, l2). 171 | kernel_regularizer=('hessian', 0.0, 1e-4)) 172 | combined_calibrators.append(calibrator) 173 | 174 | # ############### trestbps ############### 175 | 176 | calibrator = tfl.layers.PWLCalibration( 177 | # Alternatively to uniform keypoints you might want to use quantiles as 178 | # keypoints. 179 | input_keypoints=np.quantile( 180 | training_data_df['trestbps'], np.linspace(0.0, 1.0, num=5)), 181 | dtype=tf.float32, 182 | # Together with quantile keypoints you might want to initialize piecewise 183 | # linear function to have 'equal_slopes' in order for output of layer 184 | # after initialization to preserve original distribution. 185 | kernel_initializer='equal_slopes', 186 | output_min=0.0, 187 | output_max=lattice_sizes[3] - 1.0, 188 | # You might consider clamping extreme inputs of the calibrator to output 189 | # bounds. 190 | clamp_min=True, 191 | clamp_max=True, 192 | monotonicity='increasing') 193 | combined_calibrators.append(calibrator) 194 | 195 | # ############### chol ############### 196 | 197 | calibrator = tfl.layers.PWLCalibration( 198 | # Explicit input keypoint initialization. 199 | input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0], 200 | dtype=tf.float32, 201 | output_min=0.0, 202 | output_max=lattice_sizes[4] - 1.0, 203 | # Monotonicity of calibrator can be 'decreasing'. Note that corresponding 204 | # lattice dimension must have 'increasing' monotonicity regardless of 205 | # monotonicity direction of calibrator. 206 | # It's not some weird configuration hack. It's just how math works :) 207 | monotonicity='decreasing', 208 | # Convexity together with decreasing monotonicity result in diminishing 209 | # return constraint. 210 | convexity='convex', 211 | # You can specify list of regularizers. You are not limited to TFL 212 | # regularizrs. Feel free to use any :) 213 | kernel_regularizer=[('laplacian', 0.0, 1e-4), 214 | keras.regularizers.l1_l2(l1=0.001)]) 215 | combined_calibrators.append(calibrator) 216 | 217 | # ############### fbs ############### 218 | 219 | calibrator = tfl.layers.CategoricalCalibration( 220 | num_buckets=2, 221 | output_min=0.0, 222 | output_max=lattice_sizes[5] - 1.0, 223 | # For categorical calibration layer monotonicity is specified for pairs 224 | # of indices of categories. Output for first category in pair will be 225 | # smaller than output for second category. 226 | # 227 | # Don't forget to set monotonicity of corresponding dimension of Lattice 228 | # layer to 'increasing'. 229 | monotonicities=[(0, 1)], 230 | # This initializer is identical to default one('uniform'), but has fixed 231 | # seed in order to simplify experimentation. 232 | kernel_initializer=keras.initializers.RandomUniform( 233 | minval=0.0, maxval=lattice_sizes[5] - 1.0, seed=1)) 234 | combined_calibrators.append(calibrator) 235 | 236 | # ############### restecg ############### 237 | 238 | calibrator = tfl.layers.CategoricalCalibration( 239 | num_buckets=3, 240 | output_min=0.0, 241 | output_max=lattice_sizes[6] - 1.0, 242 | # Categorical monotonicity can be partial order. 243 | monotonicities=[(0, 1), (0, 2)], 244 | # Categorical calibration layer supports standard Keras regularizers. 245 | kernel_regularizer=keras.regularizers.l1_l2(l1=0.001), 246 | kernel_initializer='constant') 247 | combined_calibrators.append(calibrator) 248 | 249 | # Create Lattice layer to nonlineary fuse output of calibrators. Don't forget 250 | # to specify monotonicity 'increasing' for any dimension which calibrator is 251 | # monotonic regardless of monotonicity direction of calibrator. This includes 252 | # partial monotonicity of CategoricalCalibration layer. 253 | lattice = tfl.layers.Lattice( 254 | lattice_sizes=lattice_sizes, 255 | monotonicities=['increasing', 'none', 'increasing', 'increasing', 256 | 'increasing', 'increasing', 'increasing'], 257 | output_min=0.0, 258 | output_max=1.0) 259 | 260 | model = keras.models.Sequential() 261 | # We have just 2 layer as far as Sequential model is concerned. 262 | # PWLConcatenate layer takes care of grouping calibrators. 263 | model.add(combined_calibrators) 264 | model.add(lattice) 265 | model.compile(loss=keras.losses.mean_squared_error, 266 | optimizer=keras.optimizers.Adagrad(learning_rate=1.0)) 267 | 268 | features = training_data_df[ 269 | ['age', 'sex', 'cp', 270 | 'trestbps', 'chol', 'fbs', 'restecg']].values.astype(np.float32) 271 | target = training_data_df[['target']].values.astype(np.float32) 272 | 273 | model.fit(features, 274 | target, 275 | batch_size=32, 276 | epochs=FLAGS.num_epochs, 277 | validation_split=0.2, 278 | shuffle=False) 279 | 280 | 281 | if __name__ == '__main__': 282 | app.run(main) 283 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Lattice Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | # use this file except in compliance with the License. You may obtain a copy of 5 | # the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations under 13 | # the License. 14 | # ============================================================================== 15 | """Package setup script for TensorFlow Lattice library.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import datetime 22 | import sys 23 | 24 | from setuptools import find_packages 25 | from setuptools import setup 26 | 27 | # This version number should always be that of the *next* (unreleased) version. 28 | # Immediately after uploading a package to PyPI, you should increment the 29 | # version number and push to gitHub. 30 | __version__ = "2.1.1" 31 | 32 | if "--release" in sys.argv: 33 | sys.argv.remove("--release") 34 | _name = "tensorflow_lattice" 35 | else: 36 | # Build a nightly package by default. 37 | _name = "tensorflow_lattice_nightly" 38 | __version__ += datetime.datetime.now().strftime(".dev%Y%m%d") 39 | 40 | _install_requires = [ 41 | "absl-py", 42 | "numpy", 43 | "pandas", 44 | "six", 45 | "scikit-learn", 46 | "matplotlib", 47 | "graphviz", 48 | "tf-keras", 49 | ] 50 | 51 | # Part of the visualization code uses colabtools and IPython libraries. These 52 | # are not added as hard requirements as they are mainly used in jupyter/colabs. 53 | 54 | _extras_require = { 55 | "tensorflow": "tensorflow>=1.15", 56 | } 57 | 58 | _classifiers = [ 59 | "Development Status :: 4 - Beta", 60 | "Intended Audience :: Developers", 61 | "Intended Audience :: Education", 62 | "Intended Audience :: Science/Research", 63 | "License :: OSI Approved :: Apache Software License", 64 | "Operating System :: OS Independent", 65 | "Programming Language :: Python", 66 | "Programming Language :: Python :: 2", 67 | "Programming Language :: Python :: 3", 68 | "Topic :: Scientific/Engineering", 69 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 70 | "Topic :: Scientific/Engineering :: Mathematics", 71 | "Topic :: Software Development", 72 | "Topic :: Software Development :: Libraries", 73 | "Topic :: Software Development :: Libraries :: Python Modules", 74 | ] 75 | 76 | _description = ( 77 | "A library that implements optionally monotonic lattice based models.") 78 | _long_description = """\ 79 | TensorFlow Lattice is a library that implements fast-to-evaluate and 80 | interpretable (optionally monotonic) lattice based models, which are also known 81 | as *interpolated look-up tables*. The library includes a collection of Keras 82 | layers for lattices and feature calibration that can be composed into custom 83 | models or used inside generic premade models. 84 | """ 85 | 86 | setup( 87 | name=_name, 88 | version=__version__, 89 | author="Google Inc.", 90 | author_email="no-reply@google.com", 91 | license="Apache 2.0", 92 | classifiers=_classifiers, 93 | install_requires=_install_requires, 94 | extras_require=_extras_require, 95 | packages=find_packages(), 96 | include_package_data=True, 97 | description=_description, 98 | long_description=_long_description, 99 | long_description_content_type="text/markdown", 100 | keywords="tensorflow lattice calibration machine learning", 101 | url=( 102 | "https://github.com/tensorflow/lattice" 103 | ), 104 | ) 105 | -------------------------------------------------------------------------------- /tensorflow_lattice/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Lattice Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | package( 17 | default_visibility = [ 18 | "//visibility:public", 19 | ], 20 | ) 21 | 22 | licenses(["notice"]) 23 | 24 | exports_files(["LICENSE"]) 25 | 26 | py_library( 27 | name = "tensorflow_lattice", 28 | srcs = [ 29 | "__init__.py", 30 | "layers/__init__.py", 31 | ], 32 | srcs_version = "PY2AND3", 33 | deps = [ 34 | "//tensorflow_lattice/python:aggregation_layer", 35 | "//tensorflow_lattice/python:categorical_calibration_layer", 36 | "//tensorflow_lattice/python:categorical_calibration_lib", 37 | "//tensorflow_lattice/python:cdf_layer", 38 | "//tensorflow_lattice/python:conditional_cdf", 39 | "//tensorflow_lattice/python:conditional_pwl_calibration", 40 | "//tensorflow_lattice/python:configs", 41 | "//tensorflow_lattice/python:kronecker_factored_lattice_layer", 42 | "//tensorflow_lattice/python:kronecker_factored_lattice_lib", 43 | "//tensorflow_lattice/python:lattice_layer", 44 | "//tensorflow_lattice/python:lattice_lib", 45 | "//tensorflow_lattice/python:linear_layer", 46 | "//tensorflow_lattice/python:linear_lib", 47 | "//tensorflow_lattice/python:model_info", 48 | "//tensorflow_lattice/python:parallel_combination_layer", 49 | "//tensorflow_lattice/python:premade", 50 | "//tensorflow_lattice/python:premade_lib", 51 | "//tensorflow_lattice/python:pwl_calibration_layer", 52 | "//tensorflow_lattice/python:pwl_calibration_lib", 53 | "//tensorflow_lattice/python:rtl_layer", 54 | "//tensorflow_lattice/python:test_utils", 55 | "//tensorflow_lattice/python:utils", 56 | ], 57 | ) 58 | -------------------------------------------------------------------------------- /tensorflow_lattice/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tensorflow Lattice Library. 16 | 17 | This package provides functions and classes for lattice modeling. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | 22 | import tensorflow_lattice.layers 23 | from tensorflow_lattice.python import aggregation_layer 24 | from tensorflow_lattice.python import categorical_calibration_layer 25 | from tensorflow_lattice.python import categorical_calibration_lib 26 | from tensorflow_lattice.python import cdf_layer 27 | from tensorflow_lattice.python import conditional_cdf 28 | from tensorflow_lattice.python import conditional_pwl_calibration 29 | from tensorflow_lattice.python import configs 30 | from tensorflow_lattice.python import kronecker_factored_lattice_layer 31 | from tensorflow_lattice.python import kronecker_factored_lattice_lib 32 | from tensorflow_lattice.python import lattice_layer 33 | from tensorflow_lattice.python import lattice_lib 34 | from tensorflow_lattice.python import linear_layer 35 | from tensorflow_lattice.python import linear_lib 36 | from tensorflow_lattice.python import model_info 37 | from tensorflow_lattice.python import parallel_combination_layer 38 | from tensorflow_lattice.python import premade 39 | from tensorflow_lattice.python import premade_lib 40 | from tensorflow_lattice.python import pwl_calibration_layer 41 | from tensorflow_lattice.python import pwl_calibration_lib 42 | from tensorflow_lattice.python import test_utils 43 | from tensorflow_lattice.python import utils 44 | -------------------------------------------------------------------------------- /tensorflow_lattice/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """'layers' namespace for TFL layers.""" 16 | 17 | from tensorflow_lattice.python.aggregation_layer import Aggregation 18 | from tensorflow_lattice.python.categorical_calibration_layer import CategoricalCalibration 19 | from tensorflow_lattice.python.cdf_layer import CDF 20 | from tensorflow_lattice.python.kronecker_factored_lattice_layer import KroneckerFactoredLattice 21 | from tensorflow_lattice.python.lattice_layer import Lattice 22 | from tensorflow_lattice.python.linear_layer import Linear 23 | from tensorflow_lattice.python.parallel_combination_layer import ParallelCombination 24 | from tensorflow_lattice.python.pwl_calibration_layer import PWLCalibration 25 | from tensorflow_lattice.python.rtl_layer import RTL 26 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Lattice Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | load("//third_party/bazel_rules/rules_python/python:py_library.bzl", "py_library") 17 | load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test") 18 | 19 | package( 20 | default_visibility = [ 21 | "//tensorflow_lattice:__subpackages__", 22 | ], 23 | ) 24 | 25 | licenses(["notice"]) 26 | 27 | # Build rules are alphabetized. Please add new rules alphabetically 28 | # to maintain the ordering. 29 | py_library( 30 | name = "aggregation_layer", 31 | srcs = ["aggregation_layer.py"], 32 | srcs_version = "PY2AND3", 33 | deps = [ 34 | # tensorflow dep, 35 | ], 36 | ) 37 | 38 | py_test( 39 | name = "aggregation_test", 40 | srcs = ["aggregation_test.py"], 41 | python_version = "PY3", 42 | srcs_version = "PY2AND3", 43 | deps = [ 44 | ":aggregation_layer", 45 | # tensorflow dep, 46 | ], 47 | ) 48 | 49 | py_library( 50 | name = "categorical_calibration_layer", 51 | srcs = ["categorical_calibration_layer.py"], 52 | srcs_version = "PY2AND3", 53 | deps = [ 54 | ":categorical_calibration_lib", 55 | # tensorflow:tensorflow_no_contrib dep, 56 | ], 57 | ) 58 | 59 | py_library( 60 | name = "categorical_calibration_lib", 61 | srcs = ["categorical_calibration_lib.py"], 62 | srcs_version = "PY2AND3", 63 | deps = [ 64 | ":internal_utils", 65 | # tensorflow:tensorflow_no_contrib dep, 66 | ], 67 | ) 68 | 69 | py_test( 70 | name = "categorical_calibration_test", 71 | size = "large", 72 | srcs = ["categorical_calibration_test.py"], 73 | python_version = "PY3", 74 | # shard_count = 4, 75 | srcs_version = "PY2AND3", 76 | deps = [ 77 | ":categorical_calibration_layer", 78 | ":parallel_combination_layer", 79 | ":test_utils", 80 | # absl/logging dep, 81 | # absl/testing:parameterized dep, 82 | # numpy dep, 83 | # tensorflow dep, 84 | ], 85 | ) 86 | 87 | py_library( 88 | name = "cdf_layer", 89 | srcs = ["cdf_layer.py"], 90 | srcs_version = "PY3", 91 | deps = [ 92 | ":utils", 93 | # tensorflow dep, 94 | ], 95 | ) 96 | 97 | py_test( 98 | name = "cdf_test", 99 | size = "large", 100 | srcs = ["cdf_test.py"], 101 | python_version = "PY3", 102 | # shard_count = 12, 103 | srcs_version = "PY3", 104 | deps = [ 105 | ":cdf_layer", 106 | ":test_utils", 107 | ":utils", 108 | # absl/logging dep, 109 | # absl/testing:parameterized dep, 110 | # numpy dep, 111 | # tensorflow dep, 112 | ], 113 | ) 114 | 115 | py_library( 116 | name = "configs", 117 | srcs = ["configs.py"], 118 | srcs_version = "PY2AND3", 119 | deps = [ 120 | # absl/logging dep, 121 | # tensorflow dep, 122 | ], 123 | ) 124 | 125 | py_test( 126 | name = "configs_test", 127 | size = "small", 128 | srcs = ["configs_test.py"], 129 | python_version = "PY3", 130 | srcs_version = "PY2AND3", 131 | deps = [ 132 | ":categorical_calibration_layer", 133 | ":configs", 134 | ":lattice_layer", 135 | ":linear_layer", 136 | ":premade", 137 | ":pwl_calibration_layer", 138 | # absl/logging dep, 139 | # tensorflow dep, 140 | ], 141 | ) 142 | 143 | py_library( 144 | name = "internal_utils", 145 | srcs = ["internal_utils.py"], 146 | srcs_version = "PY2AND3", 147 | deps = [ 148 | # tensorflow dep, 149 | ], 150 | ) 151 | 152 | py_test( 153 | name = "internal_utils_test", 154 | srcs = ["internal_utils_test.py"], 155 | python_version = "PY3", 156 | srcs_version = "PY2AND3", 157 | deps = [ 158 | ":internal_utils", 159 | # tensorflow dep, 160 | ], 161 | ) 162 | 163 | py_library( 164 | name = "kronecker_factored_lattice_layer", 165 | srcs = ["kronecker_factored_lattice_layer.py"], 166 | srcs_version = "PY2AND3", 167 | deps = [ 168 | ":kronecker_factored_lattice_lib", 169 | ":utils", 170 | # tensorflow dep, 171 | ], 172 | ) 173 | 174 | py_library( 175 | name = "kronecker_factored_lattice_lib", 176 | srcs = ["kronecker_factored_lattice_lib.py"], 177 | srcs_version = "PY2AND3", 178 | deps = [ 179 | ":utils", 180 | # numpy dep, 181 | # tensorflow dep, 182 | ], 183 | ) 184 | 185 | py_test( 186 | name = "kronecker_factored_lattice_test", 187 | size = "large", 188 | srcs = ["kronecker_factored_lattice_test.py"], 189 | python_version = "PY3", 190 | # shard_count = 12, 191 | srcs_version = "PY2AND3", 192 | deps = [ 193 | ":kronecker_factored_lattice_layer", 194 | ":kronecker_factored_lattice_lib", 195 | ":test_utils", 196 | # absl/logging dep, 197 | # absl/testing:parameterized dep, 198 | # numpy dep, 199 | # tensorflow dep, 200 | ], 201 | ) 202 | 203 | py_library( 204 | name = "lattice_layer", 205 | srcs = ["lattice_layer.py"], 206 | srcs_version = "PY2AND3", 207 | deps = [ 208 | ":lattice_lib", 209 | ":pwl_calibration_layer", 210 | ":utils", 211 | # tensorflow:tensorflow_no_contrib dep, 212 | ], 213 | ) 214 | 215 | py_library( 216 | name = "lattice_lib", 217 | srcs = ["lattice_lib.py"], 218 | srcs_version = "PY2AND3", 219 | deps = [ 220 | ":utils", 221 | # absl/logging dep, 222 | # numpy dep, 223 | # tensorflow:tensorflow_no_contrib dep, 224 | ], 225 | ) 226 | 227 | py_test( 228 | name = "lattice_test", 229 | size = "large", 230 | srcs = ["lattice_test.py"], 231 | python_version = "PY3", 232 | # shard_count = 12, 233 | srcs_version = "PY2AND3", 234 | deps = [ 235 | ":lattice_layer", 236 | ":test_utils", 237 | # absl/logging dep, 238 | # absl/testing:parameterized dep, 239 | # numpy dep, 240 | # tensorflow dep, 241 | ], 242 | ) 243 | 244 | py_library( 245 | name = "linear_layer", 246 | srcs = ["linear_layer.py"], 247 | srcs_version = "PY2AND3", 248 | deps = [ 249 | ":linear_lib", 250 | ":utils", 251 | # tensorflow:tensorflow_no_contrib dep, 252 | ], 253 | ) 254 | 255 | py_library( 256 | name = "linear_lib", 257 | srcs = ["linear_lib.py"], 258 | srcs_version = "PY2AND3", 259 | deps = [ 260 | ":internal_utils", 261 | ":utils", 262 | # tensorflow:tensorflow_no_contrib dep, 263 | ], 264 | ) 265 | 266 | py_test( 267 | name = "linear_test", 268 | size = "large", 269 | srcs = ["linear_test.py"], 270 | python_version = "PY3", 271 | srcs_version = "PY2AND3", 272 | deps = [ 273 | ":linear_layer", 274 | ":test_utils", 275 | ":utils", 276 | # absl/logging dep, 277 | # absl/testing:parameterized dep, 278 | # numpy dep, 279 | # tensorflow dep, 280 | ], 281 | ) 282 | 283 | py_library( 284 | name = "model_info", 285 | srcs = ["model_info.py"], 286 | srcs_version = "PY2AND3", 287 | ) 288 | 289 | py_library( 290 | name = "parallel_combination_layer", 291 | srcs = ["parallel_combination_layer.py"], 292 | srcs_version = "PY2AND3", 293 | deps = [ 294 | ":categorical_calibration_layer", 295 | ":lattice_layer", 296 | ":linear_layer", 297 | ":pwl_calibration_layer", 298 | # tensorflow:tensorflow_no_contrib dep, 299 | ], 300 | ) 301 | 302 | py_test( 303 | name = "parallel_combination_test", 304 | size = "large", 305 | srcs = ["parallel_combination_test.py"], 306 | python_version = "PY3", 307 | srcs_version = "PY2AND3", 308 | deps = [ 309 | ":lattice_layer", 310 | ":parallel_combination_layer", 311 | # absl/testing:parameterized dep, 312 | # numpy dep, 313 | # tensorflow dep, 314 | ], 315 | ) 316 | 317 | py_library( 318 | name = "premade", 319 | srcs = ["premade.py"], 320 | srcs_version = "PY2AND3", 321 | deps = [ 322 | ":aggregation_layer", 323 | ":categorical_calibration_layer", 324 | ":configs", 325 | ":kronecker_factored_lattice_layer", 326 | ":lattice_layer", 327 | ":parallel_combination_layer", 328 | ":premade_lib", 329 | ":pwl_calibration_layer", 330 | # absl/logging dep, 331 | # tensorflow dep, 332 | ], 333 | ) 334 | 335 | py_library( 336 | name = "premade_lib", 337 | srcs = ["premade_lib.py"], 338 | srcs_version = "PY2AND3", 339 | deps = [ 340 | ":aggregation_layer", 341 | ":categorical_calibration_layer", 342 | ":configs", 343 | ":kronecker_factored_lattice_layer", 344 | ":kronecker_factored_lattice_lib", 345 | ":lattice_layer", 346 | ":lattice_lib", 347 | ":linear_layer", 348 | ":pwl_calibration_layer", 349 | ":rtl_layer", 350 | ":utils", 351 | # absl/logging dep, 352 | # numpy dep, 353 | # six dep, 354 | # tensorflow dep, 355 | ], 356 | ) 357 | 358 | py_test( 359 | name = "premade_test", 360 | size = "large", 361 | srcs = ["premade_test.py"], 362 | python_version = "PY3", 363 | # shard_count = 10, 364 | srcs_version = "PY2AND3", 365 | deps = [ 366 | ":configs", 367 | ":premade", 368 | ":premade_lib", 369 | # absl/logging dep, 370 | # absl/testing:parameterized dep, 371 | # numpy dep, 372 | # tensorflow dep, 373 | ], 374 | ) 375 | 376 | py_library( 377 | name = "pwl_calibration_layer", 378 | srcs = ["pwl_calibration_layer.py"], 379 | srcs_version = "PY2AND3", 380 | deps = [ 381 | ":pwl_calibration_lib", 382 | ":utils", 383 | # absl/logging dep, 384 | # tensorflow:tensorflow_no_contrib dep, 385 | ], 386 | ) 387 | 388 | py_library( 389 | name = "pwl_calibration_lib", 390 | srcs = ["pwl_calibration_lib.py"], 391 | srcs_version = "PY2AND3", 392 | deps = [ 393 | ":utils", 394 | # tensorflow:tensorflow_no_contrib dep, 395 | ], 396 | ) 397 | 398 | py_test( 399 | name = "pwl_calibration_test", 400 | size = "large", 401 | srcs = ["pwl_calibration_test.py"], 402 | python_version = "PY3", 403 | # shard_count = 12, 404 | srcs_version = "PY2AND3", 405 | deps = [ 406 | ":parallel_combination_layer", 407 | ":pwl_calibration_layer", 408 | ":test_utils", 409 | ":utils", 410 | # absl/logging dep, 411 | # absl/testing:parameterized dep, 412 | # numpy dep, 413 | # tensorflow dep, 414 | # tensorflow:tensorflow_no_contrib dep, 415 | ], 416 | ) 417 | 418 | py_library( 419 | name = "rtl_layer", 420 | srcs = ["rtl_layer.py"], 421 | srcs_version = "PY2AND3", 422 | deps = [ 423 | ":kronecker_factored_lattice_layer", 424 | ":lattice_layer", 425 | ":rtl_lib", 426 | # tensorflow:tensorflow_no_contrib dep, 427 | ], 428 | ) 429 | 430 | py_library( 431 | name = "rtl_lib", 432 | srcs = ["rtl_lib.py"], 433 | srcs_version = "PY2AND3", 434 | deps = [ 435 | # six dep, 436 | ], 437 | ) 438 | 439 | py_test( 440 | name = "rtl_test", 441 | size = "large", 442 | srcs = ["rtl_test.py"], 443 | python_version = "PY3", 444 | srcs_version = "PY2AND3", 445 | deps = [ 446 | ":linear_layer", 447 | ":pwl_calibration_layer", 448 | ":rtl_layer", 449 | # absl/testing:parameterized dep, 450 | # numpy dep, 451 | # tensorflow dep, 452 | ], 453 | ) 454 | 455 | py_library( 456 | name = "test_utils", 457 | srcs = ["test_utils.py"], 458 | srcs_version = "PY2AND3", 459 | deps = [ 460 | # absl/logging dep, 461 | # numpy dep, 462 | ], 463 | ) 464 | 465 | py_library( 466 | name = "utils", 467 | srcs = ["utils.py"], 468 | srcs_version = "PY2AND3", 469 | deps = [ 470 | # six dep, 471 | ], 472 | ) 473 | 474 | py_library( 475 | name = "conditional_pwl_calibration", 476 | srcs = ["conditional_pwl_calibration.py"], 477 | deps = [ 478 | # numpy dep, 479 | # tensorflow:tensorflow_no_contrib dep, 480 | ], 481 | ) 482 | 483 | py_library( 484 | name = "conditional_cdf", 485 | srcs = ["conditional_cdf.py"], 486 | deps = [ 487 | # tensorflow:tensorflow_no_contrib dep, 488 | ], 489 | ) 490 | 491 | py_test( 492 | name = "conditional_cdf_test", 493 | srcs = ["conditional_cdf_test.py"], 494 | deps = [ 495 | ":conditional_cdf", 496 | # absl/testing:parameterized dep, 497 | # tensorflow:tensorflow_no_contrib dep, 498 | ], 499 | ) 500 | 501 | py_test( 502 | name = "conditional_pwl_calibration_test", 503 | srcs = ["conditional_pwl_calibration_test.py"], 504 | deps = [ 505 | ":conditional_pwl_calibration", 506 | # tensorflow:tensorflow_no_contrib dep, 507 | ], 508 | ) 509 | 510 | py_test( 511 | name = "utils_test", 512 | srcs = ["utils_test.py"], 513 | python_version = "PY3", 514 | srcs_version = "PY2AND3", 515 | deps = [ 516 | ":utils", 517 | # absl/testing:parameterized dep, 518 | # tensorflow dep, 519 | ], 520 | ) 521 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """TensorFlow Lattice python package.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/aggregation_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Layer which represents aggregation function. 15 | 16 | See class level comment. 17 | 18 | This layer applies the provided model to the ragged input tensor and aggregates 19 | the results. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | # pylint: disable=g-import-not-at-top 28 | # Use Keras 2. 29 | version_fn = getattr(tf.keras, 'version', None) 30 | if version_fn and version_fn().startswith('3.'): 31 | import tf_keras as keras 32 | else: 33 | keras = tf.keras 34 | 35 | 36 | class Aggregation(keras.layers.Layer): 37 | # pyformat: disable 38 | """Layer which represents an aggregation function. 39 | 40 | Calls the model on each of the ragged dimensions and takes the mean. 41 | 42 | Input shape: 43 | A list or dictionary with num_input_dims Rank-2 ragged tensors with 44 | shape: (batch_size, ?) 45 | 46 | Output shape: 47 | Rank-2 tensor with shape: (batch_size, 1) 48 | 49 | Attributes: 50 | - All `__init__ `arguments. 51 | 52 | Example: 53 | 54 | ```python 55 | model = keras.Model(inputs=inputs, outputs=outputs) 56 | layer = tfl.layers.Aggregation(model) 57 | ``` 58 | """ 59 | # pyformat: enable 60 | 61 | def __init__(self, model, **kwargs): 62 | """initializes an instance of `Aggregation`. 63 | 64 | Args: 65 | model: A keras.Model instance. 66 | **kwargs: Other args passed to `keras.layers.Layer` initializer. 67 | 68 | Raises: 69 | ValueError: if model is not at `keras.Model` instance. 70 | """ 71 | if not isinstance(model, keras.Model): 72 | raise ValueError('Model must be a keras.Model instance.') 73 | super(Aggregation, self).__init__(**kwargs) 74 | # This flag enables inputs to be Ragged Tensors 75 | self._supports_ragged_inputs = True 76 | self.model = model 77 | 78 | def call(self, x): 79 | """Standard Keras call() method.""" 80 | return tf.reduce_mean(tf.ragged.map_flat_values(self.model, x), axis=1) 81 | 82 | def get_config(self): 83 | """Standard Keras get_config() method.""" 84 | config = super(Aggregation, self).get_config().copy() 85 | config.update( 86 | {'model': keras.utils.legacy.serialize_keras_object(self.model)} 87 | ) 88 | return config 89 | 90 | @classmethod 91 | def from_config(cls, config, custom_objects=None): 92 | model = keras.utils.legacy.deserialize_keras_object( 93 | config.pop('model'), custom_objects=custom_objects 94 | ) 95 | return cls(model, **config) 96 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/aggregation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Tensorflow Lattice premade.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | from tensorflow_lattice.python import aggregation_layer 22 | # pylint: disable=g-import-not-at-top 23 | # Use Keras 2. 24 | version_fn = getattr(tf.keras, 'version', None) 25 | if version_fn and version_fn().startswith('3.'): 26 | import tf_keras as keras 27 | else: 28 | keras = tf.keras 29 | 30 | 31 | test_input = [ 32 | tf.ragged.constant([[1, 2], [1, 2, 3], [3]]), 33 | tf.ragged.constant([[4, 5], [4, 4, 4], [6]]), 34 | tf.ragged.constant([[1, 6], [5, 5, 5], [9]]) 35 | ] 36 | 37 | expected_output = tf.constant([32, 40, 162]) 38 | 39 | 40 | class AggregationTest(tf.test.TestCase): 41 | 42 | def testAggregationLayer(self): 43 | # First we test our assertion that the model must be a keras.Model 44 | with self.assertRaisesRegex(ValueError, 45 | 'Model must be a keras.Model instance.'): 46 | aggregation_layer.Aggregation(None) 47 | # Now let's make sure our layer aggregates properly. 48 | inputs = [keras.Input(shape=()) for _ in range(len(test_input))] 49 | output = keras.layers.multiply(inputs) 50 | model = keras.Model(inputs=inputs, outputs=output) 51 | agg_layer = aggregation_layer.Aggregation(model) 52 | self.assertAllEqual(agg_layer(test_input), expected_output) 53 | 54 | 55 | if __name__ == '__main__': 56 | tf.test.main() 57 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/categorical_calibration_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Categorical calibration layer with monotonicity and bound constraints. 15 | 16 | Keras implementation of tensorflow lattice categorical calibration layer. This 17 | layer takes single or multi-dimensional input and transforms it using lookup 18 | tables satisfying monotonicity and bounds constraints if specified. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | import tensorflow as tf 25 | # pylint: disable=g-import-not-at-top 26 | # Use Keras 2. 27 | version_fn = getattr(tf.keras, "version", None) 28 | if version_fn and version_fn().startswith("3."): 29 | import tf_keras as keras 30 | else: 31 | keras = tf.keras 32 | from . import categorical_calibration_lib 33 | 34 | DEFAULT_INPUT_VALUE_NAME = "default_input_value" 35 | CATEGORICAL_CALIBRATION_KERNEL_NAME = "categorical_calibration_kernel" 36 | 37 | # TODO: implement variation/variance regularizer. 38 | 39 | 40 | class CategoricalCalibration(keras.layers.Layer): 41 | # pyformat: disable 42 | """Categorical calibration layer with monotonicity and bound constraints. 43 | 44 | This layer takes input of shape `(batch_size, units)` or `(batch_size, 1)` and 45 | transforms it using `units` number of lookup tables satisfying monotonicity 46 | and bounds constraints if specified. If multi dimensional input is provided, 47 | each output will be for the corresponding input, otherwise all calibration 48 | functions will act on the same input. All units share the same layer 49 | configuration, but each one has their separate set of trained parameters. 50 | 51 | Input shape: 52 | Rank-2 tensor with shape: `(batch_size, units)` or `(batch_size, 1)`. 53 | 54 | Output shape: 55 | If units > 1 and split_outputs is True, a length `units` list of Rank-2 56 | tensors with shape `(batch_size, 1)`. Otherwise, a Rank-2 tensor with shape: 57 | `(batch_size, units)` 58 | 59 | Attributes: 60 | - All `__init__` args. 61 | kernel: TF variable of shape `(batch_size, units)` which stores the lookup 62 | table. 63 | 64 | Example: 65 | 66 | ```python 67 | calibrator = tfl.layers.CategoricalCalibration( 68 | # Number of categories. 69 | num_buckets=3, 70 | # Output can be bounded. 71 | output_min=0.0, 72 | output_max=1.0, 73 | # For categorical calibration layer monotonicity is specified for pairs of 74 | # indices of categories. Output for first category in pair will be less 75 | # than or equal to output for second category. 76 | monotonicities=[(0, 1), (0, 2)]) 77 | ``` 78 | 79 | Usage with functional models: 80 | 81 | ```python 82 | input_feature = keras.layers.Input(shape=[1]) 83 | calibrated_feature = tfl.layers.CategoricalCalibration( 84 | num_buckets=3, 85 | output_min=0.0, 86 | output_max=1.0, 87 | monotonicities=[(0, 1), (0, 2)], 88 | )(feature) 89 | ... 90 | model = keras.models.Model( 91 | inputs=[input_feature, ...], 92 | outputs=...) 93 | ``` 94 | """ 95 | # pyformat: enable 96 | 97 | def __init__(self, 98 | num_buckets, 99 | units=1, 100 | output_min=None, 101 | output_max=None, 102 | monotonicities=None, 103 | kernel_initializer="uniform", 104 | kernel_regularizer=None, 105 | default_input_value=None, 106 | split_outputs=False, 107 | **kwargs): 108 | # pyformat: disable 109 | """Initializes a `CategoricalCalibration` instance. 110 | 111 | Args: 112 | num_buckets: Number of categories. 113 | units: Output dimension of the layer. See class comments for details. 114 | output_min: Minimum output of calibrator. 115 | output_max: Maximum output of calibrator. 116 | monotonicities: List of pairs with `(i, j)` indices indicating `output(i)` 117 | should be less than or equal to `output(j)`. 118 | kernel_initializer: None or one of: 119 | - `'uniform'`: If `output_min` and `output_max` are provided initial 120 | values will be uniformly sampled from `[output_min, output_max]` 121 | range. 122 | - `'constant'`: If `output_min` and `output_max` are provided all output 123 | values will be initlized to the constant 124 | `(output_min + output_max) / 2`. 125 | - Any Keras initializer object. 126 | kernel_regularizer: None or single element or list of any Keras 127 | regularizer objects. 128 | default_input_value: If set, all inputs which are equal to this value will 129 | be treated as default and mapped to the last bucket. 130 | split_outputs: Whether to split the output tensor into a list of 131 | outputs for each unit. Ignored if units < 2. 132 | **kwargs: Other args passed to `keras.layers.Layer` initializer. 133 | 134 | Raises: 135 | ValueError: If layer hyperparameters are invalid. 136 | """ 137 | # pyformat: enable 138 | dtype = kwargs.pop("dtype", tf.float32) # output dtype 139 | super(CategoricalCalibration, self).__init__(dtype=dtype, **kwargs) 140 | 141 | categorical_calibration_lib.verify_hyperparameters( 142 | num_buckets=num_buckets, 143 | output_min=output_min, 144 | output_max=output_max, 145 | monotonicities=monotonicities) 146 | self.num_buckets = num_buckets 147 | self.units = units 148 | self.output_min = output_min 149 | self.output_max = output_max 150 | self.monotonicities = monotonicities 151 | if output_min is not None and output_max is not None: 152 | if kernel_initializer == "constant": 153 | kernel_initializer = keras.initializers.Constant( 154 | (output_min + output_max) / 2) 155 | elif kernel_initializer == "uniform": 156 | kernel_initializer = keras.initializers.RandomUniform( 157 | output_min, output_max) 158 | self.kernel_initializer = keras.initializers.get(kernel_initializer) 159 | self.kernel_regularizer = [] 160 | if kernel_regularizer: 161 | if callable(kernel_regularizer): 162 | kernel_regularizer = [kernel_regularizer] 163 | for reg in kernel_regularizer: 164 | self.kernel_regularizer.append(keras.regularizers.get(reg)) 165 | self.default_input_value = default_input_value 166 | self.split_outputs = split_outputs 167 | 168 | def build(self, input_shape): 169 | """Standard Keras build() method.""" 170 | if (self.output_min is not None or self.output_max is not None or 171 | self.monotonicities): 172 | constraints = CategoricalCalibrationConstraints( 173 | output_min=self.output_min, 174 | output_max=self.output_max, 175 | monotonicities=self.monotonicities) 176 | else: 177 | constraints = None 178 | 179 | if not self.kernel_regularizer: 180 | kernel_reg = None 181 | elif len(self.kernel_regularizer) == 1: 182 | kernel_reg = self.kernel_regularizer[0] 183 | else: 184 | # Keras interface assumes only one regularizer, so summ all regularization 185 | # losses which we have. 186 | kernel_reg = lambda x: tf.add_n([r(x) for r in self.kernel_regularizer]) 187 | 188 | # categorical calibration layer kernel is units-column matrix with value of 189 | # output(i) = self.kernel[i]. Default value converted to the last index. 190 | self.kernel = self.add_weight( 191 | CATEGORICAL_CALIBRATION_KERNEL_NAME, 192 | shape=[self.num_buckets, self.units], 193 | initializer=self.kernel_initializer, 194 | regularizer=kernel_reg, 195 | constraint=constraints, 196 | dtype=self.dtype) 197 | 198 | if self.kernel_regularizer and not tf.executing_eagerly(): 199 | # Keras has its own mechanism to handle regularization losses which 200 | # does not use GraphKeys, but we want to also add losses to graph keys so 201 | # they are easily accessable when layer is being used outside of Keras. 202 | # Adding losses to GraphKeys will not interfer with Keras. 203 | for reg in self.kernel_regularizer: 204 | tf.compat.v1.add_to_collection( 205 | tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, reg(self.kernel)) 206 | 207 | super(CategoricalCalibration, self).build(input_shape) 208 | 209 | def call(self, inputs): 210 | """Standard Keras call() method.""" 211 | if inputs.dtype not in [tf.uint8, tf.int32, tf.int64]: 212 | inputs = tf.cast(inputs, dtype=tf.int32) 213 | 214 | if self.default_input_value is not None: 215 | default_input_value_tensor = tf.constant( 216 | int(self.default_input_value), 217 | dtype=inputs.dtype, 218 | name=DEFAULT_INPUT_VALUE_NAME) 219 | replacement = tf.zeros_like(inputs) + (self.num_buckets - 1) 220 | inputs = tf.where( 221 | tf.equal(inputs, default_input_value_tensor), replacement, inputs) 222 | 223 | # We can't use tf.gather_nd(self.kernel, inputs) as it doesn't support 224 | # constraints (constraint functions are not supported for IndexedSlices). 225 | # Instead we use matrix multiplication by one-hot encoding of the index. 226 | if self.units == 1: 227 | # This can be slightly faster as it uses matmul. 228 | return tf.matmul( 229 | tf.one_hot(tf.squeeze(inputs, axis=[-1]), depth=self.num_buckets), 230 | self.kernel) 231 | result = tf.reduce_sum( 232 | tf.one_hot(inputs, axis=1, depth=self.num_buckets) * self.kernel, 233 | axis=1) 234 | 235 | if self.split_outputs: 236 | result = tf.split(result, self.units, axis=1) 237 | 238 | return result 239 | 240 | def compute_output_shape(self, input_shape): 241 | """Standard Keras compute_output_shape() method.""" 242 | del input_shape 243 | if self.units > 1 and self.split_outputs: 244 | return [(None, 1)] * self.units 245 | else: 246 | return (None, self.units) 247 | 248 | def get_config(self): 249 | """Standard Keras config for serialization.""" 250 | config = { 251 | "num_buckets": self.num_buckets, 252 | "units": self.units, 253 | "output_min": self.output_min, 254 | "output_max": self.output_max, 255 | "monotonicities": self.monotonicities, 256 | "kernel_initializer": 257 | keras.initializers.serialize( 258 | self.kernel_initializer, use_legacy_format=True), 259 | "kernel_regularizer": 260 | [keras.regularizers.serialize(r, use_legacy_format=True) 261 | for r in self.kernel_regularizer], 262 | "default_input_value": self.default_input_value, 263 | "split_outputs": self.split_outputs, 264 | } # pyformat: disable 265 | config.update(super(CategoricalCalibration, self).get_config()) 266 | return config 267 | 268 | def assert_constraints(self, eps=1e-6): 269 | """Asserts that layer weights satisfy all constraints. 270 | 271 | In graph mode builds and returns list of assertion ops. Note that ops will 272 | be created at the moment when this function is being called. 273 | In eager mode directly executes assertions. 274 | 275 | Args: 276 | eps: Allowed constraints violation. 277 | 278 | Returns: 279 | List of assertion ops in graph mode or immediately asserts in eager mode. 280 | """ 281 | return categorical_calibration_lib.assert_constraints( 282 | weights=self.kernel, 283 | output_min=self.output_min, 284 | output_max=self.output_max, 285 | monotonicities=self.monotonicities, 286 | eps=eps) 287 | 288 | 289 | class CategoricalCalibrationConstraints(keras.constraints.Constraint): 290 | # pyformat: disable 291 | """Monotonicity and bounds constraints for categorical calibration layer. 292 | 293 | Updates the weights of CategoricalCalibration layer to satify bound and 294 | monotonicity constraints. The update is an approximate L2 projection into the 295 | constrained parameter space. 296 | 297 | Attributes: 298 | - All `__init__` arguments. 299 | """ 300 | # pyformat: enable 301 | 302 | def __init__(self, output_min=None, output_max=None, monotonicities=None): 303 | """Initializes an instance of `CategoricalCalibrationConstraints`. 304 | 305 | Args: 306 | output_min: Minimum possible output of categorical function. 307 | output_max: Maximum possible output of categorical function. 308 | monotonicities: Monotonicities of CategoricalCalibration layer. 309 | """ 310 | categorical_calibration_lib.verify_hyperparameters( 311 | output_min=output_min, 312 | output_max=output_max, 313 | monotonicities=monotonicities) 314 | self.monotonicities = monotonicities 315 | self.output_min = output_min 316 | self.output_max = output_max 317 | 318 | def __call__(self, w): 319 | """Applies constraints to w.""" 320 | return categorical_calibration_lib.project( 321 | weights=w, 322 | output_min=self.output_min, 323 | output_max=self.output_max, 324 | monotonicities=self.monotonicities) 325 | 326 | def get_config(self): 327 | """Standard Keras config for serialization.""" 328 | return { 329 | "output_min": self.output_min, 330 | "output_max": self.output_max, 331 | "monotonicities": self.monotonicities, 332 | } # pyformat: disable 333 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/categorical_calibration_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Helpers and computations of categorical calibration layer.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from . import internal_utils 21 | import tensorflow as tf 22 | 23 | 24 | def project(weights, output_min, output_max, monotonicities): 25 | """Monotonicity/bounds constraints implementation for categorical calibration. 26 | 27 | Returns the approximate L2 projection of the CategoricalCalibration weights 28 | into the constrained parameter space. 29 | 30 | Args: 31 | weights: Tensor which represents weights of Categorical calibration layer. 32 | output_min: Lower bound constraint on weights. 33 | output_max: Upper bound constraint on weights. 34 | monotonicities: List of pair of indices `(i, j)`, indicating constraint 35 | `weight[i] <= weight[j]`. 36 | 37 | Returns: 38 | Projected `weights` tensor. 39 | 40 | Raises: 41 | ValueError: If monotonicities are not of the correct format or are circular. 42 | """ 43 | num_buckets = weights.shape[0] 44 | verify_hyperparameters( 45 | num_buckets=num_buckets, 46 | output_min=output_min, 47 | output_max=output_max, 48 | monotonicities=monotonicities) 49 | 50 | projected_weights = weights 51 | 52 | if monotonicities: 53 | projected_weights = ( 54 | internal_utils.approximately_project_categorical_partial_monotonicities( 55 | projected_weights, monotonicities)) 56 | 57 | if output_min is not None: 58 | projected_weights = tf.maximum(projected_weights, output_min) 59 | if output_max is not None: 60 | projected_weights = tf.minimum(projected_weights, output_max) 61 | return projected_weights 62 | 63 | 64 | def assert_constraints(weights, 65 | output_min, 66 | output_max, 67 | monotonicities, 68 | debug_tensors=None, 69 | eps=1e-6): 70 | """Asserts that `weights` satisfiy constraints. 71 | 72 | Args: 73 | weights: Tensor which represents weights of Categorical calibration layer. 74 | output_min: Lower bound constraint on weights. 75 | output_max: Upper bound constraint on weights. 76 | monotonicities: List of pair of indices `(i, j)`, indicating constraint 77 | `weight[i] <= weight[j]`. 78 | debug_tensors: None or list of anything convertible to tensor (for example 79 | tensors or strings) which will be printed in case of constraints 80 | violation. 81 | eps: Allowed constraints violation. 82 | 83 | Returns: 84 | List of assertion ops in graph mode or immideately asserts in eager mode. 85 | """ 86 | num_buckets = weights.shape[0] 87 | verify_hyperparameters( 88 | num_buckets=num_buckets, 89 | output_min=output_min, 90 | output_max=output_max, 91 | monotonicities=monotonicities) 92 | 93 | info = ["Outputs: ", weights, "Epsilon: ", eps] 94 | if debug_tensors: 95 | info += debug_tensors 96 | asserts = [] 97 | 98 | if output_min is not None: 99 | min_output = tf.reduce_min(weights) 100 | asserts.append( 101 | tf.Assert( 102 | min_output >= output_min - eps, 103 | data=["Lower bound violation.", "output_min:", output_min] + info, 104 | summarize=num_buckets)) 105 | 106 | if output_max is not None: 107 | max_output = tf.reduce_max(weights) 108 | asserts.append( 109 | tf.Assert( 110 | max_output <= output_max + eps, 111 | data=["Upper bound violation.", "output_max:", output_max] + info, 112 | summarize=num_buckets)) 113 | 114 | if monotonicities: 115 | left = tf.gather_nd(weights, [[i] for (i, j) in monotonicities]) 116 | right = tf.gather_nd(weights, [[j] for (i, j) in monotonicities]) 117 | asserts.append( 118 | tf.Assert( 119 | tf.reduce_min(left - right) < eps, 120 | data=["Monotonicity violation.", "monotonicities:", monotonicities] 121 | + info, 122 | summarize=num_buckets)) 123 | 124 | return asserts 125 | 126 | 127 | def verify_hyperparameters(num_buckets=None, 128 | output_min=None, 129 | output_max=None, 130 | monotonicities=None): 131 | """Verifies that all given hyperparameters are consistent. 132 | 133 | See `tfl.layers.CategoricalCalibration` class level comment for detailes. 134 | 135 | Args: 136 | num_buckets: `num_buckets` of CategoricalCalibration layer. 137 | output_min: `smallest output` of CategoricalCalibration layer. 138 | output_max: `largest output` of CategoricalCalibration layer. 139 | monotonicities: `monotonicities` of CategoricalCalibration layer. 140 | 141 | Raises: 142 | ValueError: If parameters are incorrect or inconsistent. 143 | """ 144 | if output_min is not None and output_max is not None: 145 | if output_max < output_min: 146 | raise ValueError( 147 | "If specified output_max must be greater than output_min. " 148 | "They are: ({}, {})".format(output_min, output_max)) 149 | 150 | if monotonicities: 151 | if (not isinstance(monotonicities, list) or not all( 152 | isinstance(m, (list, tuple)) and len(m) == 2 for m in monotonicities)): 153 | raise ValueError( 154 | "Monotonicities should be a list of pairs (list/tuples).") 155 | for (i, j) in monotonicities: 156 | if (i < 0 or j < 0 or (num_buckets is not None and 157 | (i >= num_buckets or j >= num_buckets))): 158 | raise ValueError( 159 | "Monotonicities should be pairs of be indices in range " 160 | "[0, num_buckets). They are: {}".format(monotonicities)) 161 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/categorical_calibration_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for categorical calibration layer. 15 | 16 | This test should be run with "-c opt" since otherwise it's slow. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from absl import logging 24 | from absl.testing import parameterized 25 | import numpy as np 26 | import tensorflow as tf 27 | from tensorflow_lattice.python import categorical_calibration_layer as categorical_calibraion 28 | from tensorflow_lattice.python import parallel_combination_layer as parallel_combination 29 | from tensorflow_lattice.python import test_utils 30 | # pylint: disable=g-import-not-at-top 31 | # Use Keras 2. 32 | version_fn = getattr(tf.keras, "version", None) 33 | if version_fn and version_fn().startswith("3."): 34 | import tf_keras as keras 35 | else: 36 | keras = tf.keras 37 | 38 | 39 | class CategoricalCalibrationLayerTest(parameterized.TestCase, tf.test.TestCase): 40 | 41 | def setUp(self): 42 | super(CategoricalCalibrationLayerTest, self).setUp() 43 | self._disable_all = False 44 | self._loss_eps = 1e-2 45 | self._loss_diff_eps = 1e-4 46 | keras.utils.set_random_seed(42) 47 | 48 | def _ResetAllBackends(self): 49 | keras.backend.clear_session() 50 | tf.compat.v1.reset_default_graph() 51 | 52 | def _ScatterXUniformly(self, units, num_points, num_buckets, 53 | missing_probability, default_input_value): 54 | """Randomly uniformly scatters points across input space.""" 55 | data = [] 56 | for unit_idx in range(units): 57 | if missing_probability > 0.0: 58 | missing_points = int(num_points * missing_probability) 59 | else: 60 | missing_points = 0 61 | 62 | x = ([default_input_value for _ in range(missing_points)] + 63 | [i % num_buckets for i in range(num_points - missing_points)]) 64 | np.random.seed(unit_idx) 65 | np.random.shuffle(x) 66 | if data: 67 | data = [values + (value,) for values, value in zip(data, x)] 68 | else: 69 | data = [(value,) for value in x] 70 | 71 | return [np.asarray(v, dtype=np.int32) for v in data] 72 | 73 | def _SetDefaults(self, config): 74 | config.setdefault("units", 1) 75 | config.setdefault("use_multi_calibration_layer", False) 76 | config.setdefault("one_d_input", False) 77 | config.setdefault("output_min", None) 78 | config.setdefault("output_max", None) 79 | config.setdefault("default_input_value", None) 80 | config.setdefault("monotonicities", None) 81 | config.setdefault("missing_probability", 0.0) 82 | config.setdefault("constraint_assertion_eps", 1e-6) 83 | config.setdefault("kernel_regularizer", None) 84 | config.setdefault("model_dir", "/tmp/test_pwl_model_dir/") 85 | return config 86 | 87 | def _TrainModel(self, config): 88 | """Trains model and returns loss. 89 | 90 | Args: 91 | config: Layer config internal for this test which specifies params of 92 | piecewise linear layer to train. 93 | 94 | Returns: 95 | Training loss. 96 | """ 97 | logging.info("Testing config:") 98 | logging.info(config) 99 | config = self._SetDefaults(config) 100 | 101 | self._ResetAllBackends() 102 | 103 | if config["default_input_value"] is not None: 104 | # default_input_value is mapped to the last bucket, hence x_generator 105 | # needs to generate in [0, ..., num_buckets-2] range. 106 | num_random_buckets = config["num_buckets"] - 1 107 | else: 108 | num_random_buckets = config["num_buckets"] 109 | 110 | # The input to the model can either be single or multi dimensional. 111 | input_units = 1 if config["one_d_input"] else config["units"] 112 | 113 | training_inputs = config["x_generator"]( 114 | units=input_units, 115 | num_points=config["num_training_records"], 116 | num_buckets=num_random_buckets, 117 | missing_probability=config["missing_probability"], 118 | default_input_value=config["default_input_value"]) 119 | training_labels = [config["y_function"](x) for x in training_inputs] 120 | 121 | # Either create multiple CategoricalCalibration layers and combine using a 122 | # ParallelCombination layer, or create a single CategoricalCalibration with 123 | # multiple output dimensions. 124 | if config["use_multi_calibration_layer"]: 125 | num_calibration_layers = config["units"] 126 | categorical_calibraion_units = 1 127 | else: 128 | num_calibration_layers = 1 129 | categorical_calibraion_units = config["units"] 130 | 131 | model = keras.models.Sequential() 132 | model.add(keras.layers.Input(shape=[input_units], dtype=tf.int32)) 133 | calibration_layers = [] 134 | for _ in range(num_calibration_layers): 135 | calibration_layers.append( 136 | categorical_calibraion.CategoricalCalibration( 137 | units=categorical_calibraion_units, 138 | kernel_initializer="constant", 139 | num_buckets=config["num_buckets"], 140 | output_min=config["output_min"], 141 | output_max=config["output_max"], 142 | monotonicities=config["monotonicities"], 143 | kernel_regularizer=config["kernel_regularizer"], 144 | default_input_value=config["default_input_value"])) 145 | if len(calibration_layers) == 1: 146 | model.add(calibration_layers[0]) 147 | else: 148 | model.add(parallel_combination.ParallelCombination(calibration_layers)) 149 | if config["units"] > 1: 150 | model.add( 151 | keras.layers.Lambda( 152 | lambda x: tf.reduce_mean(x, axis=1, keepdims=True))) 153 | model.compile( 154 | loss=keras.losses.mean_squared_error, 155 | optimizer=config["optimizer"](learning_rate=config["learning_rate"])) 156 | 157 | training_data = (training_inputs, training_labels) 158 | 159 | loss = test_utils.run_training_loop( 160 | config=config, 161 | training_data=training_data, 162 | keras_model=model, 163 | input_dtype=np.int32) 164 | 165 | assetion_ops = [] 166 | for calibration_layer in calibration_layers: 167 | assetion_ops.extend( 168 | calibration_layer.assert_constraints( 169 | eps=config["constraint_assertion_eps"])) 170 | if not tf.executing_eagerly() and assetion_ops: 171 | tf.compat.v1.keras.backend.get_session().run(assetion_ops) 172 | 173 | return loss 174 | 175 | @parameterized.parameters((np.mean,), (lambda x: -np.mean(x),)) 176 | def testUnconstrainedNoMissingValue(self, y_function): 177 | if self._disable_all: 178 | return 179 | config = { 180 | "num_training_records": 200, 181 | "num_training_epoch": 500, 182 | "optimizer": keras.optimizers.Adam, 183 | "learning_rate": 0.15, 184 | "x_generator": self._ScatterXUniformly, 185 | "y_function": y_function, 186 | "num_buckets": 10, 187 | "output_min": None, 188 | "output_max": None, 189 | "monotonicities": None, 190 | } 191 | loss = self._TrainModel(config) 192 | self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps) 193 | config["units"] = 3 194 | loss = self._TrainModel(config) 195 | self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps) 196 | config["one_d_input"] = True 197 | loss = self._TrainModel(config) 198 | self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps) 199 | 200 | @parameterized.parameters((np.mean,), (lambda x: -np.mean(x),)) 201 | def testUnconstrainedWithMissingValue(self, y_function): 202 | if self._disable_all: 203 | return 204 | config = { 205 | "num_training_records": 200, 206 | "num_training_epoch": 500, 207 | "optimizer": keras.optimizers.Adam, 208 | "learning_rate": 0.15, 209 | "x_generator": self._ScatterXUniformly, 210 | "y_function": y_function, 211 | "num_buckets": 10, 212 | "output_min": None, 213 | "output_max": None, 214 | "monotonicities": None, 215 | "default_input_value": -1, 216 | "missing_probability": 0.1, 217 | } 218 | loss = self._TrainModel(config) 219 | self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps) 220 | config["units"] = 3 221 | loss = self._TrainModel(config) 222 | self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps) 223 | config["one_d_input"] = True 224 | loss = self._TrainModel(config) 225 | self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps) 226 | 227 | @parameterized.parameters( 228 | (0.0, 9.0, None, 0.0), 229 | (1.0, 8.0, None, 0.2), 230 | (1.0, 8.0, [(6, 5)], 0.25), 231 | (1.0, 8.0, [(6, 5), (5, 4)], 0.4), 232 | (1.0, 8.0, [(6, 5), (7, 5)], 0.4), 233 | (1.0, 8.0, [(6, 5), (5, 4), (4, 3)], 0.7), 234 | (1.0, 8.0, [(7, 6), (6, 5), (4, 3), (3, 2)], 0.6), 235 | (1.0, 8.0, [(7, 6), (6, 5), (5, 4), (4, 3), (3, 2)], 1.95), 236 | ) 237 | def testConstraints(self, output_min, output_max, monotonicities, 238 | expected_loss): 239 | if self._disable_all: 240 | return 241 | config = { 242 | "num_training_records": 1000, 243 | "num_training_epoch": 1000, 244 | "optimizer": keras.optimizers.Adam, 245 | "learning_rate": 1.0, 246 | "x_generator": self._ScatterXUniformly, 247 | "y_function": np.mean, 248 | "num_buckets": 10, 249 | "output_min": output_min, 250 | "output_max": output_max, 251 | "monotonicities": monotonicities, 252 | } 253 | 254 | loss = self._TrainModel(config) 255 | self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps) 256 | 257 | # Same input with multiple calibration units, should give out the same loss. 258 | config["one_d_input"] = True 259 | loss = self._TrainModel(config) 260 | self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps) 261 | 262 | # With independently sampled unit-dim inputs loss is caled by 1/units. 263 | config["one_d_input"] = False 264 | loss = self._TrainModel(config) 265 | self.assertAlmostEqual( 266 | loss, 267 | expected_loss / config["units"], 268 | delta=self._loss_eps * config["units"]) 269 | 270 | # Using separate calibration layers should give out the same loss. 271 | config["use_multi_calibration_layer"] = True 272 | loss_multi_calib = self._TrainModel(config) 273 | self.assertAlmostEqual(loss, loss_multi_calib, delta=self._loss_diff_eps) 274 | 275 | def testCircularMonotonicites(self): 276 | if self._disable_all: 277 | return 278 | config = { 279 | "num_training_records": 200, 280 | "num_training_epoch": 500, 281 | "optimizer": keras.optimizers.Adam, 282 | "learning_rate": 0.15, 283 | "x_generator": self._ScatterXUniformly, 284 | "y_function": float, 285 | "num_buckets": 5, 286 | "monotonicities": [(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)], 287 | } 288 | 289 | with self.assertRaises(ValueError): 290 | self._TrainModel(config) 291 | 292 | @parameterized.parameters( 293 | # Standard Keras regularizer: 294 | ( 295 | keras.regularizers.l1_l2(l1=0.01, l2=0.001),), 296 | # Tuple of regularizers: 297 | ( 298 | (keras.regularizers.l1_l2( 299 | l1=0.01, l2=0.0), keras.regularizers.l1_l2(l1=0.0, l2=0.001)),), 300 | ) 301 | def testRegularizers(self, regularizer): 302 | if self._disable_all: 303 | return 304 | config = { 305 | "num_training_records": 20, 306 | "num_training_epoch": 0, 307 | "optimizer": keras.optimizers.Adam, 308 | "learning_rate": 1.0, 309 | "x_generator": self._ScatterXUniformly, 310 | "y_function": lambda _: 2.0, 311 | "kernel_regularizer": regularizer, 312 | "num_buckets": 3, 313 | "output_min": 0.0, 314 | "output_max": 4.0, 315 | } 316 | loss = self._TrainModel(config) 317 | # This loss is pure regularization loss because initializer matches target 318 | # function and there was 0 training epochs. 319 | self.assertAlmostEqual(loss, 0.072, delta=self._loss_eps) 320 | 321 | def testOutputShape(self): 322 | if self._disable_all: 323 | return 324 | 325 | # Not Splitting 326 | units = 10 327 | input_shape, output_shape = (units,), (None, units) 328 | input_a = keras.layers.Input(shape=input_shape) 329 | cat_cal_0 = categorical_calibraion.CategoricalCalibration( 330 | num_buckets=3, units=units) 331 | output = cat_cal_0(input_a) 332 | self.assertAllEqual(output_shape, 333 | cat_cal_0.compute_output_shape(input_a.shape)) 334 | self.assertAllEqual(output_shape, output.shape) 335 | 336 | # Splitting 337 | output_shape = [(None, 1)] * units 338 | cat_cal_1 = categorical_calibraion.CategoricalCalibration( 339 | num_buckets=3, units=units, split_outputs=True) 340 | output = cat_cal_1(input_a) 341 | self.assertAllEqual(output_shape, 342 | cat_cal_1.compute_output_shape(input_a.shape)) 343 | self.assertAllEqual(output_shape, [o.shape for o in output]) 344 | 345 | 346 | if __name__ == "__main__": 347 | tf.test.main() 348 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/cdf_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Projection free Cumulative Distribution Function layer. 15 | 16 | Keras implementation of TensorFlow Lattice CDF layer. Layer takes single or 17 | multi-dimensional input and transforms it using a set of step functions. The 18 | layer is naturally monotonic and bounded to the range [0, 1]. 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow as tf 25 | # pylint: disable=g-import-not-at-top 26 | # Use Keras 2. 27 | version_fn = getattr(tf.keras, "version", None) 28 | if version_fn and version_fn().startswith("3."): 29 | import tf_keras as keras 30 | else: 31 | keras = tf.keras 32 | from . import utils 33 | 34 | 35 | class CDF(keras.layers.Layer): 36 | # pyformat: disable 37 | """Cumulative Distribution Function (CDF) layer. 38 | 39 | Layer takes input of shape `(batch_size, input_dim)` or `(batch_size, 1)` and 40 | transforms it using `input_dim` number of cumulative distribution functions, 41 | which are naturally monotonic and bounded to the range [0, 1]. If multi 42 | dimensional input is provided, each output will be for the corresponding 43 | input, otherwise all CDF functions will act on the same input. All units share 44 | the same layer configuration, but each has their separate set of trained 45 | parameters. The smoothness of the cumulative distribution functions depends on 46 | the number of keypoints (i.e. step functions), the activation, and input 47 | scaling. 48 | 49 | Input shape: 50 | Single input should be a rank-2 tensor with shape: `(batch_size, input_dim)` 51 | or `(batch_size, 1)`. 52 | 53 | Output shape: 54 | Rank-2 tensor with shape `(batch, input_dim / factor, units)` if 55 | `reduction=='none'`. Otherwise a rank-2 tensor with shape 56 | `(batch_size, units)`. 57 | 58 | Attributes: 59 | - All `__init__` arguments. 60 | kernel: TF variable which stores weights of each cdf function. 61 | input_scaling: A constant if `input_scaling_type` is `'fixed'`, and a TF 62 | variable if set to `'learned'`. 63 | 64 | Example: 65 | 66 | ```python 67 | cdf = tfl.layers.CDF( 68 | num_keypoints=10, 69 | units=10, 70 | # You can specify the type of activation to use for the step functions. 71 | activation="sigmoid", 72 | # You can specifyc the type of reduction to use across the input dimension. 73 | reduction="mean", 74 | # The input scaling type determines whether or not to use a fixed value or 75 | # to learn the value during training. 76 | input_scaling_type="fixed", 77 | # You can make the layer less connected by increasing the pruning factor, 78 | # which must be a divisor of both the input dimension and units. 79 | sparsity_factor=1, 80 | ) 81 | ``` 82 | """ 83 | 84 | def __init__(self, 85 | num_keypoints, 86 | units=1, 87 | activation="relu6", 88 | reduction="mean", 89 | input_scaling_init=None, 90 | input_scaling_type="fixed", 91 | input_scaling_monotonicity="increasing", 92 | sparsity_factor=1, 93 | kernel_initializer="random_uniform", 94 | **kwargs): 95 | # pyformat: disable 96 | """Initializes an instance of `Lattice`. 97 | 98 | Args: 99 | num_keypoints: The number of keypoints (i.e. step functions) to use for 100 | each of `units` CDF functions. 101 | units: The output dimension of the layer. 102 | activation: The activation function to use for the step functions. One of: 103 | - `'relu6'`: The `tf.nn.relu6` function. 104 | - `'sigmoid'`: The `tf.nn.sigmoid` function. 105 | reduction: The reduction used for each of the `units` CDF functions to 106 | combine the CDF function output for each input dimension. One of: 107 | - `'mean'`: The `tf.reduce_mean` function. 108 | - `'geometric_mean'`: The n'th root of the product of each of the n 109 | input dimensions. 110 | - `'none'`: No input reduction. 111 | input_scaling_init: The value used to initialize the input scaling. 112 | Defaults to `num_keypoints` if set to `None`. 113 | input_scaling_type: The type of input scaling to use. One of: 114 | - `'fixed'`: input scaling will be a constant with value 115 | `input_scaling_init`. This will be the value used for all input 116 | dimensions. 117 | - `'learned_shared'`: input scaling will be a weight learned during 118 | training initialized with value `input_scaling_init`. This will be the 119 | value used for all input dimensions. 120 | - `'learned_per_input'`: input scaling will be a weight learned during 121 | training initialized with value `input_scaling_init`. A separate value 122 | will be learned for each input dimension. 123 | input_scaling_monotonicity: One of: 124 | - `'increasing'` or `1`: input scaling will be constrained to be 125 | non-negative such that the output of the layer is monotonic in each 126 | dimension. 127 | - `'none'` or `0`: input scaling will not be constrained and the output 128 | of the layer will no be guaranteed to be monotonic. 129 | sparsity_factor: The factor by which to prune the connectivity of the 130 | layer. If set to `1` there will be no pruning and the layer will be 131 | fully connected. If set to `>1` the layer will be partially connected 132 | where the number of connections will be reduced by this factor. Must be 133 | a divisor of both the `input_dim` and `units`. 134 | kernel_initializer: None or one of: 135 | - `'random_uniform'`: initializes parameters as uniform 136 | random functions in the range [0, 1]. 137 | - Any Keras initializer object. 138 | **kwargs: Any additional `keras.layers.Layer` arguments. 139 | """ 140 | # pyformat: enable 141 | super(CDF, self).__init__(**kwargs) 142 | self.num_keypoints = num_keypoints 143 | self.units = units 144 | self.activation = activation 145 | self.reduction = reduction 146 | if input_scaling_init is None: 147 | self.input_scaling_init = float(num_keypoints) 148 | else: 149 | self.input_scaling_init = float(input_scaling_init) 150 | self.input_scaling_type = input_scaling_type 151 | self.input_scaling_monotonicity = utils.canonicalize_monotonicity( 152 | input_scaling_monotonicity) 153 | self.sparsity_factor = sparsity_factor 154 | 155 | self.kernel_initializer = create_kernel_initializer( 156 | kernel_initializer_id=kernel_initializer) 157 | 158 | def build(self, input_shape): 159 | """Standard Keras build() method.""" 160 | input_dim = int(input_shape[-1]) 161 | if input_dim % self.sparsity_factor != 0: 162 | raise ValueError( 163 | "sparsity_factor ({}) must be a divisor of input_dim ({})".format( 164 | self.sparsity_factor, input_dim)) 165 | if self.units % self.sparsity_factor != 0: 166 | raise ValueError( 167 | "sparsity_factor ({}) must be a divisor of units ({})".format( 168 | self.sparsity_factor, self.units)) 169 | 170 | # Each keypoint represents a step function defined by the activation 171 | # function specified. For an activation like relu6, this represents the 172 | # the hinge point. 173 | self.kernel = self.add_weight( 174 | "kernel", 175 | initializer=self.kernel_initializer, 176 | shape=[ 177 | 1, input_dim, self.num_keypoints, 178 | int(self.units // self.sparsity_factor) 179 | ]) 180 | 181 | # Input scaling ultimately represents the slope of the step function used. 182 | # If the type is "learned_*" then input scaling will be a variable weight 183 | # that is constrained depending on the monotonicity specified. 184 | if self.input_scaling_type == "fixed": 185 | self.input_scaling = tf.constant(self.input_scaling_init) 186 | elif self.input_scaling_type == "learned_shared": 187 | self.input_scaling = self.add_weight( 188 | "input_scaling", 189 | initializer=keras.initializers.Constant(self.input_scaling_init), 190 | constraint=keras.constraints.NonNeg() 191 | if self.input_scaling_monotonicity else None, 192 | shape=[1]) 193 | elif self.input_scaling_type == "learned_per_input": 194 | self.input_scaling = self.add_weight( 195 | "input_scaling", 196 | initializer=keras.initializers.Constant(self.input_scaling_init), 197 | constraint=keras.constraints.NonNeg() 198 | if self.input_scaling_monotonicity else None, 199 | shape=[1, input_dim, 1, 1]) 200 | else: 201 | raise ValueError("Invalid input_scaling_type: {}".format( 202 | self.input_scaling_type)) 203 | 204 | def call(self, inputs): 205 | """Standard Keras call() method.""" 206 | input_dim = int(inputs.shape[-1]) 207 | # We add new axes to enable broadcasting. 208 | x = inputs[..., tf.newaxis, tf.newaxis] 209 | 210 | # Shape: (batch, input_dim, 1, 1) 211 | # --> (batch, input_dim, num_keypoints, units / factor) 212 | # --> (batch, input_dim, units / factor) 213 | if self.activation == "relu6": 214 | cdfs = tf.reduce_mean( 215 | tf.nn.relu6(self.input_scaling * (x - self.kernel)), axis=2) / 6 216 | elif self.activation == "sigmoid": 217 | cdfs = tf.reduce_mean( 218 | tf.nn.sigmoid(self.input_scaling * (x - self.kernel)), axis=2) 219 | else: 220 | raise ValueError("Invalid activation: {}".format(self.activation)) 221 | 222 | result = cdfs 223 | 224 | if self.sparsity_factor != 1: 225 | # Shape: (batch, input_dim, units / factor) 226 | # --> (batch, input_dim / factor, units) 227 | result = tf.reshape( 228 | result, [-1, int(input_dim // self.sparsity_factor), self.units]) 229 | 230 | # Shape: (batch, input_dim / factor, units) 231 | #. --> (batch, units) 232 | if self.reduction == "mean": 233 | result = tf.reduce_mean(result, axis=1) 234 | elif self.reduction == "geometric_mean": 235 | num_terms = input_dim // self.sparsity_factor 236 | result = tf.math.exp( 237 | tf.reduce_sum(tf.math.log(result + 1e-3), axis=1) / num_terms) 238 | # we use the log form above so that we can add the epsilon term 239 | # tf.pow(tf.reduce_prod(cdfs, axis=1), 1. / num_terms) 240 | elif self.reduction != "none": 241 | raise ValueError("Invalid reduction: {}".format(self.reduction)) 242 | 243 | return result 244 | 245 | def get_config(self): 246 | """Standard Keras get_config() method.""" 247 | config = { 248 | "num_keypoints": 249 | self.num_keypoints, 250 | "units": 251 | self.units, 252 | "activation": 253 | self.activation, 254 | "reduction": 255 | self.reduction, 256 | "input_scaling_init": 257 | self.input_scaling_init, 258 | "input_scaling_type": 259 | self.input_scaling_type, 260 | "input_scaling_monotonicity": 261 | self.input_scaling_monotonicity, 262 | "sparsity_factor": 263 | self.sparsity_factor, 264 | "kernel_initializer": 265 | keras.initializers.serialize( 266 | self.kernel_initializer, use_legacy_format=True), 267 | } 268 | config.update(super(CDF, self).get_config()) 269 | return config 270 | 271 | 272 | def create_kernel_initializer(kernel_initializer_id): 273 | """Returns a kernel Keras initializer object from its id. 274 | 275 | This function is used to convert the 'kernel_initializer' parameter in the 276 | constructor of `tfl.layers.CDF` into the corresponding initializer object. 277 | 278 | Args: 279 | kernel_initializer_id: See the documentation of the 'kernel_initializer' 280 | parameter in the constructor of `tfl.layers.CDF`. 281 | 282 | Returns: 283 | The Keras initializer object for the `tfl.layers.CDF` kernel variable. 284 | """ 285 | if kernel_initializer_id in ["random_uniform", "RandomUniform"]: 286 | return keras.initializers.RandomUniform(0.0, 1.0) 287 | else: 288 | return keras.initializers.get(kernel_initializer_id) 289 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/conditional_cdf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implements CDF transformation with derived parameters (kernels). 15 | 16 | `cdf_fn` is similar to `tfl.layers.CDF`, which is an additive / multiplicative 17 | average of a few shifted and scaled `sigmoid` or `relu6` basis functions, 18 | with the difference that the functions are parametrized by the provided 19 | parameters instead of learnable weights belonging to a `tfl.layers.CDF` layer. 20 | 21 | These parameters can be one of: 22 | 23 | - constants, 24 | - trainable variables, 25 | - outputs from other TF modules. 26 | 27 | For inputs of shape `(batch_size, input_dim)`, two sets of free-form 28 | parameters are used to configure the CDF function: 29 | 30 | - `location_parameters` for where to place the sigmoid / relu6 transformation 31 | basis, 32 | - `scaling_parameters` (optional) for the horizontal scaling before applying 33 | the transformation basis. 34 | """ 35 | 36 | from typing import Optional, Union, Tuple 37 | import tensorflow as tf 38 | 39 | 40 | def _verify_cdf_params( 41 | inputs: tf.Tensor, 42 | location_parameters: tf.Tensor, 43 | scaling_parameters: Optional[tf.Tensor], 44 | units: int, 45 | activation: str, 46 | reduction: str, 47 | sparsity_factor: int, 48 | ) -> None: 49 | """Verifies the arguments of cdf_fn call. 50 | 51 | Args: 52 | inputs: inputs to the CDF function. 53 | location_parameters: parameters for deciding the locations of the 54 | transformations. 55 | scaling_parameters: parameters for deciding the horizontal scaling of the 56 | transformations. 57 | units: output dimension. 58 | activation: either `sigmoid` or `relu6` for selecting the transformation. 59 | reduction: either `mean`, `geometric_mean`, or `none` to specify whether to 60 | perform averaging and which average to perform. 61 | sparsity_factor: deciding the level of sparsity during reduction. 62 | `input_dim` and `units` should both be divisible by `sparsity_factor`. 63 | """ 64 | if activation not in ("sigmoid", "relu6"): 65 | raise ValueError( 66 | f"activation = {activation} is not supported. Use 'sigmoid' or 'relu6'." 67 | ) 68 | if reduction not in ("mean", "geometric_mean", "none"): 69 | raise ValueError( 70 | f"reduction = {reduction} is not supported. Use 'mean'," 71 | " 'geometric_mean' or 'none'." 72 | ) 73 | 74 | if len(inputs.shape) != 2: 75 | raise ValueError( 76 | f"inputs shape {inputs.shape} is not (batch_size, input_dim)." 77 | ) 78 | 79 | input_dim = inputs.shape[1] 80 | if units % sparsity_factor != 0: 81 | raise ValueError( 82 | f"units = {units} is not divisible by sparsity_factor =" 83 | f" {sparsity_factor}." 84 | ) 85 | if input_dim % sparsity_factor != 0: 86 | raise ValueError( 87 | f"input_dim = {input_dim} is not divisible by sparsity_factor =" 88 | f" {sparsity_factor}." 89 | ) 90 | 91 | if ( 92 | len(location_parameters.shape) != 4 93 | or location_parameters.shape[1] != input_dim 94 | or location_parameters.shape[3] != units // sparsity_factor 95 | ): 96 | raise ValueError( 97 | "location_parameters shape" 98 | f" {location_parameters.shape} is not (batch, input_dim, " 99 | f"num_functions, units / sparsity_factor = {units // sparsity_factor})." 100 | ) 101 | 102 | if scaling_parameters is not None: 103 | try: 104 | _ = tf.broadcast_to( 105 | scaling_parameters, 106 | location_parameters.shape, 107 | name="cdf_fn_try_broadcasting", 108 | ) 109 | except Exception as err: 110 | raise ValueError( 111 | "scaling_parameters and location_parameters likely" 112 | " are not broadcastable. Shapes of scaling_parameters:" 113 | f" {scaling_parameters.shape}, location_parameters:" 114 | f" {location_parameters.shape}." 115 | ) from err 116 | 117 | 118 | @tf.function 119 | def cdf_fn( 120 | inputs: tf.Tensor, 121 | location_parameters: tf.Tensor, 122 | scaling_parameters: Optional[tf.Tensor] = None, 123 | units: int = 1, 124 | activation: str = "relu6", 125 | reduction: str = "mean", 126 | sparsity_factor: int = 1, 127 | scaling_exp_transform_multiplier: Optional[float] = None, 128 | return_derived_parameters: bool = False, 129 | ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]]: 130 | r"""Maps `inputs` through a CDF function specified by keypoint parameters. 131 | 132 | `cdf_fn` is similar to `tfl.layers.CDF`, which is an additive / multiplicative 133 | average of a few shifted and scaled `sigmoid` or `relu6` basis functions, 134 | with the difference that the functions are parametrized by the provided 135 | parameters instead of learnable weights belonging to a `tfl.layers.CDF` layer. 136 | 137 | These parameters can be one of: 138 | 139 | - constants, 140 | - trainable variables, 141 | - outputs from other TF modules. 142 | 143 | For inputs of shape `(batch_size, input_dim)`, two sets of free-form 144 | parameters are used to configure the CDF function: 145 | 146 | - `location_parameters` for where to place the sigmoid / relu6 transformation 147 | basis, 148 | - `scaling_parameters` (optional) for the horizontal scaling before applying 149 | the transformation basis. 150 | 151 | The transformation per dimension is `x -> activation(scale * (x - location))`, 152 | where: 153 | 154 | - `scale` (specified via `scaling_parameter`) is the input scaling for each 155 | dimension and needs to be strictly positive for the CDF function to become 156 | monotonic. If needed, you can set `scaling_exp_transform_multiplier` to get 157 | `scale = exp(scaling_parameter * scaling_exp_transform_multiplier)` and 158 | guarantees strict positivity. 159 | - `location` (specified via `location_parameter`) is the input shift. Notice 160 | for `relu6` this is where the transformation starts to be nonzero, whereas for 161 | `sigmoid` this is where the transformation hits 0.5. 162 | - `activation` is either `sigmoid` or `relu6` (for `relu6 / 6`). 163 | 164 | An optional `reduction` operation will compute the additive / multiplicative 165 | average for the input dims after their individual CDF transformation. `mean` 166 | and `geometric_mean` are supported if sepcified. 167 | 168 | `sparsity_factor` decides the level of sparsity during reduction. For 169 | instance, default of `sparsity = 1` calculates the average of *all* input 170 | dims, whereas `sparsity = 2` calculates the average of *every other* input 171 | dim, and so on. 172 | 173 | Input shape: 174 | We denote `num_functions` as the number of `sigmoid` or `relu6 / 6` basis 175 | functions used for each CDF transformation. 176 | 177 | `inputs` should be: 178 | 179 | - `(batch_size, input_dim)`. 180 | 181 | `location_parameters` should be: 182 | 183 | - `(batch_size, input_dim, num_functions, units // sparsity_factor)`. 184 | 185 | `scaling_parameters` when provided should be broadcast friendly 186 | with `location_parameters`, e.g. one of 187 | 188 | - `(batch_size, input_dim, 1, 1)`, 189 | - `(batch_size, input_dim, num_functions, 1)`, 190 | - `(batch_size, input_dim, 1, units // sparsity_factor)`, 191 | - `(batch_size, input_dim, num_functions, units // sparsity_factor)`. 192 | 193 | Args: 194 | inputs: inputs to the CDF function. 195 | location_parameters: parameters for deciding the locations of the 196 | transformations. 197 | scaling_parameters: parameters for deciding the horizontal scaling of the 198 | transformations. 199 | units: output dimension. 200 | activation: either `sigmoid` or `relu6` for selecting the transformation. 201 | reduction: either `mean`, `geometric_mean`, or `none` to specify whether to 202 | perform averaging and which average to perform. 203 | sparsity_factor: deciding the level of sparsity during reduction. 204 | `input_dim` and `units` should both be divisible by `sparsity_factor`. 205 | scaling_exp_transform_multiplier: if provided, will be used inside an 206 | exponential transformation for `scaling_parameters`. This can be useful if 207 | `scaling_parameters` is free-form. 208 | return_derived_parameters: Whether `location_parameters` and 209 | `scaling_parameters` should be output along with the model output (e.g. 210 | for loss function computation purpoeses). 211 | 212 | Returns: 213 | If `return_derived_parameters = False`: 214 | 215 | - The CDF transformed outputs as a tensor with shape either 216 | `(batch_size, units)` if `reduction = 'mean' / 'geometric_mean'`, or 217 | `(batch_size, input_dim // sparsity_factor, units)` if 218 | `reduction = 'none'`. 219 | 220 | If `return_derived_parameters = True`: 221 | 222 | - A tuple of three elements: 223 | 224 | 1. The CDF transformed outputs. 225 | 2. `location_parameters`. 226 | 3. `scaling_parameters`, with `exp` transformation applied if specified. 227 | """ 228 | 229 | _verify_cdf_params( 230 | inputs, 231 | location_parameters, 232 | scaling_parameters, 233 | units, 234 | activation, 235 | reduction, 236 | sparsity_factor, 237 | ) 238 | input_dim = inputs.shape[1] 239 | x = inputs[..., tf.newaxis, tf.newaxis] - location_parameters 240 | if scaling_parameters is not None: 241 | if scaling_exp_transform_multiplier is not None: 242 | scaling_parameters = tf.math.exp( 243 | scaling_parameters * scaling_exp_transform_multiplier 244 | ) 245 | x *= scaling_parameters 246 | else: 247 | # For use when return_derived_parameters = True. 248 | scaling_parameters = tf.ones_like(location_parameters, dtype=tf.float32) 249 | 250 | # Shape: (batch, input_dim, 1, 1) 251 | # --> (batch, input_dim, num_functions, units / factor) 252 | # --> (batch, input_dim, units / factor). 253 | if activation == "relu6": 254 | result = tf.reduce_mean(tf.nn.relu6(x), axis=2) / 6 255 | else: # activation == "sigmoid": 256 | result = tf.reduce_mean(tf.nn.sigmoid(x), axis=2) 257 | 258 | if sparsity_factor != 1: 259 | # Shape: (batch, input_dim, units / factor) 260 | # --> (batch, input_dim / factor, units). 261 | result = tf.reshape(result, (-1, input_dim // sparsity_factor, units)) 262 | 263 | # Shape: (batch, input_dim / factor, units) --> (batch, units). 264 | if reduction == "mean": 265 | result = tf.reduce_mean(result, axis=1) 266 | elif reduction == "geometric_mean": 267 | # We use the log form so that we can add the epsilon term 268 | # tf.pow(tf.reduce_prod(cdfs, axis=1), 1. / num_terms). 269 | result = tf.math.exp(tf.reduce_mean(tf.math.log(result + 1e-8), axis=1)) 270 | # Otherwise reduction == "none". 271 | 272 | if return_derived_parameters: 273 | return (result, location_parameters, scaling_parameters) 274 | else: 275 | return result 276 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/configs_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for TFL model configuration library.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | from tensorflow_lattice.python import categorical_calibration_layer 22 | from tensorflow_lattice.python import configs 23 | from tensorflow_lattice.python import lattice_layer 24 | from tensorflow_lattice.python import linear_layer 25 | from tensorflow_lattice.python import premade 26 | from tensorflow_lattice.python import pwl_calibration_layer 27 | 28 | tfl_custom_objects = { 29 | 'CalibratedLatticeEnsemble': 30 | premade.CalibratedLatticeEnsemble, 31 | 'CalibratedLattice': 32 | premade.CalibratedLattice, 33 | 'CalibratedLinear': 34 | premade.CalibratedLinear, 35 | 'CategoricalCalibration': 36 | categorical_calibration_layer.CategoricalCalibration, 37 | 'FeatureConfig': 38 | configs.FeatureConfig, 39 | 'RegularizerConfig': 40 | configs.RegularizerConfig, 41 | 'TrustConfig': 42 | configs.TrustConfig, 43 | 'DominanceConfig': 44 | configs.DominanceConfig, 45 | 'CalibratedLatticeEnsembleConfig': 46 | configs.CalibratedLatticeEnsembleConfig, 47 | 'CalibratedLatticeConfig': 48 | configs.CalibratedLatticeConfig, 49 | 'CalibratedLinearConfig': 50 | configs.CalibratedLinearConfig, 51 | 'Lattice': 52 | lattice_layer.Lattice, 53 | 'Linear': 54 | linear_layer.Linear, 55 | 'PWLCalibration': 56 | pwl_calibration_layer.PWLCalibration, 57 | } 58 | 59 | 60 | class ConfigsTest(tf.test.TestCase): 61 | 62 | def test_from_config(self): 63 | feature_configs = [ 64 | configs.FeatureConfig( 65 | name='feature_a', 66 | pwl_calibration_input_keypoints='quantiles', 67 | pwl_calibration_num_keypoints=8, 68 | monotonicity=1, 69 | pwl_calibration_clip_max=100, 70 | ), 71 | configs.FeatureConfig( 72 | name='feature_b', 73 | lattice_size=3, 74 | unimodality='valley', 75 | pwl_calibration_input_keypoints='uniform', 76 | pwl_calibration_num_keypoints=5, 77 | pwl_calibration_clip_min=130, 78 | pwl_calibration_convexity='convex', 79 | regularizer_configs=[ 80 | configs.RegularizerConfig(name='calib_hesian', l2=3e-3), 81 | ], 82 | ), 83 | configs.FeatureConfig( 84 | name='feature_c', 85 | pwl_calibration_input_keypoints=[0.0, 0.5, 1.0], 86 | reflects_trust_in=[ 87 | configs.TrustConfig(feature_name='feature_a'), 88 | configs.TrustConfig(feature_name='feature_b', direction=-1), 89 | ], 90 | dominates=[ 91 | configs.DominanceConfig( 92 | feature_name='feature_d', dominance_type='monotonic'), 93 | ], 94 | ), 95 | configs.FeatureConfig( 96 | name='feature_d', 97 | num_buckets=3, 98 | vocabulary_list=['a', 'b', 'c'], 99 | default_value=-1, 100 | ), 101 | ] 102 | # First we test CalibratedLatticeEnsembleConfig 103 | model_config = configs.CalibratedLatticeEnsembleConfig( 104 | feature_configs=feature_configs, 105 | lattices=[['feature_a', 'feature_b'], ['feature_c', 'feature_d']], 106 | separate_calibrators=True, 107 | regularizer_configs=[ 108 | configs.RegularizerConfig('torsion', l2=1e-4), 109 | ], 110 | output_min=0.0, 111 | output_max=1.0, 112 | output_calibration=True, 113 | output_calibration_num_keypoints=5, 114 | output_initialization=[0.0, 1.0]) 115 | model_config_copy = configs.CalibratedLatticeEnsembleConfig.from_config( 116 | model_config.get_config(), tfl_custom_objects) 117 | self.assertDictEqual(model_config.get_config(), 118 | model_config_copy.get_config()) 119 | # Next we test CalibratedLatticeConfig 120 | model_config = configs.CalibratedLatticeConfig( 121 | feature_configs=feature_configs, 122 | regularizer_configs=[ 123 | configs.RegularizerConfig('torsion', l2=1e-4), 124 | ], 125 | output_min=0.0, 126 | output_max=1.0, 127 | output_calibration=True, 128 | output_calibration_num_keypoints=8, 129 | output_initialization='quantiles') 130 | model_config_copy = configs.CalibratedLatticeConfig.from_config( 131 | model_config.get_config(), tfl_custom_objects) 132 | self.assertDictEqual(model_config.get_config(), 133 | model_config_copy.get_config()) 134 | # Last we test CalibratedLinearConfig 135 | model_config = configs.CalibratedLinearConfig( 136 | feature_configs=feature_configs, 137 | regularizer_configs=[ 138 | configs.RegularizerConfig('calib_hessian', l2=1e-4), 139 | ], 140 | use_bias=True, 141 | output_min=0.0, 142 | output_max=None, 143 | output_calibration=True, 144 | output_initialization='uniform') 145 | model_config_copy = configs.CalibratedLinearConfig.from_config( 146 | model_config.get_config(), tfl_custom_objects) 147 | self.assertDictEqual(model_config.get_config(), 148 | model_config_copy.get_config()) 149 | 150 | def test_updates(self): 151 | model_config = configs.CalibratedLatticeConfig( 152 | output_min=0, 153 | regularizer_configs=[ 154 | configs.RegularizerConfig(name='torsion', l2=2e-3), 155 | ], 156 | feature_configs=[ 157 | configs.FeatureConfig( 158 | name='feature_a', 159 | pwl_calibration_input_keypoints='quantiles', 160 | pwl_calibration_num_keypoints=8, 161 | monotonicity=1, 162 | pwl_calibration_clip_max=100, 163 | ), 164 | configs.FeatureConfig( 165 | name='feature_b', 166 | lattice_size=3, 167 | unimodality='valley', 168 | pwl_calibration_input_keypoints='uniform', 169 | pwl_calibration_num_keypoints=5, 170 | pwl_calibration_clip_min=130, 171 | pwl_calibration_convexity='convex', 172 | regularizer_configs=[ 173 | configs.RegularizerConfig(name='calib_hessian', l2=3e-3), 174 | ], 175 | ), 176 | configs.FeatureConfig( 177 | name='feature_c', 178 | pwl_calibration_input_keypoints=[0.0, 0.5, 1.0], 179 | reflects_trust_in=[ 180 | configs.TrustConfig(feature_name='feature_a'), 181 | configs.TrustConfig(feature_name='feature_b', direction=-1), 182 | ], 183 | ), 184 | configs.FeatureConfig( 185 | name='feature_d', 186 | num_buckets=3, 187 | vocabulary_list=['a', 'b', 'c'], 188 | default_value=-1, 189 | ), 190 | ]) 191 | 192 | updates = [ 193 | # Update values can be passed in as numbers. 194 | ('output_max', 1.0), # update 195 | ('regularizer__torsion__l2', 0.004), # update 196 | ('regularizer__calib_hessian__l1', 0.005), # insert 197 | ('feature__feature_a__lattice_size', 3), # update 198 | ('feature__feature_e__lattice_size', 4), # insert 199 | # Update values can be strings. 200 | ('unrelated_hparams_not_affecting_config', 'unrelated'), 201 | ('feature__feature_a__regularizer__calib_wrinkle__l1', '0.6'), # insert 202 | ('feature__feature_b__regularizer__calib_hessian__l1', '0.7'), # update 203 | ('yet__another__unrelated_config', '4'), 204 | ] 205 | self.assertEqual(configs.apply_updates(model_config, updates), 7) 206 | 207 | model_config.feature_config_by_name('feature_a').monotonicity = 'none' 208 | model_config.feature_config_by_name('feature_f').num_buckets = 4 # insert 209 | 210 | feature_names = [ 211 | feature_config.name for feature_config in model_config.feature_configs 212 | ] 213 | expected_feature_names = [ 214 | 'feature_a', 'feature_b', 'feature_c', 'feature_d', 'feature_e', 215 | 'feature_f' 216 | ] 217 | self.assertCountEqual(feature_names, expected_feature_names) 218 | 219 | global_regularizer_names = [ 220 | regularizer_config.name 221 | for regularizer_config in model_config.regularizer_configs 222 | ] 223 | expected_global_regularizer_names = ['torsion', 'calib_hessian'] 224 | self.assertCountEqual(global_regularizer_names, 225 | expected_global_regularizer_names) 226 | 227 | self.assertEqual(model_config.output_max, 1.0) 228 | self.assertEqual( 229 | model_config.feature_config_by_name('feature_a').lattice_size, 3) 230 | self.assertEqual( 231 | model_config.feature_config_by_name( 232 | 'feature_b').pwl_calibration_convexity, 'convex') 233 | self.assertEqual( 234 | model_config.feature_config_by_name('feature_e').lattice_size, 4) 235 | self.assertEqual( 236 | model_config.regularizer_config_by_name('torsion').l2, 0.004) 237 | self.assertEqual( 238 | model_config.regularizer_config_by_name('calib_hessian').l1, 0.005) 239 | self.assertEqual( 240 | model_config.feature_config_by_name( 241 | 'feature_a').regularizer_config_by_name('calib_wrinkle').l1, 0.6) 242 | self.assertEqual( 243 | model_config.feature_config_by_name( 244 | 'feature_b').regularizer_config_by_name('calib_hessian').l1, 0.7) 245 | 246 | 247 | if __name__ == '__main__': 248 | tf.test.main() 249 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/internal_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Internal helpers shared by multiple modules in TFL. 15 | 16 | Note that this module is not expected to be used by TFL users, and that it is 17 | not exposed in the TFL package. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import collections 25 | import tensorflow as tf 26 | 27 | 28 | def _topological_sort(key_less_than_values): 29 | """Topological sort for monotonicities. 30 | 31 | Args: 32 | key_less_than_values: A defaultdict from index to a list of indices, such 33 | that for j in key_less_than_values[i] we must have output(i) <= output(j). 34 | 35 | Returns: 36 | A topologically sorted list of indices. 37 | 38 | Raises: 39 | ValueError: If monotonicities are circular. 40 | """ 41 | all_values = set() 42 | for values in key_less_than_values.values(): 43 | all_values.update(values) 44 | 45 | q = [k for k in key_less_than_values if k not in all_values] 46 | if not q: 47 | raise ValueError( 48 | "Circular monotonicity constraints: {}".format(key_less_than_values)) 49 | 50 | result = [] 51 | seen = set() 52 | while q: 53 | v = q[-1] 54 | seen.add(v) 55 | expand = [x for x in key_less_than_values[v] if x not in seen] 56 | if not expand: 57 | result = [v] + result 58 | q.pop() 59 | else: 60 | q.append(expand[0]) 61 | 62 | return result 63 | 64 | 65 | def _min_projection(weights, sorted_indices, key_less_than_values, step): 66 | """Returns an approximate partial min projection with the given step_size. 67 | 68 | Args: 69 | weights: A list of tensors of shape `(units,)` to be approximatly projected 70 | based on the monotonicity constraints. 71 | sorted_indices: Topologically sorted list of indices based on the 72 | monotonicity constraints. 73 | key_less_than_values: A defaultdict from index to a list of indices, such 74 | that for `j` in `key_less_than_values[i]` we must have `weight[i] <= 75 | weight[j]`. 76 | step: A value defining if we should apply a full projection (`step == 1`) or 77 | a partial projection (`step < 1`). 78 | 79 | Returns: 80 | Projected list of tensors. 81 | """ 82 | projected_weights = list(weights) # copy 83 | for i in sorted_indices[::-1]: 84 | if key_less_than_values[i]: 85 | min_projection = projected_weights[i] 86 | for j in key_less_than_values[i]: 87 | min_projection = tf.minimum(min_projection, projected_weights[j]) 88 | if step == 1: 89 | projected_weights[i] = min_projection 90 | else: 91 | projected_weights[i] = ( 92 | step * min_projection + (1 - step) * projected_weights[i]) 93 | return projected_weights 94 | 95 | 96 | def _max_projection(weights, sorted_indices, key_greater_than_values, step): 97 | """Returns an approximate partial max projection with the given step_size. 98 | 99 | Args: 100 | weights: A list of tensors of shape `(units,)` to be approximatly projected 101 | based on the monotonicity constraints. 102 | sorted_indices: Topologically sorted list of indices based on the 103 | monotonicity constraints. 104 | key_greater_than_values: A defaultdict from index to a list of indices, 105 | indicating that for index `j` in `key_greater_than_values[i]` we must have 106 | `weight[i] >= weight[j]`. 107 | step: A value defining if we should apply a full projection (`step == 1`) or 108 | a partial projection (`step < 1`). 109 | 110 | Returns: 111 | Projected list of tensors. 112 | """ 113 | projected_weights = list(weights) # copy 114 | for i in sorted_indices: 115 | if key_greater_than_values[i]: 116 | max_projection = projected_weights[i] 117 | for j in key_greater_than_values[i]: 118 | max_projection = tf.maximum(max_projection, projected_weights[j]) 119 | if step == 1: 120 | projected_weights[i] = max_projection 121 | else: 122 | projected_weights[i] = ( 123 | step * max_projection + (1 - step) * projected_weights[i]) 124 | return projected_weights 125 | 126 | 127 | def approximately_project_categorical_partial_monotonicities( 128 | weights, monotonicities): 129 | """Returns an approximation L2 projection for categorical monotonicities. 130 | 131 | Categorical monotonocities are monotonicity constraints applied to the real 132 | values that are mapped from categorical inputs. Each monotonicity constraint 133 | is specified by a pair of categorical input indices. The projection is also 134 | used to constrain pairs of coefficients in linear models. 135 | 136 | Args: 137 | weights: Tensor of weights to be approximately projected based on the 138 | monotonicity constraints. 139 | monotonicities: List of pairs of indices `(i, j)`, indicating constraint 140 | `weights[i] <= weights[j]`. 141 | """ 142 | key_less_than_values = collections.defaultdict(list) 143 | key_greater_than_values = collections.defaultdict(list) 144 | for i, j in monotonicities: 145 | key_less_than_values[i].append(j) 146 | key_greater_than_values[j].append(i) 147 | 148 | sorted_indices = _topological_sort(key_less_than_values) 149 | 150 | projected_weights = tf.unstack(weights) 151 | 152 | # A 0.5 min projection followed by a full max projection. 153 | projected_weights_min_max = _min_projection(projected_weights, sorted_indices, 154 | key_less_than_values, 0.5) 155 | projected_weights_min_max = _max_projection(projected_weights_min_max, 156 | sorted_indices, 157 | key_greater_than_values, 1) 158 | projected_weights_min_max = tf.stack(projected_weights_min_max) 159 | 160 | # A 0.5 max projection followed by a full min projection. 161 | projected_weights_max_min = _max_projection(projected_weights, sorted_indices, 162 | key_greater_than_values, 0.5) 163 | projected_weights_max_min = _min_projection(projected_weights_max_min, 164 | sorted_indices, 165 | key_less_than_values, 1) 166 | projected_weights_max_min = tf.stack(projected_weights_max_min) 167 | 168 | # Take the average of the two results to avoid sliding to one direction. 169 | projected_weights = (projected_weights_min_max + 170 | projected_weights_max_min) / 2 171 | return projected_weights 172 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/internal_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Tensorflow Lattice utility functions.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl.testing import parameterized 21 | import numpy as np 22 | import tensorflow as tf 23 | from tensorflow_lattice.python import internal_utils 24 | 25 | 26 | class InternalUtilsTest(parameterized.TestCase, tf.test.TestCase): 27 | 28 | def _ResetAllBackends(self): 29 | tf.compat.v1.reset_default_graph() 30 | 31 | @parameterized.parameters( 32 | ([3., 4.], [(0, 1)], [3., 4.]), ([4., 3.], [(0, 1)], [3.5, 3.5]), 33 | ([1., 0.], [(0, 1)], [0.5, 0.5]), ([-1., 0.], [(1, 0)], [-0.5, -0.5]), 34 | ([4., 3., 2., 1., 0.], [(0, 1), (1, 2), (2, 3), 35 | (3, 4)], [2., 2., 2., 2., 2.])) 36 | def testApproximatelyProjectCategoricalPartialMonotonicities( 37 | self, weights, monotonicities, expected_projected_weights): 38 | self._ResetAllBackends() 39 | weights = tf.Variable(weights) 40 | projected_weights = ( 41 | internal_utils.approximately_project_categorical_partial_monotonicities( 42 | weights, monotonicities)) 43 | self.evaluate(tf.compat.v1.global_variables_initializer()) 44 | self.assertAllClose( 45 | self.evaluate(projected_weights), np.array(expected_projected_weights)) 46 | 47 | 48 | if __name__ == '__main__': 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/model_info.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Classes defining trained TFL model structure and parameter information. 16 | 17 | This package provides representations and tools for analysis of a trained 18 | TF Lattice model, e.g. a canned estimator in saved model format. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import collections 26 | 27 | 28 | class ModelGraph( 29 | collections.namedtuple('ModelGraph', ['nodes', 'output_node'])): 30 | """Model info and parameter as a graph. 31 | 32 | Note that this is not a TF graph, but rather a graph of python object that 33 | describe model structure and parameters. 34 | 35 | Attributes: 36 | nodes: List of all the nodes in the model. 37 | output_node: The output node of the model. 38 | """ 39 | 40 | 41 | class InputFeatureNode( 42 | collections.namedtuple('InputFeatureNode', 43 | ['name', 'is_categorical', 'vocabulary_list'])): 44 | """Input features to the model. 45 | 46 | Attributes: 47 | name: Name of the input feature. 48 | is_categorical: If the feature is categorical. 49 | vocabulary_list: Category values for categorical features or None. 50 | """ 51 | 52 | 53 | class PWLCalibrationNode( 54 | collections.namedtuple('PWLCalibrationNode', [ 55 | 'input_node', 'input_keypoints', 'output_keypoints', 'default_input', 56 | 'default_output' 57 | ])): 58 | """Represetns a PWL calibration layer. 59 | 60 | Attributes: 61 | input_node: Input node for the calibration. 62 | input_keypoints: Input keypoints for PWL calibration. 63 | output_keypoints: Output keypoints for PWL calibration. 64 | default_input: Default/missing input value or None. 65 | default_output: Default/missing output value or None. 66 | """ 67 | 68 | 69 | class CategoricalCalibrationNode( 70 | collections.namedtuple('CategoricalCalibrationNode', 71 | ['input_node', 'output_values', 'default_input'])): 72 | """Represetns a categorical calibration layer. 73 | 74 | Attributes: 75 | input_node: Input node for the calibration. 76 | output_values: Output calibration values. If the calibrated feature has 77 | default/missing values, the last value will be for default/missing. 78 | default_input: Default/missing input value or None. 79 | """ 80 | 81 | 82 | class LinearNode( 83 | collections.namedtuple('LinearNode', 84 | ['input_nodes', 'coefficients', 'bias'])): 85 | """Represents a linear layer. 86 | 87 | Attributes: 88 | input_nodes: List of input nodes to the linear layer. 89 | coefficients: Linear weights. 90 | bias: Bias term for the linear layer. 91 | """ 92 | 93 | 94 | class LatticeNode( 95 | collections.namedtuple('LatticeNode', ['input_nodes', 'weights'])): 96 | """Represetns a lattice layer. 97 | 98 | Attributes: 99 | input_nodes: List of input nodes to the lattice layer. 100 | weights: Lattice parameters. 101 | """ 102 | 103 | 104 | class KroneckerFactoredLatticeNode( 105 | collections.namedtuple('KroneckerFactoredLatticeNode', 106 | ['input_nodes', 'weights', 'scale', 'bias'])): 107 | """Represents a kronecker-factored lattice layer. 108 | 109 | Attributes: 110 | input_nodes: List of input nodes to the kronecker-factored lattice layer. 111 | weights: Kronecker-factored lattice kernel parameters of shape 112 | `(1, lattice_sizes, units * dims, num_terms)`. 113 | scale: Kronecker-factored lattice scale parameters of shape 114 | `(units, num_terms)`. 115 | bias: Kronecker-factored lattice bias parameters of shape `(units)`. 116 | """ 117 | 118 | 119 | class MeanNode(collections.namedtuple('MeanNode', ['input_nodes'])): 120 | """Represents an averaging layer. 121 | 122 | Attributes: 123 | input_nodes: List of input nodes to the average layer. 124 | """ 125 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/parallel_combination_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ParallelCombination layer for combining several parallel calibration layers. 15 | 16 | This layer wraps several calibration layers under single ParallelCombination one 17 | that can be used by `Sequential` Keras model. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow as tf 25 | from tensorflow_lattice.python import categorical_calibration_layer 26 | from tensorflow_lattice.python import lattice_layer 27 | from tensorflow_lattice.python import linear_layer 28 | from tensorflow_lattice.python import pwl_calibration_layer 29 | # pylint: disable=g-import-not-at-top 30 | # Use Keras 2. 31 | version_fn = getattr(tf.keras, "version", None) 32 | if version_fn and version_fn().startswith("3."): 33 | import tf_keras as keras 34 | else: 35 | keras = tf.keras 36 | 37 | 38 | # TODO: Add support for calibrators with units > 1. 39 | class ParallelCombination(keras.layers.Layer): 40 | # pyformat: disable 41 | """Wraps several parallel calibration layers under single one. 42 | 43 | `ParallelCombination` is designed for combning several calibration layers 44 | which output goes into single `Lattice` or `Linear` layer in order to be able 45 | to use calibration layers within `Sequential` model. 46 | 47 | Difference from `keras.layers.Concatenate` is that last one operates on 48 | already built objects and thus cannot be used to group layers for `Sequential` 49 | model. 50 | 51 | Input shape: 52 | `(batch_size, k)` or list of length `k` of shapes: `(batch_size, 1)` where 53 | `k` is a number of associated calibration layers. 54 | 55 | Output shape: 56 | `(batch_size, k)` or list of length `k` of shapes: `(batch_size, 1)` where 57 | `k` is a number of associated calibration layers. Shape of output depends on 58 | `single_output` parameter. 59 | 60 | Attributes: 61 | - All `__init__` arguments. 62 | 63 | Example: 64 | 65 | Example usage with a Sequential model: 66 | 67 | ```python 68 | model = keras.models.Sequential() 69 | combined_calibrators = ParallelCombination() 70 | for i in range(num_dims): 71 | calibration_layer = PWLCalibration(...) 72 | combined_calibrators.append(calibration_layer) 73 | model.add(combined_calibrators) 74 | model.add(Lattice(...)) 75 | ``` 76 | """ 77 | # pyformat: enable 78 | 79 | def __init__(self, calibration_layers=None, single_output=True, **kwargs): 80 | """Initializes an instance of `ParallelCombination`. 81 | 82 | Args: 83 | calibration_layers: List of `PWLCalibration` or `CategoricalCalibration` 84 | objects or any other layers taking and returning tensor of shape 85 | `(batch_size, 1)`. 86 | single_output: if True returns output as single tensor of shape 87 | `(batch_size, k)`. Otherwise returns list of `k` tensors of shape 88 | `(batch_size, 1)`. 89 | **kwargs: other args passed to `keras.layers.Layer` initializer. 90 | """ 91 | super(ParallelCombination, self).__init__(**kwargs) 92 | self.calibration_layers = [] 93 | for calibration_layer in calibration_layers or []: 94 | if not isinstance(calibration_layer, dict): 95 | self.calibration_layers.append(calibration_layer) 96 | else: 97 | # Keras deserialization logic must have explicit acceess to all custom 98 | # classes. This is standard way to provide such access. 99 | with keras.utils.custom_object_scope({ 100 | "Lattice": 101 | lattice_layer.Lattice, 102 | "Linear": 103 | linear_layer.Linear, 104 | "PWLCalibration": 105 | pwl_calibration_layer.PWLCalibration, 106 | "CategoricalCalibration": 107 | categorical_calibration_layer.CategoricalCalibration, 108 | }): 109 | self.calibration_layers.append( 110 | keras.layers.deserialize( 111 | calibration_layer, use_legacy_format=True 112 | ) 113 | ) 114 | self.single_output = single_output 115 | 116 | def append(self, calibration_layer): 117 | """Appends new calibration layer to the end.""" 118 | self.calibration_layers.append(calibration_layer) 119 | 120 | def build(self, input_shape): 121 | """Standard Keras build() method.""" 122 | if isinstance(input_shape, list): 123 | if len(input_shape) != len(self.calibration_layers): 124 | raise ValueError("Number of ParallelCombination input tensors does not " 125 | "match number of calibration layers. input_shape: %s, " 126 | "layers: %s" % (input_shape, self.calibration_layers)) 127 | else: 128 | if input_shape[1] != len(self.calibration_layers): 129 | raise ValueError("Second dimension of ParallelCombination input tensor " 130 | "does not match number of calibration layers. " 131 | "input_shape: %s, layers: %s" % 132 | (input_shape, self.calibration_layers)) 133 | super(ParallelCombination, self).build(input_shape) 134 | 135 | def call(self, inputs): 136 | """Standard Keras call() method.""" 137 | if not isinstance(inputs, list): 138 | if len(inputs.shape) != 2: 139 | raise ValueError("'inputs' is expected to have rank-2. " 140 | "Given: %s" % inputs) 141 | inputs = tf.split(inputs, axis=1, num_or_size_splits=inputs.shape[1]) 142 | if len(inputs) != len(self.calibration_layers): 143 | raise ValueError("Number of ParallelCombination input tensors does not " 144 | "match number of calibration layers. inputs: %s, " 145 | "layers: %s" % (inputs, self.calibration_layers)) 146 | outputs = [ 147 | layer(one_d_input) 148 | for layer, one_d_input in zip(self.calibration_layers, inputs) 149 | ] 150 | if self.single_output: 151 | return tf.concat(outputs, axis=1) 152 | else: 153 | return outputs 154 | 155 | def compute_output_shape(self, input_shape): 156 | if self.single_output: 157 | return tf.TensorShape([None, len(self.calibration_layers)]) 158 | else: 159 | return [tf.TensorShape([None, 1])] * len(self.calibration_layers) 160 | 161 | def get_config(self): 162 | """Standard Keras config for serialization.""" 163 | config = { 164 | "calibration_layers": [ 165 | keras.layers.serialize(layer, use_legacy_format=True) 166 | for layer in self.calibration_layers 167 | ], 168 | "single_output": self.single_output, 169 | } # pyformat: disable 170 | config.update(super(ParallelCombination, self).get_config()) 171 | return config 172 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/parallel_combination_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Lattice Layer.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tempfile 20 | from absl.testing import parameterized 21 | import numpy as np 22 | import tensorflow as tf 23 | from tensorflow_lattice.python import lattice_layer as ll 24 | from tensorflow_lattice.python import parallel_combination_layer as pcl 25 | # pylint: disable=g-import-not-at-top 26 | # Use Keras 2. 27 | version_fn = getattr(tf.keras, "version", None) 28 | if version_fn and version_fn().startswith("3."): 29 | import tf_keras as keras 30 | else: 31 | keras = tf.keras 32 | 33 | 34 | class ParallelCombinationTest(parameterized.TestCase, tf.test.TestCase): 35 | 36 | def setUp(self): 37 | super(ParallelCombinationTest, self).setUp() 38 | self.disable_all = False 39 | keras.utils.set_random_seed(42) 40 | 41 | def testParallelCombinationSingleInput(self): 42 | if self.disable_all: 43 | return 44 | all_calibrators = pcl.ParallelCombination() 45 | for i in range(3): 46 | # Its not typical to use 1-d Lattice layer for calibration, but lets do it 47 | # to avoid redundant dependency on PWLCalibration layer. 48 | calibrator = ll.Lattice( 49 | lattice_sizes=[2], output_min=0.0, output_max=i + 1.0) 50 | all_calibrators.append(calibrator) 51 | 52 | # Given output range specified below linear initializer will have lattice to 53 | # simply sum up inputs. 54 | simple_sum = ll.Lattice( 55 | lattice_sizes=[5] * 3, 56 | kernel_initializer="linear_initializer", 57 | output_min=0.0, 58 | output_max=12.0, 59 | name="SummingLattice") 60 | model = keras.models.Sequential() 61 | model.add(all_calibrators) 62 | model.add(simple_sum) 63 | 64 | test_inputs = np.asarray([ 65 | [0.0, 0.0, 0.0], 66 | [0.1, 0.2, 0.3], 67 | [1.0, 1.0, 1.0], 68 | ]) 69 | predictions = model.predict(test_inputs) 70 | print("predictions") 71 | print(predictions) 72 | self.assertTrue(np.allclose(predictions, np.asarray([[0.0], [1.4], [6.0]]))) 73 | 74 | def testParallelCombinationMultipleInputs(self): 75 | if self.disable_all: 76 | return 77 | input_layers = [keras.layers.Input(shape=[1]) for _ in range(3)] 78 | all_calibrators = pcl.ParallelCombination(single_output=False) 79 | for i in range(3): 80 | # Its not typical to use 1-d Lattice layer for calibration, but lets do it 81 | # to avoid redundant dependency on PWLCalibration layer. 82 | calibrator = ll.Lattice( 83 | lattice_sizes=[2], output_min=0.0, output_max=i + 1.0) 84 | all_calibrators.append(calibrator) 85 | 86 | # Given output range specified below linear initializer will have lattice to 87 | # simply sum up inputs. 88 | simple_sum = ll.Lattice( 89 | lattice_sizes=[5] * 3, 90 | kernel_initializer="linear_initializer", 91 | output_min=0.0, 92 | output_max=12.0, 93 | name="SummingLattice", 94 | trainable=False) 95 | 96 | output = simple_sum(all_calibrators(input_layers)) 97 | model = keras.models.Model(inputs=input_layers, outputs=output) 98 | 99 | test_inputs = [ 100 | np.asarray([[0.0], [0.1], [1.0]]), 101 | np.asarray([[0.0], [0.2], [1.0]]), 102 | np.asarray([[0.0], [0.3], [1.0]]), 103 | ] 104 | predictions = model.predict(test_inputs) 105 | print("predictions") 106 | print(predictions) 107 | self.assertTrue(np.allclose(predictions, np.asarray([[0.0], [1.4], [6.0]]))) 108 | 109 | def testParallelCombinationClone(self): 110 | if self.disable_all: 111 | return 112 | input_layers = [keras.layers.Input(shape=[1]) for _ in range(3)] 113 | all_calibrators = pcl.ParallelCombination(single_output=False) 114 | for i in range(3): 115 | # Its not typical to use 1-d Lattice layer for calibration, but lets do it 116 | # to avoid redundant dependency on PWLCalibration layer. 117 | calibrator = ll.Lattice( 118 | lattice_sizes=[2], output_min=0.0, output_max=i + 1.0) 119 | all_calibrators.append(calibrator) 120 | 121 | # Given output range specified below linear initializer will have lattice to 122 | # simply sum up inputs. 123 | simple_sum = ll.Lattice( 124 | lattice_sizes=[5] * 3, 125 | kernel_initializer="linear_initializer", 126 | output_min=0.0, 127 | output_max=12.0, 128 | name="SummingLattice", 129 | trainable=False) 130 | 131 | output = simple_sum(all_calibrators(input_layers)) 132 | model = keras.models.Model(inputs=input_layers, outputs=output) 133 | clone = keras.models.clone_model(model) 134 | 135 | test_inputs = [ 136 | np.asarray([[0.0], [0.1], [1.0]]), 137 | np.asarray([[0.0], [0.2], [1.0]]), 138 | np.asarray([[0.0], [0.3], [1.0]]), 139 | ] 140 | predictions = clone.predict(test_inputs) 141 | print("predictions") 142 | print(predictions) 143 | self.assertTrue(np.allclose(predictions, np.asarray([[0.0], [1.4], [6.0]]))) 144 | 145 | with tempfile.NamedTemporaryFile(suffix=".h5") as f: 146 | model.save(f.name) 147 | loaded_model = keras.models.load_model( 148 | f.name, 149 | custom_objects={ 150 | "ParallelCombination": pcl.ParallelCombination, 151 | "Lattice": ll.Lattice, 152 | }, 153 | ) 154 | predictions = loaded_model.predict(test_inputs) 155 | self.assertTrue( 156 | np.allclose(predictions, np.asarray([[0.0], [1.4], [6.0]]))) 157 | 158 | 159 | if __name__ == "__main__": 160 | tf.test.main() 161 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/rtl_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implementation of algorithms required for RTL layer.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import six 21 | 22 | 23 | def verify_hyperparameters(lattice_size, 24 | input_shape=None, 25 | output_min=None, 26 | output_max=None, 27 | interpolation="hypercube", 28 | parameterization="all_vertices", 29 | kernel_initializer=None, 30 | kernel_regularizer=None): 31 | """Verifies that all given hyperparameters are consistent. 32 | 33 | See `tfl.layers.RTL` class level comment for detailed description of 34 | arguments. 35 | 36 | Args: 37 | lattice_size: Lattice size to check againts. 38 | input_shape: Shape of layer input. 39 | output_min: Minimum output of `RTL` layer. 40 | output_max: Maximum output of `RTL` layer. 41 | interpolation: One of 'simplex' or 'hypercube' interpolation. 42 | parameterization: One of 'all_vertices' or 'kronecker_factored' 43 | parameterizations. 44 | kernel_initializer: Initizlier to check against. 45 | kernel_regularizer: Regularizers to check against. 46 | 47 | Raises: 48 | ValueError: If lattice_size < 2. 49 | KeyError: If input_shape is a dict with incorrect keys. 50 | ValueError: If output_min >= output_max. 51 | ValueError: If interpolation is not one of 'simplex' or 'hypercube'. 52 | ValueError: If parameterization is 'kronecker_factored' and 53 | kernel_initializer is 'linear_initializer'. 54 | ValueError: If parameterization is 'kronecker_factored' and 55 | kernel_regularizer is not None. 56 | ValueError: If kernel_regularizer contains a tuple with len != 3. 57 | ValueError: If kernel_regularizer contains a tuple with non-float l1 value. 58 | ValueError: If kernel_regularizer contains a tuple with non-flaot l2 value. 59 | 60 | """ 61 | if lattice_size < 2: 62 | raise ValueError( 63 | "Lattice size must be at least 2. Given: {}".format(lattice_size)) 64 | 65 | if input_shape: 66 | if isinstance(input_shape, dict): 67 | for key in input_shape: 68 | if key not in ["unconstrained", "increasing"]: 69 | raise KeyError("Input shape keys should be either 'unconstrained' " 70 | "or 'increasing', but seeing: {}".format(key)) 71 | 72 | if output_min is not None and output_max is not None: 73 | if output_min >= output_max: 74 | raise ValueError("'output_min' must be not greater than 'output_max'. " 75 | "'output_min': %f, 'output_max': %f" % 76 | (output_min, output_max)) 77 | 78 | if interpolation not in ["hypercube", "simplex"]: 79 | raise ValueError("RTL interpolation type should be either 'simplex' " 80 | "or 'hypercube': %s" % interpolation) 81 | 82 | if (parameterization == "kronecker_factored" and 83 | kernel_initializer == "linear_initializer"): 84 | raise ValueError("'kronecker_factored' parameterization does not currently " 85 | "support linear iniitalization. 'parameterization': %s, " 86 | "'kernel_initializer': %s" % 87 | (parameterization, kernel_initializer)) 88 | 89 | if (parameterization == "kronecker_factored" and 90 | kernel_regularizer is not None): 91 | raise ValueError("'kronecker_factored' parameterization does not currently " 92 | "support regularization. 'parameterization': %s, " 93 | "'kernel_regularizer': %s" % 94 | (parameterization, kernel_regularizer)) 95 | 96 | if kernel_regularizer: 97 | if isinstance(kernel_regularizer, list): 98 | regularizers = kernel_regularizer 99 | if isinstance(kernel_regularizer[0], six.string_types): 100 | regularizers = [kernel_regularizer] 101 | for regularizer in regularizers: 102 | if len(regularizer) != 3: 103 | raise ValueError("Regularizer tuples/lists must have three elements " 104 | "(type, l1, and l2). Given: {}".format(regularizer)) 105 | _, l1, l2 = regularizer 106 | if not isinstance(l1, float): 107 | raise ValueError( 108 | "Regularizer l1 must be a single float. Given: {}".format( 109 | type(l1))) 110 | if not isinstance(l2, float): 111 | raise ValueError( 112 | "Regularizer l2 must be a single float. Given: {}".format( 113 | type(l2))) 114 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/rtl_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Lattice Layer.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tempfile 20 | from absl.testing import parameterized 21 | import numpy as np 22 | import tensorflow as tf 23 | from tensorflow_lattice.python import linear_layer 24 | from tensorflow_lattice.python import pwl_calibration_layer 25 | from tensorflow_lattice.python import rtl_layer 26 | # pylint: disable=g-import-not-at-top 27 | # Use Keras 2. 28 | version_fn = getattr(tf.keras, "version", None) 29 | if version_fn and version_fn().startswith("3."): 30 | import tf_keras as keras 31 | else: 32 | keras = tf.keras 33 | 34 | 35 | class RTLTest(parameterized.TestCase, tf.test.TestCase): 36 | 37 | def setUp(self): 38 | super(RTLTest, self).setUp() 39 | self.disable_all = False 40 | keras.utils.set_random_seed(42) 41 | 42 | def testRTLInputShapes(self): 43 | if self.disable_all: 44 | return 45 | data_size = 100 46 | 47 | # Dense input format. 48 | a = np.random.random_sample(size=(data_size, 10)) 49 | b = np.random.random_sample(size=(data_size, 20)) 50 | target_ab = ( 51 | np.max(a, axis=1, keepdims=True) + np.min(b, axis=1, keepdims=True)) 52 | 53 | input_a = keras.layers.Input(shape=(10,)) 54 | input_b = keras.layers.Input(shape=(20,)) 55 | 56 | rtl_0 = rtl_layer.RTL(num_lattices=6, lattice_rank=5) 57 | rtl_outputs = rtl_0({"unconstrained": input_a, "increasing": input_b}) 58 | outputs = keras.layers.Dense(1)(rtl_outputs) 59 | model = keras.Model(inputs=[input_a, input_b], outputs=outputs) 60 | model.compile(loss="mse") 61 | model.fit([a, b], target_ab) 62 | model.predict([a, b]) 63 | 64 | # Inputs to be calibrated. 65 | c = np.random.random_sample(size=(data_size, 1)) 66 | d = np.random.random_sample(size=(data_size, 1)) 67 | e = np.random.random_sample(size=(data_size, 1)) 68 | f = np.random.random_sample(size=(data_size, 1)) 69 | target_cdef = np.sin(np.pi * c) * np.cos(np.pi * d) - e * f 70 | 71 | input_c = keras.layers.Input(shape=(1,)) 72 | input_d = keras.layers.Input(shape=(1,)) 73 | input_e = keras.layers.Input(shape=(1,)) 74 | input_f = keras.layers.Input(shape=(1,)) 75 | 76 | input_keypoints = np.linspace(0.0, 1.0, 10) 77 | calib_c = pwl_calibration_layer.PWLCalibration( 78 | units=2, 79 | input_keypoints=input_keypoints, 80 | output_min=0.0, 81 | output_max=1.0)( 82 | input_c) 83 | calib_d = pwl_calibration_layer.PWLCalibration( 84 | units=3, 85 | input_keypoints=input_keypoints, 86 | output_min=0.0, 87 | output_max=1.0)( 88 | input_d) 89 | calib_e = pwl_calibration_layer.PWLCalibration( 90 | units=4, 91 | input_keypoints=input_keypoints, 92 | output_min=0.0, 93 | output_max=1.0, 94 | monotonicity="decreasing")( 95 | input_e) 96 | calib_f = pwl_calibration_layer.PWLCalibration( 97 | units=5, 98 | input_keypoints=input_keypoints, 99 | output_min=0.0, 100 | output_max=1.0, 101 | monotonicity="decreasing")( 102 | input_f) 103 | 104 | rtl_0 = rtl_layer.RTL(num_lattices=10, lattice_rank=3) 105 | rtl_0_outputs = rtl_0({ 106 | "unconstrained": [calib_c, calib_d], 107 | "increasing": [calib_e, calib_f] 108 | }) 109 | outputs = linear_layer.Linear( 110 | num_input_dims=10, monotonicities=[1] * 10)( 111 | rtl_0_outputs) 112 | model = keras.Model( 113 | inputs=[input_c, input_d, input_e, input_f], outputs=outputs 114 | ) 115 | model.compile(loss="mse") 116 | model.fit([c, d, e, f], target_cdef) 117 | model.predict([c, d, e, f]) 118 | 119 | # Two layer RTL model. 120 | rtl_0 = rtl_layer.RTL( 121 | num_lattices=10, 122 | lattice_rank=3, 123 | output_min=0.0, 124 | output_max=1.0, 125 | separate_outputs=True) 126 | rtl_0_outputs = rtl_0({ 127 | "unconstrained": [calib_c, calib_d], 128 | "increasing": [calib_e, calib_f] 129 | }) 130 | rtl_1 = rtl_layer.RTL(num_lattices=3, lattice_rank=4) 131 | rtl_1_outputs = rtl_1(rtl_0_outputs) 132 | outputs = linear_layer.Linear( 133 | num_input_dims=3, monotonicities=[1] * 3)( 134 | rtl_1_outputs) 135 | model = keras.Model( 136 | inputs=[input_c, input_d, input_e, input_f], outputs=outputs 137 | ) 138 | model.compile(loss="mse") 139 | model.fit([c, d, e, f], target_cdef) 140 | model.predict([c, d, e, f]) 141 | 142 | def testRTLOutputShape(self): 143 | if self.disable_all: 144 | return 145 | 146 | # Multiple Outputs Per Lattice 147 | input_shape, output_shape = (30,), (None, 6) 148 | input_a = keras.layers.Input(shape=input_shape) 149 | rtl_0 = rtl_layer.RTL(num_lattices=6, lattice_rank=5) 150 | output = rtl_0(input_a) 151 | self.assertAllEqual(output_shape, rtl_0.compute_output_shape(input_a.shape)) 152 | self.assertAllEqual(output_shape, output.shape) 153 | 154 | # Average Outputs 155 | output_shape = (None, 1) 156 | rtl_1 = rtl_layer.RTL(num_lattices=6, lattice_rank=5, average_outputs=True) 157 | output = rtl_1(input_a) 158 | self.assertAllEqual(output_shape, rtl_1.compute_output_shape(input_a.shape)) 159 | self.assertAllEqual(output_shape, output.shape) 160 | 161 | def testRTLSaveLoad(self): 162 | if self.disable_all: 163 | return 164 | 165 | input_c = keras.layers.Input(shape=(1,)) 166 | input_d = keras.layers.Input(shape=(1,)) 167 | input_e = keras.layers.Input(shape=(1,)) 168 | input_f = keras.layers.Input(shape=(1,)) 169 | 170 | input_keypoints = np.linspace(0.0, 1.0, 10) 171 | calib_c = pwl_calibration_layer.PWLCalibration( 172 | units=2, 173 | input_keypoints=input_keypoints, 174 | output_min=0.0, 175 | output_max=1.0)( 176 | input_c) 177 | calib_d = pwl_calibration_layer.PWLCalibration( 178 | units=3, 179 | input_keypoints=input_keypoints, 180 | output_min=0.0, 181 | output_max=1.0)( 182 | input_d) 183 | calib_e = pwl_calibration_layer.PWLCalibration( 184 | units=4, 185 | input_keypoints=input_keypoints, 186 | output_min=0.0, 187 | output_max=1.0, 188 | monotonicity="decreasing")( 189 | input_e) 190 | calib_f = pwl_calibration_layer.PWLCalibration( 191 | units=5, 192 | input_keypoints=input_keypoints, 193 | output_min=0.0, 194 | output_max=1.0, 195 | monotonicity="decreasing")( 196 | input_f) 197 | 198 | rtl_0 = rtl_layer.RTL( 199 | num_lattices=10, 200 | lattice_rank=3, 201 | output_min=0.0, 202 | output_max=1.0, 203 | separate_outputs=True) 204 | rtl_0_outputs = rtl_0({ 205 | "unconstrained": [calib_c, calib_d], 206 | "increasing": [calib_e, calib_f] 207 | }) 208 | rtl_1 = rtl_layer.RTL(num_lattices=3, lattice_rank=4) 209 | rtl_1_outputs = rtl_1(rtl_0_outputs) 210 | outputs = linear_layer.Linear( 211 | num_input_dims=3, monotonicities=[1] * 3)( 212 | rtl_1_outputs) 213 | model = keras.Model( 214 | inputs=[input_c, input_d, input_e, input_f], outputs=outputs 215 | ) 216 | model.compile(loss="mse") 217 | model.use_legacy_config = True 218 | 219 | with tempfile.NamedTemporaryFile(suffix=".h5") as f: 220 | model.save(f.name) 221 | _ = keras.models.load_model( 222 | f.name, 223 | custom_objects={ 224 | "RTL": rtl_layer.RTL, 225 | "PWLCalibration": pwl_calibration_layer.PWLCalibration, 226 | "Linear": linear_layer.Linear, 227 | }, 228 | ) 229 | 230 | 231 | if __name__ == "__main__": 232 | tf.test.main() 233 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helpers to train simple model for tests and print debug output.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import time 22 | 23 | from absl import logging 24 | import numpy as np 25 | 26 | 27 | class TimeTracker(object): 28 | """Tracks time. 29 | 30 | Keeps track of time spent in its scope and appends it to 'list_to_append' 31 | on exit from scope divided by 'num_steps' if provided. 32 | 33 | Example: 34 | training_step_times = [] 35 | with TimeTracker(training_step_times, num_steps=num_epochs): 36 | model.fit(... epochs=num_epochs ...) 37 | print np.median(training_step_times) 38 | """ 39 | 40 | def __init__(self, list_to_append, num_steps=1): 41 | self._list_to_append = list_to_append 42 | self._num_steps = float(num_steps) 43 | 44 | def __enter__(self): 45 | self._start_time = time.time() 46 | return self 47 | 48 | def __exit__(self, unuesd_type, unuesd_value, unuesd_traceback): 49 | duration = time.time() - self._start_time 50 | self._list_to_append.append( 51 | duration / self._num_steps if self._num_steps else 0.0) 52 | 53 | 54 | def run_training_loop(config, 55 | training_data, 56 | keras_model, 57 | input_dtype=np.float32, 58 | label_dtype=np.float32): 59 | """Trains models and prints debug info. 60 | 61 | Args: 62 | config: dictionary of test case parameters. See tests for TensorFlow Lattice 63 | layers. 64 | training_data: tuple: (training_inputs, labels) where 65 | training_inputs and labels are proper data to train models passed via 66 | other parameters. 67 | keras_model: Keras model to train on training_data. 68 | input_dtype: dtype for input conversion. 69 | label_dtype: dtype for label conversion. 70 | 71 | Returns: 72 | Loss measured on training data and tf.session() if one was initialized 73 | explicitly during training. 74 | """ 75 | (training_inputs, training_labels) = training_data 76 | np_training_inputs = np.asarray(training_inputs).astype(input_dtype) 77 | np_training_labels = np.asarray(training_labels).astype(label_dtype) 78 | 79 | logging.info(" {0: <10}{1: <10}".format("it", "Loss")) 80 | 81 | num_steps = 10 82 | training_step_times = [] 83 | for step in range(num_steps): 84 | begin = (config["num_training_epoch"] * step) // num_steps 85 | end = (config["num_training_epoch"] * (step + 1)) // num_steps 86 | num_epochs = end - begin 87 | if num_epochs == 0: 88 | continue 89 | 90 | loss = keras_model.evaluate(np_training_inputs, np_training_labels, 91 | batch_size=len(np_training_inputs), 92 | verbose=0) 93 | with TimeTracker(training_step_times, num_steps=num_epochs): 94 | keras_model.fit(np_training_inputs, np_training_labels, 95 | batch_size=len(np_training_inputs), 96 | epochs=num_epochs, 97 | verbose=0) 98 | logging.info("{0: <10}{1: <10,.6f}".format(begin, loss)) 99 | # End of: 'for step in range(num_steps):' 100 | 101 | loss = keras_model.evaluate(np_training_inputs, np_training_labels, 102 | batch_size=len(np_training_inputs), 103 | verbose=0) 104 | logging.info("Final loss: %f", loss) 105 | 106 | if training_step_times: 107 | logging.info("Median training step time: %f", 108 | np.median(training_step_times)) 109 | 110 | return loss 111 | 112 | 113 | def two_dim_mesh_grid(num_points, x_min, y_min, x_max, y_max): 114 | """Generates uniform 2-d mesh grid for 3-d surfaces visualisation via pyplot. 115 | 116 | Uniformly distributes 'num_points' within rectangle: 117 | (x_min, y_min) - (x_max, y_max) 118 | 'num_points' should be such that uniform distribution is possible. In other 119 | words there should exist such integers 'x_points' and 'y_points' that: 120 | - x_points * y_points == num_points 121 | - x_points / y_points == (x_max - x_min) / (y_max - y_min) 122 | 123 | Args: 124 | num_points: number of points in the grid. 125 | x_min: bounds of the grid. 126 | y_min: bounds of the grid. 127 | x_max: bounds of the grid. 128 | y_max: bounds of the grid. 129 | 130 | Returns: 131 | Tuple containing 2 numpy arrays which represent X and Y coordinates of mesh 132 | grid 133 | 134 | Raises: 135 | ValueError: if it's impossible to uniformly distribute 'num_points' across 136 | specified grid. 137 | 138 | """ 139 | x_size = x_max - x_min 140 | y_size = y_max - y_min 141 | x_points = (num_points * x_size / y_size)**0.5 142 | y_points = num_points / x_points 143 | 144 | eps = 1e-7 145 | is_int = lambda x: abs(x - int(x + eps)) < eps 146 | if not is_int(x_points) or not is_int(y_points): 147 | raise ValueError("Cannot evenly distribute %d points across sides of " 148 | "lengths: %f and %f" % (num_points, x_size, y_size)) 149 | 150 | x_grid = np.linspace(start=x_min, stop=x_max, num=int(x_points + eps)) 151 | y_grid = np.linspace(start=y_min, stop=y_max, num=int(y_points + eps)) 152 | 153 | # Convert list returned by meshgrid() to tuple so we can easily distinguish 154 | # mesh grid vs list of points. 155 | return tuple(np.meshgrid(x_grid, y_grid)) 156 | 157 | 158 | def sample_uniformly(num_points, lower_bounds, upper_bounds): 159 | """Deterministically generates num_point random points within bounds. 160 | 161 | Points will be such that: 162 | lower_bounds[i] <= p[i] <= upper_bounds[i] 163 | 164 | Number of dimensions is defined by lengths of lower_bounds list. 165 | 166 | Args: 167 | num_points: number of points to generate. 168 | lower_bounds: list or tuple of lower bounds. 169 | upper_bounds: list or tuple of upper bounds. 170 | 171 | Returns: 172 | List of generated points. 173 | """ 174 | if len(lower_bounds) != len(upper_bounds): 175 | raise ValueError("Lower and upper bounds must have same length. They are: " 176 | "lower_bounds: %s, upper_bounds: %s" % 177 | (lower_bounds, upper_bounds)) 178 | np.random.seed(41) 179 | x = [] 180 | for _ in range(num_points): 181 | point = [ 182 | lower + np.random.random() * (upper - lower) 183 | for lower, upper in zip(lower_bounds, upper_bounds) 184 | ] 185 | x.append(np.asarray(point)) 186 | return x 187 | 188 | 189 | def get_hypercube_interpolation_fn(coefficients): 190 | """Returns function which does hypercube interpolation. 191 | 192 | This is only for 2^d lattice aka hypercube. 193 | 194 | Args: 195 | coefficients: coefficients of hypercube ordered according to index of 196 | corresponding vertex. 197 | 198 | Returns: 199 | Function which takes d-dimension point and performs hypercube interpolation 200 | with given coefficients. 201 | """ 202 | 203 | def hypercube_interpolation_fn(x): 204 | """Does hypercube interpolation.""" 205 | if 2**len(x) != len(coefficients): 206 | raise ValueError("Number of coefficients(%d) does not correspond to " 207 | "dimension 'x'(%s)" % (len(coefficients), x)) 208 | result = 0.0 209 | for coefficient_index in range(len(coefficients)): 210 | weight = 1.0 211 | for input_dimension in range(len(x)): 212 | if coefficient_index & (1 << input_dimension): 213 | # If statement checks whether 'input_dimension' bit of 214 | # 'coefficient_index' is set to 1. 215 | weight *= x[input_dimension] 216 | else: 217 | weight *= (1.0 - x[input_dimension]) 218 | result += coefficients[coefficient_index] * weight 219 | return result 220 | 221 | return hypercube_interpolation_fn 222 | 223 | 224 | def get_linear_lattice_interpolation_fn(lattice_sizes, monotonicities, 225 | output_min, output_max): 226 | """Returns function which does lattice interpolation. 227 | 228 | Returned function matches lattice_layer.LinearInitializer with corresponding 229 | parameters. 230 | 231 | Args: 232 | lattice_sizes: list or tuple of integers which represents lattice sizes. 233 | monotonicities: monotonicity constraints. 234 | output_min: minimum output of linear function. 235 | output_max: maximum output of linear function. 236 | 237 | Returns: 238 | Function which takes d-dimension point and performs lattice interpolation 239 | assuming lattice weights are such that lattice represents linear function 240 | with given output_min and output_max. All monotonic dimesions of this linear 241 | function cotribute with same weight despite of numer of vertices per 242 | dimension. All non monotonic dimensions have weight 0.0. 243 | """ 244 | 245 | def linear_interpolation_fn(x): 246 | """Linear along monotonic dims and 0.0 along non monotonic.""" 247 | result = output_min 248 | num_monotonic_dims = len(monotonicities) - monotonicities.count(0) 249 | if num_monotonic_dims == 0: 250 | local_monotonicities = [1] * len(lattice_sizes) 251 | num_monotonic_dims = len(lattice_sizes) 252 | else: 253 | local_monotonicities = monotonicities 254 | 255 | weight = (output_max - output_min) / num_monotonic_dims 256 | for i in range(len(x)): 257 | if local_monotonicities[i]: 258 | result += x[i] * weight / (lattice_sizes[i] - 1.0) 259 | return result 260 | 261 | return linear_interpolation_fn 262 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Helpers shared by multiple modules in TFL.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import six 21 | 22 | 23 | # TODO: update library not to explicitly check if None so we can return 24 | # an empty list instead of None for these canonicalization methods. 25 | def canonicalize_convexity(convexity): 26 | """Converts string constants representing convexity into integers. 27 | 28 | Args: 29 | convexity: The convexity hyperparameter of `tfl.layers.PWLCalibration` 30 | layer. 31 | 32 | Returns: 33 | convexity represented as -1, 0, 1, or None. 34 | 35 | Raises: 36 | ValueError: If convexity is not in the set 37 | {-1, 0, 1, 'concave', 'none', 'convex'}. 38 | """ 39 | if convexity is None: 40 | return None 41 | 42 | if convexity in [-1, 0, 1]: 43 | return convexity 44 | elif isinstance(convexity, six.string_types): 45 | if convexity.lower() == "concave": 46 | return -1 47 | if convexity.lower() == "none": 48 | return 0 49 | if convexity.lower() == "convex": 50 | return 1 51 | raise ValueError("'convexity' must be from: [-1, 0, 1, 'concave', " 52 | "'none', 'convex']. Given: {}".format(convexity)) 53 | 54 | 55 | def canonicalize_input_bounds(input_bounds): 56 | """Converts string constant 'none' representing unspecified bound into None. 57 | 58 | Args: 59 | input_bounds: The input_min or input_max hyperparameter of 60 | `tfl.layers.Linear` layer. 61 | 62 | Returns: 63 | A list of [val, val, ...] where val can be a float or None, or the value 64 | None if input_bounds is None. 65 | 66 | Raises: 67 | ValueError: If one of elements in input_bounds is not a float, None or 68 | 'none'. 69 | """ 70 | if input_bounds: 71 | canonicalized = [] 72 | for item in input_bounds: 73 | if isinstance(item, float) or item is None: 74 | canonicalized.append(item) 75 | elif isinstance(item, six.string_types) and item.lower() == "none": 76 | canonicalized.append(None) 77 | else: 78 | raise ValueError("Both 'input_min' and 'input_max' elements must be " 79 | "either int, float, None, or 'none'. Given: {}".format( 80 | input_bounds)) 81 | return canonicalized 82 | return None 83 | 84 | 85 | def canonicalize_monotonicity(monotonicity, allow_decreasing=True): 86 | """Converts string constants representing monotonicity into integers. 87 | 88 | Args: 89 | monotonicity: The monotonicities hyperparameter of a `tfl.layers` Layer 90 | (e.g. `tfl.layers.PWLCalibration`). 91 | allow_decreasing: If decreasing monotonicity is considered a valid 92 | monotonicity. 93 | 94 | Returns: 95 | monotonicity represented as -1, 0, 1, or None. 96 | 97 | Raises: 98 | ValueError: If monotonicity is not in the set 99 | {-1, 0, 1, 'decreasing', 'none', 'increasing'} and allow_decreasing is 100 | True. 101 | ValueError: If monotonicity is not in the set {0, 1, 'none', 'increasing'} 102 | and allow_decreasing is False. 103 | """ 104 | if monotonicity is None: 105 | return None 106 | 107 | if monotonicity in [-1, 0, 1]: 108 | if not allow_decreasing and monotonicity == -1: 109 | raise ValueError( 110 | "'monotonicities' must be from: [0, 1, 'none', 'increasing']. " 111 | "Given: {}".format(monotonicity)) 112 | return monotonicity 113 | elif isinstance(monotonicity, six.string_types): 114 | if monotonicity.lower() == "decreasing": 115 | if not allow_decreasing: 116 | raise ValueError( 117 | "'monotonicities' must be from: [0, 1, 'none', 'increasing']. " 118 | "Given: {}".format(monotonicity)) 119 | return -1 120 | if monotonicity.lower() == "none": 121 | return 0 122 | if monotonicity.lower() == "increasing": 123 | return 1 124 | raise ValueError("'monotonicities' must be from: [-1, 0, 1, 'decreasing', " 125 | "'none', 'increasing']. Given: {}".format(monotonicity)) 126 | 127 | 128 | def canonicalize_monotonicities(monotonicities, allow_decreasing=True): 129 | """Converts string constants representing monotonicities into integers. 130 | 131 | Args: 132 | monotonicities: monotonicities hyperparameter of a `tfl.layers` Layer (e.g. 133 | `tfl.layers.Lattice`). 134 | allow_decreasing: If decreasing monotonicity is considered a valid 135 | monotonicity. 136 | 137 | Returns: 138 | A list of monotonicities represented as -1, 0, 1, or the value None 139 | if monotonicities is None. 140 | 141 | Raises: 142 | ValueError: If one of monotonicities is not in the set 143 | {-1, 0, 1, 'decreasing', 'none', 'increasing'} and allow_decreasing is 144 | True. 145 | ValueError: If one of monotonicities is not in the set 146 | {0, 1, 'none', 'increasing'} and allow_decreasing is False. 147 | """ 148 | if monotonicities: 149 | return [ 150 | canonicalize_monotonicity( 151 | monotonicity, allow_decreasing=allow_decreasing) 152 | for monotonicity in monotonicities 153 | ] 154 | return None 155 | 156 | 157 | def canonicalize_trust(trusts): 158 | """Converts string constants representing trust direction into integers. 159 | 160 | Args: 161 | trusts: edgeworth_trusts or trapezoid_trusts hyperparameter of 162 | `tfl.layers.Lattice` layer. 163 | 164 | Returns: 165 | A list of trust constraint tuples of the form 166 | (feature_a, feature_b, direction) where direction can be -1 or 1, or the 167 | value None if trusts is None. 168 | 169 | Raises: 170 | ValueError: If one of trust constraints does not have 3 elements. 171 | ValueError: If one of trust constraints' direction is not in the set 172 | {-1, 1, 'negative', 'positive'}. 173 | """ 174 | if trusts: 175 | canonicalized = [] 176 | for trust in trusts: 177 | if len(trust) != 3: 178 | raise ValueError("Trust constraints must consist of 3 elements. Seeing " 179 | "constraint tuple {}".format(trust)) 180 | feature_a, feature_b, direction = trust 181 | if direction in [-1, 1]: 182 | canonicalized.append(trust) 183 | elif (isinstance(direction, six.string_types) and 184 | direction.lower() == "negative"): 185 | canonicalized.append((feature_a, feature_b, -1)) 186 | elif (isinstance(direction, six.string_types) and 187 | direction.lower() == "positive"): 188 | canonicalized.append((feature_a, feature_b, 1)) 189 | else: 190 | raise ValueError("trust constraint direction must be from: [-1, 1, " 191 | "'negative', 'positive']. Given: {}".format(direction)) 192 | return canonicalized 193 | return None 194 | 195 | 196 | def canonicalize_unimodalities(unimodalities): 197 | """Converts string constants representing unimodalities into integers. 198 | 199 | Args: 200 | unimodalities: unimodalities hyperparameter of `tfl.layers.Lattice` layer. 201 | 202 | Returns: 203 | A list of unimodalities represented as -1, 0, 1, or the value None if 204 | unimodalities is None. 205 | 206 | Raises: 207 | ValueError: If one of unimodalities is not in the set 208 | {-1, 0, 1, 'peak', 'none', 'valley'}. 209 | """ 210 | if not unimodalities: 211 | return None 212 | canonicalized = [] 213 | for unimodality in unimodalities: 214 | if unimodality in [-1, 0, 1]: 215 | canonicalized.append(unimodality) 216 | elif isinstance(unimodality, 217 | six.string_types) and unimodality.lower() == "peak": 218 | canonicalized.append(-1) 219 | elif isinstance(unimodality, 220 | six.string_types) and unimodality.lower() == "none": 221 | canonicalized.append(0) 222 | elif isinstance(unimodality, 223 | six.string_types) and unimodality.lower() == "valley": 224 | canonicalized.append(1) 225 | else: 226 | raise ValueError( 227 | "'unimodalities' elements must be from: [-1, 0, 1, 'peak', 'none', " 228 | "'valley']. Given: {}".format(unimodalities)) 229 | return canonicalized 230 | 231 | 232 | def count_non_zeros(*iterables): 233 | """Returns total number of non 0 elements in given iterables. 234 | 235 | Args: 236 | *iterables: Any number of the value None or iterables of numeric values. 237 | """ 238 | result = 0 239 | for iterable in iterables: 240 | if iterable is not None: 241 | result += sum(1 for element in iterable if element != 0) 242 | return result 243 | -------------------------------------------------------------------------------- /tensorflow_lattice/python/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Tensorflow Lattice utility functions.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl.testing import parameterized 21 | import tensorflow as tf 22 | from tensorflow_lattice.python import utils 23 | 24 | 25 | class UtilsTest(parameterized.TestCase, tf.test.TestCase): 26 | 27 | @parameterized.parameters((-1, -1), (0, 0), (1, 1), ("concave", -1), 28 | ("none", 0), ("convex", 1)) 29 | def testCanonicalizeConvexity(self, convexity, 30 | expected_canonicalized_convexity): 31 | canonicalized_convexity = utils.canonicalize_convexity(convexity) 32 | self.assertEqual(canonicalized_convexity, expected_canonicalized_convexity) 33 | 34 | @parameterized.parameters((-2), (0.5), (3), ("invalid_convexity"), 35 | ("concaves"), ("nonw"), ("conve")) 36 | def testInvalidConvexity(self, invalid_convexity): 37 | error_message = ( 38 | "'convexity' must be from: [-1, 0, 1, 'concave', 'none', 'convex']. " 39 | "Given: {}").format(invalid_convexity) 40 | with self.assertRaisesWithLiteralMatch(ValueError, error_message): 41 | utils.canonicalize_convexity(invalid_convexity) 42 | 43 | # Note: must use mapping format because otherwise input parameter list is 44 | # considered multiple parameters (not just a single list parameter). 45 | @parameterized.parameters( 46 | { 47 | "input_bounds": [0.0, -3.0], 48 | "expected_canonicalized_input_bounds": [0.0, -3.0] 49 | }, { 50 | "input_bounds": [float("-inf"), 0.12345], 51 | "expected_canonicalized_input_bounds": [float("-inf"), 0.12345] 52 | }, { 53 | "input_bounds": ["none", None], 54 | "expected_canonicalized_input_bounds": [None, None] 55 | }) 56 | def testCanonicalizeInputBounds(self, input_bounds, 57 | expected_canonicalized_input_bounds): 58 | canonicalized_input_bounds = utils.canonicalize_input_bounds(input_bounds) 59 | self.assertAllEqual(canonicalized_input_bounds, 60 | expected_canonicalized_input_bounds) 61 | 62 | @parameterized.parameters({"invalid_input_bounds": [0, 1.0, 2.0]}, 63 | {"invalid_input_bounds": [None, "nonw"]}) 64 | def testInvalidInputBounds(self, invalid_input_bounds): 65 | error_message = ( 66 | "Both 'input_min' and 'input_max' elements must be either int, float, " 67 | "None, or 'none'. Given: {}").format(invalid_input_bounds) 68 | with self.assertRaisesWithLiteralMatch(ValueError, error_message): 69 | utils.canonicalize_input_bounds(invalid_input_bounds) 70 | 71 | @parameterized.parameters((-1, -1), (0, 0), (1, 1), ("decreasing", -1), 72 | ("none", 0), ("increasing", 1)) 73 | def testCanonicalizeMonotonicity(self, monotonicity, 74 | expected_canonicalized_monotonicity): 75 | canonicalized_monotonicity = utils.canonicalize_monotonicity(monotonicity) 76 | self.assertEqual(canonicalized_monotonicity, 77 | expected_canonicalized_monotonicity) 78 | 79 | @parameterized.parameters((-2), (0.5), (3), ("invalid_monotonicity"), 80 | ("decrease"), ("increase")) 81 | def testInvalidMonotonicity(self, invalid_monotonicity): 82 | error_message = ( 83 | "'monotonicities' must be from: [-1, 0, 1, 'decreasing', 'none', " 84 | "'increasing']. Given: {}").format(invalid_monotonicity) 85 | with self.assertRaisesWithLiteralMatch(ValueError, error_message): 86 | utils.canonicalize_monotonicity(invalid_monotonicity) 87 | 88 | @parameterized.parameters(("decreasing"), (-1)) 89 | def testInvalidDecreasingMonotonicity(self, invalid_monotonicity): 90 | error_message = ( 91 | "'monotonicities' must be from: [0, 1, 'none', 'increasing']. " 92 | "Given: {}").format(invalid_monotonicity) 93 | with self.assertRaisesWithLiteralMatch(ValueError, error_message): 94 | utils.canonicalize_monotonicity( 95 | invalid_monotonicity, allow_decreasing=False) 96 | 97 | # Note: since canonicalize_monotonicities calls canonicalize_monotonicity, 98 | # the above test for invalidity is sufficient. 99 | @parameterized.parameters(([-1, 0, 1], [-1, 0, 1]), 100 | (["decreasing", "none", "increasing"], [-1, 0, 1]), 101 | (["decreasing", -1], [-1, -1]), 102 | (["none", 0], [0, 0]), (["increasing", 1], [1, 1])) 103 | def testCanonicalizeMonotonicities(self, monotonicities, 104 | expected_canonicalized_monotonicities): 105 | canonicalized_monotonicities = utils.canonicalize_monotonicities( 106 | monotonicities) 107 | self.assertAllEqual(canonicalized_monotonicities, 108 | expected_canonicalized_monotonicities) 109 | 110 | @parameterized.parameters(([("a", "b", -1), ("b", "c", 1)], [("a", "b", -1), 111 | ("b", "c", 1)]), 112 | ([("a", "b", "negative"), 113 | ("b", "c", "positive")], [("a", "b", -1), 114 | ("b", "c", 1)])) 115 | def testCanonicalizeTrust(self, trusts, expected_canonicalized_trusts): 116 | canonicalized_trusts = utils.canonicalize_trust(trusts) 117 | self.assertAllEqual(canonicalized_trusts, expected_canonicalized_trusts) 118 | 119 | # Note 1: this test assumes the first trust in the list has the incorrect 120 | # direction. A list with a single trust tuple is sufficient. 121 | # Note 2: must use mapping format because otherwise input parameter list is 122 | # considered multiple parameters (not just a single list parameter). 123 | @parameterized.parameters({"invalid_trusts": [("a", "b", 0)]}, 124 | {"invalid_trusts": [("a", "b", "negativ")]}) 125 | def testInvalidTrustDirection(self, invalid_trusts): 126 | error_message = ( 127 | "trust constraint direction must be from: [-1, 1, 'negative', " 128 | "'positive']. Given: {}").format(invalid_trusts[0][2]) 129 | with self.assertRaisesWithLiteralMatch(ValueError, error_message): 130 | utils.canonicalize_trust(invalid_trusts) 131 | 132 | # Note 1: this test assumes the first trust in the list has the incorrect 133 | # size. A list with a single trust tuple is sufficient. 134 | # Note 2: must use mapping format because otherwise input parameter list is 135 | # considered multiple parameters (not just a single list parameter). 136 | @parameterized.parameters({"invalid_trusts": [("a", 1)]}, 137 | {"invalid_trusts": [("a", "b", -1, 1)]}) 138 | def testInvalidTrustLength(self, invalid_trusts): 139 | error_message = ( 140 | "Trust constraints must consist of 3 elements. Seeing constraint " 141 | "tuple {}").format(invalid_trusts[0]) 142 | with self.assertRaisesWithLiteralMatch(ValueError, error_message): 143 | utils.canonicalize_trust(invalid_trusts) 144 | 145 | @parameterized.parameters(([0, 1, 1, 0], [1, 0], 3), 146 | ([0, 0, 0], [0, 0, 0], 0), 147 | ([-1, 0, 0, 1], [0, 0], 2), 148 | (None, [1, 1, 1, 1, 1], 5)) 149 | def testCountNonZeros(self, monotonicities, unimodalities, 150 | expected_non_zeros): 151 | non_zeros = utils.count_non_zeros(monotonicities, unimodalities) 152 | self.assertEqual(non_zeros, expected_non_zeros) 153 | 154 | @parameterized.parameters( 155 | ([-1, 0, 1], [-1, 0, 1]), (["peak", "none", "valley"], [-1, 0, 1]), 156 | (["peak", -1], [-1, -1]), (["none", 0], [0, 0]), (["valley", 1], [1, 1])) 157 | def testCanonicalizeUnimodalities(self, unimodalities, 158 | expected_canonicalized_unimodalities): 159 | canonicalized_unimodalities = utils.canonicalize_unimodalities( 160 | unimodalities) 161 | self.assertAllEqual(canonicalized_unimodalities, 162 | expected_canonicalized_unimodalities) 163 | 164 | # Note: must use mapping format because otherwise input parameter list is 165 | # considered multiple parameters (not just a single list parameter). 166 | @parameterized.parameters({"invalid_unimodalities": ["vally", 0]}, 167 | {"invalid_unimodalities": [-1, 0, 2]}) 168 | def testInvalidUnimoadlities(self, invalid_unimodalities): 169 | error_message = ( 170 | "'unimodalities' elements must be from: [-1, 0, 1, 'peak', 'none', " 171 | "'valley']. Given: {}").format(invalid_unimodalities) 172 | with self.assertRaisesWithLiteralMatch(ValueError, error_message): 173 | utils.canonicalize_unimodalities(invalid_unimodalities) 174 | 175 | 176 | if __name__ == "__main__": 177 | tf.test.main() 178 | --------------------------------------------------------------------------------