├── .coveragerc ├── .gitignore ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── RELEASE.md ├── WORKSPACE ├── adanet ├── BUILD ├── __init__.py ├── adanet_test.py ├── autoensemble │ ├── BUILD │ ├── __init__.py │ ├── common.py │ ├── estimator.py │ ├── estimator_test.py │ ├── estimator_v2_test.py │ └── tpu_estimator_test.py ├── core │ ├── BUILD │ ├── __init__.py │ ├── architecture.py │ ├── architecture_test.py │ ├── candidate.py │ ├── candidate_test.py │ ├── ensemble_builder.py │ ├── ensemble_builder_test.py │ ├── estimator.py │ ├── estimator_distributed_test.py │ ├── estimator_distributed_test_runner.py │ ├── estimator_test.py │ ├── estimator_v2_test.py │ ├── eval_metrics.py │ ├── eval_metrics_test.py │ ├── evaluator.py │ ├── evaluator_test.py │ ├── iteration.py │ ├── iteration_test.py │ ├── report_accessor.py │ ├── report_accessor_test.py │ ├── report_materializer.py │ ├── report_materializer_test.py │ ├── summary.py │ ├── summary_test.py │ ├── summary_v2_test.py │ ├── testing_utils.py │ ├── timer.py │ ├── timer_test.py │ ├── tpu_estimator.py │ └── tpu_estimator_test.py ├── distributed │ ├── BUILD │ ├── __init__.py │ ├── devices.py │ ├── devices_test.py │ ├── placement.py │ └── placement_test.py ├── ensemble │ ├── BUILD │ ├── __init__.py │ ├── ensembler.py │ ├── mean.py │ ├── mean_test.py │ ├── strategy.py │ ├── strategy_test.py │ ├── weighted.py │ └── weighted_test.py ├── examples │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── simple_dnn.py │ ├── simple_dnn_test.py │ └── tutorials │ │ ├── BUILD │ │ ├── README.md │ │ ├── adanet_objective.ipynb │ │ ├── adanet_tpu.ipynb │ │ ├── customizing_adanet.ipynb │ │ └── customizing_adanet_with_tfhub.ipynb ├── experimental │ ├── BUILD │ ├── __init__.py │ ├── adanet_modelflow_tutorial.ipynb │ ├── controllers │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── controller.py │ │ └── sequential_controller.py │ ├── keras │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── ensemble_model.py │ │ ├── ensemble_model_test.py │ │ ├── model_search.py │ │ ├── model_search_test.py │ │ └── testing_utils.py │ ├── phases │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── autoensemble_phase.py │ │ ├── input_phase.py │ │ ├── keras_trainer_phase.py │ │ ├── keras_tuner_phase.py │ │ ├── phase.py │ │ └── repeat_phase.py │ ├── schedulers │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── in_process_scheduler.py │ │ └── scheduler.py │ ├── storages │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── in_memory_storage.py │ │ └── storage.py │ └── work_units │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── keras_trainer_work_unit.py │ │ ├── keras_tuner_work_unit.py │ │ └── work_unit.py ├── modelflow_test.py ├── pip_package │ ├── BUILD │ ├── PIP.md │ ├── build_pip_package.sh │ ├── setup.cfg │ └── setup.py ├── replay │ ├── BUILD │ └── __init__.py ├── subnetwork │ ├── BUILD │ ├── __init__.py │ ├── generator.py │ ├── generator_test.py │ ├── report.py │ └── report_test.py ├── tf_compat │ ├── BUILD │ └── __init__.py └── version.py ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── custom.css │ ├── adanet.distributed.rst │ ├── adanet.ensemble.rst │ ├── adanet.replay.rst │ ├── adanet.rst │ ├── adanet.subnetwork.rst │ ├── algorithm.md │ ├── assets │ ├── adanet_tangram_logo.png │ ├── adanet_tangram_logo.svg │ ├── candidates.png │ ├── different_complexity_ensemble.svg │ ├── lifecycle.svg │ ├── replication_strategy.svg │ ├── round_robin.svg │ ├── shared_embedding.svg │ └── terminology.svg │ ├── conf.py │ ├── distributed.md │ ├── index.rst │ ├── overview.md │ ├── quick_start.md │ ├── tensorboard.md │ ├── theory.md │ ├── tpu.md │ └── tutorials.md ├── images ├── adanet_animation.gif └── adanet_tangram_logo.png ├── oss_scripts └── oss_pip_install.sh ├── requirements.txt ├── research └── improve_nas │ ├── README.md │ ├── config.yaml │ ├── config_test.yaml │ ├── images │ ├── cif100_caption.png │ ├── cif10_caption.png │ ├── ensemble.png │ ├── ensemble_accuracy_cif10.png │ ├── ensemble_accuracy_cif100.png │ └── search_space.png │ ├── setup.py │ └── trainer │ ├── BUILD │ ├── __init__.py │ ├── adanet_improve_nas.py │ ├── adanet_improve_nas_test.py │ ├── cifar10.py │ ├── cifar100.py │ ├── cifar100_test.py │ ├── cifar10_test.py │ ├── fake_data.py │ ├── image_processing.py │ ├── improve_nas.py │ ├── improve_nas_test.py │ ├── nasnet.py │ ├── nasnet_utils.py │ ├── optimizer.py │ ├── subnetwork_utils.py │ └── trainer.py └── setup.cfg /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = adanet 5 | parallel = True 6 | 7 | [report] 8 | # Regexes for lines to exclude from consideration 9 | exclude_lines = 10 | # Have to re-enable the standard pragma 11 | pragma: no cover 12 | 13 | # Don't complain about missing debug-only code: 14 | def __repr__ 15 | if self\.debug 16 | 17 | # Don't complain if tests don't hit defensive assertion code: 18 | raise AssertionError 19 | raise NotImplementedError 20 | 21 | # Don't complain if non-runnable code isn't run: 22 | if 0: 23 | if __name__ == .__main__.: 24 | 25 | ignore_errors = True 26 | 27 | omit = 28 | *_test.py 29 | */nasnet.py 30 | 31 | [html] 32 | directory = coverage_html_report 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Documentation 2 | docs/build* 3 | 4 | # Compiled python modules. 5 | *.pyc 6 | 7 | # Bazel outputs. 8 | bazel-* 9 | *_pb2.py 10 | 11 | # Byte-compiled 12 | _pycache__/ 13 | .cache/ 14 | 15 | # Python egg metadata, regenerated from source files by setuptools. 16 | /*.egg-info 17 | .eggs/ 18 | 19 | # PyPI distribution artifacts. 20 | build/ 21 | dist/ 22 | 23 | # Sublime project files 24 | *.sublime-project 25 | *.sublime-workspace 26 | 27 | # Tests 28 | .pytest_cache/ 29 | .coverage 30 | 31 | # Pytype 32 | .pytype 33 | 34 | # Other 35 | *.DS_Store 36 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of AdaNet authors for copyright purposes. 2 | # 3 | # Names should be added to this file as: 4 | # Name or Organization 5 | # The email address is not required for organizations. 6 | Google LLC. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | workspace(name = "org_adanet") 16 | 17 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") 18 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 19 | 20 | git_repository( 21 | name = "protobuf_bzl", 22 | # v3.6.1.3 23 | commit = "ab8edf1dbe2237b4717869eaab11a2998541ad8d", 24 | remote = "https://github.com/google/protobuf.git", 25 | ) 26 | 27 | bind( 28 | name = "protobuf", 29 | actual = "@protobuf_bzl//:protobuf", 30 | ) 31 | 32 | bind( 33 | name = "protobuf_python", 34 | actual = "@protobuf_bzl//:protobuf_python", 35 | ) 36 | 37 | bind( 38 | name = "protobuf_python_genproto", 39 | actual = "@protobuf_bzl//:protobuf_python_genproto", 40 | ) 41 | 42 | bind( 43 | name = "protoc", 44 | actual = "@protobuf_bzl//:protoc", 45 | ) 46 | 47 | # Using protobuf version 3.6.1.3 48 | http_archive( 49 | name = "com_google_protobuf", 50 | strip_prefix = "protobuf-3.6.1.3", 51 | urls = ["https://github.com/google/protobuf/archive/v3.6.1.3.zip"], 52 | ) 53 | 54 | # required by protobuf_python 55 | http_archive( 56 | name = "six_archive", 57 | build_file = "@protobuf_bzl//:six.BUILD", 58 | sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", 59 | url = "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz#md5=34eed507548117b2ab523ab14b2f8b55", 60 | ) 61 | 62 | bind( 63 | name = "six", 64 | actual = "@six_archive//:six", 65 | ) 66 | 67 | # Google abseil py library 68 | http_archive( 69 | name = "absl_py", 70 | sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c", 71 | strip_prefix = "abseil-py-pypi-v0.2.2", 72 | urls = [ 73 | "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz", 74 | "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz", 75 | ], 76 | ) 77 | 78 | # TensorFlow models repository for slim. 79 | http_archive( 80 | name = "tf_slim", 81 | strip_prefix = "models-master/research", 82 | urls = ["https://github.com/tensorflow/models/archive/master.zip"], 83 | ) 84 | -------------------------------------------------------------------------------- /adanet/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # adanet is a TensorFlow AutoML framework for designing, training, and serving 3 | # adaptive neural network using the AdaNet algorithm. 4 | # This is a reference implementation of AdaNet as a TensorFlow library. 5 | 6 | licenses(["notice"]) 7 | 8 | exports_files(["LICENSE"]) 9 | 10 | py_library( 11 | name = "adanet", 12 | srcs = [ 13 | "__init__.py", 14 | "version.py", 15 | ], 16 | visibility = ["//visibility:public"], 17 | deps = [ 18 | "//adanet/autoensemble", 19 | "//adanet/core", 20 | "//adanet/distributed", 21 | "//adanet/ensemble", 22 | "//adanet/replay", 23 | "//adanet/subnetwork", 24 | ], 25 | ) 26 | 27 | py_test( 28 | name = "adanet_test", 29 | srcs = ["adanet_test.py"], 30 | srcs_version = "PY3", 31 | deps = [ 32 | ":adanet", 33 | "//adanet/examples:simple_dnn", 34 | ], 35 | ) 36 | 37 | py_test( 38 | name = "modelflow_test", 39 | srcs = ["modelflow_test.py"], 40 | srcs_version = "PY3", 41 | deps = [ 42 | ":adanet", 43 | "//adanet/experimental", 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /adanet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """AdaNet: Fast and flexible AutoML with learning guarantees.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from adanet import distributed 22 | from adanet import ensemble 23 | from adanet import replay 24 | from adanet import subnetwork 25 | from adanet.autoensemble import AutoEnsembleEstimator 26 | from adanet.autoensemble import AutoEnsembleSubestimator 27 | from adanet.autoensemble import AutoEnsembleTPUEstimator 28 | from adanet.core import Estimator 29 | from adanet.core import Evaluator 30 | from adanet.core import ReportMaterializer 31 | from adanet.core import Summary 32 | from adanet.core import TPUEstimator 33 | # For backwards compatibility. Previously all Ensemblers were complexity 34 | # regularized using the AdaNet objective. 35 | from adanet.ensemble import ComplexityRegularized as Ensemble 36 | from adanet.ensemble import MixtureWeightType 37 | from adanet.ensemble import WeightedSubnetwork 38 | from adanet.subnetwork import Subnetwork 39 | 40 | from adanet.version import __version__ 41 | 42 | __all__ = [ 43 | "AutoEnsembleEstimator", 44 | "AutoEnsembleSubestimator", 45 | "AutoEnsembleTPUEstimator", 46 | "distributed", 47 | "ensemble", 48 | "Ensemble", 49 | "Estimator", 50 | "Evaluator", 51 | "replay", 52 | "ReportMaterializer", 53 | "subnetwork", 54 | "Summary", 55 | "TPUEstimator", 56 | "MixtureWeightType", 57 | "WeightedSubnetwork", 58 | "Subnetwork", 59 | ] 60 | -------------------------------------------------------------------------------- /adanet/adanet_test.py: -------------------------------------------------------------------------------- 1 | """Test AdaNet package. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | import adanet 19 | from adanet.examples import simple_dnn 20 | import tensorflow.compat.v1 as tf 21 | 22 | 23 | class AdaNetTest(tf.test.TestCase): 24 | 25 | def test_public(self): 26 | self.assertIsNotNone(adanet.__version__) 27 | self.assertIsNotNone(adanet.AutoEnsembleEstimator) 28 | self.assertIsNotNone(adanet.AutoEnsembleSubestimator) 29 | self.assertIsNotNone(adanet.AutoEnsembleTPUEstimator) 30 | self.assertIsNotNone(adanet.distributed.PlacementStrategy) 31 | self.assertIsNotNone(adanet.distributed.ReplicationStrategy) 32 | self.assertIsNotNone(adanet.distributed.RoundRobinStrategy) 33 | self.assertIsNotNone(adanet.ensemble.Ensemble) 34 | self.assertIsNotNone(adanet.ensemble.Ensembler) 35 | self.assertIsNotNone(adanet.ensemble.TrainOpSpec) 36 | self.assertIsNotNone(adanet.ensemble.AllStrategy) 37 | self.assertIsNotNone(adanet.ensemble.Candidate) 38 | self.assertIsNotNone(adanet.ensemble.GrowStrategy) 39 | self.assertIsNotNone(adanet.ensemble.Strategy) 40 | self.assertIsNotNone(adanet.ensemble.ComplexityRegularized) 41 | self.assertIsNotNone(adanet.ensemble.ComplexityRegularizedEnsembler) 42 | self.assertIsNotNone(adanet.ensemble.MeanEnsemble) 43 | self.assertIsNotNone(adanet.ensemble.MeanEnsembler) 44 | self.assertIsNotNone(adanet.ensemble.MixtureWeightType) 45 | self.assertIsNotNone(adanet.ensemble.WeightedSubnetwork) 46 | self.assertIsNotNone(adanet.Ensemble) 47 | self.assertIsNotNone(adanet.Estimator) 48 | self.assertIsNotNone(adanet.Evaluator) 49 | self.assertIsNotNone(adanet.MixtureWeightType) 50 | self.assertIsNotNone(adanet.replay.Config) 51 | self.assertIsNotNone(adanet.ReportMaterializer) 52 | self.assertIsNotNone(adanet.Subnetwork) 53 | self.assertIsNotNone(adanet.subnetwork.Builder) 54 | self.assertIsNotNone(adanet.subnetwork.Generator) 55 | self.assertIsNotNone(adanet.subnetwork.Subnetwork) 56 | self.assertIsNotNone(adanet.subnetwork.TrainOpSpec) 57 | self.assertIsNotNone(adanet.Summary) 58 | self.assertIsNotNone(adanet.TPUEstimator) 59 | self.assertIsNotNone(adanet.WeightedSubnetwork) 60 | self.assertIsNotNone(simple_dnn.Generator) 61 | 62 | if __name__ == "__main__": 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /adanet/autoensemble/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Auto-ensemble logic. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | py_library( 9 | name = "autoensemble", 10 | srcs = ["__init__.py"], 11 | visibility = ["//adanet:__subpackages__"], 12 | deps = [ 13 | ":common", 14 | ":estimator", 15 | ], 16 | ) 17 | 18 | py_library( 19 | name = "common", 20 | srcs = ["common.py"], 21 | visibility = ["//adanet:__subpackages__"], 22 | deps = [ 23 | "//adanet/subnetwork", 24 | "//adanet/tf_compat", 25 | ], 26 | ) 27 | 28 | py_library( 29 | name = "estimator", 30 | srcs = ["estimator.py"], 31 | deps = [ 32 | ":common", 33 | "//adanet/core", 34 | ], 35 | ) 36 | 37 | py_test( 38 | name = "estimator_test", 39 | size = "large", 40 | srcs = ["estimator_test.py"], 41 | shard_count = 5, 42 | deps = [ 43 | ":estimator", 44 | "//adanet/tf_compat", 45 | "@absl_py//absl/testing:parameterized", 46 | ], 47 | ) 48 | 49 | py_test( 50 | name = "tpu_estimator_test", 51 | size = "large", 52 | srcs = ["tpu_estimator_test.py"], 53 | shard_count = 3, 54 | deps = [ 55 | ":estimator", 56 | "//adanet/tf_compat", 57 | "@absl_py//absl/flags", 58 | "@absl_py//absl/testing:parameterized", 59 | ], 60 | ) 61 | 62 | py_test( 63 | name = "estimator_v2_test", 64 | size = "large", 65 | srcs = ["estimator_v2_test.py"], 66 | shard_count = 5, 67 | deps = [ 68 | ":estimator", 69 | "@absl_py//absl/testing:parameterized", 70 | ], 71 | ) 72 | -------------------------------------------------------------------------------- /adanet/autoensemble/__init__.py: -------------------------------------------------------------------------------- 1 | """The TensorFlow AdaNet autoensemble module. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from adanet.autoensemble.common import AutoEnsembleSubestimator 23 | from adanet.autoensemble.estimator import AutoEnsembleEstimator 24 | from adanet.autoensemble.estimator import AutoEnsembleTPUEstimator 25 | 26 | __all__ = [ 27 | "AutoEnsembleEstimator", 28 | "AutoEnsembleSubestimator", 29 | "AutoEnsembleTPUEstimator", 30 | ] 31 | -------------------------------------------------------------------------------- /adanet/core/__init__.py: -------------------------------------------------------------------------------- 1 | """TensorFLow AdaNet core logic. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from adanet.core.estimator import Estimator 23 | from adanet.core.evaluator import Evaluator 24 | from adanet.core.report_materializer import ReportMaterializer 25 | from adanet.core.summary import Summary 26 | from adanet.core.tpu_estimator import TPUEstimator 27 | 28 | __all__ = [ 29 | "Estimator", 30 | "Evaluator", 31 | "ReportMaterializer", 32 | "Summary", 33 | "TPUEstimator", 34 | ] 35 | -------------------------------------------------------------------------------- /adanet/core/architecture_test.py: -------------------------------------------------------------------------------- 1 | """Test for the AdaNet architecture. 2 | 3 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from adanet.core.architecture import _Architecture 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | class ArchitectureTest(parameterized.TestCase, tf.test.TestCase): 28 | 29 | @parameterized.named_parameters({ 30 | "testcase_name": "empty", 31 | "subnetworks": [], 32 | "want": (), 33 | }, { 34 | "testcase_name": "single", 35 | "subnetworks": [(0, "linear")], 36 | "want": ((0, "linear"),), 37 | }, { 38 | "testcase_name": "different_iterations", 39 | "subnetworks": [(0, "linear"), (1, "dnn")], 40 | "want": ((0, "linear"), (1, "dnn")), 41 | }, { 42 | "testcase_name": "same_iterations", 43 | "subnetworks": [(0, "linear"), (0, "dnn"), (1, "dnn")], 44 | "want": ((0, "linear"), (0, "dnn"), (1, "dnn")), 45 | }) 46 | def test_subnetworks(self, subnetworks, want): 47 | arch = _Architecture("foo", "dummy_ensembler_name") 48 | for subnetwork in subnetworks: 49 | arch.add_subnetwork(*subnetwork) 50 | self.assertEqual(want, arch.subnetworks) 51 | 52 | @parameterized.named_parameters({ 53 | "testcase_name": "empty", 54 | "subnetworks": [], 55 | "want": (), 56 | }, { 57 | "testcase_name": "single", 58 | "subnetworks": [(0, "linear")], 59 | "want": ((0, ("linear",)),), 60 | }, { 61 | "testcase_name": "different_iterations", 62 | "subnetworks": [(0, "linear"), (1, "dnn")], 63 | "want": ((0, ("linear",)), (1, ("dnn",))), 64 | }, { 65 | "testcase_name": "same_iterations", 66 | "subnetworks": [(0, "linear"), (0, "dnn"), (1, "dnn")], 67 | "want": ((0, ("linear", "dnn")), (1, ("dnn",))), 68 | }) 69 | def test_subnetworks_grouped_by_iteration(self, subnetworks, want): 70 | arch = _Architecture("foo", "dummy_ensembler_name") 71 | for subnetwork in subnetworks: 72 | arch.add_subnetwork(*subnetwork) 73 | self.assertEqual(want, arch.subnetworks_grouped_by_iteration) 74 | 75 | def test_set_and_add_replay_index(self): 76 | arch = _Architecture("foo", "dummy_ensembler_name") 77 | arch.set_replay_indices([1, 2, 3]) 78 | self.assertAllEqual([1, 2, 3], arch.replay_indices) 79 | arch.add_replay_index(4) 80 | self.assertAllEqual([1, 2, 3, 4], arch.replay_indices) 81 | 82 | def test_serialization_lifecycle(self): 83 | arch = _Architecture("foo", "dummy_ensembler_name", replay_indices=[1, 2]) 84 | arch.add_subnetwork(0, "linear") 85 | arch.add_subnetwork(0, "dnn") 86 | arch.add_subnetwork(1, "dnn") 87 | self.assertEqual("foo", arch.ensemble_candidate_name) 88 | self.assertEqual("dummy_ensembler_name", arch.ensembler_name) 89 | self.assertEqual(((0, ("linear", "dnn")), (1, ("dnn",))), 90 | arch.subnetworks_grouped_by_iteration) 91 | iteration_number = 2 92 | global_step = 100 93 | serialized = arch.serialize(iteration_number, global_step) 94 | self.assertEqual( 95 | '{"ensemble_candidate_name": "foo", "ensembler_name": ' 96 | '"dummy_ensembler_name", "global_step": 100, "iteration_number": 2, ' 97 | '"replay_indices": [1, 2], ' 98 | '"subnetworks": [{"builder_name": "linear", "iteration_number": 0}, ' 99 | '{"builder_name": "dnn", "iteration_number": 0},' 100 | ' {"builder_name": "dnn", "iteration_number": 1}]}', serialized) 101 | deserialized_arch = _Architecture.deserialize(serialized) 102 | self.assertEqual(arch.ensemble_candidate_name, 103 | deserialized_arch.ensemble_candidate_name) 104 | self.assertEqual(arch.ensembler_name, 105 | deserialized_arch.ensembler_name) 106 | self.assertEqual(arch.subnetworks_grouped_by_iteration, 107 | deserialized_arch.subnetworks_grouped_by_iteration) 108 | self.assertEqual(global_step, deserialized_arch.global_step) 109 | 110 | 111 | if __name__ == "__main__": 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /adanet/core/candidate_test.py: -------------------------------------------------------------------------------- 1 | """Test AdaNet single graph candidate implementation. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import contextlib 23 | 24 | from absl.testing import parameterized 25 | from adanet import tf_compat 26 | from adanet.core.candidate import _Candidate 27 | from adanet.core.candidate import _CandidateBuilder 28 | import adanet.core.testing_utils as tu 29 | import tensorflow.compat.v2 as tf 30 | # pylint: disable=g-direct-tensorflow-import 31 | from tensorflow.python.eager import context 32 | from tensorflow.python.framework import test_util 33 | # pylint: enable=g-direct-tensorflow-import 34 | 35 | 36 | class CandidateTest(parameterized.TestCase, tf.test.TestCase): 37 | 38 | @parameterized.named_parameters({ 39 | "testcase_name": "valid", 40 | "ensemble_spec": tu.dummy_ensemble_spec("foo"), 41 | "adanet_loss": [.1], 42 | }) 43 | @test_util.run_in_graph_and_eager_modes 44 | def test_new(self, ensemble_spec, adanet_loss, variables=None): 45 | with self.test_session(): 46 | got = _Candidate(ensemble_spec, adanet_loss, variables) 47 | self.assertEqual(got.ensemble_spec, ensemble_spec) 48 | self.assertEqual(got.adanet_loss, adanet_loss) 49 | 50 | @parameterized.named_parameters( 51 | { 52 | "testcase_name": "none_ensemble_spec", 53 | "ensemble_spec": None, 54 | "adanet_loss": [.1], 55 | }, { 56 | "testcase_name": "none_adanet_loss", 57 | "ensemble_spec": tu.dummy_ensemble_spec("foo"), 58 | "adanet_loss": None, 59 | }) 60 | @test_util.run_in_graph_and_eager_modes 61 | def test_new_errors(self, ensemble_spec, adanet_loss, variables=None): 62 | with self.test_session(): 63 | with self.assertRaises(ValueError): 64 | _Candidate(ensemble_spec, adanet_loss, variables) 65 | 66 | 67 | class _FakeSummary(object): 68 | """A fake adanet.Summary.""" 69 | 70 | def scalar(self, name, tensor, family=None): 71 | del name 72 | del tensor 73 | del family 74 | return "fake_scalar" 75 | 76 | @contextlib.contextmanager 77 | def current_scope(self): 78 | yield 79 | 80 | 81 | class CandidateBuilderTest(parameterized.TestCase, tf.test.TestCase): 82 | 83 | @parameterized.named_parameters( 84 | { 85 | "testcase_name": "evaluate", 86 | "training": False, 87 | "want_adanet_losses": [0., 0., 0.], 88 | }, { 89 | "testcase_name": "train_exactly_max_steps", 90 | "training": True, 91 | "want_adanet_losses": [1., .750, .583], 92 | }, { 93 | "testcase_name": "train_one_step_max_one_step", 94 | "training": True, 95 | "want_adanet_losses": [1.], 96 | }, { 97 | "testcase_name": "train_two_steps_max_two_steps", 98 | "training": True, 99 | "want_adanet_losses": [1., .750], 100 | }, { 101 | "testcase_name": "train_three_steps_max_four_steps", 102 | "training": True, 103 | "want_adanet_losses": [1., .750, .583], 104 | }, { 105 | "testcase_name": "eval_one_step", 106 | "training": False, 107 | "want_adanet_losses": [0.], 108 | }) 109 | @test_util.run_in_graph_and_eager_modes 110 | def test_build_candidate(self, training, want_adanet_losses): 111 | # `Cadidate#build_candidate` will only ever be called in graph mode. 112 | with context.graph_mode(): 113 | # A fake adanet_loss that halves at each train step: 1.0, 0.5, 0.25, ... 114 | fake_adanet_loss = tf.Variable(1.) 115 | fake_train_op = fake_adanet_loss.assign(fake_adanet_loss / 2) 116 | fake_ensemble_spec = tu.dummy_ensemble_spec( 117 | "new", adanet_loss=fake_adanet_loss, train_op=fake_train_op) 118 | 119 | builder = _CandidateBuilder() 120 | candidate = builder.build_candidate( 121 | ensemble_spec=fake_ensemble_spec, 122 | training=training, 123 | summary=_FakeSummary()) 124 | self.evaluate(tf_compat.v1.global_variables_initializer()) 125 | adanet_losses = [] 126 | for _ in range(len(want_adanet_losses)): 127 | adanet_loss = self.evaluate(candidate.adanet_loss) 128 | adanet_losses.append(adanet_loss) 129 | self.evaluate(fake_train_op) 130 | 131 | # Verify that adanet_loss moving average works. 132 | self.assertAllClose(want_adanet_losses, adanet_losses, atol=1e-3) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /adanet/core/evaluator.py: -------------------------------------------------------------------------------- 1 | """An AdaNet evaluator implementation in Tensorflow using a single graph. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | 24 | from absl import logging 25 | from adanet import tf_compat 26 | import numpy as np 27 | import tensorflow.compat.v2 as tf 28 | 29 | 30 | # TODO: Remove uses of Evaluator once AdaNet Ranker is implemented. 31 | class Evaluator(object): 32 | """Evaluates candidate ensemble performance.""" 33 | 34 | class Objective(object): 35 | """The Evaluator objective for the metric being optimized. 36 | 37 | Two objectives are currently supported: 38 | - MINIMIZE: Lower is better for the metric being optimized. 39 | - MAXIMIZE: Higher is better for the metric being optimized. 40 | """ 41 | 42 | MINIMIZE = "minimize" 43 | MAXIMIZE = "maximize" 44 | 45 | def __init__(self, 46 | input_fn, 47 | metric_name="adanet_loss", 48 | objective=Objective.MINIMIZE, 49 | steps=None): 50 | """Initializes a new Evaluator instance. 51 | 52 | Args: 53 | input_fn: Input function returning a tuple of: features - Dictionary of 54 | string feature name to `Tensor`. labels - `Tensor` of labels. 55 | metric_name: The name of the evaluation metrics to use when choosing the 56 | best ensemble. Must refer to a valid evaluation metric. 57 | objective: Either `Objective.MINIMIZE` or `Objective.MAXIMIZE`. 58 | steps: Number of steps for which to evaluate the ensembles. If an 59 | `OutOfRangeError` occurs, evaluation stops. If set to None, will iterate 60 | the dataset until all inputs are exhausted. 61 | 62 | Returns: 63 | An :class:`adanet.Evaluator` instance. 64 | """ 65 | self._input_fn = input_fn 66 | self._steps = steps 67 | self._metric_name = metric_name 68 | self._objective = objective 69 | if objective == self.Objective.MINIMIZE: 70 | self._objective_fn = np.nanargmin 71 | elif objective == self.Objective.MAXIMIZE: 72 | self._objective_fn = np.nanargmax 73 | else: 74 | raise ValueError( 75 | "Evaluator objective must be one of MINIMIZE or MAXIMIZE.") 76 | 77 | @property 78 | def input_fn(self): 79 | """Return the input_fn.""" 80 | return self._input_fn 81 | 82 | @property 83 | def steps(self): 84 | """Return the number of evaluation steps.""" 85 | return self._steps 86 | 87 | @property 88 | def metric_name(self): 89 | """Returns the name of the metric being optimized.""" 90 | return self._metric_name 91 | 92 | @property 93 | def objective_fn(self): 94 | """Returns a fn which selects the best metric based on the objective.""" 95 | return self._objective_fn 96 | 97 | def evaluate(self, sess, ensemble_metrics): 98 | """Evaluates the given AdaNet objectives on the data from `input_fn`. 99 | 100 | The candidates are fed the same batches of features and labels as 101 | provided by `input_fn`, and their losses are computed and summed over 102 | `steps` batches. 103 | 104 | Args: 105 | sess: `Session` instance with most recent variable values loaded. 106 | ensemble_metrics: A list dictionaries of `tf.metrics` for each candidate 107 | ensemble. 108 | 109 | Returns: 110 | List of evaluated metrics. 111 | """ 112 | 113 | evals_completed = 0 114 | if self.steps is None: 115 | logging_frequency = 1000 116 | elif self.steps < 10: 117 | logging_frequency = 1 118 | else: 119 | logging_frequency = math.floor(self.steps / 10.) 120 | 121 | objective_metrics = [em[self._metric_name] for em in ensemble_metrics] 122 | 123 | sess.run(tf_compat.v1.local_variables_initializer()) 124 | while True: 125 | if self.steps is not None and evals_completed == self.steps: 126 | break 127 | try: 128 | evals_completed += 1 129 | if (evals_completed % logging_frequency == 0 or 130 | self.steps == evals_completed): 131 | logging.info("Ensemble evaluation [%d/%s]", evals_completed, 132 | self.steps or "??") 133 | sess.run(objective_metrics) 134 | except tf.errors.OutOfRangeError: 135 | logging.info("Encountered end of input after %d evaluations", 136 | evals_completed) 137 | break 138 | 139 | # Evaluating the first element is idempotent for metric tuples. 140 | return sess.run([metric[0] for metric in objective_metrics]) 141 | -------------------------------------------------------------------------------- /adanet/core/timer.py: -------------------------------------------------------------------------------- 1 | """A simple timer implementation. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import time 23 | 24 | 25 | class _CountDownTimer(object): 26 | """A simple count down timer implementation.""" 27 | 28 | def __init__(self, duration_secs): 29 | """Initializes a `_CountDownTimer`. 30 | 31 | Args: 32 | duration_secs: Float seconds for countdown. 33 | 34 | Returns: 35 | A `_CountDownTimer` instance. 36 | """ 37 | 38 | self._start_time_secs = time.time() 39 | self._duration_secs = duration_secs 40 | 41 | def secs_remaining(self): 42 | """Returns the remaining countdown seconds.""" 43 | 44 | diff = self._duration_secs - (time.time() - self._start_time_secs) 45 | return max(0., diff) 46 | -------------------------------------------------------------------------------- /adanet/core/timer_test.py: -------------------------------------------------------------------------------- 1 | """Tests for timer. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import time 23 | 24 | from adanet.core.timer import _CountDownTimer 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class CountDownTimerTest(tf.test.TestCase): 29 | 30 | def test_secs_remaining_long(self): 31 | timer = _CountDownTimer(60) 32 | time.sleep(.1) 33 | secs_remaining = timer.secs_remaining() 34 | self.assertLess(0., secs_remaining) 35 | self.assertGreater(60., secs_remaining) 36 | 37 | def test_secs_remaining_short(self): 38 | timer = _CountDownTimer(.001) 39 | time.sleep(.1) 40 | secs_remaining = timer.secs_remaining() 41 | self.assertEqual(0., secs_remaining) 42 | 43 | def test_secs_remaining_zero(self): 44 | timer = _CountDownTimer(0.) 45 | time.sleep(.01) 46 | secs_remaining = timer.secs_remaining() 47 | self.assertEqual(0., secs_remaining) 48 | 49 | 50 | if __name__ == "__main__": 51 | tf.test.main() 52 | -------------------------------------------------------------------------------- /adanet/distributed/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet distributed logic. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | py_library( 9 | name = "distributed", 10 | srcs = ["__init__.py"], 11 | visibility = ["//adanet:__subpackages__"], 12 | deps = [ 13 | ":devices", 14 | ":placement", 15 | ], 16 | ) 17 | 18 | py_library( 19 | name = "devices", 20 | srcs = ["devices.py"], 21 | visibility = ["//adanet/core:__subpackages__"], 22 | deps = [ 23 | ], 24 | ) 25 | 26 | py_test( 27 | name = "devices_test", 28 | srcs = ["devices_test.py"], 29 | deps = [ 30 | ":devices", 31 | "@absl_py//absl/testing:parameterized", 32 | ], 33 | ) 34 | 35 | py_library( 36 | name = "placement", 37 | srcs = ["placement.py"], 38 | deps = [ 39 | ":devices", 40 | "//adanet/tf_compat", 41 | "@absl_py//absl/logging", 42 | "@six_archive//:six", 43 | ], 44 | ) 45 | 46 | py_test( 47 | name = "placement_test", 48 | srcs = ["placement_test.py"], 49 | deps = [ 50 | ":placement", 51 | "@absl_py//absl/testing:parameterized", 52 | ], 53 | ) 54 | -------------------------------------------------------------------------------- /adanet/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The `adanet.distributed` package. 15 | 16 | This package methods for distributing computation using the TensorFlow 17 | computation graph. 18 | """ 19 | 20 | # TODO: Add more details documentation. 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | from adanet.distributed.placement import PlacementStrategy 27 | from adanet.distributed.placement import ReplicationStrategy 28 | from adanet.distributed.placement import RoundRobinStrategy 29 | 30 | __all__ = [ 31 | "PlacementStrategy", 32 | "ReplicationStrategy", 33 | "RoundRobinStrategy", 34 | ] 35 | -------------------------------------------------------------------------------- /adanet/distributed/devices.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Device placement functions.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import contextlib 21 | import hashlib 22 | 23 | 24 | class _OpNameHashStrategy(object): 25 | """Returns the ps task index for placement using a hash of the op name.""" 26 | 27 | def __init__(self, num_tasks): 28 | """Create a new `_OpNameHashStrategy`. 29 | 30 | Args: 31 | num_tasks: Number of ps tasks to cycle among. 32 | """ 33 | 34 | self._num_tasks = num_tasks 35 | 36 | def __call__(self, op): 37 | """Choose a ps task index for the given `Operation`. 38 | 39 | Hashes the op name and assigns it to a ps task modulo the number of tasks. 40 | This ensures that variables with the same name are always placed on the same 41 | parameter server. 42 | 43 | Args: 44 | op: An `Operation` to be placed on ps. 45 | 46 | Returns: 47 | The ps task index to use for the `Operation`. 48 | """ 49 | 50 | hashed = int(hashlib.sha256(op.name.encode("utf-8")).hexdigest(), 16) 51 | return hashed % self._num_tasks 52 | 53 | 54 | @contextlib.contextmanager 55 | def monkey_patch_default_variable_placement_strategy(): 56 | """Monkey patches the default variable placement strategy. 57 | 58 | This strategy is used by tf.train.replica_device_setter. The new strategy 59 | allows workers to having different graphs from the chief. 60 | 61 | Yields: 62 | A context with the monkey-patched default variable placement strategy. 63 | """ 64 | 65 | # Import here to avoid strict BUILD deps check. 66 | from tensorflow.python.training import device_setter # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top 67 | old_round_robin_strategy = device_setter._RoundRobinStrategy # pylint: disable=protected-access 68 | setattr(device_setter, "_RoundRobinStrategy", _OpNameHashStrategy) 69 | try: 70 | yield 71 | finally: 72 | setattr(device_setter, "_RoundRobinStrategy", old_round_robin_strategy) 73 | -------------------------------------------------------------------------------- /adanet/distributed/devices_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Device placement function tests.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl.testing import parameterized 21 | from adanet.distributed.devices import monkey_patch_default_variable_placement_strategy 22 | 23 | import tensorflow.compat.v2 as tf 24 | # pylint: disable=g-direct-tensorflow-import 25 | from tensorflow.python.eager import context 26 | from tensorflow.python.framework import test_util 27 | # pylint: enable=g-direct-tensorflow-import 28 | 29 | 30 | class DevicesTest(parameterized.TestCase, tf.test.TestCase): 31 | 32 | @test_util.run_in_graph_and_eager_modes 33 | def test_monkey_patch_default_variable_placement_strategy_no_ps(self): 34 | with context.graph_mode(): 35 | with monkey_patch_default_variable_placement_strategy(): 36 | device_fn = tf.compat.v1.train.replica_device_setter(ps_tasks=0) 37 | self.assertIsNone(device_fn) 38 | 39 | @parameterized.named_parameters( 40 | { 41 | "testcase_name": 42 | "one_ps", 43 | "num_tasks": 44 | 1, 45 | "op_names": ["foo", "bar", "baz"], 46 | "before_want_ps": 47 | ["/job:ps/task:0", "/job:ps/task:0", "/job:ps/task:0"], 48 | "after_want_ps": 49 | ["/job:ps/task:0", "/job:ps/task:0", "/job:ps/task:0"], 50 | }, { 51 | "testcase_name": 52 | "three_ps", 53 | "num_tasks": 54 | 3, 55 | "op_names": ["foo", "bar", "baz"], 56 | "before_want_ps": 57 | ["/job:ps/task:0", "/job:ps/task:1", "/job:ps/task:2"], 58 | "after_want_ps": 59 | ["/job:ps/task:2", "/job:ps/task:0", "/job:ps/task:1"], 60 | }, { 61 | "testcase_name": 62 | "reverse_three_ps", 63 | "num_tasks": 64 | 3, 65 | "op_names": ["baz", "bar", "foo"], 66 | "before_want_ps": 67 | ["/job:ps/task:0", "/job:ps/task:1", "/job:ps/task:2"], 68 | "after_want_ps": 69 | ["/job:ps/task:1", "/job:ps/task:0", "/job:ps/task:2"], 70 | }, { 71 | "testcase_name": 72 | "six_ps", 73 | "num_tasks": 74 | 6, 75 | "op_names": ["foo", "bar", "baz"], 76 | "before_want_ps": 77 | ["/job:ps/task:0", "/job:ps/task:1", "/job:ps/task:2"], 78 | "after_want_ps": 79 | ["/job:ps/task:2", "/job:ps/task:3", "/job:ps/task:4"], 80 | }, { 81 | "testcase_name": 82 | "reverse_six_ps", 83 | "num_tasks": 84 | 6, 85 | "op_names": ["baz", "bar", "foo"], 86 | "before_want_ps": 87 | ["/job:ps/task:0", "/job:ps/task:1", "/job:ps/task:2"], 88 | "after_want_ps": 89 | ["/job:ps/task:4", "/job:ps/task:3", "/job:ps/task:2"], 90 | }) 91 | @test_util.run_in_graph_and_eager_modes 92 | def test_monkey_patch_default_variable_placement_strategy( 93 | self, num_tasks, op_names, before_want_ps, after_want_ps): 94 | """Checks that ps placement is based on var name.""" 95 | 96 | with context.graph_mode(): 97 | var_ops = [tf.Variable(0., name=op_name).op for op_name in op_names] 98 | before_device_fn = tf.compat.v1.train.replica_device_setter( 99 | ps_tasks=num_tasks) 100 | self.assertEqual(before_want_ps, [before_device_fn(op) for op in var_ops]) 101 | 102 | with monkey_patch_default_variable_placement_strategy(): 103 | after_device_fn = tf.compat.v1.train.replica_device_setter( 104 | ps_tasks=num_tasks) 105 | self.assertEqual(after_want_ps, [after_device_fn(op) for op in var_ops]) 106 | 107 | # Check that monkey-patch is only for the context. 108 | before_device_fn = tf.compat.v1.train.replica_device_setter( 109 | ps_tasks=num_tasks) 110 | self.assertEqual(before_want_ps, [before_device_fn(op) for op in var_ops]) 111 | 112 | 113 | if __name__ == "__main__": 114 | tf.test.main() 115 | -------------------------------------------------------------------------------- /adanet/ensemble/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet ensemble logic. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | py_library( 9 | name = "ensemble", 10 | srcs = ["__init__.py"], 11 | visibility = ["//adanet:__subpackages__"], 12 | deps = [ 13 | ":ensembler", 14 | ":mean", 15 | ":strategy", 16 | ":weighted", 17 | ], 18 | ) 19 | 20 | py_library( 21 | name = "ensembler", 22 | srcs = ["ensembler.py"], 23 | deps = [ 24 | "@six_archive//:six", 25 | ], 26 | ) 27 | 28 | py_library( 29 | name = "strategy", 30 | srcs = ["strategy.py"], 31 | deps = [ 32 | "@six_archive//:six", 33 | ], 34 | ) 35 | 36 | py_test( 37 | name = "strategy_test", 38 | srcs = ["strategy_test.py"], 39 | deps = [ 40 | ":ensemble", 41 | "//adanet/subnetwork", 42 | ], 43 | ) 44 | 45 | py_library( 46 | name = "weighted", 47 | srcs = ["weighted.py"], 48 | deps = [ 49 | ":ensembler", 50 | "//adanet/tf_compat", 51 | "@absl_py//absl/logging", 52 | ], 53 | ) 54 | 55 | py_test( 56 | name = "weighted_test", 57 | srcs = ["weighted_test.py"], 58 | deps = [ 59 | ":ensemble", 60 | "//adanet/core", 61 | "//adanet/subnetwork", 62 | "//adanet/tf_compat", 63 | "@absl_py//absl/testing:parameterized", 64 | ], 65 | ) 66 | 67 | py_library( 68 | name = "mean", 69 | srcs = ["mean.py"], 70 | deps = [ 71 | ":ensembler", 72 | "@absl_py//absl/logging", 73 | ], 74 | ) 75 | 76 | py_test( 77 | name = "mean_test", 78 | srcs = ["mean_test.py"], 79 | deps = [ 80 | ":ensemble", 81 | "//adanet/core", 82 | "//adanet/subnetwork", 83 | "@absl_py//absl/testing:parameterized", 84 | ], 85 | ) 86 | -------------------------------------------------------------------------------- /adanet/ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines built-in ensemble methods and interfaces for custom ensembles.""" 16 | 17 | # TODO: Add more detailed documentation. 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from adanet.ensemble.ensembler import Ensemble 24 | from adanet.ensemble.ensembler import Ensembler 25 | from adanet.ensemble.ensembler import TrainOpSpec 26 | from adanet.ensemble.mean import MeanEnsemble 27 | from adanet.ensemble.mean import MeanEnsembler 28 | from adanet.ensemble.strategy import AllStrategy 29 | from adanet.ensemble.strategy import Candidate 30 | from adanet.ensemble.strategy import GrowStrategy 31 | from adanet.ensemble.strategy import SoloStrategy 32 | from adanet.ensemble.strategy import Strategy 33 | from adanet.ensemble.weighted import ComplexityRegularized 34 | from adanet.ensemble.weighted import ComplexityRegularizedEnsembler 35 | from adanet.ensemble.weighted import MixtureWeightType 36 | from adanet.ensemble.weighted import WeightedSubnetwork 37 | 38 | __all__ = [ 39 | "Ensemble", 40 | "Ensembler", 41 | "TrainOpSpec", 42 | "AllStrategy", 43 | "Candidate", 44 | "GrowStrategy", 45 | "SoloStrategy", 46 | "Strategy", 47 | "ComplexityRegularized", 48 | "ComplexityRegularizedEnsembler", 49 | "MeanEnsemble", 50 | "MeanEnsembler", 51 | "MixtureWeightType", 52 | "WeightedSubnetwork", 53 | ] 54 | -------------------------------------------------------------------------------- /adanet/ensemble/strategy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Search strategy algorithms.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import abc 21 | import collections 22 | 23 | import six 24 | 25 | 26 | class Candidate( 27 | collections.namedtuple("Candidate", [ 28 | "name", "subnetwork_builders", "previous_ensemble_subnetwork_builders" 29 | ])): 30 | """An ensemble candidate found during the search phase. 31 | 32 | Args: 33 | name: String name of this ensemble candidate. 34 | subnetwork_builders: Candidate :class:`adanet.subnetwork.Builder` instances 35 | to include in the ensemble. 36 | previous_ensemble_subnetwork_builders: :class:`adanet.subnetwork.Builder` 37 | instances to include from the previous ensemble. 38 | """ 39 | 40 | def __new__(cls, name, subnetwork_builders, 41 | previous_ensemble_subnetwork_builders): 42 | return super(Candidate, cls).__new__( 43 | cls, 44 | name=name, 45 | subnetwork_builders=tuple(subnetwork_builders), 46 | previous_ensemble_subnetwork_builders=tuple( 47 | previous_ensemble_subnetwork_builders or [])) 48 | 49 | 50 | @six.add_metaclass(abc.ABCMeta) 51 | class Strategy(object): # pytype: disable=ignored-metaclass 52 | """An abstract ensemble strategy.""" 53 | 54 | __metaclass__ = abc.ABCMeta 55 | 56 | @abc.abstractmethod 57 | def generate_ensemble_candidates(self, subnetwork_builders, 58 | previous_ensemble_subnetwork_builders): 59 | """Generates ensemble candidates to search over this iteration. 60 | 61 | Args: 62 | subnetwork_builders: Candidate :class:`adanet.subnetwork.Builder` 63 | instances for this iteration. 64 | previous_ensemble_subnetwork_builders: :class:`adanet.subnetwork.Builder` 65 | instances from the previous ensemble. Including only a subset of these 66 | in a returned :class:`adanet.ensemble.Candidate` is equivalent to 67 | pruning the previous ensemble. 68 | 69 | Returns: 70 | An iterable of :class:`adanet.ensemble.Candidate` instances to train and 71 | consider this iteration. 72 | """ 73 | 74 | # TODO: Pruning the previous subnetwork may require more metadata 75 | # such as `subnetwork.Reports` and `ensemble.Reports` to make smart 76 | # decisions. 77 | 78 | 79 | class SoloStrategy(Strategy): 80 | """Produces a model composed of a single subnetwork. 81 | 82 | *An ensemble of one.* 83 | 84 | This is effectively the same as pruning all previous ensemble subnetworks, 85 | and only adding one subnetwork candidate to the ensemble. 86 | """ 87 | 88 | def generate_ensemble_candidates(self, subnetwork_builders, 89 | previous_ensemble_subnetwork_builders): 90 | return [ 91 | Candidate("{}_solo".format(subnetwork_builder.name), 92 | [subnetwork_builder], None) 93 | for subnetwork_builder in subnetwork_builders 94 | ] 95 | 96 | 97 | class GrowStrategy(Strategy): 98 | """Greedily grows an ensemble, one subnetwork at a time.""" 99 | 100 | def generate_ensemble_candidates(self, subnetwork_builders, 101 | previous_ensemble_subnetwork_builders): 102 | return [ 103 | Candidate("{}_grow".format(subnetwork_builder.name), 104 | [subnetwork_builder], previous_ensemble_subnetwork_builders) 105 | for subnetwork_builder in subnetwork_builders 106 | ] 107 | 108 | 109 | class AllStrategy(Strategy): 110 | """Ensembles all subnetworks from the current iteration.""" 111 | 112 | def generate_ensemble_candidates(self, subnetwork_builders, 113 | previous_ensemble_subnetwork_builders): 114 | return [ 115 | Candidate("all", subnetwork_builders, 116 | previous_ensemble_subnetwork_builders) 117 | ] 118 | -------------------------------------------------------------------------------- /adanet/ensemble/strategy_test.py: -------------------------------------------------------------------------------- 1 | """Test AdaNet single graph subnetwork implementation. 2 | 3 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from adanet import ensemble 23 | from adanet import subnetwork 24 | import mock 25 | import tensorflow.compat.v2 as tf 26 | # pylint: disable=g-direct-tensorflow-import 27 | from tensorflow.python.framework import test_util 28 | # pylint: enable=g-direct-tensorflow-import 29 | 30 | 31 | class StrategyTest(tf.test.TestCase): 32 | 33 | def setUp(self): 34 | self.fake_builder_1 = mock.create_autospec(spec=subnetwork.Builder) 35 | self.fake_builder_2 = mock.create_autospec(spec=subnetwork.Builder) 36 | self.fake_builder_3 = mock.create_autospec(spec=subnetwork.Builder) 37 | self.fake_builder_4 = mock.create_autospec(spec=subnetwork.Builder) 38 | 39 | @test_util.run_in_graph_and_eager_modes 40 | def test_solo_strategy(self): 41 | want = [ 42 | ensemble.Candidate("{}_solo".format(self.fake_builder_1.name), 43 | [self.fake_builder_1], []), 44 | ensemble.Candidate("{}_solo".format(self.fake_builder_2.name), 45 | [self.fake_builder_2], []) 46 | ] 47 | got = ensemble.SoloStrategy().generate_ensemble_candidates( 48 | [self.fake_builder_1, self.fake_builder_2], None) 49 | 50 | self.assertEqual(want, got) 51 | 52 | @test_util.run_in_graph_and_eager_modes 53 | def test_solo_strategy_with_previous_ensemble_subnetwork_builders(self): 54 | want = [ 55 | ensemble.Candidate("{}_solo".format(self.fake_builder_1.name), 56 | [self.fake_builder_1], []), 57 | ensemble.Candidate("{}_solo".format(self.fake_builder_2.name), 58 | [self.fake_builder_2], []) 59 | ] 60 | got = ensemble.SoloStrategy().generate_ensemble_candidates( 61 | [self.fake_builder_1, self.fake_builder_2], 62 | [self.fake_builder_3, self.fake_builder_4]) 63 | 64 | self.assertEqual(want, got) 65 | 66 | @test_util.run_in_graph_and_eager_modes 67 | def test_grow_strategy(self): 68 | want = [ 69 | ensemble.Candidate("{}_grow".format(self.fake_builder_1.name), 70 | [self.fake_builder_1], []), 71 | ensemble.Candidate("{}_grow".format(self.fake_builder_2.name), 72 | [self.fake_builder_2], []) 73 | ] 74 | got = ensemble.GrowStrategy().generate_ensemble_candidates( 75 | [self.fake_builder_1, self.fake_builder_2], None) 76 | self.assertEqual(want, got) 77 | 78 | @test_util.run_in_graph_and_eager_modes 79 | def test_grow_strategy_with_previous_ensemble_subnetwork_builders(self): 80 | want = [ 81 | ensemble.Candidate("{}_grow".format(self.fake_builder_1.name), 82 | [self.fake_builder_1], 83 | [self.fake_builder_3, self.fake_builder_4]), 84 | ensemble.Candidate("{}_grow".format(self.fake_builder_2.name), 85 | [self.fake_builder_2], 86 | [self.fake_builder_3, self.fake_builder_4]) 87 | ] 88 | got = ensemble.GrowStrategy().generate_ensemble_candidates( 89 | [self.fake_builder_1, self.fake_builder_2], 90 | [self.fake_builder_3, self.fake_builder_4]) 91 | self.assertEqual(want, got) 92 | 93 | @test_util.run_in_graph_and_eager_modes 94 | def test_all_strategy(self): 95 | want = [ 96 | ensemble.Candidate("all", [self.fake_builder_1, self.fake_builder_2], 97 | []) 98 | ] 99 | got = ensemble.AllStrategy().generate_ensemble_candidates( 100 | [self.fake_builder_1, self.fake_builder_2], None) 101 | self.assertEqual(want, got) 102 | 103 | @test_util.run_in_graph_and_eager_modes 104 | def test_all_strategy_with_previous_ensemble_subnetwork_builders(self): 105 | want = [ 106 | ensemble.Candidate("all", [self.fake_builder_1, self.fake_builder_2], 107 | [self.fake_builder_3, self.fake_builder_4]) 108 | ] 109 | got = ensemble.AllStrategy().generate_ensemble_candidates( 110 | [self.fake_builder_1, self.fake_builder_2], 111 | [self.fake_builder_3, self.fake_builder_4]) 112 | self.assertEqual(want, got) 113 | 114 | 115 | if __name__ == "__main__": 116 | tf.test.main() 117 | -------------------------------------------------------------------------------- /adanet/examples/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | exports_files(["LICENSE"]) 4 | 5 | py_library( 6 | name = "examples", 7 | srcs = ["__init__.py"], 8 | srcs_version = "PY3", 9 | visibility = ["//visibility:public"], 10 | deps = [ 11 | ":simple_dnn", 12 | ], 13 | ) 14 | 15 | py_library( 16 | name = "simple_dnn", 17 | srcs = ["simple_dnn.py"], 18 | srcs_version = "PY3", 19 | visibility = ["//visibility:public"], 20 | deps = [ 21 | "//adanet", 22 | ], 23 | ) 24 | 25 | py_test( 26 | name = "simple_dnn_test", 27 | srcs = ["simple_dnn_test.py"], 28 | srcs_version = "PY3", 29 | deps = [ 30 | ":simple_dnn", 31 | "@absl_py//absl/testing:parameterized", 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /adanet/examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains some example user-defined subnetworks, generators, and complexity measures for AdaNets. 4 | -------------------------------------------------------------------------------- /adanet/examples/__init__.py: -------------------------------------------------------------------------------- 1 | """Some examples using AdaNet. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | -------------------------------------------------------------------------------- /adanet/examples/tutorials/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | exports_files(["LICENSE"]) 4 | -------------------------------------------------------------------------------- /adanet/examples/tutorials/README.md: -------------------------------------------------------------------------------- 1 | ## Tutorials 2 | 3 | Welcome to `adanet`! For a tour of this python package's capabilities, please work through the following notebooks: 4 | 5 | 1. [The AdaNet objective](./adanet_objective.ipynb) 6 | 1. [Customizing AdaNet](./customizing_adanet.ipynb) 7 | 1. [AdaNet on TPU](./adanet_tpu.ipynb) 8 | -------------------------------------------------------------------------------- /adanet/experimental/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet experimental code. 3 | # **HIGHLY EXPERIMENTAL AND SUBJECT TO CHANGE** 4 | 5 | licenses(["notice"]) # Apache 2.0 6 | 7 | py_library( 8 | name = "experimental", 9 | srcs = ["__init__.py"], 10 | srcs_version = "PY3", 11 | visibility = ["//visibility:public"], 12 | deps = [ 13 | "//adanet/experimental/controllers", 14 | "//adanet/experimental/keras", 15 | "//adanet/experimental/phases", 16 | "//adanet/experimental/schedulers", 17 | "//adanet/experimental/storages", 18 | "//adanet/experimental/work_units", 19 | ], 20 | ) 21 | -------------------------------------------------------------------------------- /adanet/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """AdaNet experimental directory.""" 16 | 17 | from adanet.experimental import controllers 18 | from adanet.experimental import keras 19 | from adanet.experimental import phases 20 | from adanet.experimental import schedulers 21 | from adanet.experimental import storages 22 | from adanet.experimental import work_units 23 | 24 | 25 | __all__ = [ 26 | "controllers", 27 | "keras", 28 | "phases", 29 | "schedulers", 30 | "storages", 31 | "work_units", 32 | ] 33 | -------------------------------------------------------------------------------- /adanet/experimental/controllers/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet controller. 3 | # **HIGHLY EXPERIMENTAL AND SUBJECT TO CHANGE** 4 | 5 | licenses(["notice"]) # Apache 2.0 6 | 7 | exports_files(["LICENSE"]) 8 | 9 | py_library( 10 | name = "controllers", 11 | srcs = ["__init__.py"], 12 | srcs_version = "PY3", 13 | visibility = ["//adanet:__subpackages__"], 14 | deps = [":sequential_controller"], 15 | ) 16 | 17 | py_library( 18 | name = "controller", 19 | srcs = ["controller.py"], 20 | srcs_version = "PY3", 21 | visibility = ["//adanet:__subpackages__"], 22 | deps = [ 23 | "//adanet/experimental/storages:storage", 24 | "//adanet/experimental/work_units:work_unit", 25 | ], 26 | ) 27 | 28 | py_library( 29 | name = "sequential_controller", 30 | srcs = ["sequential_controller.py"], 31 | srcs_version = "PY3", 32 | visibility = ["//adanet:__subpackages__"], 33 | deps = [ 34 | ":controller", 35 | "//adanet/experimental/phases:phase", 36 | "//adanet/experimental/storages:in_memory_storage", 37 | "//adanet/experimental/storages:storage", 38 | "//adanet/experimental/work_units:work_unit", 39 | ], 40 | ) 41 | -------------------------------------------------------------------------------- /adanet/experimental/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """AdaNet ModelFlow controllers.""" 16 | 17 | from adanet.experimental.controllers.sequential_controller import SequentialController 18 | 19 | 20 | __all__ = [ 21 | "SequentialController", 22 | ] 23 | -------------------------------------------------------------------------------- /adanet/experimental/controllers/controller.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The AutoML controller for AdaNet.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import abc 22 | from typing import Iterator, Sequence 23 | 24 | from adanet.experimental.work_units.work_unit import WorkUnit 25 | import tensorflow.compat.v2 as tf 26 | 27 | 28 | class Controller(abc.ABC): 29 | """Defines the machine learning workflow to produce high-quality models.""" 30 | 31 | @abc.abstractmethod 32 | def work_units(self) -> Iterator[WorkUnit]: 33 | """Yields `WorkUnit` instances.""" 34 | pass 35 | 36 | @abc.abstractmethod 37 | def get_best_models(self, num_models) -> Sequence[tf.keras.Model]: 38 | """Returns the top models produced from executing the controller.""" 39 | pass 40 | -------------------------------------------------------------------------------- /adanet/experimental/controllers/sequential_controller.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A manual controller for model search.""" 16 | 17 | from typing import Iterator, Sequence 18 | from adanet.experimental.controllers.controller import Controller 19 | from adanet.experimental.phases.phase import ModelProvider 20 | from adanet.experimental.phases.phase import Phase 21 | from adanet.experimental.work_units.work_unit import WorkUnit 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | class SequentialController(Controller): 26 | """A controller where the user specifies the sequences of phase to execute.""" 27 | 28 | # TODO: Add checks to make sure phases are valid. 29 | def __init__(self, phases: Sequence[Phase]): 30 | """Initializes a SequentialController. 31 | 32 | Args: 33 | phases: A list of `Phase` instances. 34 | """ 35 | 36 | self._phases = phases 37 | 38 | def work_units(self) -> Iterator[WorkUnit]: 39 | previous_phase = None 40 | for phase in self._phases: 41 | for work_unit in phase.work_units(previous_phase): 42 | yield work_unit 43 | previous_phase = phase 44 | 45 | def get_best_models(self, num_models: int) -> Sequence[tf.keras.Model]: 46 | final_phase = self._phases[-1] 47 | if isinstance(final_phase, ModelProvider): 48 | return self._phases[-1].get_best_models(num_models) 49 | raise RuntimeError('Final phase does not provide models.') 50 | -------------------------------------------------------------------------------- /adanet/experimental/keras/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Main Keras API logic. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | py_library( 7 | name = "keras", 8 | srcs = ["__init__.py"], 9 | srcs_version = "PY3", 10 | visibility = ["//adanet:__subpackages__"], 11 | deps = [ 12 | ":ensemble_model", 13 | ":model_search", 14 | ], 15 | ) 16 | 17 | py_library( 18 | name = "ensemble_model", 19 | srcs = ["ensemble_model.py"], 20 | srcs_version = "PY3", 21 | visibility = ["//adanet:__subpackages__"], 22 | deps = [ 23 | ], 24 | ) 25 | 26 | py_library( 27 | name = "model_search", 28 | srcs = ["model_search.py"], 29 | srcs_version = "PY3", 30 | visibility = ["//adanet:__subpackages__"], 31 | deps = [ 32 | "//adanet/experimental/controllers:controller", 33 | "//adanet/experimental/schedulers:in_process_scheduler", 34 | "//adanet/experimental/schedulers:scheduler", 35 | ], 36 | ) 37 | 38 | py_test( 39 | name = "model_search_test", 40 | size = "large", 41 | srcs = ["model_search_test.py"], 42 | srcs_version = "PY3", 43 | deps = [ 44 | ":ensemble_model", 45 | ":model_search", 46 | ":testing_utils", 47 | "//adanet/experimental/controllers:sequential_controller", 48 | "//adanet/experimental/phases:autoensemble_phase", 49 | "//adanet/experimental/phases:input_phase", 50 | "//adanet/experimental/phases:keras_trainer_phase", 51 | "//adanet/experimental/phases:keras_tuner_phase", 52 | "//adanet/experimental/phases:repeat_phase", 53 | "//adanet/experimental/storages:in_memory_storage", 54 | "@absl_py//absl/flags", 55 | "@absl_py//absl/testing:parameterized", 56 | ], 57 | ) 58 | 59 | py_library( 60 | name = "testing_utils", 61 | srcs = ["testing_utils.py"], 62 | srcs_version = "PY3", 63 | deps = [ 64 | ], 65 | ) 66 | 67 | py_test( 68 | name = "ensemble_model_test", 69 | srcs = ["ensemble_model_test.py"], 70 | srcs_version = "PY3", 71 | deps = [ 72 | ":ensemble_model", 73 | ":testing_utils", 74 | "@absl_py//absl/testing:parameterized", 75 | ], 76 | ) 77 | -------------------------------------------------------------------------------- /adanet/experimental/keras/__init__.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """AdaNet Keras models.""" 16 | 17 | from adanet.experimental.keras.ensemble_model import EnsembleModel 18 | from adanet.experimental.keras.ensemble_model import MeanEnsemble 19 | from adanet.experimental.keras.ensemble_model import WeightedEnsemble 20 | from adanet.experimental.keras.model_search import ModelSearch 21 | 22 | 23 | __all__ = [ 24 | "EnsembleModel", 25 | "MeanEnsemble", 26 | "WeightedEnsemble", 27 | "ModelSearch", 28 | ] 29 | -------------------------------------------------------------------------------- /adanet/experimental/keras/ensemble_model.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """An AdaNet ensemble implementation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from typing import Sequence 22 | 23 | import tensorflow.compat.v2 as tf 24 | 25 | 26 | class EnsembleModel(tf.keras.Model): 27 | """An ensemble of Keras models.""" 28 | 29 | def __init__(self, submodels: Sequence[tf.keras.Model], 30 | freeze_submodels: bool = True): 31 | """Initializes an EnsembleModel. 32 | 33 | Args: 34 | submodels: A list of `tf.keras.Model` that compose the ensemble. 35 | freeze_submodels: Whether to freeze the weights of submodels. 36 | """ 37 | 38 | super().__init__() 39 | if freeze_submodels: 40 | for submodel in submodels: 41 | for layer in submodel.layers: 42 | layer.trainable = False 43 | self._submodels = submodels 44 | 45 | @property 46 | def submodels(self) -> Sequence[tf.keras.Model]: 47 | return self._submodels 48 | 49 | def call(self, inputs): 50 | raise NotImplementedError 51 | 52 | 53 | class MeanEnsemble(EnsembleModel): 54 | """An ensemble that averages submodel outputs.""" 55 | 56 | def call(self, inputs): 57 | if len(self._submodels) == 1: 58 | return self._submodels[0](inputs) 59 | 60 | submodel_outputs = [] 61 | for submodel in self._submodels: 62 | submodel_outputs.append(submodel(inputs)) 63 | return tf.keras.layers.average(submodel_outputs) 64 | 65 | 66 | class WeightedEnsemble(EnsembleModel): 67 | """An ensemble that linearly combines submodel outputs.""" 68 | 69 | # TODO: Extract output shapes from submodels instead of passing in 70 | # as argument. 71 | def __init__(self, submodels: Sequence[tf.keras.Model], output_units: int): 72 | """Initializes a WeightedEnsemble. 73 | 74 | Args: 75 | submodels: A list of `adanet.keras.SubModel` that compose the ensemble. 76 | output_units: The output size of the last layer of each submodel. 77 | """ 78 | 79 | super().__init__(submodels) 80 | self.dense = tf.keras.layers.Dense(units=output_units) 81 | 82 | def call(self, inputs): 83 | submodel_outputs = [] 84 | for submodel in self.submodels: 85 | submodel_outputs.append(submodel(inputs)) 86 | return self.dense(tf.stack(submodel_outputs)) 87 | -------------------------------------------------------------------------------- /adanet/experimental/keras/ensemble_model_test.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for adanet.experimental.keras.EnsembleModel.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import parameterized 22 | from adanet.experimental.keras import testing_utils 23 | from adanet.experimental.keras.ensemble_model import MeanEnsemble 24 | from adanet.experimental.keras.ensemble_model import WeightedEnsemble 25 | 26 | import tensorflow.compat.v2 as tf 27 | 28 | 29 | class EnsembleModelTest(parameterized.TestCase, tf.test.TestCase): 30 | 31 | @parameterized.named_parameters( 32 | { 33 | 'testcase_name': 'mean_ensemble', 34 | 'ensemble': MeanEnsemble, 35 | 'want_results': [0.07671691, 0.20448962], 36 | }, { 37 | 'testcase_name': 'weighted_ensemble', 38 | 'ensemble': WeightedEnsemble, 39 | 'output_units': 2, 40 | 'want_results': [0.42579408, 0.53439462], 41 | }) 42 | def test_lifecycle(self, ensemble, want_results, output_units=None): 43 | train_dataset, test_dataset = testing_utils.get_holdout_data( 44 | train_samples=128, 45 | test_samples=64, 46 | input_shape=(10,), 47 | num_classes=2, 48 | random_seed=42) 49 | 50 | # TODO: Consider performing `tf.data.Dataset` transformations 51 | # within get_test_data function. 52 | train_dataset = train_dataset.batch(32).repeat(10) 53 | test_dataset = test_dataset.batch(32).repeat(10) 54 | 55 | model1 = tf.keras.Sequential([ 56 | tf.keras.layers.Dense(64, activation='relu'), 57 | tf.keras.layers.Dense(64, activation='relu'), 58 | tf.keras.layers.Dense(2), 59 | ]) 60 | model1.compile( 61 | optimizer=tf.keras.optimizers.Adam(0.01), 62 | loss='mse') 63 | model1.fit(train_dataset) 64 | model1.trainable = False # Since models inside ensemble should be trained. 65 | model1_pre_train_weights = model1.get_weights() 66 | 67 | model2 = tf.keras.Sequential([ 68 | tf.keras.layers.Dense(64, activation='relu'), 69 | tf.keras.layers.Dense(64, activation='relu'), 70 | tf.keras.layers.Dense(2), 71 | ]) 72 | model2.compile( 73 | optimizer=tf.keras.optimizers.Adam(0.01), 74 | loss='mse') 75 | model2.fit(train_dataset) 76 | model2.trainable = False # Since models inside ensemble should be trained. 77 | model2_pre_train_weights = model2.get_weights() 78 | 79 | if output_units: 80 | ensemble = ensemble(submodels=[model1, model2], 81 | output_units=output_units) 82 | else: 83 | ensemble = ensemble(submodels=[model1, model2]) 84 | ensemble.compile( 85 | optimizer=tf.keras.optimizers.Adam(0.01), 86 | loss='mse', 87 | metrics=['mae']) 88 | 89 | ensemble.fit(train_dataset) 90 | 91 | # Make sure submodel weights were not altered during ensemble training. 92 | model1_post_train_weights = model1.get_weights() 93 | model2_post_train_weights = model2.get_weights() 94 | self.assertAllClose(model1_pre_train_weights, model1_post_train_weights) 95 | self.assertAllClose(model2_pre_train_weights, model2_post_train_weights) 96 | 97 | eval_results = ensemble.evaluate(test_dataset) 98 | self.assertAllClose(eval_results, want_results) 99 | 100 | 101 | if __name__ == '__main__': 102 | tf.enable_v2_behavior() 103 | tf.test.main() 104 | -------------------------------------------------------------------------------- /adanet/experimental/keras/model_search.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """An AdaNet interface for model search.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from typing import Sequence 22 | 23 | from adanet.experimental.controllers.controller import Controller 24 | from adanet.experimental.schedulers.in_process_scheduler import InProcessScheduler 25 | from adanet.experimental.schedulers.scheduler import Scheduler 26 | import tensorflow.compat.v2 as tf 27 | 28 | 29 | class ModelSearch(object): 30 | """An AutoML pipeline manager.""" 31 | 32 | def __init__(self, 33 | controller: Controller, 34 | scheduler: Scheduler = InProcessScheduler()): 35 | """Initializes a ModelSearch. 36 | 37 | Args: 38 | controller: A `Controller` instance. 39 | scheduler: A `Scheduler` instance. 40 | """ 41 | 42 | self._controller = controller 43 | self._scheduler = scheduler 44 | 45 | def run(self): 46 | """Executes the training workflow to generate models.""" 47 | self._scheduler.schedule(self._controller.work_units()) 48 | 49 | def get_best_models(self, num_models) -> Sequence[tf.keras.Model]: 50 | """Returns the top models from the run.""" 51 | return self._controller.get_best_models(num_models) 52 | -------------------------------------------------------------------------------- /adanet/experimental/keras/testing_utils.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utilities for unit-testing AdaNet Keras.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from typing import Optional, Tuple 22 | 23 | import numpy as np 24 | import tensorflow.compat.v2 as tf 25 | 26 | 27 | # TODO: Add ability to choose the problem type: regression, 28 | # classification, multi-class etc. 29 | def get_holdout_data( 30 | train_samples: int, 31 | test_samples: int, 32 | input_shape: Tuple[int], 33 | num_classes: int, 34 | random_seed: Optional[int] = None 35 | ) -> Tuple[tf.data.Dataset, tf.data.Dataset]: 36 | """Generates training and test data. 37 | 38 | Args: 39 | train_samples: Number of training samples to generate. 40 | test_samples: Number of training samples to generate. 41 | input_shape: Shape of the inputs. 42 | num_classes: Number of classes for the data and targets. 43 | random_seed: A random seed for numpy to use. 44 | 45 | Returns: 46 | A tuple of `tf.data.Datasets`. 47 | """ 48 | if random_seed: 49 | np.random.seed(random_seed) 50 | 51 | num_sample = train_samples + test_samples 52 | templates = 2 * num_classes * np.random.random((num_classes,) + input_shape) 53 | y = np.random.randint(0, num_classes, size=(num_sample,)) 54 | x = np.zeros((num_sample,) + input_shape, dtype=np.float32) 55 | for i in range(num_sample): 56 | x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape) 57 | 58 | train_dataset = tf.data.Dataset.from_tensor_slices( 59 | (x[:train_samples], y[:train_samples])) 60 | test_dataset = tf.data.Dataset.from_tensor_slices( 61 | (x[train_samples:], y[train_samples:])) 62 | return train_dataset, test_dataset 63 | -------------------------------------------------------------------------------- /adanet/experimental/phases/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet phases. 3 | # **HIGHLY EXPERIMENTAL AND SUBJECT TO CHANGE** 4 | 5 | licenses(["notice"]) # Apache 2.0 6 | 7 | exports_files(["LICENSE"]) 8 | 9 | py_library( 10 | name = "phases", 11 | srcs = ["__init__.py"], 12 | srcs_version = "PY3", 13 | visibility = ["//visibility:public"], 14 | deps = [ 15 | ":autoensemble_phase", 16 | ":input_phase", 17 | ":keras_trainer_phase", 18 | ":keras_tuner_phase", 19 | ":repeat_phase", 20 | ], 21 | ) 22 | 23 | py_library( 24 | name = "phase", 25 | srcs = ["phase.py"], 26 | srcs_version = "PY3", 27 | visibility = ["//adanet:__subpackages__"], 28 | deps = [ 29 | "//adanet/experimental/storages:in_memory_storage", 30 | "//adanet/experimental/storages:storage", 31 | "//adanet/experimental/work_units:work_unit", 32 | ], 33 | ) 34 | 35 | py_library( 36 | name = "keras_trainer_phase", 37 | srcs = ["keras_trainer_phase.py"], 38 | srcs_version = "PY3", 39 | visibility = ["//adanet:__subpackages__"], 40 | deps = [ 41 | ":phase", 42 | "//adanet/experimental/storages:in_memory_storage", 43 | "//adanet/experimental/storages:storage", 44 | "//adanet/experimental/work_units:keras_trainer_work_unit", 45 | "//adanet/experimental/work_units:work_unit", 46 | ], 47 | ) 48 | 49 | py_library( 50 | name = "keras_tuner_phase", 51 | srcs = ["keras_tuner_phase.py"], 52 | srcs_version = "PY3", 53 | visibility = ["//adanet/experimental:__subpackages__"], 54 | deps = [ 55 | ":phase", 56 | "//adanet/experimental/work_units:keras_tuner_work_unit", 57 | "//adanet/experimental/work_units:work_unit", 58 | ], 59 | ) 60 | 61 | py_library( 62 | name = "input_phase", 63 | srcs = ["input_phase.py"], 64 | srcs_version = "PY3", 65 | visibility = ["//adanet/experimental:__subpackages__"], 66 | deps = [ 67 | ":phase", 68 | ], 69 | ) 70 | 71 | py_library( 72 | name = "autoensemble_phase", 73 | srcs = ["autoensemble_phase.py"], 74 | srcs_version = "PY3", 75 | visibility = ["//adanet/experimental:__subpackages__"], 76 | deps = [ 77 | ":phase", 78 | "//adanet/experimental/keras:ensemble_model", 79 | "//adanet/experimental/storages:in_memory_storage", 80 | "//adanet/experimental/storages:storage", 81 | "//adanet/experimental/work_units:keras_trainer_work_unit", 82 | "//adanet/experimental/work_units:work_unit", 83 | ], 84 | ) 85 | 86 | py_library( 87 | name = "repeat_phase", 88 | srcs = ["repeat_phase.py"], 89 | srcs_version = "PY3", 90 | visibility = ["//adanet/experimental:__subpackages__"], 91 | deps = [ 92 | ":phase", 93 | "//adanet/experimental/work_units:work_unit", 94 | ], 95 | ) 96 | -------------------------------------------------------------------------------- /adanet/experimental/phases/__init__.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """AdaNet ModelFlow phases.""" 16 | 17 | from adanet.experimental.phases.autoensemble_phase import AutoEnsemblePhase 18 | from adanet.experimental.phases.input_phase import InputPhase 19 | from adanet.experimental.phases.keras_trainer_phase import KerasTrainerPhase 20 | from adanet.experimental.phases.keras_tuner_phase import KerasTunerPhase 21 | from adanet.experimental.phases.repeat_phase import RepeatPhase 22 | 23 | 24 | __all__ = [ 25 | "AutoEnsemblePhase", 26 | "InputPhase", 27 | "KerasTrainerPhase", 28 | "KerasTunerPhase", 29 | "RepeatPhase", 30 | ] 31 | -------------------------------------------------------------------------------- /adanet/experimental/phases/input_phase.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A phase that provides datasets.""" 16 | 17 | from typing import Optional 18 | from adanet.experimental.phases.phase import DatasetProvider 19 | from adanet.experimental.phases.phase import Phase 20 | import tensorflow.compat.v2 as tf 21 | 22 | 23 | class InputPhase(DatasetProvider): 24 | """A phase that simply relays train and eval datasets.""" 25 | 26 | def __init__(self, train_dataset: tf.data.Dataset, 27 | eval_dataset: tf.data.Dataset): 28 | """Initializes an InputPhase. 29 | 30 | Args: 31 | train_dataset: A `tf.data.Dataset` for training. 32 | eval_dataset: A `tf.data.Dataset` for evaluation. 33 | """ 34 | 35 | self._train_dataset = train_dataset 36 | self._eval_dataset = eval_dataset 37 | 38 | def get_train_dataset(self) -> tf.data.Dataset: 39 | return self._train_dataset 40 | 41 | def get_eval_dataset(self) -> tf.data.Dataset: 42 | return self._eval_dataset 43 | 44 | def work_units(self, previous_phase: Optional[Phase]): 45 | return [] 46 | -------------------------------------------------------------------------------- /adanet/experimental/phases/keras_trainer_phase.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A phase in the AdaNet workflow.""" 16 | 17 | from typing import Callable, Iterable, Iterator, Union 18 | from adanet.experimental.phases.phase import DatasetProvider 19 | from adanet.experimental.phases.phase import ModelProvider 20 | from adanet.experimental.storages.in_memory_storage import InMemoryStorage 21 | from adanet.experimental.storages.storage import Storage 22 | from adanet.experimental.work_units.keras_trainer_work_unit import KerasTrainerWorkUnit 23 | from adanet.experimental.work_units.work_unit import WorkUnit 24 | import tensorflow.compat.v2 as tf 25 | 26 | 27 | class KerasTrainerPhase(DatasetProvider, ModelProvider): 28 | """Trains Keras models.""" 29 | 30 | def __init__(self, 31 | models: Union[Iterable[tf.keras.Model], 32 | Callable[[], Iterable[tf.keras.Model]]], 33 | storage: Storage = InMemoryStorage()): 34 | """Initializes a KerasTrainerPhase. 35 | 36 | Args: 37 | models: A list of `tf.keras.Model` instances or a list of callables that 38 | return `tf.keras.Model` instances. 39 | storage: A `Storage` instance. 40 | """ 41 | # TODO: Consume arbitary fit inputs. 42 | # Dataset should be wrapped inside a work unit. 43 | # For instance when you create KerasTrainer work unit the dataset is 44 | # encapsulated inside that work unit. 45 | # What if you want to run on different (parts of the) datasets 46 | # what if a work units consumes numpy arrays? 47 | super().__init__(storage) 48 | self._models = models 49 | 50 | def work_units(self, previous_phase: DatasetProvider) -> Iterator[WorkUnit]: 51 | self._train_dataset = previous_phase.get_train_dataset() 52 | self._eval_dataset = previous_phase.get_eval_dataset() 53 | models = self._models 54 | if callable(models): 55 | models = models() 56 | for model in models: 57 | yield KerasTrainerWorkUnit(model, self._train_dataset, self._eval_dataset, 58 | self._storage) 59 | 60 | def get_models(self) -> Iterable[tf.keras.Model]: 61 | return self._storage.get_models() 62 | 63 | def get_best_models(self, num_models) -> Iterable[tf.keras.Model]: 64 | return self._storage.get_best_models(num_models) 65 | 66 | def get_train_dataset(self) -> tf.data.Dataset: 67 | return self._train_dataset 68 | 69 | def get_eval_dataset(self) -> tf.data.Dataset: 70 | return self._eval_dataset 71 | -------------------------------------------------------------------------------- /adanet/experimental/phases/keras_tuner_phase.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A phase in the AdaNet workflow.""" 16 | 17 | import sys 18 | 19 | from typing import Callable, Iterable, Iterator, Union 20 | from adanet.experimental.phases.phase import DatasetProvider 21 | from adanet.experimental.phases.phase import ModelProvider 22 | from adanet.experimental.work_units.keras_tuner_work_unit import KerasTunerWorkUnit 23 | from adanet.experimental.work_units.work_unit import WorkUnit 24 | from kerastuner.engine.tuner import Tuner 25 | import tensorflow.compat.v2 as tf 26 | 27 | 28 | class KerasTunerPhase(DatasetProvider, ModelProvider): 29 | """Tunes Keras Model hyperparameters using the Keras Tuner.""" 30 | 31 | def __init__(self, tuner: Union[Callable[..., Tuner], Tuner], *search_args, 32 | **search_kwargs): 33 | """Initializes a KerasTunerPhase. 34 | 35 | Args: 36 | tuner: A `kerastuner.tuners.tuner.Tuner` instance or a callable that 37 | returns a `kerastuner.tuners.tuner.Tuner` instance. 38 | *search_args: Arguments to pass to the tuner search method. 39 | **search_kwargs: Keyword arguments to pass to the tuner search method. 40 | """ 41 | 42 | if callable(tuner): 43 | self._tuner = tuner() 44 | else: 45 | self._tuner = tuner 46 | self._search_args = search_args 47 | self._search_kwargs = search_kwargs 48 | 49 | def work_units(self, previous_phase: DatasetProvider) -> Iterator[WorkUnit]: 50 | self._train_dataset = previous_phase.get_train_dataset() 51 | self._eval_dataset = previous_phase.get_eval_dataset() 52 | yield KerasTunerWorkUnit( 53 | self._tuner, 54 | x=self._train_dataset, 55 | validation_data=self._eval_dataset, 56 | *self._search_args, 57 | **self._search_kwargs) 58 | 59 | # TODO: Find a better way to get all models than to pass in a 60 | # large number. 61 | def get_models(self) -> Iterable[tf.keras.Model]: 62 | return self._tuner.get_best_models(num_models=sys.maxsize) 63 | 64 | def get_best_models(self, num_models) -> Iterable[tf.keras.Model]: 65 | return self._tuner.get_best_models(num_models=num_models) 66 | 67 | def get_train_dataset(self) -> tf.data.Dataset: 68 | return self._train_dataset 69 | 70 | def get_eval_dataset(self) -> tf.data.Dataset: 71 | return self._eval_dataset 72 | -------------------------------------------------------------------------------- /adanet/experimental/phases/phase.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A phase in the AdaNet workflow.""" 16 | 17 | import abc 18 | 19 | from typing import Iterable, Iterator, Optional 20 | from adanet.experimental.storages.in_memory_storage import InMemoryStorage 21 | from adanet.experimental.storages.storage import Storage 22 | from adanet.experimental.work_units.work_unit import WorkUnit 23 | import tensorflow.compat.v2 as tf 24 | 25 | 26 | class Phase(abc.ABC): 27 | """A stage in a linear workflow.""" 28 | 29 | def __init__(self, storage: Storage = InMemoryStorage()): 30 | self._storage = storage 31 | 32 | # TODO: Find a better way to ensure work_units only gets called 33 | # once per phase. 34 | @abc.abstractmethod 35 | def work_units(self, previous_phase: Optional['Phase']) -> Iterator[WorkUnit]: 36 | pass 37 | 38 | 39 | class DatasetProvider(Phase, abc.ABC): 40 | """An interface for a phase that produces datasets.""" 41 | 42 | def __init__(self, storage: Storage = InMemoryStorage()): 43 | """Initializes a Phase. 44 | 45 | Args: 46 | storage: A `Storage` instance. 47 | """ 48 | 49 | super().__init__(storage) 50 | self._train_dataset = None 51 | self._eval_dataset = None 52 | 53 | @abc.abstractmethod 54 | def get_train_dataset(self) -> tf.data.Dataset: 55 | """Returns the dataset for train data.""" 56 | pass 57 | 58 | @abc.abstractmethod 59 | def get_eval_dataset(self) -> tf.data.Dataset: 60 | """Returns the dataset for eval data.""" 61 | pass 62 | 63 | 64 | class ModelProvider(Phase, abc.ABC): 65 | """An interface for a phase that produces models.""" 66 | 67 | @abc.abstractmethod 68 | def get_models(self) -> Iterable[tf.keras.Model]: 69 | """Returns the models produced by this phase.""" 70 | pass 71 | 72 | @abc.abstractmethod 73 | def get_best_models(self, num_models: int = 1) -> Iterable[tf.keras.Model]: 74 | """Returns the `k` best models produced by this phase.""" 75 | pass 76 | 77 | -------------------------------------------------------------------------------- /adanet/experimental/phases/repeat_phase.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A phase that repeats its inner phases.""" 16 | 17 | from typing import Callable, Iterable, Iterator, List 18 | from adanet.experimental.phases.phase import DatasetProvider 19 | from adanet.experimental.phases.phase import ModelProvider 20 | from adanet.experimental.phases.phase import Phase 21 | from adanet.experimental.work_units.work_unit import WorkUnit 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | class RepeatPhase(DatasetProvider, ModelProvider): 26 | """A phase that repeats its inner phases.""" 27 | 28 | def __init__(self, 29 | phase_factory: List[Callable[..., Phase]], 30 | repetitions: int): 31 | self._phase_factory = phase_factory 32 | self._repetitions = repetitions 33 | self._final_phase = None 34 | """Initializes a RepeatPhase. 35 | 36 | Args: 37 | phase_factory: A list of callables that return `Phase` instances. 38 | repetitions: Number of times to repeat the phases in the phase factory. 39 | """ 40 | 41 | def work_units(self, previous_phase: DatasetProvider) -> Iterator[WorkUnit]: 42 | for _ in range(self._repetitions): 43 | # Each repetition, the "first" previous phase is the one preceeding the 44 | # repeat phase itself. 45 | prev_phase = previous_phase 46 | for phase in self._phase_factory: 47 | phase = phase() 48 | for work_unit in phase.work_units(prev_phase): 49 | yield work_unit 50 | prev_phase = phase 51 | self._final_phase = prev_phase 52 | 53 | def get_train_dataset(self) -> tf.data.Dataset: 54 | if not isinstance(self._final_phase, DatasetProvider): 55 | raise NotImplementedError( 56 | 'The last phase in repetition does not provide datasets.') 57 | return self._final_phase.get_train_dataset() 58 | 59 | def get_eval_dataset(self) -> tf.data.Dataset: 60 | if not isinstance(self._final_phase, DatasetProvider): 61 | raise NotImplementedError( 62 | 'The last phase in repetition does not provide datasets.') 63 | return self._final_phase.get_eval_dataset() 64 | 65 | def get_models(self) -> Iterable[tf.keras.Model]: 66 | if not isinstance(self._final_phase, ModelProvider): 67 | raise NotImplementedError( 68 | 'The last phase in repetition does not provide models.') 69 | return self._final_phase.get_models() 70 | 71 | def get_best_models(self, num_models=1) -> Iterable[tf.keras.Model]: 72 | if not isinstance(self._final_phase, ModelProvider): 73 | raise NotImplementedError( 74 | 'The last phase in repetition does not provide models.') 75 | return self._final_phase.get_best_models(num_models) 76 | -------------------------------------------------------------------------------- /adanet/experimental/schedulers/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet schedulers. 3 | # **HIGHLY EXPERIMENTAL AND SUBJECT TO CHANGE** 4 | 5 | licenses(["notice"]) # Apache 2.0 6 | 7 | exports_files(["LICENSE"]) 8 | 9 | py_library( 10 | name = "schedulers", 11 | srcs = ["__init__.py"], 12 | srcs_version = "PY3", 13 | visibility = ["//adanet:__subpackages__"], 14 | deps = [":in_process_scheduler"], 15 | ) 16 | 17 | py_library( 18 | name = "scheduler", 19 | srcs = ["scheduler.py"], 20 | srcs_version = "PY3", 21 | visibility = ["//adanet:__subpackages__"], 22 | deps = [ 23 | "//adanet/experimental/work_units:work_unit", 24 | ], 25 | ) 26 | 27 | py_library( 28 | name = "in_process_scheduler", 29 | srcs = ["in_process_scheduler.py"], 30 | srcs_version = "PY3", 31 | visibility = ["//adanet:__subpackages__"], 32 | deps = [ 33 | ":scheduler", 34 | "//adanet/experimental/work_units:work_unit", 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /adanet/experimental/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """AdaNet ModelFlow schedulers.""" 16 | 17 | from adanet.experimental.schedulers.in_process_scheduler import InProcessScheduler 18 | 19 | 20 | __all__ = [ 21 | "InProcessScheduler", 22 | ] 23 | -------------------------------------------------------------------------------- /adanet/experimental/schedulers/in_process_scheduler.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """An in process scheduler for managing AdaNet phases.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from typing import Iterator 22 | 23 | from adanet.experimental.schedulers import scheduler 24 | from adanet.experimental.work_units.work_unit import WorkUnit 25 | 26 | 27 | class InProcessScheduler(scheduler.Scheduler): 28 | """A scheduler that executes in a single process.""" 29 | 30 | def schedule(self, work_units: Iterator[WorkUnit]): 31 | """Schedules and execute work units in a single process. 32 | 33 | Args: 34 | work_units: An iterator that yields `WorkUnit` instances. 35 | """ 36 | 37 | for work_unit in work_units: 38 | work_unit.execute() 39 | -------------------------------------------------------------------------------- /adanet/experimental/schedulers/scheduler.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A scheduler for managing AdaNet phases.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import abc 22 | from typing import Iterator 23 | 24 | from adanet.experimental.work_units.work_unit import WorkUnit 25 | 26 | 27 | class Scheduler(abc.ABC): 28 | """Abstract interface for a scheduler to be used in ModelFlow pipelines.""" 29 | 30 | @abc.abstractmethod 31 | def schedule(self, work_units: Iterator[WorkUnit]): 32 | """Schedules and executes work units. 33 | 34 | Args: 35 | work_units: An iterator that yields `WorkUnit` instances. 36 | """ 37 | pass 38 | -------------------------------------------------------------------------------- /adanet/experimental/storages/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet storages. 3 | # **HIGHLY EXPERIMENTAL AND SUBJECT TO CHANGE** 4 | 5 | licenses(["notice"]) # Apache 2.0 6 | 7 | exports_files(["LICENSE"]) 8 | 9 | py_library( 10 | name = "storages", 11 | srcs = ["__init__.py"], 12 | srcs_version = "PY3", 13 | visibility = ["//adanet:__subpackages__"], 14 | deps = [":in_memory_storage"], 15 | ) 16 | 17 | py_library( 18 | name = "storage", 19 | srcs = ["storage.py"], 20 | srcs_version = "PY3", 21 | visibility = ["//adanet:__subpackages__"], 22 | deps = [ 23 | ], 24 | ) 25 | 26 | py_library( 27 | name = "in_memory_storage", 28 | srcs = ["in_memory_storage.py"], 29 | srcs_version = "PY3", 30 | visibility = ["//adanet:__subpackages__"], 31 | deps = [ 32 | ":storage", 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /adanet/experimental/storages/__init__.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """AdaNet ModelFlow storages.""" 16 | 17 | from adanet.experimental.storages.in_memory_storage import InMemoryStorage 18 | 19 | 20 | __all__ = [ 21 | "InMemoryStorage", 22 | ] 23 | -------------------------------------------------------------------------------- /adanet/experimental/storages/in_memory_storage.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A storage for persisting results and managing stage.""" 16 | 17 | import heapq 18 | 19 | from typing import List 20 | from adanet.experimental.storages.storage import ModelContainer 21 | from adanet.experimental.storages.storage import Storage 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | class InMemoryStorage(Storage): 26 | """In memory storage for testing-only. 27 | 28 | Uses a priority queue under the hood to sort the models according to their 29 | score. 30 | 31 | Currently the only supported score is 'loss'. 32 | """ 33 | 34 | def __init__(self): 35 | self._model_containers = [] 36 | 37 | def save_model(self, model_container: ModelContainer): 38 | """Stores a model. 39 | 40 | Args: 41 | model_container: A `ModelContainer` instance. 42 | """ 43 | # We use a counter since heappush will compare on the second item in the 44 | # tuple in the case of a tie in the first item comparison. This is for the 45 | # off chance that two models have the same loss. 46 | heapq.heappush(self._model_containers, model_container) 47 | 48 | def get_models(self) -> List[tf.keras.Model]: 49 | """Returns all stored models.""" 50 | return [c.model for c in self._model_containers] 51 | 52 | def get_best_models(self, num_models: int = 1) -> List[tf.keras.Model]: 53 | """Returns the top `num_models` stored models in descending order.""" 54 | return [c.model 55 | for c in heapq.nsmallest(num_models, self._model_containers)] 56 | 57 | def get_model_metrics(self) -> List[List[float]]: 58 | """Returns the metrics for all stored models.""" 59 | return [c.metrics for c in self._model_containers] 60 | -------------------------------------------------------------------------------- /adanet/experimental/storages/storage.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A storage for persisting results and managing stage.""" 16 | 17 | import abc 18 | 19 | from typing import Iterable, List 20 | import tensorflow.compat.v2 as tf 21 | 22 | 23 | class ModelContainer: 24 | """A container for a model and its metadata.""" 25 | 26 | def __init__(self, score: float, model: tf.keras.Model, metrics: List[float]): 27 | self.score = score 28 | self.model = model 29 | self.metrics = metrics 30 | 31 | def __eq__(self, other: 'ModelContainer'): 32 | return self.score == other.score 33 | 34 | def __lt__(self, other: 'ModelContainer'): 35 | return self.score < other.score 36 | 37 | 38 | class Storage(abc.ABC): 39 | """A storage for persisting results and managing state.""" 40 | 41 | @abc.abstractmethod 42 | def save_model(self, model_container: ModelContainer): 43 | """Stores a model and its metadata.""" 44 | # TODO: How do we enforce that save_model is called only once per 45 | # model? 46 | pass 47 | 48 | @abc.abstractmethod 49 | def get_models(self) -> Iterable[tf.keras.Model]: 50 | """Returns all stored models.""" 51 | pass 52 | 53 | @abc.abstractmethod 54 | def get_best_models(self, num_models: int = 1) -> Iterable[tf.keras.Model]: 55 | """Returns the top `num_models` stored models in descending order.""" 56 | pass 57 | 58 | @abc.abstractmethod 59 | def get_model_metrics(self) -> Iterable[Iterable[float]]: 60 | """Returns the metrics for all stored models.""" 61 | pass 62 | -------------------------------------------------------------------------------- /adanet/experimental/work_units/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet work units. 3 | # **HIGHLY EXPERIMENTAL AND SUBJECT TO CHANGE** 4 | 5 | licenses(["notice"]) # Apache 2.0 6 | 7 | exports_files(["LICENSE"]) 8 | 9 | py_library( 10 | name = "work_units", 11 | srcs = ["__init__.py"], 12 | srcs_version = "PY3", 13 | visibility = ["//adanet:__subpackages__"], 14 | deps = [ 15 | ":keras_trainer_work_unit", 16 | ":keras_tuner_work_unit", 17 | ], 18 | ) 19 | 20 | py_library( 21 | name = "work_unit", 22 | srcs = ["work_unit.py"], 23 | srcs_version = "PY3", 24 | visibility = ["//adanet/experimental:__subpackages__"], 25 | deps = [ 26 | ], 27 | ) 28 | 29 | py_library( 30 | name = "keras_trainer_work_unit", 31 | srcs = ["keras_trainer_work_unit.py"], 32 | srcs_version = "PY3", 33 | visibility = ["//adanet/experimental:__subpackages__"], 34 | deps = [ 35 | ":work_unit", 36 | "//adanet/experimental/storages:storage", 37 | ], 38 | ) 39 | 40 | py_library( 41 | name = "keras_tuner_work_unit", 42 | srcs = ["keras_tuner_work_unit.py"], 43 | srcs_version = "PY3", 44 | visibility = ["//adanet/experimental:__subpackages__"], 45 | deps = [ 46 | ":work_unit", 47 | ], 48 | ) 49 | -------------------------------------------------------------------------------- /adanet/experimental/work_units/__init__.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """AdaNet ModelFlow work units.""" 16 | 17 | from adanet.experimental.work_units.keras_trainer_work_unit import KerasTrainerWorkUnit 18 | from adanet.experimental.work_units.keras_tuner_work_unit import KerasTunerWorkUnit 19 | 20 | 21 | __all__ = [ 22 | "KerasTrainerWorkUnit", 23 | "KerasTunerWorkUnit", 24 | ] 25 | -------------------------------------------------------------------------------- /adanet/experimental/work_units/keras_trainer_work_unit.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A work unit for training, evaluating, and saving a Keras model.""" 16 | 17 | import os 18 | import time 19 | 20 | from adanet.experimental.storages.storage import ModelContainer 21 | from adanet.experimental.storages.storage import Storage 22 | from adanet.experimental.work_units import work_unit 23 | import tensorflow.compat.v2 as tf 24 | 25 | 26 | class KerasTrainerWorkUnit(work_unit.WorkUnit): 27 | """Trains, evaluates, and saves a Keras model.""" 28 | 29 | def __init__(self, model: tf.keras.Model, 30 | train_dataset: tf.data.Dataset, 31 | eval_dataset: tf.data.Dataset, 32 | storage: Storage, 33 | tensorboard_base_dir: str = '/tmp'): 34 | self._model = model 35 | self._train_dataset = train_dataset 36 | self._eval_dataset = eval_dataset 37 | self._storage = storage 38 | self._tensorboard_base_dir = tensorboard_base_dir 39 | 40 | # TODO: Allow better customization of TensorBoard log_dir. 41 | def execute(self): 42 | log_dir = os.path.join(self._tensorboard_base_dir, str(int(time.time()))) 43 | tensorboard = tf.keras.callbacks.TensorBoard(log_dir=log_dir, 44 | update_freq='batch') 45 | if self._model.trainable: 46 | self._model.fit(self._train_dataset, callbacks=[tensorboard]) 47 | else: 48 | print('Skipping training since model.trainable set to false.') 49 | results = self._model.evaluate(self._eval_dataset, callbacks=[tensorboard]) 50 | # If the model was compiled with metrics, the results is a list of loss + 51 | # metric values. If the model was compiled without metrics, it is a loss 52 | # scalar. 53 | if not isinstance(results, list): 54 | results = [results] 55 | self._storage.save_model(ModelContainer(results[0], self._model, results)) 56 | -------------------------------------------------------------------------------- /adanet/experimental/work_units/keras_tuner_work_unit.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A work unit for training, evaluating, and saving a Keras model.""" 16 | 17 | import os 18 | import time 19 | 20 | from adanet.experimental.work_units import work_unit 21 | from kerastuner.engine.tuner import Tuner 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | class KerasTunerWorkUnit(work_unit.WorkUnit): 26 | """Trains, evaluates and saves a tuned Keras model.""" 27 | 28 | def __init__(self, tuner: Tuner, *search_args, **search_kwargs): 29 | self._tuner = tuner 30 | self._search_args = search_args 31 | self._search_kwargs = search_kwargs 32 | 33 | # TODO: Allow better customization of TensorBoard log_dir. 34 | def execute(self): 35 | log_dir = os.path.join('/tmp', str(int(time.time()))) 36 | tensorboard = tf.keras.callbacks.TensorBoard(log_dir=log_dir, 37 | update_freq='batch') 38 | # We don't need to eval and store, because the Tuner does it for us. 39 | self._tuner.search(callbacks=[tensorboard], *self._search_args, 40 | **self._search_kwargs) 41 | -------------------------------------------------------------------------------- /adanet/experimental/work_units/work_unit.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A work unit for an AdaNet scheduler.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import abc 22 | 23 | 24 | class WorkUnit(abc.ABC): 25 | 26 | @abc.abstractproperty 27 | def execute(self): 28 | pass 29 | -------------------------------------------------------------------------------- /adanet/modelflow_test.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 The AdaNet Authors. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Test ModelFlow imports.""" 16 | 17 | import adanet.experimental as adanet 18 | import tensorflow.compat.v2 as tf 19 | 20 | 21 | class ModelFlowTest(tf.test.TestCase): 22 | 23 | def test_public(self): 24 | self.assertIsNotNone(adanet.controllers.SequentialController) 25 | self.assertIsNotNone(adanet.keras.EnsembleModel) 26 | self.assertIsNotNone(adanet.keras.MeanEnsemble) 27 | self.assertIsNotNone(adanet.keras.WeightedEnsemble) 28 | self.assertIsNotNone(adanet.keras.ModelSearch) 29 | self.assertIsNotNone(adanet.phases.AutoEnsemblePhase) 30 | self.assertIsNotNone(adanet.phases.InputPhase) 31 | self.assertIsNotNone(adanet.phases.KerasTrainerPhase) 32 | self.assertIsNotNone(adanet.phases.KerasTunerPhase) 33 | self.assertIsNotNone(adanet.phases.RepeatPhase) 34 | self.assertIsNotNone(adanet.schedulers.InProcessScheduler) 35 | self.assertIsNotNone(adanet.storages.InMemoryStorage) 36 | self.assertIsNotNone(adanet.work_units.KerasTrainerWorkUnit) 37 | self.assertIsNotNone(adanet.work_units.KerasTunerWorkUnit) 38 | 39 | if __name__ == "__main__": 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /adanet/pip_package/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | licenses(["notice"]) # Apache 2.0 17 | 18 | package(default_visibility = ["//visibility:private"]) 19 | 20 | sh_binary( 21 | name = "build_pip_package", 22 | srcs = ["build_pip_package.sh"], 23 | data = [ 24 | "//adanet", 25 | "//adanet/examples", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /adanet/pip_package/PIP.md: -------------------------------------------------------------------------------- 1 | 15 | # Creating the adanet pip package using Linux 16 | 17 | This requires Python, Bazel and Git. (And TensorFlow for testing the package.) 18 | 19 | ### Activate virtualenv 20 | 21 | Install virtualenv if it's not installed already: 22 | 23 | ```shell 24 | ~$ sudo apt-get install python-virtualenv 25 | ``` 26 | 27 | Create a virtual environment for the package creation: 28 | 29 | ```shell 30 | ~$ virtualenv --system-site-packages adanet_env 31 | ``` 32 | 33 | And activate it: 34 | 35 | ```shell 36 | ~$ source ~/adanet_env/bin/activate # bash, sh, ksh, or zsh 37 | ~$ source ~/adanet_env/bin/activate.csh # csh or tcsh 38 | ``` 39 | 40 | ### Clone the adanet repository. 41 | 42 | ```shell 43 | (adanet_env)~$ git clone https://github.com/tensorflow/adanet && cd adanet 44 | ``` 45 | 46 | ### Build adanet pip packaging script 47 | 48 | To build a pip package for adanet: 49 | 50 | ```shell 51 | (adanet_env)~/adanet$ bazel build //adanet/pip_package:build_pip_package 52 | ``` 53 | 54 | ### Create the adanet pip package 55 | 56 | ```shell 57 | (adanet_env)~/adanet$ bazel-bin/adanet/pip_package/build_pip_package /tmp/adanet_pkg 58 | ``` 59 | 60 | ### Install and test the pip package (optional) 61 | 62 | Run the following command to install the pip package: 63 | 64 | ```shell 65 | (adanet_env)~/adanet$ pip install /tmp/adanet_pkg/*.whl 66 | ``` 67 | 68 | Finally try importing `adanet` in Python outside the cloned directory: 69 | 70 | ```shell 71 | (adanet_env)~/adanet$ cd ~ 72 | (adanet_env)~$ python -c "import adanet" 73 | ``` 74 | 75 | ### De-activate the virtualenv 76 | 77 | ```shell 78 | (adanet_env)~/$ deactivate 79 | ``` 80 | -------------------------------------------------------------------------------- /adanet/pip_package/build_pip_package.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # This script should be run from the repo root. 18 | 19 | set -e 20 | set -o pipefail 21 | 22 | die() { 23 | printf >&2 '%s\n' "$1" 24 | exit 1 25 | } 26 | 27 | function main() { 28 | if [ $# -lt 1 ] ; then 29 | die "ERROR: no destination dir provided" 30 | fi 31 | 32 | DEST=$1 33 | TMPDIR=$(mktemp -d -t XXXXXXXXXXXXXXXXXXXXXXX_adanet_pip_pkg) 34 | RUNFILES="bazel-bin/adanet/pip_package/build_pip_package.runfiles/org_adanet" 35 | 36 | echo $(date) : "=== Using tmpdir: ${TMPDIR}" 37 | 38 | bazel build //adanet/pip_package:build_pip_package 39 | 40 | if [ ! -d bazel-bin/adanet ]; then 41 | echo `pwd` 42 | die "ERROR: Could not find bazel-bin. Did you run from the build root?" 43 | fi 44 | 45 | cp "adanet/pip_package/setup.py" "${TMPDIR}" 46 | cp "adanet/pip_package/setup.cfg" "${TMPDIR}" 47 | cp "LICENSE" "${TMPDIR}/LICENSE.txt" 48 | cp -R "${RUNFILES}/adanet" "${TMPDIR}" 49 | 50 | pushd ${TMPDIR} 51 | rm -f MANIFEST 52 | 53 | echo $(date) : "=== Building universal python wheel in $PWD" 54 | python setup.py bdist_wheel --universal >/dev/null 55 | mkdir -p ${DEST} 56 | cp dist/* ${DEST} 57 | popd 58 | rm -rf ${TMPDIR} 59 | echo $(date) : "=== Output wheel files are in: ${DEST}" 60 | } 61 | 62 | main "$@" 63 | -------------------------------------------------------------------------------- /adanet/pip_package/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_file = LICENSE.txt 3 | -------------------------------------------------------------------------------- /adanet/pip_package/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # Lint as: python3 16 | """Setup for pip package.""" 17 | 18 | from adanet import version 19 | import setuptools 20 | 21 | # Can't import the module during setup.py. 22 | # Use execfile to find __version__. 23 | with open('adanet/version.py') as in_file: 24 | exec(in_file.read()) 25 | 26 | REQUIRED_PACKAGES = [ 27 | 'absl-py>=0.7,<1.0', 28 | 'six>=1.11,<2.0', 29 | 'numpy>=1.15,<2.0', 30 | 'nose>=1.3,<2.0', 31 | 'rednose>=1.3,<2.0', 32 | 'coverage>=4.5,<5.0', 33 | 'protobuf>=3.6,<4.0', 34 | 'mock>=3.0,<4.0', 35 | ] 36 | 37 | setuptools.setup( 38 | name='adanet', # Automatic: adanet, etc. Case insensitive. 39 | version=version.__version__.replace('-', ''), 40 | description=( 41 | 'adanet is a lightweight and scalable TensorFlow AutoML framework for ' 42 | 'training and deploying adaptive neural networks using the AdaNet ' 43 | 'algorithm [Cortes et al. ICML 2017](https://arxiv.org/abs/1607.01097).' 44 | ), 45 | long_description='', 46 | url='https://github.com/tensorflow/adanet', 47 | author='Google LLC', 48 | install_requires=REQUIRED_PACKAGES, 49 | packages=setuptools.find_packages(), 50 | # PyPI package information. 51 | classifiers=[ 52 | 'Development Status :: 4 - Beta', 53 | 'Intended Audience :: Developers', 54 | 'Intended Audience :: Education', 55 | 'Intended Audience :: Science/Research', 56 | 'License :: OSI Approved :: Apache Software License', 57 | 'Programming Language :: Python :: 3', 58 | 'Programming Language :: Python :: 3.4', 59 | 'Programming Language :: Python :: 3.5', 60 | 'Programming Language :: Python :: 3.6', 61 | 'Topic :: Scientific/Engineering', 62 | 'Topic :: Scientific/Engineering :: Mathematics', 63 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 64 | 'Topic :: Software Development', 65 | 'Topic :: Software Development :: Libraries', 66 | 'Topic :: Software Development :: Libraries :: Python Modules', 67 | ], 68 | license='Apache 2.0', 69 | keywords=('tensorflow machine learning automl module subgraph framework ' 70 | 'ensemble neural network adaptive metalearning'), 71 | ) 72 | -------------------------------------------------------------------------------- /adanet/replay/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet replay. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | py_library( 9 | name = "replay", 10 | srcs = ["__init__.py"], 11 | visibility = ["//adanet:__subpackages__"], 12 | deps = [ 13 | ], 14 | ) 15 | -------------------------------------------------------------------------------- /adanet/replay/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Defines mechanisms for deterministically replaying an AdaNet model search.""" 15 | 16 | # TODO: Add more detailed documentation. 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import json 23 | import os 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class Config(object): # pylint: disable=g-classes-have-attributes 29 | # pyformat: disable 30 | """Defines how to deterministically replay an AdaNet model search. 31 | 32 | Specifically, it reconstructs the previous model and trains its components 33 | in the correct order without performing any search. 34 | 35 | Args: 36 | best_ensemble_indices: A list of the best ensemble indices (one per 37 | iteration). 38 | 39 | Returns: 40 | An :class:`adanet.replay.Config` instance. 41 | """ 42 | # pyformat: enable 43 | 44 | def __init__(self, best_ensemble_indices=None): 45 | self._best_ensemble_indices = best_ensemble_indices 46 | 47 | @property 48 | def best_ensemble_indices(self): 49 | """The best ensemble indices per iteration.""" 50 | return self._best_ensemble_indices 51 | 52 | def get_best_ensemble_index(self, iteration_number): 53 | """Returns the best ensemble index given an iteration number.""" 54 | # If we are provided the list 55 | if (self._best_ensemble_indices 56 | and iteration_number < len(self._best_ensemble_indices)): 57 | return self._best_ensemble_indices[iteration_number] 58 | 59 | return None 60 | 61 | 62 | __all__ = ["Config"] 63 | -------------------------------------------------------------------------------- /adanet/subnetwork/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Core AdaNet subnetwork logic. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | py_library( 9 | name = "subnetwork", 10 | srcs = ["__init__.py"], 11 | visibility = ["//adanet:__subpackages__"], 12 | deps = [ 13 | ":generator", 14 | ":report", 15 | ], 16 | ) 17 | 18 | py_library( 19 | name = "generator", 20 | srcs = ["generator.py"], 21 | deps = [ 22 | "@six_archive//:six", 23 | ], 24 | ) 25 | 26 | py_test( 27 | name = "generator_test", 28 | srcs = ["generator_test.py"], 29 | deps = [ 30 | ":generator", 31 | "//adanet/tf_compat", 32 | "@absl_py//absl/testing:parameterized", 33 | ], 34 | ) 35 | 36 | py_library( 37 | name = "report", 38 | srcs = ["report.py"], 39 | deps = [ 40 | "//adanet/tf_compat", 41 | "@six_archive//:six", 42 | ], 43 | ) 44 | 45 | py_test( 46 | name = "report_test", 47 | srcs = ["report_test.py"], 48 | deps = [ 49 | ":report", 50 | "@absl_py//absl/testing:parameterized", 51 | ], 52 | ) 53 | -------------------------------------------------------------------------------- /adanet/subnetwork/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Low-level APIs for defining custom subnetworks and search spaces.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from adanet.subnetwork.generator import Builder 22 | from adanet.subnetwork.generator import Generator 23 | from adanet.subnetwork.generator import SimpleGenerator 24 | from adanet.subnetwork.generator import Subnetwork 25 | from adanet.subnetwork.generator import TrainOpSpec 26 | from adanet.subnetwork.report import MaterializedReport 27 | from adanet.subnetwork.report import Report 28 | 29 | __all__ = [ 30 | "Builder", 31 | "Generator", 32 | "MaterializedReport", 33 | "Report", 34 | "SimpleGenerator", 35 | "Subnetwork", 36 | "TrainOpSpec", 37 | ] 38 | -------------------------------------------------------------------------------- /adanet/subnetwork/report_test.py: -------------------------------------------------------------------------------- 1 | """Test AdaNet single graph subnetwork implementation. 2 | 3 | Copyright 2018 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from adanet.subnetwork.report import Report 24 | import tensorflow.compat.v2 as tf 25 | 26 | # pylint: disable=g-direct-tensorflow-import 27 | from tensorflow.python.eager import context 28 | from tensorflow.python.framework import test_util 29 | # pylint: enable=g-direct-tensorflow-import 30 | 31 | 32 | class ReportTest(parameterized.TestCase, tf.test.TestCase): 33 | 34 | # pylint: disable=g-long-lambda 35 | @parameterized.named_parameters( 36 | { 37 | "testcase_name": "empty", 38 | "hparams": {}, 39 | "attributes": lambda: {}, 40 | "metrics": lambda: {}, 41 | }, { 42 | "testcase_name": "non_empty", 43 | "hparams": { 44 | "hoo": 1 45 | }, 46 | "attributes": lambda: { 47 | "aoo": tf.constant(1) 48 | }, 49 | "metrics": lambda: { 50 | "moo": (tf.constant(1), tf.constant(1)) 51 | }, 52 | }, { 53 | "testcase_name": "non_tensor_update_op", 54 | "hparams": { 55 | "hoo": 1 56 | }, 57 | "attributes": lambda: { 58 | "aoo": tf.constant(1) 59 | }, 60 | "metrics": lambda: { 61 | "moo": (tf.constant(1), tf.no_op()) 62 | }, 63 | }) 64 | # pylint: enable=g-long-lambda 65 | @test_util.run_in_graph_and_eager_modes 66 | def test_new(self, hparams, attributes, metrics): 67 | with context.graph_mode(): 68 | _ = tf.constant(0) # Just to have a non-empty graph. 69 | report = Report( 70 | hparams=hparams, attributes=attributes(), metrics=metrics()) 71 | self.assertEqual(hparams, report.hparams) 72 | self.assertEqual( 73 | self.evaluate(attributes()), self.evaluate(report.attributes)) 74 | self.assertEqual(self.evaluate(metrics()), self.evaluate(report.metrics)) 75 | 76 | @test_util.run_in_graph_and_eager_modes 77 | def test_drop_non_scalar_metric(self): 78 | """Tests b/118632346.""" 79 | 80 | hparams = {"hoo": 1} 81 | attributes = {"aoo": tf.constant(1)} 82 | metrics = { 83 | "moo1": (tf.constant(1), tf.constant(1)), 84 | "moo2": (tf.constant([1, 1]), tf.constant([1, 1])), 85 | } 86 | want_metrics = metrics.copy() 87 | del want_metrics["moo2"] 88 | with self.test_session(): 89 | report = Report(hparams=hparams, attributes=attributes, metrics=metrics) 90 | self.assertEqual(hparams, report.hparams) 91 | self.assertEqual(attributes, report.attributes) 92 | self.assertEqual(want_metrics, report.metrics) 93 | 94 | @parameterized.named_parameters( 95 | { 96 | "testcase_name": "tensor_hparams", 97 | "hparams": { 98 | "hoo": tf.constant(1) 99 | }, 100 | "attributes": {}, 101 | "metrics": {}, 102 | }, { 103 | "testcase_name": "non_tensor_attributes", 104 | "hparams": {}, 105 | "attributes": { 106 | "aoo": 1, 107 | }, 108 | "metrics": {}, 109 | }, { 110 | "testcase_name": "non_tuple_metrics", 111 | "hparams": {}, 112 | "attributes": {}, 113 | "metrics": { 114 | "moo": tf.constant(1) 115 | }, 116 | }, { 117 | "testcase_name": "one_item_tuple_metrics", 118 | "hparams": {}, 119 | "attributes": {}, 120 | "metrics": { 121 | "moo": (tf.constant(1),) 122 | }, 123 | }) 124 | @test_util.run_in_graph_and_eager_modes 125 | def test_new_errors(self, hparams, attributes, metrics): 126 | with self.assertRaises(ValueError): 127 | Report(hparams=hparams, attributes=attributes, metrics=metrics) 128 | 129 | 130 | if __name__ == "__main__": 131 | tf.test.main() 132 | -------------------------------------------------------------------------------- /adanet/tf_compat/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # AdaNet-TensorFlow compatibility logic. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | py_library( 9 | name = "tf_compat", 10 | srcs = ["__init__.py"], 11 | visibility = [ 12 | "//adanet:__subpackages__", 13 | ], 14 | deps = [ 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /adanet/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the version string.""" 16 | 17 | __version__ = u"0.9.0" 18 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Minimal makefile for Sphinx documentation 16 | # 17 | 18 | # You can set these variables from the command line. 19 | SPHINXOPTS = 20 | SPHINXBUILD = sphinx-build 21 | SPHINXPROJ = AdaNet 22 | SOURCEDIR = source 23 | BUILDDIR = build 24 | 25 | # Put it first so that "make" without argument is like "make help". 26 | help: 27 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 28 | 29 | .PHONY: help Makefile 30 | 31 | # Catch-all target: route all unknown targets to Sphinx using the new 32 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 33 | %: Makefile 34 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=AdaNet 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.0,<3.0 2 | sphinx==1.8.* 3 | sphinx_rtd_theme==0.4.* 4 | recommonmark==0.5.* 5 | -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | .wy-nav-side { 2 | border-right: 1px solid #dedede; 3 | } 4 | 5 | .wy-side-nav-search { 6 | background-color: #fafafa; 7 | } 8 | 9 | .wy-side-nav-search>a { 10 | color: #404040; 11 | } 12 | 13 | .wy-side-nav-search>div.version { 14 | color: #808080; 15 | } 16 | -------------------------------------------------------------------------------- /docs/source/adanet.distributed.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | adanet.distributed 5 | ============================== 6 | 7 | .. automodule:: adanet.distributed 8 | .. currentmodule:: adanet.distributed 9 | 10 | :hidden:`PlacementStrategy` 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autoclass:: PlacementStrategy 14 | :members: 15 | 16 | :hidden:`ReplicationStrategy` 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | .. autoclass:: ReplicationStrategy 20 | :members: 21 | 22 | :hidden:`RoundRobinStrategy` 23 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | .. autoclass:: RoundRobinStrategy 26 | :members: 27 | -------------------------------------------------------------------------------- /docs/source/adanet.ensemble.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | adanet.ensemble 5 | ============================== 6 | 7 | 8 | .. automodule:: adanet.ensemble 9 | .. currentmodule:: adanet.ensemble 10 | 11 | Ensembles 12 | --------------- 13 | 14 | Interfaces and containers for defining ensembles. 15 | 16 | :hidden:`Ensemble` 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | .. autoclass:: Ensemble 20 | :members: 21 | 22 | :hidden:`ComplexityRegularized` 23 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | .. autoclass:: ComplexityRegularized 26 | :members: 27 | 28 | :hidden:`MeanEnsemble` 29 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 30 | 31 | .. autoclass:: MeanEnsemble 32 | :members: 33 | 34 | :hidden:`MixtureWeightType` 35 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 36 | 37 | .. autoclass:: MixtureWeightType 38 | :members: 39 | 40 | :hidden:`WeightedSubnetwork` 41 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 42 | 43 | .. autoclass:: WeightedSubnetwork 44 | :members: 45 | 46 | Ensemblers 47 | --------------- 48 | 49 | Ensemble learning definitions. 50 | 51 | :hidden:`Ensembler` 52 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 53 | 54 | .. autoclass:: Ensembler 55 | :members: 56 | 57 | :hidden:`ComplexityRegularizedEnsembler` 58 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 59 | 60 | .. autoclass:: ComplexityRegularizedEnsembler 61 | :members: 62 | 63 | :hidden:`MeanEnsembler` 64 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 65 | 66 | .. autoclass:: MeanEnsembler 67 | :members: 68 | 69 | :hidden:`TrainOpSpec` 70 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 71 | 72 | .. autoclass:: TrainOpSpec 73 | :members: 74 | 75 | Strategies 76 | --------------- 77 | 78 | Ensemble strategies for grouping subnetworks. 79 | 80 | :hidden:`Strategy` 81 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 82 | 83 | .. autoclass:: Strategy 84 | :members: 85 | 86 | :hidden:`SoloStrategy` 87 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 88 | 89 | .. autoclass:: SoloStrategy 90 | :members: 91 | 92 | :hidden:`GrowStrategy` 93 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 94 | 95 | .. autoclass:: GrowStrategy 96 | :members: 97 | 98 | :hidden:`AllStrategy` 99 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 100 | 101 | .. autoclass:: AllStrategy 102 | :members: 103 | 104 | :hidden:`Candidate` 105 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 106 | 107 | .. autoclass:: Candidate 108 | :members: 109 | -------------------------------------------------------------------------------- /docs/source/adanet.replay.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | adanet.replay 5 | ============================== 6 | 7 | 8 | .. automodule:: adanet.replay 9 | .. currentmodule:: adanet.replay 10 | 11 | :hidden:`Config` 12 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: Config 15 | :members: 16 | -------------------------------------------------------------------------------- /docs/source/adanet.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | adanet 5 | ============== 6 | 7 | .. automodule:: adanet 8 | .. currentmodule:: adanet 9 | 10 | Estimators 11 | --------------- 12 | 13 | High-level APIs for training, evaluating, predicting, and serving AdaNet model. 14 | 15 | :hidden:`AutoEnsembleEstimator` 16 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 | 18 | .. autoclass:: AutoEnsembleEstimator 19 | :members: 20 | :show-inheritance: 21 | :inherited-members: 22 | 23 | :hidden:`AutoEnsembleSubestimator` 24 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 25 | 26 | .. autoclass:: AutoEnsembleSubestimator 27 | :members: 28 | :show-inheritance: 29 | :inherited-members: 30 | 31 | :hidden:`AutoEnsembleTPUEstimator` 32 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 33 | 34 | .. autoclass:: AutoEnsembleTPUEstimator 35 | :members: 36 | :show-inheritance: 37 | :inherited-members: 38 | 39 | :hidden:`Estimator` 40 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 41 | 42 | .. autoclass:: Estimator 43 | :members: 44 | :show-inheritance: 45 | :inherited-members: 46 | 47 | :hidden:`TPUEstimator` 48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 49 | 50 | .. autoclass:: TPUEstimator 51 | :members: 52 | :show-inheritance: 53 | :inherited-members: 54 | 55 | Evaluator 56 | --------------- 57 | 58 | Measures :class:`adanet.Ensemble` performance on a given dataset. 59 | 60 | :hidden:`Evaluator` 61 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 62 | 63 | .. autoclass:: Evaluator 64 | :members: 65 | 66 | Keras 67 | --------------- 68 | 69 | **Experimental** Keras API for training, evaluating, predicting, and serving 70 | AdaNet models. 71 | 72 | :hidden:`AutoEnsemble` 73 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 74 | 75 | .. autoclass:: AutoEnsemble 76 | :members: 77 | :show-inheritance: 78 | :inherited-members: 79 | 80 | :hidden:`Model` 81 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 82 | 83 | .. autoclass:: Model 84 | :members: 85 | 86 | Summary 87 | --------------- 88 | 89 | Extends :mod:`tf.summary` to power AdaNet's TensorBoard integration. 90 | 91 | :hidden:`Summary` 92 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 93 | 94 | .. autoclass:: Summary 95 | :members: 96 | 97 | ReportMaterializer 98 | --------------- 99 | 100 | :hidden:`ReportMaterializer` 101 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 102 | 103 | .. autoclass:: ReportMaterializer 104 | :members: 105 | -------------------------------------------------------------------------------- /docs/source/adanet.subnetwork.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | adanet.subnetwork 5 | ============================== 6 | 7 | 8 | .. automodule:: adanet.subnetwork 9 | .. currentmodule:: adanet.subnetwork 10 | 11 | Generators 12 | --------------- 13 | 14 | Interfaces and containers for defining subnetworks, search spaces, and search algorithms. 15 | 16 | :hidden:`Subnetwork` 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | .. autoclass:: Subnetwork 20 | :members: 21 | 22 | :hidden:`TrainOpSpec` 23 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | .. autoclass:: TrainOpSpec 26 | :members: 27 | 28 | :hidden:`Builder` 29 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 30 | 31 | .. autoclass:: Builder 32 | :members: 33 | :show-inheritance: 34 | 35 | :hidden:`Generator` 36 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 37 | 38 | .. autoclass:: Generator 39 | :members: 40 | :show-inheritance: 41 | 42 | Reports 43 | --------------- 44 | 45 | Containers for metadata about trained subnetworks. 46 | 47 | :hidden:`Report` 48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 49 | 50 | .. autoclass:: Report 51 | :members: 52 | 53 | :hidden:`MaterializedReport` 54 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 55 | 56 | .. autoclass:: MaterializedReport 57 | :members: 58 | -------------------------------------------------------------------------------- /docs/source/algorithm.md: -------------------------------------------------------------------------------- 1 | # Algorithm 2 | 3 | ## Neural architecture search 4 | 5 | AutoML is a family of techniques and algorithms seeking to automatically solve 6 | supervised learning tasks. Recently, researchers in AutoML have investigated 7 | whether we can automate learning the structure of a neural network for a given 8 | dataset, automating a task that requires significant domain expertise. This 9 | subdomain known as neural architecture search has seen advances in the 10 | state-of-the-art using reinforcement learning 11 | [[Zoph et al. '17](https://arxiv.org/abs/1707.07012)], evolutionary strategies 12 | [[Real et al., '17](https://arxiv.org/abs/1802.01548)], and gradient-based 13 | methods [[Liu et al., '18](https://arxiv.org/abs/1806.09055)] to learn neural 14 | network substructures. However, in these papers, the high-level structure of the 15 | network generally remains user defined. 16 | 17 | ![Two candidate ensembles](./assets/candidates.png "Two candidate ensembles.") 18 | 19 | > This illustration shows the algorithm’s incremental construction of a 20 | > fully-connected neural network. The input layer is indicated in blue, the 21 | > output layer in green. Units in the yellow block are added at the first 22 | > iteration while units in purple are added at the second iteration. Two 23 | > candidate extensions of the architecture are considered at the third iteration 24 | > (shown in red): (a) a two-layer extension; (b) a three-layer extension. Here, 25 | > a line between two blocks of units indicates that these blocks are 26 | > fully-connected. 27 | 28 | ## Neural networks are ensembles 29 | 30 | Ensembles of neural networks have shown remarkable performance in domains such 31 | as natural language processing, image recognition, and many others. The two 32 | composing techniques are interesting in their own rights: ensemble techniques 33 | have a rich history and theoretical understanding, while neural networks provide 34 | a general framework for solving complex tasks across many domains at scale. 35 | 36 | Coincidentally, an ensemble of neural networks whose outputs are linearly 37 | combined is also a neural network. With that definition in mind, we seek to 38 | answer the question: Can we learn a neural network architecture as an ensemble 39 | of subnetworks? And can we adaptively learn such an ensemble with fewer 40 | trainable parameters and that performs better than any single neural network 41 | trained end-to-end? 42 | 43 | ## Adaptive architecture search 44 | 45 | Our algorithm for performing adaptive neural architecture search is AdaNet 46 | [[Cortes et al., ICML '17](https://arxiv.org/abs/1607.01097)], which iteratively 47 | grows an ensemble of neural networks while providing learning guarantees. It is 48 | *adaptive* because at each iteration the candidate subnetworks are generated and 49 | trained based on the current state of the neural network. 50 | 51 | We show this algorithm can in fact learn a neural network (ensemble) that 52 | achieves state of the art results across several datasets. We also show how this 53 | algorithm is complementary with the neural architecture search research 54 | mentioned earlier, as it learns to combine these substructures in a principled 55 | manner to achieve these results. 56 | 57 | ## The AdaNet algorithm 58 | 59 | The AdaNet algorithm works as follows: a generator iteratively creates a set of 60 | candidate base learners to consider including in the final ensemble. How these 61 | base learners are trained is left completely up to the user, but generally they 62 | are trained to optimize some common loss function such as cross-entropy loss or 63 | mean squared error. At every iteration, the trained base learners then evaluated 64 | on their ability to minimize the AdaNet objective $F$, and the best one is 65 | included in the final ensemble. 66 | 67 | $$\begin{aligned} &F\left ( w \right ) = \frac{1}{m} \sum_{i=0}^{N-1} \Phi \left (\sum_{j=0}^{N-1}w_jh_j(x_i), y_i \right ) + \sum_{j=0}^{N-1} \left (\lambda r(h_j) + \beta \right )\left | w_j \right |\\ &\text{where }w_j \text{ is the weight of model } j \text{'s contribution to the ensemble,}\\ &h_j \text{ is model } j,\\ &\Phi \text{ is the loss function,}\\ &r(h_j) \text{ is model } j\text{'s complexity, and}\\ &\lambda \text{ and } \beta \text{ are tunable hyperparameters.} \end{aligned}$$ 68 | 69 | For every iteration after the first, the generator can generate neural networks 70 | based on the current state of the ensemble. This allows AdaNet to create complex 71 | structures or use advanced techniques for training candidates so that they will 72 | most significantly improve the ensemble. For an optimization example, knowledge 73 | distillation [[Hinton et al., '15](https://arxiv.org/abs/1503.02531)] is a 74 | technique that uses a teacher network's logits as the ground-truth when 75 | computing the loss of a trainable student network, and is shown to produce 76 | students that perform better than a identical network trained without. At every 77 | iteration, we can use the current ensemble as a teacher network and the 78 | candidates as students, to obtain base learners that perform better, and 79 | significantly improve the performance of the final ensemble. 80 | 81 | ## More information 82 | 83 | * [A step by step walkthrough of the AdaNet algorithm](https://docs.google.com/presentation/d/19NL1nI-MpwysxDsjSNmHbzLnr4NGacw6a8YGo88VA2Y/present?slide=id.g3d1c8865a3_0_0) 84 | -------------------------------------------------------------------------------- /docs/source/assets/adanet_tangram_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/docs/source/assets/adanet_tangram_logo.png -------------------------------------------------------------------------------- /docs/source/assets/adanet_tangram_logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/assets/candidates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/docs/source/assets/candidates.png -------------------------------------------------------------------------------- /docs/source/distributed.md: -------------------------------------------------------------------------------- 1 | # Distributed training 2 | 3 | 4 | 5 | AdaNet uses the same distributed training model as `tf.estimator.Estimator`. 6 | 7 | For training TensorFlow estimators on Google Cloud ML Engine, please see 8 | [this guide](https://cloud.google.com/blog/products/gcp/easy-distributed-training-with-tensorflow-using-tfestimatortrain-and-evaluate-on-cloud-ml-engine). 9 | 10 | 11 | ## Placement Strategies 12 | 13 | Given a cluster of worker and parameter servers, AdaNet will manage distributed 14 | training automatically. When creating an AdaNet `Estimator`, you can specify the 15 | `adanet.distributed.PlacementStrategy` to decide which subnetworks each worker 16 | will be responsible for training. 17 | 18 | ### Replication Strategy 19 | 20 | The default distributed training strategy is the same as the default 21 | `tf.estimator.Estimator` model: each worker will create the full training graph, 22 | including all subnetworks and ensembles, and optimize all the trainable 23 | parameters. Each variable will be randomly allocated to a parameter server to 24 | minimize bottlenecks in workers fetching them. Worker's updates will be sent to 25 | the parameter servers which apply the updates to their managed variables. 26 | 27 | ![Replication strategy](./assets/replication_strategy.svg) 28 | 29 | To learn more, see the implementation at 30 | [`adanet.distributed.ReplicationStrategy`](https://adanet.readthedocs.io/en/latest/adanet.distributed.html#replicationstrategy). 31 | 32 | ### Round Robin Stategy (experimental) 33 | 34 | A strategy that scales better than the Replication Strategy is the experimental 35 | Round Robin Stategy. Instead of replicating the same graph on each worker, 36 | AdaNet will round robin assign workers to train a single subnetwork. 37 | 38 | ![Round robin strategy](./assets/round_robin.svg) 39 | 40 | To learn more, see the implementation at 41 | [`adanet.distributed.RoundRobinStrategy`](https://adanet.readthedocs.io/en/latest/adanet.distributed.html#roundrobinstrategy). 42 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. AdaNet documentation master file, created by 2 | sphinx-quickstart on Fri Nov 30 18:10:08 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/tensorflow/adanet 7 | 8 | AdaNet documentation 9 | ================================== 10 | 11 | AdaNet is a TensorFlow framework for fast and flexible AutoML with learning guarantees. 12 | 13 | .. raw:: html 14 | 15 |
16 | adanet_tangram_logo

17 |
18 | 19 | **AdaNet** is a lightweight TensorFlow-based framework for automatically learning high-quality models with minimal expert intervention. AdaNet builds on recent AutoML efforts to be fast and flexible while providing learning guarantees. Importantly, AdaNet provides a general framework for not only learning a neural network architecture, but also for learning to ensemble to obtain even better models. 20 | 21 | This project is based on the `AdaNet algorithm`, presented in “`AdaNet: Adaptive Structural Learning of Artificial Neural Networks `_” at `ICML 2017 `_, for learning the structure of a neural network as an ensemble of subnetworks. 22 | 23 | AdaNet has the following goals: 24 | 25 | * **Ease of use**: Provide familiar APIs (e.g. Keras, Estimator) for training, evaluating, and serving models. 26 | * **Speed**: Scale with available compute and quickly produce high quality models. 27 | * **Flexibility**: Allow researchers and practitioners to extend AdaNet to novel subnetwork architectures, search spaces, and tasks. 28 | * **Learning guarantees**: Optimize an objective that offers theoretical learning guarantees. 29 | 30 | The following animation shows AdaNet adaptively growing an ensemble of neural networks. At each iteration, it measures the ensemble loss for each candidate, and selects the best one to move onto the next iteration. At subsequent iterations, the blue subnetworks are frozen, and only yellow subnetworks are trained: 31 | 32 | .. raw:: html 33 | 34 |
35 | adanet_animation

36 |
37 | 38 | AdaNet was first announced on the Google AI research blog: "`Introducing AdaNet: Fast and Flexible AutoML with Learning Guarantees `_". 39 | 40 | This is not an official Google product. 41 | 42 | .. toctree:: 43 | :glob: 44 | :maxdepth: 1 45 | :caption: Getting Started 46 | 47 | overview 48 | quick_start 49 | tutorials 50 | tensorboard 51 | distributed 52 | tpu 53 | 54 | .. toctree:: 55 | :glob: 56 | :maxdepth: 1 57 | :caption: Research 58 | 59 | algorithm 60 | theory 61 | 62 | 63 | .. toctree:: 64 | :glob: 65 | :maxdepth: 1 66 | :caption: Package Reference 67 | 68 | adanet 69 | adanet.ensemble 70 | adanet.subnetwork 71 | adanet.distributed 72 | 73 | Indices and tables 74 | ================== 75 | 76 | * :ref:`genindex` 77 | * :ref:`modindex` 78 | -------------------------------------------------------------------------------- /docs/source/overview.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | 4 | 5 | AdaNet is an extended implementation of [*AdaNet: Adaptive Structural Learning 6 | of Artificial Neural Networks* by [Cortes et al., ICML 7 | 2017]](https://arxiv.org/abs/1607.01097), an algorithm for iteratively learning 8 | both the **structure** and **weights** of a neural network as an **ensemble of 9 | subnetworks**. 10 | 11 | ## Ensembles of subnetworks 12 | 13 | In AdaNet, **ensembles** are first-class objects. Every model you train will be 14 | one form of an ensemble or another. An ensemble is composed of one or more 15 | **subnetworks** whose outputs are combined via an **ensembler**. 16 | 17 | ![Terminology.](./assets/terminology.svg "An ensemble is composed of subnetworks whose outputs are combined via an ensembler.") 18 | 19 | Ensembles are model-agnostic, meaning a subnetwork can be as complex as deep 20 | neural network, or as simple as an if-statement. All that matters is that for a 21 | given input tensor, the subnetworks' outputs can be combined by the ensembler to 22 | form a single prediction. 23 | 24 | ## Adaptive architecture search 25 | 26 |
27 | adanet_animation

28 |
29 | 30 | In the animation above, the AdaNet algorithm iteratively performs the following 31 | architecture search to grow an ensemble of subnetworks: 32 | 33 | 1. Generates a pool of candidate subnetworks. 34 | 1. Trains the subnetworks in whatever manner the user defines. 35 | 1. Evaluates the performance of the subnetworks as part of the ensemble, which 36 | is an ensemble of one at the first iteration. 37 | 1. Adds the subnetwork that most improves the ensemble performance to the 38 | ensemble for the next iteration. 39 | 1. Prunes the other subnetworks from the graph. 40 | 1. Adapts the subnetwork search space according to the information gained from 41 | the current iteration. 42 | 1. Moves onto the next iteration. 43 | 1. Repeats. 44 | 45 | ## Iteration lifecycle 46 | 47 | Each AdaNet **iteration** has the given lifecycle: 48 | 49 | ![AdaNet iteration lifecucle](./assets/lifecycle.svg "The lifecycle of an AdaNet iteration.") 50 | 51 | Each of these concepts has an associated Python object: 52 | 53 | * **Subnetwork Generator** and **Subnetwork** are defined in the 54 | [`adanet.subnetwork`](https://adanet.readthedocs.io/en/latest/adanet.subnetwork.html) 55 | package. 56 | * **Ensemble Strategy**, **Ensembler**, and **Ensemble** are defined in the 57 | [`adanet.ensemble`](https://adanet.readthedocs.io/en/latest/adanet.ensemble.html) 58 | package. 59 | 60 | ## Design 61 | 62 | AdaNet is designed to operate primarily inside of TensorFlow's computation 63 | graph. This allows it to efficiently utilize available resources like 64 | distributed compute, GPU, and TPU, using TensorFlow primitives. 65 | 66 | AdaNet provides a unique adaptive computation graph, which can support building 67 | models that create and remove ops and variables over time, but still have the 68 | optimizations and scalability of TensorFlow's graph-mode. This adaptive graph 69 | enables users to develop progressively growing models (e.g. boosting style), 70 | develop architecture search algorithms, and perform hyper-parameter tuning 71 | without needing to manage an external for-loop. 72 | 73 | ## Example ensembles 74 | 75 | Below are a few more examples of ensembles you can obtain with AdaNet depending 76 | on the **search space** you define. First, there is an ensemble composed of 77 | increasingly complex neural network subnetworks whose outputs are simply 78 | averaged: 79 | 80 | ![Ensemble of subnetworks with different complexities.](./assets/different_complexity_ensemble.svg "An ensemble is composed of subnetworks with different complexities.") 81 | 82 | Another common example is an ensemble learned on top of a shared embedding. 83 | Useful when the majority of the model parameters are an embedding of a feature. 84 | The individual subnetworks' predictions are combined using a learned linear 85 | combination: 86 | 87 | ![Subnetworks sharing a common embedding.](./assets/shared_embedding.svg "An ensemble is composed of subnetworks whose outputs are combined via an ensembler.") 88 | 89 | ## Quick start 90 | 91 | Now that you are familiar with AdaNet, you can explore our 92 | [quick start guide](./quick_start.md). 93 | -------------------------------------------------------------------------------- /docs/source/quick_start.md: -------------------------------------------------------------------------------- 1 | # Quick start 2 | 3 | 4 | 5 | If you are already using 6 | [`tf.estimator.Estimator`](https://www.tensorflow.org/guide/estimators), the 7 | fastest way to get up and running with AdaNet is to use the 8 | [`adanet.AutoEnsembleEstimator`](https://adanet.readthedocs.io/en/latest/adanet.html#autoensembleestimator). 9 | This estimator will automatically convert a list of estimators into subnetworks, 10 | and learn to ensemble them for you. 11 | 12 | ## Import AdaNet 13 | 14 | The first step is to import the `adanet` package: 15 | 16 | ```python 17 | import adanet 18 | ``` 19 | 20 | 21 | ## AutoEnsembleEstimator 22 | 23 | Next you will want to define which estimators you want to ensemble. For example, 24 | if you don't know if the best model a linear model, or a neural network, or some 25 | combination, then you can try using `tf.estimator.LinearEstimator` and 26 | `tf.estimator.DNNEstimator` as subnetworks: 27 | 28 | ```python 29 | import adanet 30 | import tensorflow as tf 31 | 32 | # Define the model head for computing loss and evaluation metrics. 33 | head = MultiClassHead(n_classes=10) 34 | 35 | # Feature columns define how to process examples. 36 | feature_columns = ... 37 | 38 | # Learn to ensemble linear and neural network models. 39 | estimator = adanet.AutoEnsembleEstimator( 40 | head=head, 41 | candidate_pool=lambda config: { 42 | "linear": 43 | tf.estimator.LinearEstimator( 44 | head=head, 45 | feature_columns=feature_columns, 46 | config=config, 47 | optimizer=...), 48 | "dnn": 49 | tf.estimator.DNNEstimator( 50 | head=head, 51 | feature_columns=feature_columns, 52 | config=config, 53 | optimizer=..., 54 | hidden_units=[1000, 500, 100])}, 55 | max_iteration_steps=50) 56 | 57 | estimator.train(input_fn=train_input_fn, steps=100) 58 | metrics = estimator.evaluate(input_fn=eval_input_fn) 59 | predictions = estimator.predict(input_fn=predict_input_fn) 60 | ``` 61 | 62 | The above code will train both the `linear` and `dnn` subnetworks in parallel, 63 | and will average their predictions. After `max_iteration_steps=100` steps, the 64 | best subnetwork will compose the ensemble according to its performance on the 65 | *training set*. 66 | 67 | ## Ensemble strategies 68 | 69 | The way AdaNet chooses which subnetworks to include in a candidate ensemble is 70 | via **ensemble strategies**. 71 | 72 | ### Grow strategy 73 | 74 | The default ensemble strategy is `adanet.ensemble.GrowStrategy` which will only 75 | select the subnetwork that most improved the ensemble's performance. The 76 | remaining subnetworks will be pruned from the graph. 77 | 78 | ### All strategy 79 | 80 | Suppose instead of only selecting the *single best* subnetwork, you want to 81 | ensemble *all* of the subnetworks, regardless of their individual performance. 82 | You can pass an instance of the `adanet.ensemble.AllStrategy` to the 83 | `adanet.AutoEnsembleEstimator` constructor: 84 | 85 | ```python 86 | estimator = adanet.AutoEnsembleEstimator( 87 | [...] 88 | ensemble_strategies=[adanet.ensemble.AllStrategy()] 89 | candidate_pool={ 90 | "linear": ..., 91 | "dnn": ..., 92 | }, 93 | [...]) 94 | ``` 95 | 96 | 97 | 98 | ## Tutorials 99 | 100 | To play with AdaNet in Colab notebooks, and learn about more advanced features 101 | like customizing AdaNet and training on TPU, see our 102 | [tutorials section](./tutorials). 103 | 104 | -------------------------------------------------------------------------------- /docs/source/tensorboard.md: -------------------------------------------------------------------------------- 1 | # TensorBoard 2 | 3 | [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) is 4 | AdaNet's UI. 5 | 6 | From TensorBoard, you can vizualize the performance of candidate ensembles and 7 | individual subnetworks over time, visualize their architectures, and monitor 8 | statics. 9 | -------------------------------------------------------------------------------- /docs/source/theory.md: -------------------------------------------------------------------------------- 1 | # Theory 2 | 3 | ## Focus on generalization 4 | 5 | Generalization error is what we really want to minimize when we train a model. 6 | Most algorithms minimize generalization error indirectly by minimizing a loss 7 | function that consists of a training loss term and additional penalty terms to 8 | discourage the models away from acquiring properties that are associated with 9 | overfitting (e.g., L1 weight norms, L2 weight norms). 10 | 11 | ## Rigorous trade-offs between training loss and complexity 12 | 13 | How do we know what model properties to avoid? Currently, these usually come 14 | from practical experience or industry-accepted best practices. While this has 15 | worked well so far, we would like to minimize the generalization error in a more 16 | principled way. 17 | 18 | AdaNet's approach is to minimize a theoretical upper bound on generalization 19 | error, proven in the DeepBoost paper 20 | [[Cortes et al. '14](https://ai.google/research/pubs/pub42856)]: 21 | 22 | $$R(f) \leq \widehat{R}_{S, \rho}(f) + \frac{4}{\rho} \sum_{k = 1}^{l} \big \| \mathbf{w} _k \big \|_1 \mathfrak{R}_m(\widetilde {\cal H}_k) + \widetilde O\Big(\frac{1}{\rho} \sqrt{\frac{\log l}{m}}\Big)$$ 23 | 24 | This generalization bound allows us to make an apples-to-apples comparison 25 | between the complexities of models in an ensemble and the overall training 26 | loss -- allowing us to design an algorithm that makes this trade-off in a 27 | rigorous manner. 28 | 29 | ## Other key insights 30 | 31 | * **Convex combinations can't hurt.** Given a set of already-performant and 32 | uncorrelated base learners, one can take a linear combination of them with 33 | weights that sum to 1 to obtain an ensemble that outperforms the best among 34 | those base learners. But even though this ensemble has more trainable 35 | parameters, it does not have a greater tendency to overfit. 36 | * **De-emphasize rather than discourage complex models.** If one combines a 37 | few base learners that are each selected from a different function class 38 | (e.g., neural networks of different depths and widths), one might expect the 39 | tendency to overfit to be similar to that of an ensemble comprised of base 40 | learners selected from the union of all the function classes. Remarkably, 41 | the DeepBoost bound shows that we can actually do better, as long as the 42 | final ensemble is a weighted average of model logits where each base 43 | learner's weight is inversely proportional to the Rademacher complexity of 44 | its function class, and all the weights in the logits layer sum to 1. 45 | Additionally, at training time, we don't have to discourage the trainer from 46 | learning complex models -- it is only when we consider the how much the 47 | model should contribute to the ensemble do we take the complexity of the 48 | model into account. 49 | * **Complexity is not just about the weights.** The Rademacher complexity of a 50 | neural network does not simply depend on the number of weights or the norm 51 | of its weights -- it also depends on the number of layers and how they are 52 | connected. An upper bound on the Rademacher complexity of neural networks 53 | can be expressed recursively 54 | [[Cortes et al. '17](https://arxiv.org/abs/1607.01097)], and applies to both 55 | fully-connected and convolutional neural networks, thus allowing us to 56 | compute the complexity upper-bounds of almost any neural network that can be 57 | expressed as a directed-acyclic graph of layers, including unconventional 58 | architectures such as those found by NASNet 59 | [[Zoph et al. '17](https://arxiv.org/abs/1707.07012)]. Rademacher complexity 60 | is also data-dependent, which means that the same neural network 61 | architecture can have different generalization behavior on different data 62 | sets. 63 | 64 | ## AdaNet loss function 65 | 66 | Using these insights, AdaNet seeks to minimize the generalization error more 67 | directly using this loss function: 68 | 69 | $$\begin{align*} &F\left ( w \right ) = \frac{1}{m} \sum_{i=1}^{m} \Phi \left (\sum_{j=1}^{N}w_jh_j(x_i), y_i \right ) + \sum_{j=1}^{N} \left (\lambda r(h_j) + \beta \right )\left | w_j \right |\\ &\text{where }w_j \text{ is the weight of model } j \text{'s contribution to the ensemble,}\\ &h_j \text{ is model } j,\\ &\Phi \text{ is the loss function,}\\ &r(h_j) \text{ is model } j\text{'s complexity, and}\\ &\lambda \text{ and } \beta \text{ are tunable hyperparameters.} \end{align*}$$ 70 | 71 | By minimizing this loss function, AdaNet is able to combine base learners of 72 | different complexities in a way that generalizes better than one might expect 73 | from the total size of the base learners. 74 | 75 | -------------------------------------------------------------------------------- /docs/source/tpu.md: -------------------------------------------------------------------------------- 1 | # TPU 2 | 3 | 4 | 5 | AdaNet officially supports TPU training, evaluation, and prediction via the 6 | [`adanet.TPUEstimator`](https://adanet.readthedocs.io/en/latest/adanet.html#tpuestimator). 7 | 8 | To get started, see our 9 | [Colab notebook on TPU](https://colab.research.google.com/github/tensorflow/adanet/blob/master/adanet/examples/tutorials/adanet_tpu.ipynb). 10 | 11 | -------------------------------------------------------------------------------- /docs/source/tutorials.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | ## Notebooks 4 | 5 | Play with AdaNet in our interactive 6 | [Colab notebooks available on GitHub](https://github.com/tensorflow/adanet/tree/master/adanet/examples/tutorials). 7 | 8 | 9 | ## Misc 10 | 11 | To learn more, please visit our [quick start guide](./quick_start.md). 12 | 13 | For more about the underlying algorithm, see the [algorithm](./algorithm.md) and 14 | [theory](./theory.md) pages. 15 | -------------------------------------------------------------------------------- /images/adanet_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/images/adanet_animation.gif -------------------------------------------------------------------------------- /images/adanet_tangram_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/images/adanet_tangram_logo.png -------------------------------------------------------------------------------- /oss_scripts/oss_pip_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 The AdaNet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | set -v # print commands as they're executed 18 | set -e # fail and exit on any command erroring 19 | 20 | : "${TF_VERSION:?}" 21 | 22 | if [[ "$TF_VERSION" == "tf-nightly" ]]; then 23 | pip install tf-nightly; 24 | elif [[ "$TF_VERSION" == "tf-nightly-2.0-preview" ]]; then 25 | pip install tf-nightly-2.0-preview; 26 | else 27 | pip install -q "tensorflow==$TF_VERSION" 28 | fi 29 | 30 | # Build adanet pip packaging script 31 | bazel build -c opt //... --local_resources 2048,.5,1.0 --force_python=PY3 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.7,<1.0 2 | six>=1.11,<2.0 3 | numpy>=1.15,<2.0 4 | nose>=1.3,<2.0 5 | rednose>=1.3,<2.0 6 | nose-parallel==0.3.1 7 | coverage>=4.5,<5.0 8 | protobuf>=3.6,<4.0 9 | mock>=3.0,<4.0 10 | pytype==2019.10.17 11 | typing==3.7.4.1 12 | keras-tuner>=1.0,<2.0 13 | -------------------------------------------------------------------------------- /research/improve_nas/README.md: -------------------------------------------------------------------------------- 1 | **This is not a Google product.** 2 | 3 | Improving Neural Architecture Search Image Classifiers via Ensemble Learning 4 | ============================ 5 | 6 |
7 |
8 | adanet_search_space

9 |
10 | 11 |
12 | ensemble

13 |
14 |
15 | 16 | 17 | ## Introduction 18 | We present an algorithm that can improve the performance of NASNet models by learning an ensemble of smaller models with minimal hyperparameter tuning. 19 | Interestingly, a simple ensemble of identical architectures trained independently with a uniform averaged output performs better than the baseline single large model. 20 | Conversely, our adaptive methods show performance gains for applications where we can afford to train ensemble sequentially. 21 | We were able to achieve near state-of-the-art results by using a combination of learning mixture weights and applying Adaptive Knowledge Distillation. 22 | 23 | This paper was done as a part of Google AI residency. 24 | 25 | 26 | ## Paper results 27 | 28 |
29 |
30 | ensemble_accuracy_cifar_10

31 | Accuracy of ensemble in CIFAR-10. 32 |
33 | 34 |
35 | ensemble_accuracy_cifar_100

36 | Accuracy of ensemble in CIFAR-100. 37 |
38 |
39 | 40 | Our experiments demonstrate that ensembles of subnetworks improve accuracy upon a single neural network with the same number of parameters. 41 | On CIFAR-10 our algorithm achieves error 2.26 and on CIFAR-100 it achieves error 14.58. 42 | To our knowledge, our technique achieves a new state-of-the-art on CIFAR-100 without using additional regularization or data augmentation (e.g., [Shake-Drop](https://arxiv.org/abs/1802.02375) or [AutoAugment](https://arxiv.org/abs/1805.09501)). 43 | 44 | 45 |
46 | accuracy_improvements_cifar10

47 |
48 | 49 |
50 | accuracy_improvements_cifar100

51 |
52 | 53 | 54 | 55 | 56 | This is the code accompanying paper: [Improving Neural Architecture Search Image Classifiers via Ensemble Learning](https://arxiv.org/abs/1903.06236), currently under review for ICML 2019. 57 | 58 | For instructions on running the code on google cloud, follow [sararob/adanet-ml-engine](https://github.com/sararob/adanet-ml-engine) 59 | 60 | 61 | ## Prerequisites 62 | Follow instructions in [sararob/adanet-ml-engine](https://github.com/sararob/adanet-ml-engine) to set up your Cloud project, install gcloud CLI, setup a storage bucket and run the job. 63 | 64 | 65 | ## Reproduction 66 | To reproduce a simple experiment (ensemble of 10 NASNet(6@768)), setup a few environment variables on your local machine: 67 | 68 | ``` 69 | export JOBNAME= # improvenastest1 70 | export REGION=us-east1; 71 | export MODULE=trainer.model; 72 | export PACKAGE_PATH=trainer/; 73 | export JOB_DIR=; # gs://improve_nas_bucket/test1 74 | ``` 75 | 76 | then go to `improve_nas` directory and run (still on your local machine): 77 | 78 | ``` 79 | gcloud ml-engine jobs submit training $JOBNAME --package-path trainer/ \ 80 | --module-name trainer.trainer --job-dir $JOB_DIR/$JOBNAME --region $REGION \ 81 | --config config.yaml -- --batch_size=1 --data_params=augmentation=basic,cutout=True \ 82 | --dataset=cifar10 --train_steps=10000000 \ 83 | --hparams="adanet_beta=0.0,adanet_lambda=0.0,boosting_iterations=10,force_grow=True,\ 84 | knowledge_distillation=none,generator=simple,learn_mixture_weights=False,\ 85 | initial_learning_rate=0.025,learning_rate_schedule=cosine,aux_head_weight=0.4,\ 86 | clip_gradients=5,data_format=NHWC,dense_dropout_keep_prob=1.0,drop_path_keep_prob=0.6,\ 87 | filter_scaling_rate=2.0,label_smoothing=0.1,model_version=cifar,num_cells=6,\ 88 | num_conv_filters=32,num_reduction_layers=2,optimizer=momentum,skip_reduction_layer_input=0,\ 89 | stem_multiplier=3.0,use_aux_head=False,weight_decay=0.0005" \ 90 | --save_summary_steps 10000 --save_checkpoints_secs 600 --model_dir=$JOB_DIR/$JOBNAME 91 | ``` 92 | 93 | To train mixture weights, set hparam `learn_mixture_weights=True`. 94 | To use knowledge distillation, set hparam `knowledge_distillation` to `adaptive` or `born_again`. 95 | Finally, to perform architecture search (dynamic generator) set hparam `generator` to `dynamic` and adjust `num_cells` and `num_conv_filters` to set the initial architecture. 96 | 97 | For testing, use `--config config_test.yaml` (uses only one GPU), change the eval steps `--eval_steps=1` and set the `--dataset=fake`. 98 | 99 | ## Cite 100 | 101 | ``` 102 | @article{macko2019improving, 103 | title={Improving Neural Architecture Search Image Classifiers via Ensemble Learning}, 104 | author={Macko, Vladimir and Weill, Charles and Mazzawi, Hanna and Gonzalvo, Javier}, 105 | journal={arXiv preprint arXiv:1903.06236}, 106 | year={2019} 107 | } 108 | ``` 109 | 110 | -------------------------------------------------------------------------------- /research/improve_nas/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | trainingInput: 16 | scaleTier: CUSTOM 17 | masterType: standard_p100 18 | workerType: standard_p100 19 | parameterServerType: standard_p100 20 | workerCount: 10 21 | parameterServerCount: 3 22 | runtimeVersion: '1.12' 23 | pythonVersion: '2.7' 24 | -------------------------------------------------------------------------------- /research/improve_nas/config_test.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | trainingInput: 16 | scaleTier: BASIC_GPU 17 | -------------------------------------------------------------------------------- /research/improve_nas/images/cif100_caption.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/research/improve_nas/images/cif100_caption.png -------------------------------------------------------------------------------- /research/improve_nas/images/cif10_caption.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/research/improve_nas/images/cif10_caption.png -------------------------------------------------------------------------------- /research/improve_nas/images/ensemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/research/improve_nas/images/ensemble.png -------------------------------------------------------------------------------- /research/improve_nas/images/ensemble_accuracy_cif10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/research/improve_nas/images/ensemble_accuracy_cif10.png -------------------------------------------------------------------------------- /research/improve_nas/images/ensemble_accuracy_cif100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/research/improve_nas/images/ensemble_accuracy_cif100.png -------------------------------------------------------------------------------- /research/improve_nas/images/search_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/adanet/0364cc46810ff3831b3e4a37125de862a28da9bd/research/improve_nas/images/search_space.png -------------------------------------------------------------------------------- /research/improve_nas/setup.py: -------------------------------------------------------------------------------- 1 | """Setup file. 2 | 3 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | REQUIRED_PACKAGES = ['tensorflow>=1.12', 21 | 'adanet==0.5.0'] 22 | 23 | setup( 24 | name='trainer', 25 | version='0.1', 26 | install_requires=REQUIRED_PACKAGES, 27 | packages=find_packages(), 28 | include_package_data=True, 29 | description='improve nas model' 30 | ) 31 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | py_binary( 4 | name = "trainer", 5 | srcs = ["trainer.py"], 6 | srcs_version = "PY3", 7 | deps = [ 8 | ":adanet_improve_nas", 9 | ":cifar10", 10 | ":fake_data", 11 | "@absl_py//absl/flags", 12 | ], 13 | ) 14 | 15 | py_library( 16 | name = "improve_nas", 17 | srcs = ["improve_nas.py"], 18 | srcs_version = "PY3", 19 | deps = [ 20 | ":nasnet", 21 | ":subnetwork_utils", 22 | "//adanet", 23 | ], 24 | ) 25 | 26 | py_test( 27 | name = "improve_nas_test", 28 | srcs = ["improve_nas_test.py"], 29 | srcs_version = "PY3", 30 | deps = [ 31 | ":fake_data", 32 | ":improve_nas", 33 | "//adanet", 34 | "@absl_py//absl/testing:parameterized", 35 | ], 36 | ) 37 | 38 | py_library( 39 | name = "adanet_improve_nas", 40 | srcs = ["adanet_improve_nas.py"], 41 | srcs_version = "PY3", 42 | deps = [ 43 | ":improve_nas", 44 | ":optimizer", 45 | "//adanet", 46 | ], 47 | ) 48 | 49 | py_test( 50 | name = "adanet_improve_nas_test", 51 | srcs = ["adanet_improve_nas_test.py"], 52 | srcs_version = "PY3", 53 | deps = [ 54 | ":adanet_improve_nas", 55 | ":fake_data", 56 | "//adanet", 57 | "@absl_py//absl/testing:parameterized", 58 | ], 59 | ) 60 | 61 | py_library( 62 | name = "subnetwork_utils", 63 | srcs = ["subnetwork_utils.py"], 64 | srcs_version = "PY3", 65 | ) 66 | 67 | py_library( 68 | name = "optimizer", 69 | srcs = ["optimizer.py"], 70 | srcs_version = "PY3", 71 | ) 72 | 73 | py_library( 74 | name = "cifar10", 75 | srcs = ["cifar10.py"], 76 | srcs_version = "PY3", 77 | deps = [ 78 | ":image_processing", 79 | ], 80 | ) 81 | 82 | py_test( 83 | name = "cifar10_test", 84 | srcs = ["cifar10_test.py"], 85 | srcs_version = "PY3", 86 | deps = [ 87 | ":cifar10", 88 | ], 89 | ) 90 | 91 | py_library( 92 | name = "cifar100", 93 | srcs = ["cifar100.py"], 94 | srcs_version = "PY3", 95 | deps = [ 96 | ":image_processing", 97 | ], 98 | ) 99 | 100 | py_test( 101 | name = "cifar100_test", 102 | srcs = ["cifar100_test.py"], 103 | srcs_version = "PY3", 104 | deps = [ 105 | ":cifar100", 106 | ], 107 | ) 108 | 109 | py_library( 110 | name = "fake_data", 111 | srcs = ["fake_data.py"], 112 | srcs_version = "PY3", 113 | deps = [ 114 | ], 115 | ) 116 | 117 | py_library( 118 | name = "nasnet_utils", 119 | srcs = ["nasnet_utils.py"], 120 | srcs_version = "PY3", 121 | ) 122 | 123 | py_library( 124 | name = "image_processing", 125 | srcs = ["image_processing.py"], 126 | srcs_version = "PY3", 127 | ) 128 | 129 | py_library( 130 | name = "nasnet", 131 | srcs = ["nasnet.py"], 132 | srcs_version = "PY3", 133 | deps = [ 134 | ":nasnet_utils", 135 | ], 136 | ) 137 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """Copyright 2019 The AdaNet Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/adanet_improve_nas_test.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """Tests for improve_nas. 3 | 4 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | https://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | from absl import flags 25 | from absl.testing import parameterized 26 | from adanet.research.improve_nas.trainer import adanet_improve_nas 27 | from adanet.research.improve_nas.trainer import fake_data 28 | import tensorflow.compat.v1 as tf 29 | 30 | 31 | class AdaNetQuetzalBuilderTest(parameterized.TestCase, tf.test.TestCase): 32 | 33 | @parameterized.named_parameters({ 34 | "testcase_name": "simple_generator", 35 | "hparams_string": ("optimizer=sgd,boosting_iterations=2,generator=simple," 36 | "initial_learning_rate=.1,use_aux_head=False," 37 | "num_cells=3,num_conv_filters=2,use_evaluator=False"), 38 | }, { 39 | "testcase_name": "dynamic_generator", 40 | "hparams_string": 41 | ("optimizer=sgd,boosting_iterations=1,generator=dynamic," 42 | "initial_learning_rate=.1,use_aux_head=False," 43 | "num_cells=3,num_conv_filters=2,use_evaluator=False"), 44 | }) 45 | def test_estimator(self, 46 | hparams_string, 47 | batch_size=1): 48 | """Structural test to make sure Estimator Builder works.""" 49 | 50 | seed = 42 51 | 52 | # Set up and clean test directory. 53 | model_dir = os.path.join(flags.FLAGS.test_tmpdir, 54 | "AdanetImproveNasBuilderTest") 55 | if tf.gfile.Exists(model_dir): 56 | tf.gfile.DeleteRecursively(model_dir) 57 | tf.gfile.MkDir(model_dir) 58 | 59 | data_provider = fake_data.FakeImageProvider(seed=seed) 60 | estimator_builder = adanet_improve_nas.Builder() 61 | hparams = estimator_builder.hparams( 62 | default_batch_size=3, hparams_string=hparams_string) 63 | run_config = tf.estimator.RunConfig( 64 | tf_random_seed=seed, model_dir=model_dir) 65 | _ = data_provider.get_input_fn( 66 | "train", 67 | tf.estimator.ModeKeys.TRAIN, 68 | batch_size=batch_size) 69 | test_input_fn = data_provider.get_input_fn( 70 | "test", 71 | tf.estimator.ModeKeys.EVAL, 72 | batch_size=batch_size) 73 | 74 | estimator = estimator_builder.estimator( 75 | data_provider=data_provider, 76 | run_config=run_config, 77 | hparams=hparams, 78 | train_steps=10, 79 | seed=seed) 80 | eval_metrics = estimator.evaluate(input_fn=test_input_fn, steps=1) 81 | 82 | self.assertGreater(eval_metrics["loss"], 0.0) 83 | 84 | 85 | if __name__ == "__main__": 86 | tf.test.main() 87 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/cifar100_test.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """Tests for cifar100 dataset. 3 | 4 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | https://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from adanet.research.improve_nas.trainer import cifar100 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class Cifar100Test(tf.test.TestCase): 27 | 28 | def _check_dimensions(self, partition): 29 | provider = cifar100.Provider(seed=4) 30 | input_fn = provider.get_input_fn( 31 | partition, tf.contrib.learn.ModeKeys.TRAIN, batch_size=3) 32 | data, labels = input_fn() 33 | self.assertIn(cifar100.FEATURES, data) 34 | features = data[cifar100.FEATURES] 35 | init = tf.group(tf.global_variables_initializer(), 36 | tf.local_variables_initializer()) 37 | with self.test_session() as sess: 38 | sess.run(init) 39 | self.assertEqual((3, 32, 32, 3), sess.run(features).shape) 40 | self.assertEqual((3, 1), sess.run(labels).shape) 41 | 42 | def test_read_cifar100(self): 43 | for partition in ["train", "test"]: 44 | self._check_dimensions(partition) 45 | 46 | def test_no_preprocess(self): 47 | provider = cifar100.Provider(seed=4) 48 | input_fn = provider.get_input_fn( 49 | "train", 50 | tf.contrib.learn.ModeKeys.TRAIN, 51 | batch_size=3, 52 | preprocess=False) 53 | data, label = input_fn() 54 | 55 | init = tf.group(tf.global_variables_initializer(), 56 | tf.local_variables_initializer()) 57 | with self.test_session() as sess: 58 | sess.run(init) 59 | self.assertAllEqual([220, 25, 47], sess.run(data["x"])[0][0][0]) 60 | self.assertAllEqual([[47], [5], [52]], sess.run(label)) 61 | 62 | def test_basic_preprocess(self): 63 | provider = cifar100.Provider( 64 | params_string="augmentation=basic", seed=4) 65 | input_fn = provider.get_input_fn( 66 | "train", 67 | tf.contrib.learn.ModeKeys.TRAIN, 68 | batch_size=3, 69 | preprocess=True) 70 | data, label = input_fn() 71 | 72 | init = tf.group(tf.global_variables_initializer(), 73 | tf.local_variables_initializer()) 74 | with self.test_session() as sess: 75 | sess.run(init) 76 | data_result = sess.run(data["x"]) 77 | self.assertEqual((3, 32, 32, 3), data_result.shape) 78 | self.assertAllEqual([0, 0, 0], data_result[0, 0, 0]) 79 | self.assertAlmostEqual(0.0, data_result[0, -1, 0, 0], places=3) 80 | self.assertAllEqual([[47], [5], [52]], sess.run(label)) 81 | 82 | 83 | if __name__ == "__main__": 84 | tf.test.main() 85 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/cifar10_test.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """Tests for cifar10 dataset. 3 | 4 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | https://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from adanet.research.improve_nas.trainer import cifar10 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class Cifar10Test(tf.test.TestCase): 27 | 28 | def _check_dimensions(self, partition): 29 | provider = cifar10.Provider(seed=4) 30 | input_fn = provider.get_input_fn( 31 | partition, tf.contrib.learn.ModeKeys.TRAIN, batch_size=3) 32 | data, labels = input_fn() 33 | self.assertIn(cifar10.FEATURES, data) 34 | features = data[cifar10.FEATURES] 35 | init = tf.group(tf.global_variables_initializer(), 36 | tf.local_variables_initializer()) 37 | with self.test_session() as sess: 38 | sess.run(init) 39 | self.assertEqual((3, 32, 32, 3), sess.run(features).shape) 40 | self.assertEqual((3, 1), sess.run(labels).shape) 41 | 42 | def test_read_cifar10(self): 43 | for partition in ["train", "test"]: 44 | self._check_dimensions(partition) 45 | 46 | def test_no_preprocess(self): 47 | provider = cifar10.Provider(seed=4) 48 | input_fn = provider.get_input_fn( 49 | "train", 50 | tf.contrib.learn.ModeKeys.TRAIN, 51 | batch_size=3, 52 | preprocess=False) 53 | data, label = input_fn() 54 | 55 | init = tf.group(tf.global_variables_initializer(), 56 | tf.local_variables_initializer()) 57 | with self.test_session() as sess: 58 | sess.run(init) 59 | data_result = sess.run(data["x"]) 60 | self.assertEqual((3, 32, 32, 3), data_result.shape) 61 | self.assertAllEqual([148, 141, 174], data_result[0][0][0]) 62 | self.assertAllEqual([[5], [9], [3]], sess.run(label)) 63 | 64 | def test_basic_preprocess(self): 65 | provider = cifar10.Provider( 66 | params_string="augmentation=basic", seed=4) 67 | input_fn = provider.get_input_fn( 68 | "train", 69 | tf.contrib.learn.ModeKeys.TRAIN, 70 | batch_size=3, 71 | preprocess=True) 72 | data, label = input_fn() 73 | 74 | init = tf.group(tf.global_variables_initializer(), 75 | tf.local_variables_initializer()) 76 | with self.test_session() as sess: 77 | sess.run(init) 78 | data_result = sess.run(data["x"]) 79 | self.assertEqual((3, 32, 32, 3), data_result.shape) 80 | self.assertAllEqual([0, 0, 0], data_result[0, 0, 0]) 81 | self.assertAlmostEqual(0.0, data_result[0, -1, 0, 0], places=3) 82 | self.assertAllEqual([[5], [9], [3]], sess.run(label)) 83 | 84 | 85 | if __name__ == "__main__": 86 | tf.test.main() 87 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/fake_data.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """Fake dataset for testing and debugging. 3 | 4 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | https://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class FakeImageProvider(object): 27 | """A fake image data provider.""" 28 | 29 | def __init__(self, 30 | num_examples=3, 31 | num_classes=3, 32 | image_dim=8, 33 | channels=1, 34 | seed=42): 35 | self._num_examples = num_examples 36 | self._num_classes = num_classes 37 | self._seed = seed 38 | self._channels = channels 39 | self._image_dim = image_dim 40 | 41 | def get_head(self, name=None): 42 | return tf.contrib.estimator.multi_class_head( 43 | self._num_classes, name=name, loss_reduction=tf.losses.Reduction.SUM) 44 | 45 | def _shape(self): 46 | return [self._image_dim, self._image_dim, self._channels] 47 | 48 | def get_input_fn(self, 49 | partition, 50 | mode, 51 | batch_size): 52 | """See `data.Provider` get_input_fn.""" 53 | 54 | del partition 55 | def input_fn(params=None): 56 | """Input_fn to return.""" 57 | 58 | del params # Unused. 59 | 60 | np.random.seed(self._seed) 61 | if mode == tf.estimator.ModeKeys.EVAL: 62 | np.random.seed(self._seed + 1) 63 | 64 | images = tf.to_float( 65 | tf.convert_to_tensor( 66 | np.random.rand(self._num_examples, *self._shape()))) 67 | labels = tf.convert_to_tensor( 68 | np.random.randint(0, high=2, size=(self._num_examples, 1))) 69 | dataset = tf.data.Dataset.from_tensor_slices(({"x": images}, labels)) 70 | if mode == tf.estimator.ModeKeys.TRAIN: 71 | dataset = dataset.repeat() 72 | dataset = dataset.batch(batch_size) 73 | iterator = dataset.make_one_shot_iterator() 74 | return iterator.get_next() 75 | 76 | return input_fn 77 | 78 | def get_feature_columns(self): 79 | feature_columns = [ 80 | tf.feature_column.numeric_column(key="x", shape=self._shape()), 81 | ] 82 | return feature_columns 83 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/image_processing.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """Image preprocessing and augmentation function for a single image. 3 | 4 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | https://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class PreprocessingType(object): 27 | """Type of preprocessing to be applied on the image. 28 | 29 | * `INCEPTION`: Preprocessing used in inception. 30 | * `BASIC`: Minimalistic preprocessing used in NasNet for cifar. 31 | 32 | """ 33 | INCEPTION = "inception" 34 | BASIC = "basic" 35 | 36 | 37 | def basic_augmentation(image, image_height, image_width, seed=None): 38 | """Augment image according to NasNet paper (random flip + random crop).""" 39 | 40 | # source: https://arxiv.org/pdf/1707.07012.pdf appendix A.1 41 | padding = 4 42 | image = tf.image.random_flip_left_right(image, seed=seed) 43 | 44 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 45 | image = tf.random_crop(image, [image_height, image_width, 3], seed=seed) 46 | return image 47 | 48 | 49 | def resize_and_normalize(image, height, width): 50 | """Convert image to float, resize and normalize to zero mean and [-1, 1].""" 51 | if image.dtype != tf.float32: 52 | # Rescale pixel values to float in interval [0.0, 1.0]. 53 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 54 | 55 | # Resize the image to the specified height and width. 56 | image = tf.expand_dims(image, 0) 57 | image = tf.image.resize_bilinear(image, [height, width], align_corners=False) 58 | image = tf.squeeze(image, [0]) 59 | # Rescale pixels to range [-0.5, 0.5]. 60 | image = tf.subtract(image, 0.5) 61 | # Rescale pixels to range [-1, 1]. 62 | image = tf.multiply(image, 2.0) 63 | return image 64 | 65 | 66 | def cutout(image, pad_size, replace=0, seed=None): 67 | """Apply cutout (https://arxiv.org/abs/1708.04552) to image. 68 | 69 | Forked from learning/brain/research/meta_architect/image/image_processing.py? 70 | l=1172&rcl=193953073 71 | 72 | Args: 73 | image: Image `Tensor` with shape [height, width, channels]. 74 | pad_size: The cutout shape will be at most [pad_size * 2, pad_size * 2]. 75 | replace: Value for replacing cutout values. 76 | seed: Random seed. 77 | 78 | Returns: 79 | Image `Tensor` with cutout applied. 80 | """ 81 | 82 | with tf.variable_scope("cutout"): 83 | image_height = tf.shape(image)[0] 84 | image_width = tf.shape(image)[1] 85 | image_depth = tf.shape(image)[2] 86 | 87 | # Sample the location in the image where the zero mask will be applied. 88 | cutout_center_height = tf.random_uniform( 89 | shape=[], minval=0, maxval=image_height, seed=seed, dtype=tf.int32) 90 | 91 | cutout_center_width = tf.random_uniform( 92 | shape=[], minval=0, maxval=image_width, seed=seed, dtype=tf.int32) 93 | 94 | lower_pad = tf.maximum(0, cutout_center_height - pad_size) 95 | upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) 96 | left_pad = tf.maximum(0, cutout_center_width - pad_size) 97 | right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) 98 | 99 | cutout_shape = [ 100 | image_height - (lower_pad + upper_pad), 101 | image_width - (left_pad + right_pad) 102 | ] 103 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] 104 | mask = tf.pad( 105 | tf.zeros(cutout_shape, dtype=image.dtype), 106 | padding_dims, 107 | constant_values=1) 108 | mask = tf.expand_dims(mask, -1) 109 | mask = tf.tile(mask, [1, 1, image_depth]) 110 | image = tf.where( 111 | tf.equal(mask, 0), 112 | tf.ones_like(image, dtype=image.dtype) * replace, image) 113 | return image 114 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/optimizer.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """Definition of optimizers and learning rate schedules. 3 | 4 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | https://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import abc 23 | import functools 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class LearningRateSchedule(object): 29 | """A learning rate decay schedule interface.""" 30 | 31 | __metaclass__ = abc.ABCMeta 32 | 33 | @abc.abstractmethod 34 | def apply(self, learning_rate): 35 | """Applies the learning rate decay schedule to the given learning rate. 36 | 37 | Args: 38 | learning_rate: Float `Tensor` learning rate. 39 | 40 | Returns: 41 | Float `Tensor` learning rate with applied decay schedule. 42 | """ 43 | 44 | 45 | class Constant(LearningRateSchedule): 46 | """A constant schedule.""" 47 | 48 | def apply(self, learning_rate): 49 | """See `LearningRateSchedule`.""" 50 | 51 | return learning_rate 52 | 53 | 54 | class Cosine(LearningRateSchedule): 55 | """Cosine.""" 56 | 57 | def __init__(self, decay_steps, alpha): 58 | """Returns a `Cosine` instance. 59 | 60 | Args: 61 | decay_steps: Number of steps to decay over. 62 | alpha: Minimum learning rate value as a fraction of learning_rate. 63 | 64 | Returns: 65 | A `Cosine` instance. 66 | """ 67 | 68 | self._decay_fn = functools.partial( 69 | tf.train.cosine_decay, decay_steps=decay_steps, alpha=alpha) 70 | 71 | def apply(self, learning_rate): 72 | """See `LearningRateSchedule`.""" 73 | 74 | # Start at -1 since we increment before reading. 75 | global_step = tf.get_variable("decay_step", initializer=-1, trainable=False) 76 | increment_op = tf.assign_add(global_step, 1) 77 | with tf.control_dependencies([increment_op]): 78 | learning_rate = self._decay_fn( 79 | learning_rate=learning_rate, global_step=global_step.read_value()) 80 | return learning_rate 81 | 82 | 83 | def fn_with_name(optimizer_name, 84 | learning_rate_schedule="constant", 85 | cosine_decay_steps=None): 86 | """Returns an optimizer_fn with the given name. 87 | 88 | Args: 89 | optimizer_name: Optimizer name string for identifying the optimizer. Either 90 | 'adagrad', 'adam', 'momentum', or 'sgd'. 91 | learning_rate_schedule: Type of learning rate schedule to use. Opened for 92 | future extensions. 93 | cosine_decay_steps: See `Cosine`. 94 | 95 | Returns: 96 | An optimizer_fn which takes a `learning_rate` scalar `Tensor` argument and 97 | returns an `Optimizer` instance. 98 | 99 | Raises: 100 | ValueError: If `optimizer_name` is invalid. 101 | """ 102 | 103 | optimizers = { 104 | "adagrad": tf.train.AdagradOptimizer, 105 | "adam": tf.train.AdamOptimizer, 106 | "lazy_adam": tf.contrib.opt.LazyAdamOptimizer, 107 | "momentum": functools.partial(tf.train.MomentumOptimizer, momentum=.9), 108 | "rmsprop": tf.train.RMSPropOptimizer, 109 | "sgd": tf.train.GradientDescentOptimizer, 110 | } 111 | optimizer_name = optimizer_name.lower() 112 | if optimizer_name not in optimizers: 113 | raise ValueError("Invalid optimizer '{}'".format(optimizer_name)) 114 | optimizer_fn = optimizers[optimizer_name] 115 | schedules = { 116 | "constant": 117 | Constant(), 118 | "cosine": 119 | Cosine(decay_steps=cosine_decay_steps, alpha=0.0), 120 | } 121 | schedule_name = learning_rate_schedule.lower() 122 | if schedule_name not in schedules: 123 | raise ValueError( 124 | "Invalid learning_rate_schedule '{}'".format(schedule_name)) 125 | schedule = schedules[schedule_name] 126 | 127 | def _optimizer_with_schedule(learning_rate): 128 | learning_rate = schedule.apply(learning_rate) 129 | optimizer = optimizer_fn(learning_rate) 130 | return optimizer, learning_rate 131 | return _optimizer_with_schedule 132 | -------------------------------------------------------------------------------- /research/improve_nas/trainer/subnetwork_utils.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """Definition of helpful functions to work with AdaNet subnetworks. 3 | 4 | Copyright 2019 The AdaNet Authors. All Rights Reserved. 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | https://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | import copy 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def capture_variables(fn): 26 | """Utility function that captures which tf variables were created by `fn`. 27 | 28 | This function encourages style that is easy to write, resonably easy to 29 | understand but against google codestyle. 30 | 31 | In general, you have an function `f` that takes some arguments (`a` and `b`) 32 | and returns some output. You may enclose it in lambda and get 33 | `fn == lambda: f(a,b)`, which is a function without arguments that does the 34 | same as `f`. 35 | 36 | This idiom makes variable management much easier and less error prone. Usable 37 | for prototyping or debugging. 38 | 39 | Args: 40 | fn: function with no arguments. 41 | 42 | Returns: 43 | tuple: First element of this touple is a list of tf variables created by 44 | fn, second is the actual output of fn 45 | 46 | """ 47 | vars_before_fn = tf.trainable_variables() 48 | fn_return = fn() 49 | vars_after_fn = tf.trainable_variables() 50 | fn_vars = list(set(vars_after_fn) - set(vars_before_fn)) 51 | return set(fn_vars), fn_return 52 | 53 | 54 | def copy_update(hparams, **kwargs): 55 | """Deep copy hparams with values updated by kwargs. 56 | 57 | This enables to use hparams in an immutable manner. 58 | Args: 59 | hparams: hyperparameters. 60 | **kwargs: keyword arguments to change in hparams. 61 | 62 | Returns: 63 | updated hyperparameters object. Change in this object is not propagated to 64 | the original hparams 65 | """ 66 | values = hparams.values() 67 | values.update(kwargs) 68 | values = copy.deepcopy(values) 69 | hp = tf.contrib.training.HParams(**values) 70 | return hp 71 | 72 | 73 | def get_persisted_value_from_ensemble(ensemble, key): 74 | """Return constant persisted tensor values from the previous subnetwork. 75 | 76 | Args: 77 | ensemble: Previous ensemble. 78 | key: Name of constant to get from eprsisted tensor. 79 | 80 | Returns: 81 | int|float value of the constant. 82 | """ 83 | previous_subnetwork = ensemble.weighted_subnetworks[-1].subnetwork 84 | persisted_tensor = previous_subnetwork.shared[key] 85 | return persisted_tensor 86 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | verbosity=3 3 | rednose=1 4 | with-coverage=1 5 | cover-package=adanet 6 | with-parallel=1 7 | --------------------------------------------------------------------------------