├── .github └── workflows │ └── ci.yml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── WORKSPACE ├── docs ├── .gitignore ├── Makefile ├── _static │ └── favicon.ico ├── api.rst ├── conf.py ├── ext │ ├── BUILD │ ├── link_tf_api.py │ └── link_tf_api_test.py ├── index.rst ├── modules.rst ├── references.bib └── requirements.txt ├── examples ├── BUILD ├── README.md ├── distributed_cifar10.ipynb ├── functional_mlp_mnist.py ├── little_gan_on_mnist.ipynb ├── mlp_on_mnist.ipynb ├── simple_mnist.py ├── simple_mnist_test.py └── vqvae_example.ipynb ├── readthedocs.yml ├── requirements-test.txt ├── requirements-tf.txt ├── requirements.txt ├── setup.py ├── sonnet ├── BUILD ├── __init__.py ├── distribute.py ├── functional.py ├── initializers.py ├── mixed_precision.py ├── nets │ ├── BUILD │ ├── __init__.py │ └── resnet.py ├── optimizers.py ├── pad.py ├── regularizers.py └── src │ ├── BUILD │ ├── __init__.py │ ├── axis_norm.py │ ├── axis_norm_test.py │ ├── base.py │ ├── base_test.py │ ├── batch_apply.py │ ├── batch_apply_test.py │ ├── batch_norm.py │ ├── batch_norm_test.py │ ├── bias.py │ ├── bias_test.py │ ├── build.py │ ├── build_defs.bzl │ ├── build_test.py │ ├── conformance │ ├── BUILD │ ├── __init__.py │ ├── api_test.py │ ├── build_test.py │ ├── checkpoint_test.py │ ├── checkpoints │ │ ├── BUILD │ │ ├── README.md │ │ ├── base_batch_norm_1x2x2x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── base_batch_norm_scale_offset_1x2x2x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── batch_norm_1x2x2x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── batch_norm_scale_offset_1x2x2x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── batch_norm_training_1x2x2x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── bias_3x3x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── cifar10_convnet_2x3_2x2_1x3x3x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv1d_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv1d_lstm_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv1d_transpose_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv2d_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv2d_lstm_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv2d_transpose_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv3d_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv3d_lstm_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── conv3d_transpose_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── cross_replica_batch_norm_1x2x2x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── depthwise_conv2d_3x3_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── dropout │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── ema_2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── embed_100_100 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── generate.py │ │ ├── group_norm_2_1x3x4 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── gru_1 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── instance_norm_1_1x3_2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── layer_norm_1_1x3_2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── linear_1x1 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── linear_nobias_1x1 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── lstm_1 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── lstm_8_projected_1 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── mean_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00002 │ │ │ ├── checkpoint-1.data-00001-of-00002 │ │ │ └── checkpoint-1.index │ │ ├── mlp_3x4x5_1x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── mlp_nobias_3x4x5_1x3 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── resnet50 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── sum_2x2 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00002 │ │ │ ├── checkpoint-1.data-00001-of-00002 │ │ │ └── checkpoint-1.index │ │ ├── trainable_state │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── unrolled_lstm_1 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── vanilla_rnn_8 │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── vqvae │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ ├── vqvae_ema_eval │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ │ └── vqvae_ema_train │ │ │ ├── checkpoint │ │ │ ├── checkpoint-1.data-00000-of-00001 │ │ │ └── checkpoint-1.index │ ├── copy_test.py │ ├── descriptors.py │ ├── descriptors_test.py │ ├── distribute_test.py │ ├── doctest_test.py │ ├── function_test.py │ ├── goldens.py │ ├── goldens_test.py │ ├── keras_test.py │ ├── optimizer_test.py │ ├── pickle_test.py │ ├── saved_model_test.py │ ├── tensorflow1_test.py │ └── xla_test.py │ ├── conv.py │ ├── conv_test.py │ ├── conv_transpose.py │ ├── conv_transpose_test.py │ ├── custom_getter.py │ ├── custom_getter_test.py │ ├── deferred.py │ ├── deferred_test.py │ ├── depthwise_conv.py │ ├── depthwise_conv_test.py │ ├── distribute │ ├── BUILD │ ├── __init__.py │ ├── distributed_batch_norm.py │ ├── distributed_batch_norm_test.py │ ├── replicator.py │ ├── replicator_test.py │ └── replicator_test_utils.py │ ├── dropout.py │ ├── dropout_test.py │ ├── embed.py │ ├── embed_test.py │ ├── functional │ ├── BUILD │ ├── __init__.py │ ├── haiku.py │ ├── haiku_test.py │ ├── jax.py │ ├── jax_test.py │ ├── optimizers.py │ ├── optimizers_test.py │ └── utils.py │ ├── group_norm.py │ ├── group_norm_test.py │ ├── initializers.py │ ├── initializers_test.py │ ├── leaky_clip_by_value.py │ ├── leaky_clip_by_value_test.py │ ├── linear.py │ ├── linear_test.py │ ├── metrics.py │ ├── metrics_test.py │ ├── mixed_precision.py │ ├── mixed_precision_test.py │ ├── moving_averages.py │ ├── moving_averages_test.py │ ├── nets │ ├── BUILD │ ├── __init__.py │ ├── cifar10_convnet.py │ ├── cifar10_convnet_test.py │ ├── dnc │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── control.py │ │ ├── control_test.py │ │ ├── read.py │ │ ├── read_test.py │ │ ├── util.py │ │ ├── util_test.py │ │ ├── write.py │ │ └── write_test.py │ ├── mlp.py │ ├── mlp_test.py │ ├── resnet.py │ ├── resnet_test.py │ ├── vqvae.py │ └── vqvae_test.py │ ├── once.py │ ├── once_test.py │ ├── optimizers │ ├── BUILD │ ├── __init__.py │ ├── adam.py │ ├── adam_test.py │ ├── momentum.py │ ├── momentum_test.py │ ├── optimizer_tests.py │ ├── optimizer_utils.py │ ├── rmsprop.py │ ├── rmsprop_test.py │ ├── sgd.py │ └── sgd_test.py │ ├── pad.py │ ├── pad_test.py │ ├── parallel_linear.py │ ├── parallel_linear_test.py │ ├── recurrent.py │ ├── recurrent_test.py │ ├── regularizers.py │ ├── regularizers_test.py │ ├── reshape.py │ ├── reshape_test.py │ ├── scale_gradient.py │ ├── scale_gradient_test.py │ ├── sequential.py │ ├── sequential_test.py │ ├── test_utils.py │ ├── types.py │ ├── utils.py │ └── utils_test.py └── test.sh /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - v2 7 | push: 8 | branches: 9 | - v2 10 | 11 | jobs: 12 | test-ubuntu: 13 | name: "pytest on ${{ matrix.python-version }} on ${{ matrix.os }}" 14 | runs-on: "${{ matrix.os }}" 15 | strategy: 16 | matrix: 17 | python-version: [3.9, '3.10', '3.11'] 18 | os: [ubuntu-latest] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install -r requirements.txt 29 | pip install -r requirements-tf.txt 30 | pip install -r requirements-test.txt 31 | pip install . 32 | - name: Test with pytest 33 | run: | 34 | pip install pytest pytest-xdist 35 | pytest -n auto sonnet --ignore=sonnet/src/conformance/ 36 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | ## How to become a contributor and submit your own code 4 | 5 | ### Contributor License Agreements 6 | 7 | We'd love to accept your patches! Before we can take them, we have to jump a 8 | couple of legal hurdles. 9 | 10 | Please fill out either the individual or corporate Contributor License Agreement 11 | (CLA). 12 | 13 | * If you are an individual writing original source code and you're sure you 14 | own the intellectual property, then you'll need to sign an [individual 15 | CLA](http://code.google.com/legal/individual-cla-v1.0.html). 16 | * If you work for a company that wants to allow you to contribute your work, 17 | then you'll need to sign a [corporate 18 | CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 19 | 20 | Follow either of the two links above to access the appropriate CLA and 21 | instructions for how to sign and return it. Once we receive it, we'll be able to 22 | accept your pull requests. 23 | 24 | ***NOTE***: Only original source code from you and other people that have signed 25 | the CLA can be accepted into the main repository. 26 | 27 | ### Contributing code 28 | 29 | If you have improvements to Sonnet, send us your pull requests! For those just 30 | getting started, Github has a 31 | [howto](https://help.github.com/articles/using-pull-requests/). 32 | 33 | If you want to contribute but you're not sure where to start, take a look at the 34 | [issues with the "contributions welcome" 35 | label](https://github.com/deepmind/sonnet/labels/stat%3Acontributions%20welcome). 36 | These are issues that we believe are particularly well suited for outside 37 | contributions, often because we probably won't get to them right now. If you 38 | decide to start on an issue, leave a comment so that other people know that 39 | you're working on it. If you want to help out, but not alone, use the issue 40 | comment thread to coordinate. 41 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "sonnet") 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/docs/_static/favicon.ico -------------------------------------------------------------------------------- /docs/ext/BUILD: -------------------------------------------------------------------------------- 1 | load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") 2 | 3 | package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"]) 4 | 5 | licenses(["notice"]) 6 | 7 | snt_py_library( 8 | name = "link_tf_api", 9 | srcs = ["link_tf_api.py"], 10 | deps = [ 11 | # pip: docutils 12 | # pip: tensorflow 13 | ], 14 | ) 15 | 16 | snt_py_test( 17 | name = "link_tf_api_test", 18 | srcs = ["link_tf_api_test.py"], 19 | gpu = False, 20 | tpu = False, 21 | deps = [ 22 | ":link_tf_api", 23 | # pip: absl/testing:absltest 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /docs/ext/link_tf_api_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for ``:tf:`` Sphinx role.""" 16 | 17 | from absl.testing import absltest 18 | from docs.ext import link_tf_api 19 | 20 | DOC_BASE_URL = "https://www.tensorflow.org/versions/r2.0/api_docs/python/tf" 21 | 22 | 23 | class LinkTfApiTest(absltest.TestCase): 24 | 25 | def test_non_existent(self): 26 | self.assertIsNone(link_tf_api.tf_doc_url("tomhennigan")) 27 | self.assertIsNone(link_tf_api.tf_doc_url("autograph.1")) 28 | 29 | def test_link_to_top_level(self): 30 | self.assertEqual( 31 | link_tf_api.tf_doc_url("function"), DOC_BASE_URL + "/function") 32 | self.assertEqual(link_tf_api.tf_doc_url("Module"), DOC_BASE_URL + "/Module") 33 | 34 | def test_link_to_nested_package(self): 35 | self.assertEqual( 36 | link_tf_api.tf_doc_url("autograph.to_code"), 37 | DOC_BASE_URL + "/autograph/to_code") 38 | 39 | def test_link_to_method_of_exported_class(self): 40 | self.assertEqual( 41 | link_tf_api.tf_doc_url("TensorArray.read"), 42 | DOC_BASE_URL + "/TensorArray#read") 43 | 44 | def test_link_to_non_existent_method_of_exported_class(self): 45 | self.assertIsNone(link_tf_api.tf_doc_url("TensorArray.tomhennigan")) 46 | 47 | 48 | if __name__ == "__main__": 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/deepmind/sonnet/tree/v2/docs 2 | 3 | Sonnet Documentation 4 | ==================== 5 | 6 | Sonnet is a library built on top of TensorFlow designed to provide simple, 7 | composable abstractions for machine learning research. 8 | 9 | .. code-block:: python 10 | 11 | import sonnet as snt 12 | import tensorflow as tf 13 | 14 | mlp = snt.nets.MLP([1024, 1024, 10]) 15 | logits = mlp(tf.ones([8, 28 * 28])) 16 | 17 | Installation 18 | ------------ 19 | 20 | Install Sonnet by running:: 21 | 22 | $ pip install tensorflow 23 | $ pip install dm-sonnet 24 | 25 | .. toctree:: 26 | :caption: Guides 27 | :maxdepth: 1 28 | 29 | modules 30 | 31 | .. toctree:: 32 | :caption: Package Reference 33 | :maxdepth: 1 34 | 35 | api 36 | 37 | Contribute 38 | ---------- 39 | 40 | - Issue tracker: https://github.com/deepmind/sonnet/issues 41 | - Source code: https://github.com/deepmind/sonnet/tree/v2 42 | 43 | Support 44 | ------- 45 | 46 | If you are having issues, please let us know by filing an issue on our 47 | `issue tracker `_. 48 | 49 | License 50 | ------- 51 | 52 | Sonnet is licensed under the Apache 2.0 License. 53 | 54 | Indices and tables 55 | ================== 56 | 57 | * :ref:`genindex` 58 | * :ref:`modindex` 59 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx>=2.0.1 2 | sphinx_rtd_theme>=0.4.3 3 | sphinxcontrib-katex>=0.4.1 4 | sphinxcontrib-bibtex>=0.4.2,<2 5 | sphinx-autodoc-typehints>=1.10.3 6 | -------------------------------------------------------------------------------- /examples/BUILD: -------------------------------------------------------------------------------- 1 | # buildifier: disable=out-of-order-load - Breaks copybara otherwise 2 | load("//third_party/bazel_rules/rules_python/python:py_binary.bzl", "py_binary") 3 | load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") 4 | 5 | package(default_visibility = ["//visibility:private"]) 6 | 7 | licenses(["notice"]) 8 | 9 | py_binary( 10 | name = "simple_mnist", 11 | srcs = ["simple_mnist.py"], 12 | deps = [ 13 | # pip: absl:app 14 | "//sonnet", 15 | # pip: tensorflow 16 | # pip: tensorflow_datasets 17 | ], 18 | ) 19 | 20 | snt_py_library( 21 | name = "simple_mnist_library", 22 | srcs = ["simple_mnist.py"], 23 | deps = [ 24 | # pip: absl:app 25 | "//sonnet", 26 | # pip: tensorflow 27 | # pip: tensorflow_datasets 28 | ], 29 | ) 30 | 31 | snt_py_test( 32 | name = "simple_mnist_test", 33 | srcs = ["simple_mnist_test.py"], 34 | deps = [ 35 | ":simple_mnist_library", 36 | "//sonnet", 37 | "//sonnet/src:test_utils", 38 | # pip: tensorflow 39 | ], 40 | ) 41 | 42 | py_binary( 43 | name = "functional_mlp_mnist", 44 | srcs = ["functional_mlp_mnist.py"], 45 | deps = [ 46 | # pip: absl:app 47 | # pip: absl/logging 48 | "//sonnet", 49 | # pip: tensorflow 50 | # pip: tensorflow_datasets 51 | ], 52 | ) 53 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | Google Colab notebooks: 5 | 6 | - [Predicting MNIST with an MLP](https://colab.research.google.com/github/deepmind/sonnet/blob/v2/examples/mlp_on_mnist.ipynb) 7 | - [Training a Little GAN on MNIST](https://colab.research.google.com/github/deepmind/sonnet/blob/v2/examples/little_gan_on_mnist.ipynb) 8 | 9 | 10 | Training scripts: 11 | 12 | - [Simple ConvNet on MNIST](simple_mnist.py) 13 | -------------------------------------------------------------------------------- /examples/simple_mnist_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.examples.simple_mnist.""" 16 | 17 | import sonnet as snt 18 | from examples import simple_mnist 19 | from sonnet.src import test_utils 20 | import tensorflow as tf 21 | 22 | 23 | class SimpleMnistTest(test_utils.TestCase): 24 | 25 | def setUp(self): 26 | self.ENTER_PRIMARY_DEVICE = False # pylint: disable=invalid-name 27 | super().setUp() 28 | 29 | def test_train_epoch(self): 30 | model = snt.Sequential([ 31 | snt.Flatten(), 32 | snt.Linear(10), 33 | ]) 34 | 35 | optimizer = snt.optimizers.SGD(0.1) 36 | 37 | dataset = tf.data.Dataset.from_tensor_slices( 38 | (tf.random.normal([2, 8, 8, 1]), 39 | tf.ones([2], dtype=tf.int64))).batch(2).repeat(4) 40 | 41 | for _ in range(3): 42 | loss = simple_mnist.train_epoch(model, optimizer, dataset) 43 | self.assertEqual(loss.shape, []) 44 | self.assertEqual(loss.dtype, tf.float32) 45 | 46 | def test_test_accuracy(self): 47 | model = snt.Sequential([ 48 | snt.Flatten(), 49 | snt.Linear(10), 50 | ]) 51 | dataset = tf.data.Dataset.from_tensor_slices( 52 | (tf.random.normal([2, 8, 8, 1]), 53 | tf.ones([2], dtype=tf.int64))).batch(2).repeat(4) 54 | 55 | outputs = simple_mnist.test_accuracy(model, dataset) 56 | self.assertEqual(len(outputs), 2) 57 | 58 | 59 | if __name__ == "__main__": 60 | tf.test.main() 61 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | sphinx: 7 | builder: html 8 | configuration: docs/conf.py 9 | fail_on_warning: false 10 | 11 | python: 12 | version: 3.7 13 | install: 14 | - requirements: requirements.txt 15 | - requirements: requirements-tf.txt 16 | - requirements: docs/requirements.txt 17 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | mock>=3.0.5 2 | tensorflow-datasets>1,<4 3 | docutils 4 | -------------------------------------------------------------------------------- /requirements-tf.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.12.0rc0 2 | tensorflow-probability==0.12.2 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.7.1 2 | numpy>=1.16.3 3 | dm-tree>=0.1.1 4 | wrapt>=1.11.1 5 | tabulate>=0.7.5 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup for pip package.""" 2 | 3 | from setuptools import find_namespace_packages 4 | from setuptools import setup 5 | 6 | 7 | def _get_sonnet_version(): 8 | with open('sonnet/__init__.py') as fp: 9 | for line in fp: 10 | if line.startswith('__version__'): 11 | g = {} 12 | exec(line, g) # pylint: disable=exec-used 13 | return g['__version__'] 14 | raise ValueError('`__version__` not defined in `sonnet/__init__.py`') 15 | 16 | 17 | def _parse_requirements(requirements_txt_path): 18 | with open(requirements_txt_path) as fp: 19 | return fp.read().splitlines() 20 | 21 | 22 | _VERSION = _get_sonnet_version() 23 | 24 | EXTRA_PACKAGES = { 25 | 'tensorflow': ['tensorflow>=2'], 26 | 'tensorflow with gpu': ['tensorflow-gpu>=2'], 27 | } 28 | 29 | setup( 30 | name='dm-sonnet', 31 | version=_VERSION, 32 | url='https://github.com/deepmind/sonnet', 33 | license='Apache 2.0', 34 | author='DeepMind', 35 | description=( 36 | 'Sonnet is a library for building neural networks in TensorFlow.'), 37 | long_description=open('README.md').read(), 38 | long_description_content_type='text/markdown', 39 | author_email='sonnet-dev-os@google.com', 40 | # Contained modules and scripts. 41 | packages=find_namespace_packages(exclude=['*_test.py']), 42 | install_requires=_parse_requirements('requirements.txt'), 43 | extras_require=EXTRA_PACKAGES, 44 | tests_require=_parse_requirements('requirements-test.txt'), 45 | requires_python='>=3.6', 46 | include_package_data=True, 47 | zip_safe=False, 48 | # PyPI package information. 49 | classifiers=[ 50 | 'Development Status :: 3 - Alpha', 51 | 'Intended Audience :: Developers', 52 | 'Intended Audience :: Education', 53 | 'Intended Audience :: Science/Research', 54 | 'License :: OSI Approved :: Apache Software License', 55 | 'Programming Language :: Python :: 3', 56 | 'Programming Language :: Python :: 3.9', 57 | 'Programming Language :: Python :: 3.10', 58 | 'Programming Language :: Python :: 3.11', 59 | 'Topic :: Scientific/Engineering :: Mathematics', 60 | 'Topic :: Software Development :: Libraries :: Python Modules', 61 | 'Topic :: Software Development :: Libraries', 62 | ], 63 | ) 64 | -------------------------------------------------------------------------------- /sonnet/BUILD: -------------------------------------------------------------------------------- 1 | load("//sonnet/src:build_defs.bzl", "snt_py_library") 2 | 3 | package(default_visibility = ["//visibility:private"]) 4 | 5 | licenses(["notice"]) 6 | 7 | snt_py_library( 8 | name = "sonnet", 9 | srcs = ["__init__.py"], 10 | visibility = ["//visibility:public"], 11 | deps = [ 12 | ":distribute", 13 | ":functional", 14 | ":initializers", 15 | ":mixed_precision", 16 | ":optimizers", 17 | ":pad", 18 | ":regularizers", 19 | "//sonnet/nets", 20 | "//sonnet/src:axis_norm", 21 | "//sonnet/src:base", 22 | "//sonnet/src:batch_apply", 23 | "//sonnet/src:batch_norm", 24 | "//sonnet/src:bias", 25 | "//sonnet/src:build", 26 | "//sonnet/src:conv", 27 | "//sonnet/src:conv_transpose", 28 | "//sonnet/src:custom_getter", 29 | "//sonnet/src:deferred", 30 | "//sonnet/src:depthwise_conv", 31 | "//sonnet/src:dropout", 32 | "//sonnet/src:embed", 33 | "//sonnet/src:group_norm", 34 | "//sonnet/src:leaky_clip_by_value", 35 | "//sonnet/src:linear", 36 | "//sonnet/src:metrics", 37 | "//sonnet/src:moving_averages", 38 | "//sonnet/src:once", 39 | "//sonnet/src:recurrent", 40 | "//sonnet/src:reshape", 41 | "//sonnet/src:scale_gradient", 42 | "//sonnet/src:sequential", 43 | "//sonnet/src:utils", 44 | ], 45 | ) 46 | 47 | snt_py_library( 48 | name = "distribute", 49 | srcs = ["distribute.py"], 50 | deps = [ 51 | "//sonnet/src/distribute:distributed_batch_norm", 52 | "//sonnet/src/distribute:replicator", 53 | ], 54 | ) 55 | 56 | snt_py_library( 57 | name = "functional", 58 | srcs = ["functional.py"], 59 | deps = [ 60 | ":optimizers", 61 | "//sonnet/src/functional:haiku", 62 | "//sonnet/src/functional:jax", 63 | "//sonnet/src/functional:optimizers", 64 | ], 65 | ) 66 | 67 | snt_py_library( 68 | name = "initializers", 69 | srcs = ["initializers.py"], 70 | deps = [ 71 | "//sonnet/src:initializers", 72 | ], 73 | ) 74 | 75 | snt_py_library( 76 | name = "mixed_precision", 77 | srcs = ["mixed_precision.py"], 78 | deps = [ 79 | "//sonnet/src:mixed_precision", 80 | ], 81 | ) 82 | 83 | snt_py_library( 84 | name = "optimizers", 85 | srcs = ["optimizers.py"], 86 | deps = [ 87 | "//sonnet/src/optimizers:adam", 88 | "//sonnet/src/optimizers:momentum", 89 | "//sonnet/src/optimizers:rmsprop", 90 | "//sonnet/src/optimizers:sgd", 91 | ], 92 | ) 93 | 94 | snt_py_library( 95 | name = "pad", 96 | srcs = ["pad.py"], 97 | deps = [ 98 | "//sonnet/src:pad", 99 | ], 100 | ) 101 | 102 | snt_py_library( 103 | name = "regularizers", 104 | srcs = ["regularizers.py"], 105 | deps = [ 106 | "//sonnet/src:regularizers", 107 | ], 108 | ) 109 | -------------------------------------------------------------------------------- /sonnet/distribute.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utilities for using Sonnet with TensorFlow Distribution Strategy.""" 16 | 17 | from sonnet.src.distribute.distributed_batch_norm import CrossReplicaBatchNorm 18 | from sonnet.src.distribute.replicator import create_variables_eagerly 19 | from sonnet.src.distribute.replicator import Replicator 20 | from sonnet.src.distribute.replicator import TpuReplicator 21 | 22 | __all__ = ( 23 | "create_variables_eagerly", 24 | "Replicator", 25 | "TpuReplicator", 26 | "CrossReplicaBatchNorm", 27 | ) 28 | -------------------------------------------------------------------------------- /sonnet/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Simple functional APIs for TF2.""" 16 | 17 | from sonnet import optimizers as oo_optimizers 18 | from sonnet.src.functional import haiku 19 | from sonnet.src.functional import jax 20 | from sonnet.src.functional import optimizers 21 | 22 | # Utilities for converting Sonnet code into pure functions. 23 | variables = haiku.variables 24 | transform = haiku.transform 25 | transform_with_state = haiku.transform_with_state 26 | without_state = haiku.without_state 27 | 28 | # Utilities for working with tensors on device. 29 | device_get = jax.device_get 30 | device_put = jax.device_put 31 | 32 | # Utilities for transforming pure functions. 33 | grad = jax.grad 34 | jit = jax.jit 35 | value_and_grad = jax.value_and_grad 36 | 37 | # Optimizers. 38 | optimizer = optimizers.optimizer 39 | sgd = optimizer(oo_optimizers.SGD) 40 | adam = optimizer(oo_optimizers.Adam) 41 | rmsprop = optimizer(oo_optimizers.RMSProp) 42 | momentum = optimizer(oo_optimizers.Momentum) 43 | 44 | # Avoid accidentally exporting the private API. 45 | del oo_optimizers, haiku, optimizers, jax 46 | 47 | __all__ = ( 48 | "variables", 49 | "transform", 50 | "transform_with_state", 51 | "without_state", 52 | "device_get", 53 | "device_put", 54 | "grad", 55 | "jit", 56 | "value_and_grad", 57 | "optimizer", 58 | "sgd", 59 | "adam", 60 | "rmsprop", 61 | "momentum", 62 | ) 63 | -------------------------------------------------------------------------------- /sonnet/initializers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Initializers.""" 16 | 17 | from sonnet.src.initializers import Constant 18 | from sonnet.src.initializers import Identity 19 | from sonnet.src.initializers import Initializer 20 | from sonnet.src.initializers import Ones 21 | from sonnet.src.initializers import Orthogonal 22 | from sonnet.src.initializers import RandomNormal 23 | from sonnet.src.initializers import RandomUniform 24 | from sonnet.src.initializers import TruncatedNormal 25 | from sonnet.src.initializers import VarianceScaling 26 | from sonnet.src.initializers import Zeros 27 | 28 | __all__ = ( 29 | "Constant", 30 | "Identity", 31 | "Initializer", 32 | "Ones", 33 | "Orthogonal", 34 | "RandomNormal", 35 | "RandomUniform", 36 | "TruncatedNormal", 37 | "VarianceScaling", 38 | "Zeros", 39 | ) 40 | -------------------------------------------------------------------------------- /sonnet/mixed_precision.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Sonnet mixed precision built for TensorFlow 2.""" 16 | 17 | from sonnet.src.mixed_precision import disable 18 | from sonnet.src.mixed_precision import enable 19 | from sonnet.src.mixed_precision import modes 20 | from sonnet.src.mixed_precision import scope 21 | 22 | __all__ = ( 23 | "disable", 24 | "enable", 25 | "modes", 26 | "scope", 27 | ) 28 | -------------------------------------------------------------------------------- /sonnet/nets/BUILD: -------------------------------------------------------------------------------- 1 | load("//sonnet/src:build_defs.bzl", "snt_py_library") 2 | 3 | package(default_visibility = ["//sonnet:__pkg__"]) 4 | 5 | licenses(["notice"]) 6 | 7 | snt_py_library( 8 | name = "nets", 9 | srcs = [ 10 | "__init__.py", 11 | "resnet.py", 12 | ], 13 | deps = [ 14 | "//sonnet/src/nets:cifar10_convnet", 15 | "//sonnet/src/nets:mlp", 16 | "//sonnet/src/nets:resnet", 17 | "//sonnet/src/nets:vqvae", 18 | ], 19 | ) 20 | -------------------------------------------------------------------------------- /sonnet/nets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Common network architectures implemented as Sonnet modules.""" 16 | 17 | from sonnet.nets import resnet 18 | from sonnet.src.nets.cifar10_convnet import Cifar10ConvNet 19 | from sonnet.src.nets.mlp import MLP 20 | from sonnet.src.nets.resnet import ResNet 21 | from sonnet.src.nets.resnet import ResNet50 22 | from sonnet.src.nets.vqvae import VectorQuantizer 23 | from sonnet.src.nets.vqvae import VectorQuantizerEMA 24 | 25 | __all__ = ( 26 | "MLP", 27 | "Cifar10ConvNet", 28 | "resnet", 29 | "ResNet", 30 | "ResNet50", 31 | "VectorQuantizer", 32 | "VectorQuantizerEMA", 33 | ) 34 | -------------------------------------------------------------------------------- /sonnet/nets/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """ResNet components.""" 16 | 17 | from sonnet.src.nets.resnet import BlockGroup 18 | from sonnet.src.nets.resnet import BottleNeckBlockV1 19 | from sonnet.src.nets.resnet import BottleNeckBlockV2 20 | 21 | __all__ = ( 22 | "BlockGroup", 23 | "BottleNeckBlockV1", 24 | "BottleNeckBlockV2", 25 | ) 26 | -------------------------------------------------------------------------------- /sonnet/optimizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Sonnet optimizers built for TensorFlow 2. 16 | 17 | All optimizers implement the `snt.Optimizer` interface. 18 | """ 19 | 20 | from sonnet.src.optimizers.adam import Adam 21 | from sonnet.src.optimizers.momentum import Momentum 22 | from sonnet.src.optimizers.rmsprop import RMSProp 23 | from sonnet.src.optimizers.sgd import SGD 24 | 25 | __all__ = ( 26 | "Adam", 27 | "Momentum", 28 | "RMSProp", 29 | "SGD", 30 | ) 31 | -------------------------------------------------------------------------------- /sonnet/pad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Paddings.""" 16 | 17 | from sonnet.src.pad import causal 18 | from sonnet.src.pad import create 19 | from sonnet.src.pad import full 20 | from sonnet.src.pad import reverse_causal 21 | from sonnet.src.pad import same 22 | from sonnet.src.pad import valid 23 | 24 | __all__ = ( 25 | "causal", 26 | "create", 27 | "full", 28 | "reverse_causal", 29 | "same", 30 | "valid", 31 | ) 32 | -------------------------------------------------------------------------------- /sonnet/regularizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Regularizers.""" 16 | 17 | from sonnet.src.regularizers import L1 18 | from sonnet.src.regularizers import L2 19 | from sonnet.src.regularizers import OffDiagonalOrthogonal 20 | from sonnet.src.regularizers import Regularizer 21 | 22 | __all__ = [ 23 | "L1", 24 | "L2", 25 | "OffDiagonalOrthogonal", 26 | "Regularizer", 27 | ] 28 | -------------------------------------------------------------------------------- /sonnet/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /sonnet/src/bias_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.bias.""" 16 | 17 | from sonnet.src import bias 18 | from sonnet.src import test_utils 19 | import tensorflow as tf 20 | 21 | 22 | class BiasTest(test_utils.TestCase): 23 | 24 | def test_output_shape(self): 25 | mod = bias.Bias(output_size=(2 * 2,)) 26 | with self.assertRaisesRegex(ValueError, "Input shape must be [(]-1, 4[)]"): 27 | mod(tf.ones([2, 2, 2])) 28 | 29 | def test_output_size_valid(self): 30 | mod = bias.Bias(output_size=(2 * 2,)) 31 | mod(tf.ones([2, 2 * 2])) 32 | 33 | def test_bias_dims_scalar(self): 34 | mod = bias.Bias(bias_dims=()) 35 | mod(tf.ones([1, 2, 3, 4])) 36 | self.assertEmpty(mod.b.shape) 37 | 38 | def test_bias_dims_custom(self): 39 | b, d1, d2, d3 = range(1, 5) 40 | mod = bias.Bias(bias_dims=[1, 3]) 41 | out = mod(tf.ones([b, d1, d2, d3])) 42 | self.assertEqual(mod.b.shape, [d1, 1, d3]) 43 | self.assertEqual(out.shape, [b, d1, d2, d3]) 44 | 45 | def test_bias_dims_negative_out_of_order(self): 46 | mod = bias.Bias(bias_dims=[-1, -2]) 47 | mod(tf.ones([1, 2, 3])) 48 | self.assertEqual(mod.b.shape, [2, 3]) 49 | 50 | def test_bias_dims_invalid(self): 51 | mod = bias.Bias(bias_dims=[1, 5]) 52 | with self.assertRaisesRegex(ValueError, 53 | "5 .* out of range for input of rank 3"): 54 | mod(tf.ones([1, 2, 3])) 55 | 56 | def test_b_init_defaults_to_zeros(self): 57 | mod = bias.Bias() 58 | mod(tf.ones([1, 1])) 59 | self.assertAllEqual(mod.b.read_value(), tf.zeros_like(mod.b)) 60 | 61 | def test_b_init_custom(self): 62 | ones_initializer = lambda s, d: tf.ones(s, dtype=d) 63 | mod = bias.Bias(b_init=ones_initializer) 64 | mod(tf.ones([1, 1])) 65 | self.assertAllEqual(mod.b.read_value(), tf.ones_like(mod.b)) 66 | 67 | def test_name(self): 68 | mod = bias.Bias(name="foo") 69 | self.assertEqual(mod.name, "foo") 70 | mod(tf.ones([1, 1])) 71 | self.assertEqual(mod.b.name, "foo/b:0") 72 | 73 | def test_multiplier(self): 74 | ones_initializer = lambda s, d: tf.ones(s, dtype=d) 75 | mod = bias.Bias(b_init=ones_initializer) 76 | out = mod(tf.ones([1, 1]), multiplier=-1) 77 | self.assertAllEqual(tf.reduce_sum(out), 0) 78 | 79 | 80 | if __name__ == "__main__": 81 | tf.test.main() 82 | -------------------------------------------------------------------------------- /sonnet/src/build.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utility function to build Sonnet modules.""" 16 | 17 | from typing import Any, Callable 18 | 19 | import tensorflow as tf 20 | import tree 21 | 22 | 23 | def _int_or_none(o): 24 | return isinstance(o, (int, type(None))) 25 | 26 | 27 | def _promote_shapes(o): 28 | """Promotes lists of ints/Nones to :tf:`TensorSpec` instances.""" 29 | if isinstance(o, (list, tuple)) and all(_int_or_none(e) for e in o): 30 | return tf.TensorSpec(o) 31 | return o 32 | 33 | 34 | def _maybe_tensor_spec(shape, dtype): 35 | return tf.TensorSpec(shape, dtype) if dtype is not None else None 36 | 37 | 38 | # TODO(tomhennigan) Use TensorNest in types here. 39 | def build( 40 | f: Callable[..., Any], 41 | *args, 42 | **kwargs 43 | ): 44 | r"""Builds a module by creating all parameters but not computing any output. 45 | 46 | >>> mod = snt.nets.MLP([1000, 10]) 47 | >>> snt.build(mod, [None, 28 * 28]) 48 | TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) 49 | >>> mod.variables 50 | (, 51 | , 52 | , 53 | ) 54 | 55 | Args: 56 | f: A function or callable :class:`Module` that will create variables. 57 | *args: Positional arguments to supply to ``f``. Note that positional 58 | arguments that are sequences of None/ints are converted to 59 | :tf:`TensorSpec` instances. 60 | **kwargs: Keyword arguments to pass to the module. 61 | 62 | Returns: 63 | The output of ``f`` with any :tf:`Tensor`\ s replaced by :tf:`TensorSpec`. 64 | """ 65 | f = tf.function(f) 66 | args = map(_promote_shapes, args) 67 | # NOTE: We use a concrete function to ensure that weights are created and 68 | # initialized, but other stateful ops (e.g. updating weights) are not. 69 | cf = f.get_concrete_function(*args, **kwargs) 70 | return tree.map_structure(_maybe_tensor_spec, cf.output_shapes, 71 | cf.output_dtypes) 72 | -------------------------------------------------------------------------------- /sonnet/src/build_defs.bzl: -------------------------------------------------------------------------------- 1 | """Sonnet specific build rules.""" 2 | 3 | def snt_py_library(name, **kwargs): 4 | """Proxy for py_library. 5 | 6 | Internally we override this to enable type checking via PyType (more 7 | information at https://github.com/google/pytype). 8 | 9 | Args: 10 | name: library name. 11 | **kwargs: keyword args passed straight to py_library. 12 | """ 13 | native.py_library(name = name, **kwargs) 14 | 15 | def snt_py_test( 16 | name, 17 | deps = [], 18 | tags = [], 19 | main = None, 20 | gpu = True, 21 | tpu = True, 22 | **kwargs): 23 | """Runs a py_test. 24 | 25 | Args: 26 | name: test target name to generate suffixed with `test`. 27 | deps: additional dependencies for the test targets. 28 | tags: tags to be assigned to the different test targets. 29 | main: main script to be run for the test. 30 | gpu: Whether the test can be run on GPU. Note ignored by test. 31 | tpu: Whether the test can be run on TPU. Note ignored by test. 32 | **kwargs: extra keyword arguments to the test. 33 | """ 34 | if main == None: 35 | main = name + ".py" 36 | 37 | native.py_test( 38 | name = name, 39 | deps = deps, 40 | tags = tags, 41 | main = main, 42 | python_version = "PY3", 43 | **kwargs 44 | ) 45 | -------------------------------------------------------------------------------- /sonnet/src/build_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.build.""" 16 | 17 | from sonnet.src import build 18 | from sonnet.src import test_utils 19 | import tensorflow as tf 20 | 21 | 22 | class BuildTest(test_utils.TestCase): 23 | 24 | def test_call_with_shape_lke_object(self): 25 | output_spec = build.build(tensor_identity, [1, None, 3]) 26 | self.assertEqual(output_spec, tf.TensorSpec([1, None, 3])) 27 | 28 | def test_output_spec(self): 29 | dtype = tf.float32 if self.primary_device == "TPU" else tf.float16 30 | inputs = {"foo": [tf.ones([], dtype), None]} 31 | output_spec = build.build(lambda x: x, inputs) 32 | self.assertEqual(output_spec, 33 | {"foo": [tf.TensorSpec([], dtype), None]}) 34 | 35 | def test_does_not_trigger_sideeffects(self): 36 | mod = IncrementsCounter() 37 | output_spec = build.build(mod) 38 | self.assertIsNone(output_spec) 39 | self.assertEqual(mod.counter.numpy(), 0) 40 | 41 | 42 | def tensor_identity(x): 43 | assert isinstance(x, tf.Tensor) 44 | return x 45 | 46 | 47 | class IncrementsCounter(tf.Module): 48 | 49 | def __call__(self): 50 | if not hasattr(self, "counter"): 51 | self.counter = tf.Variable(0) 52 | self.counter.assign_add(1) 53 | 54 | if __name__ == "__main__": 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /sonnet/src/conformance/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /sonnet/src/conformance/api_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for Sonnet's public API.""" 16 | 17 | import importlib 18 | 19 | import sonnet as snt 20 | from sonnet.src import test_utils 21 | import tensorflow as tf 22 | 23 | 24 | class PublicSymbolsTest(test_utils.TestCase): 25 | 26 | def test_src_not_exported(self): 27 | self.assertFalse(hasattr(snt, "src")) 28 | 29 | def test_supports_reload(self): 30 | mysnt = snt 31 | for _ in range(2): 32 | mysnt = importlib.reload(mysnt) 33 | self.assertFalse(hasattr(mysnt, "src")) 34 | 35 | 36 | if __name__ == "__main__": 37 | tf.test.main() 38 | -------------------------------------------------------------------------------- /sonnet/src/conformance/build_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests modules support `snt.build`.""" 16 | 17 | from absl.testing import parameterized 18 | import sonnet as snt 19 | from sonnet.src import test_utils 20 | from sonnet.src.conformance import descriptors 21 | import tensorflow as tf 22 | import tree 23 | 24 | BATCH_MODULES = descriptors.BATCH_MODULES 25 | RECURRENT_MODULES = descriptors.RECURRENT_MODULES 26 | 27 | 28 | def if_present(f): 29 | return lambda o: f(o) if o is not None else None 30 | 31 | 32 | class BuildTest(test_utils.TestCase, parameterized.TestCase): 33 | 34 | @parameterized.named_parameters(*(BATCH_MODULES + RECURRENT_MODULES)) 35 | def test_build(self, module_fn, input_shape, dtype): 36 | module = module_fn() 37 | build_output_spec = snt.build(module, tf.TensorSpec(input_shape, dtype)) 38 | actual_output = module(tf.ones(input_shape, dtype)) 39 | actual_output_spec = tree.map_structure( 40 | if_present(lambda t: tf.TensorSpec(t.shape, t.dtype)), actual_output) 41 | tree.map_structure(self.assertCompatible, build_output_spec, 42 | actual_output_spec) 43 | 44 | def assertCompatible(self, a: tf.TensorSpec, b: tf.TensorSpec): 45 | self.assertTrue(a.shape.is_compatible_with(b.shape)) 46 | self.assertEqual(a.dtype, b.dtype) 47 | 48 | 49 | if __name__ == "__main__": 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/BUILD: -------------------------------------------------------------------------------- 1 | load("//third_party/bazel_rules/rules_python/python:py_binary.bzl", "py_binary") 2 | 3 | package( 4 | default_testonly = True, 5 | default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"], 6 | ) 7 | 8 | licenses(["notice"]) 9 | 10 | py_binary( 11 | name = "generate", 12 | srcs = ["generate.py"], 13 | deps = [ 14 | # pip: absl:app 15 | # pip: absl/flags 16 | # pip: absl/logging 17 | "//sonnet/src/conformance:goldens", 18 | # pip: tensorflow 19 | ], 20 | ) 21 | 22 | filegroup( 23 | name = "checkpoints", 24 | srcs = glob(["**/*"]), 25 | ) 26 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/README.md: -------------------------------------------------------------------------------- 1 | # Golden checkpoints 2 | 3 | Golden checkpoints represent checkpoints generated from stable Sonnet code. We 4 | have unit tests that ensure we don't introduce checkpoint breaking changes to 5 | Sonnet. 6 | 7 | To generate a new checkpoint first add an entry in `goldens.py` describing the 8 | module you want to add. For example: 9 | 10 | ```python 11 | @_register_golden(snt.Linear, "linear_32x64") 12 | class Linear32x64(Golden): 13 | """Tests Linear without a bias.""" 14 | 15 | def create_module(self): 16 | return snt.Linear(64) 17 | 18 | def forward(self, module): 19 | x = range_like(tf.TensorSpec([1, 32])) 20 | return module(x) 21 | 22 | def create_all_variables(self, module): 23 | self.forward(module) 24 | return module.w, module.b 25 | ``` 26 | 27 | Then run the `generate` binary to generate new golden checkpoints: 28 | 29 | ```shell 30 | $ bazel run :generate -- --dry_run=false --golden_dir="$PWD" --alsologtostderr 31 | ``` 32 | 33 | At this point your golden checkpoint will be created and registered to run 34 | whenever `goldens_test` runs: 35 | 36 | ```shell 37 | $ bazel test :goldens_test 38 | ``` 39 | 40 | ## Regenerating old checkpoints 41 | 42 | WARNING: In general once a checkpoint is checked in it is only safe to 43 | regenerate it if your module has zero users. If you are making an additive 44 | change to a module (e.g. adding a new parameter) then consider making a new 45 | checkpoint and ensure that you can load from both the old and new checkpoint. 46 | 47 | If you absolutely need to regenerate the checkpoint and know what you're doing 48 | then you can do so with: 49 | 50 | ```shell 51 | $ bazel run :generate -- --dry_run=false --golden_dir="$PWD" --alsologtostderr --filter=my_checkpoint_name --regenerate 52 | ``` 53 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/base_batch_norm_1x2x2x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/base_batch_norm_1x2x2x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/base_batch_norm_1x2x2x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/base_batch_norm_1x2x2x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/base_batch_norm_1x2x2x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/base_batch_norm_scale_offset_1x2x2x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/base_batch_norm_scale_offset_1x2x2x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/base_batch_norm_scale_offset_1x2x2x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/base_batch_norm_scale_offset_1x2x2x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/base_batch_norm_scale_offset_1x2x2x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_1x2x2x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_1x2x2x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/batch_norm_1x2x2x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_1x2x2x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/batch_norm_1x2x2x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_scale_offset_1x2x2x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_scale_offset_1x2x2x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/batch_norm_scale_offset_1x2x2x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_scale_offset_1x2x2x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/batch_norm_scale_offset_1x2x2x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_training_1x2x2x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_training_1x2x2x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/batch_norm_training_1x2x2x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/batch_norm_training_1x2x2x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/batch_norm_training_1x2x2x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/bias_3x3x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/bias_3x3x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/bias_3x3x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/bias_3x3x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/bias_3x3x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/cifar10_convnet_2x3_2x2_1x3x3x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/cifar10_convnet_2x3_2x2_1x3x3x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/cifar10_convnet_2x3_2x2_1x3x3x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/cifar10_convnet_2x3_2x2_1x3x3x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/cifar10_convnet_2x3_2x2_1x3x3x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv1d_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv1d_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_lstm_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_lstm_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv1d_lstm_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_lstm_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv1d_lstm_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_transpose_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_transpose_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv1d_transpose_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv1d_transpose_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv1d_transpose_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv2d_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv2d_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_lstm_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_lstm_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv2d_lstm_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_lstm_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv2d_lstm_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_transpose_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_transpose_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv2d_transpose_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv2d_transpose_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv2d_transpose_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv3d_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv3d_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_lstm_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_lstm_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv3d_lstm_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_lstm_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv3d_lstm_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_transpose_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_transpose_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv3d_transpose_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/conv3d_transpose_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/conv3d_transpose_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/cross_replica_batch_norm_1x2x2x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/cross_replica_batch_norm_1x2x2x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/cross_replica_batch_norm_1x2x2x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/cross_replica_batch_norm_1x2x2x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/cross_replica_batch_norm_1x2x2x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/depthwise_conv2d_3x3_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/depthwise_conv2d_3x3_2x2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/depthwise_conv2d_3x3_2x2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/depthwise_conv2d_3x3_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/depthwise_conv2d_3x3_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/dropout/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/dropout/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/dropout/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/dropout/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/dropout/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/ema_2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/ema_2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/ema_2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/ema_2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/ema_2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/embed_100_100/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/embed_100_100/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/embed_100_100/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/embed_100_100/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/embed_100_100/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Binary to generate golden checkpoint tests.""" 16 | 17 | import os 18 | import re 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | from sonnet.src.conformance import goldens 24 | import tensorflow as tf 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | flags.DEFINE_string("golden_dir", 29 | "sonnet/src/conformance/checkpoints/", 30 | "Directory where golden files are to be found.") 31 | flags.DEFINE_string("filter", ".*", "Filter to a specific golden by name.") 32 | flags.DEFINE_bool("regenerate", False, 33 | "Whether to regnerate existing checkpoints.") 34 | flags.DEFINE_bool("dry_run", True, "Whether to actually apply changes.") 35 | 36 | 37 | def safe_mkdir(directory): 38 | if FLAGS.dry_run: 39 | logging.warning("[DRY RUN] Would create %r", directory) 40 | else: 41 | logging.info("Creating %r", directory) 42 | os.mkdir(directory) 43 | 44 | 45 | def safe_unlink(path): 46 | if FLAGS.dry_run: 47 | logging.warning("[DRY RUN] Would delete %r", path) 48 | else: 49 | logging.info("Deleting %r", path) 50 | os.unlink(path) 51 | 52 | 53 | def main(unused_argv): 54 | del unused_argv 55 | 56 | for _, name, cls in goldens.list_goldens(): 57 | if not re.match(FLAGS.filter, name): 58 | continue 59 | 60 | checkpoint_dir = os.path.join(FLAGS.golden_dir, name) 61 | exists = os.path.exists(checkpoint_dir) 62 | if exists and not FLAGS.regenerate: 63 | logging.info("Skipping %s since it exists and --regenerate=false", name) 64 | continue 65 | 66 | logging.info("Processing %s", name) 67 | if not exists: 68 | safe_mkdir(checkpoint_dir) 69 | else: 70 | # Clear out old files. 71 | for file_name in os.listdir(checkpoint_dir): 72 | safe_unlink(os.path.join(checkpoint_dir, file_name)) 73 | 74 | # Create the module to checkpoint. 75 | golden = cls() 76 | module = golden.create_module() 77 | golden.create_all_variables(module) 78 | for var in module.variables: 79 | var.assign(goldens.range_like(var)) 80 | 81 | # Create a checkpoint and save the values to it. 82 | checkpoint = tf.train.Checkpoint(module=module) 83 | if FLAGS.dry_run: 84 | logging.warning("[DRY RUN] Would save %r to %r", module, checkpoint_dir) 85 | else: 86 | file_prefix = os.path.join(checkpoint_dir, "checkpoint") 87 | logging.info("Saving to checkpoint %s.", file_prefix) 88 | checkpoint.save(file_prefix=file_prefix) 89 | 90 | 91 | if __name__ == "__main__": 92 | app.run(main) 93 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/group_norm_2_1x3x4/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/group_norm_2_1x3x4/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/group_norm_2_1x3x4/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/group_norm_2_1x3x4/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/group_norm_2_1x3x4/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/gru_1/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/gru_1/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/gru_1/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/gru_1/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/gru_1/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/instance_norm_1_1x3_2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/instance_norm_1_1x3_2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/instance_norm_1_1x3_2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/instance_norm_1_1x3_2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/instance_norm_1_1x3_2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/layer_norm_1_1x3_2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/layer_norm_1_1x3_2/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/layer_norm_1_1x3_2/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/layer_norm_1_1x3_2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/layer_norm_1_1x3_2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/linear_1x1/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/linear_1x1/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/linear_1x1/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/linear_1x1/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/linear_1x1/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/linear_nobias_1x1/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/linear_nobias_1x1/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/linear_nobias_1x1/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/linear_nobias_1x1/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/linear_nobias_1x1/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/lstm_1/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/lstm_1/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/lstm_1/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/lstm_1/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/lstm_1/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/lstm_8_projected_1/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/lstm_8_projected_1/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/lstm_8_projected_1/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/lstm_8_projected_1/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/lstm_8_projected_1/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mean_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mean_2x2/checkpoint-1.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/mean_2x2/checkpoint-1.data-00000-of-00002 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mean_2x2/checkpoint-1.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/mean_2x2/checkpoint-1.data-00001-of-00002 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mean_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/mean_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mlp_3x4x5_1x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mlp_3x4x5_1x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/mlp_3x4x5_1x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mlp_3x4x5_1x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/mlp_3x4x5_1x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mlp_nobias_3x4x5_1x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mlp_nobias_3x4x5_1x3/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/mlp_nobias_3x4x5_1x3/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/mlp_nobias_3x4x5_1x3/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/mlp_nobias_3x4x5_1x3/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/resnet50/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/resnet50/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/resnet50/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/resnet50/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/resnet50/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/sum_2x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/sum_2x2/checkpoint-1.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/sum_2x2/checkpoint-1.data-00000-of-00002 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/sum_2x2/checkpoint-1.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/sum_2x2/checkpoint-1.data-00001-of-00002 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/sum_2x2/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/sum_2x2/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/trainable_state/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/trainable_state/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/trainable_state/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/trainable_state/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/trainable_state/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/unrolled_lstm_1/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/unrolled_lstm_1/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/unrolled_lstm_1/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/unrolled_lstm_1/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/unrolled_lstm_1/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vanilla_rnn_8/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vanilla_rnn_8/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/vanilla_rnn_8/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vanilla_rnn_8/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/vanilla_rnn_8/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/vqvae/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/vqvae/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae_ema_eval/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae_ema_eval/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/vqvae_ema_eval/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae_ema_eval/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/vqvae_ema_eval/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae_ema_train/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint-1" 2 | all_model_checkpoint_paths: "checkpoint-1" 3 | -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae_ema_train/checkpoint-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/vqvae_ema_train/checkpoint-1.data-00000-of-00001 -------------------------------------------------------------------------------- /sonnet/src/conformance/checkpoints/vqvae_ema_train/checkpoint-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/sonnet/c99b49136210c30fd95bd9c6350fcc3eaf9a72f3/sonnet/src/conformance/checkpoints/vqvae_ema_train/checkpoint-1.index -------------------------------------------------------------------------------- /sonnet/src/conformance/copy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests copying Sonnet modules.""" 17 | 18 | import copy 19 | 20 | from absl.testing import parameterized 21 | from sonnet.src import test_utils 22 | from sonnet.src.conformance import goldens 23 | import tensorflow as tf 24 | import tree 25 | 26 | 27 | class CopyTest(test_utils.TestCase, parameterized.TestCase): 28 | 29 | @goldens.all_goldens 30 | def test_copy(self, golden): 31 | m1 = golden.create_module() 32 | golden.create_all_variables(m1) 33 | m2 = copy.deepcopy(m1) 34 | self.assertIsNot(m1, m2) 35 | 36 | # Check that module variables are recreated with equivalent properties. 37 | for v1, v2 in zip(m1.variables, m2.variables): 38 | self.assertIsNot(v1, v2) 39 | self.assertEqual(v1.name, v2.name) 40 | self.assertEqual(v1.device, v2.device) 41 | self.assertAllEqual(v1.read_value(), v2.read_value()) 42 | 43 | if golden.deterministic: 44 | y1 = golden.forward(m1) 45 | y2 = golden.forward(m2) 46 | tree.map_structure(self.assertAllEqual, y1, y2) 47 | 48 | if __name__ == "__main__": 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /sonnet/src/conformance/descriptors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.conformance.descriptors.""" 16 | 17 | import sonnet as snt 18 | from sonnet.src import test_utils 19 | from sonnet.src.conformance import descriptors 20 | import tensorflow as tf 21 | 22 | BATCH_MODULES = descriptors.BATCH_MODULES 23 | RECURRENT_MODULES = descriptors.RECURRENT_MODULES 24 | OPTIMIZER_MODULES = descriptors.OPTIMIZER_MODULES 25 | IGNORED_MODULES = descriptors.IGNORED_MODULES 26 | 27 | 28 | class DescriptorsTest(test_utils.TestCase): 29 | 30 | def test_coverage(self): 31 | all_modules = frozenset(test_utils.find_all_sonnet_modules(snt, snt.Module)) 32 | tested_modules = { 33 | type(descriptors.unwrap(d.create())) 34 | for d in BATCH_MODULES + RECURRENT_MODULES + OPTIMIZER_MODULES 35 | } 36 | self.assertEmpty(all_modules - (tested_modules | IGNORED_MODULES)) 37 | 38 | 39 | if __name__ == '__main__': 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /sonnet/src/conformance/distribute_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests Sonnet and TF Distribution Strategy.""" 16 | 17 | from typing import Callable, Tuple 18 | 19 | from absl.testing import parameterized 20 | import sonnet as snt 21 | from sonnet.src import test_utils 22 | from sonnet.src.conformance import descriptors 23 | from sonnet.src.conformance import goldens 24 | from sonnet.src.distribute import replicator as snt_replicator 25 | from sonnet.src.distribute import replicator_test_utils as replicator_utils 26 | import tensorflow as tf 27 | 28 | 29 | class TpuReplicatorTest(test_utils.TestCase, parameterized.TestCase): 30 | 31 | @test_utils.combined_named_parameters(goldens.named_goldens(), 32 | replicator_utils.named_replicators()) 33 | def test_variable_creation_in_replica_context(self, golden, replicator_fn): 34 | tf.random.set_seed(None) 35 | replicator = replicator_fn() 36 | 37 | with replicator.scope(): 38 | mod = golden.create_module() 39 | 40 | @tf.function 41 | def forward(): 42 | step = lambda: golden.create_all_variables(mod) 43 | return replicator.run(step) 44 | 45 | # TODO(b/132329316) Remove when `xla.compile` allows tf.device(TPU). 46 | with tf.device(None): 47 | variables_per_replica = forward() 48 | 49 | self.assertLen(variables_per_replica, golden.num_variables) 50 | 51 | for per_replica_variable in variables_per_replica: 52 | self.assertSameValuePerReplica(replicator, per_replica_variable) 53 | 54 | def assertSameValuePerReplica(self, replicator, per_replica): 55 | per_replica = replicator.experimental_local_results(per_replica) 56 | first_replica = per_replica[0] 57 | for nth_replica in per_replica[1:]: 58 | self.assertAllEqual(first_replica, nth_replica) 59 | 60 | @test_utils.combined_named_parameters(descriptors.RNN_CORES, 61 | test_utils.named_bools("dynamic"), 62 | replicator_utils.named_replicators()) 63 | def test_unroll( 64 | self, 65 | core_fn: Callable[[], snt.RNNCore], 66 | input_shape: Tuple[int], 67 | dtype: tf.DType, 68 | dynamic: bool, 69 | replicator_fn: tf.distribute.Strategy, 70 | ): 71 | replicator = replicator_fn() 72 | with replicator.scope(): 73 | core = core_fn() 74 | 75 | def step_fn(): 76 | def forward(): 77 | unroll = snt.dynamic_unroll if dynamic else snt.static_unroll 78 | sequence = tf.ones((1,) + input_shape, dtype) 79 | state = core.initial_state(input_shape[0]) 80 | return unroll(core, sequence, state) 81 | 82 | return replicator.run(forward) 83 | 84 | # TpuReplicator doesn't support pure eager mode. 85 | if isinstance(replicator, snt_replicator.TpuReplicator): 86 | step_fn = tf.function(step_fn) 87 | 88 | # TODO(b/132329316) Remove when `xla.compile` allows tf.device(TPU). 89 | with tf.device(None): 90 | out_sequence, final_state = step_fn() 91 | 92 | self.assertSameValuePerReplica(replicator, out_sequence) 93 | self.assertSameValuePerReplica(replicator, final_state) 94 | 95 | if __name__ == "__main__": 96 | tf.test.main() 97 | -------------------------------------------------------------------------------- /sonnet/src/conformance/doctest_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Ensures that code samples in Sonnet are accurate.""" 16 | 17 | import doctest 18 | import inspect 19 | 20 | from absl.testing import parameterized 21 | import sonnet as snt 22 | from sonnet.src import test_utils 23 | import tensorflow as tf 24 | import tree 25 | 26 | 27 | class DoctestTest(test_utils.TestCase, parameterized.TestCase): 28 | 29 | # Avoid running doctests inside a `with tf.device` block. 30 | ENTER_PRIMARY_DEVICE = False 31 | 32 | def setUp(self): 33 | super().setUp() 34 | if self.primary_device != "TPU": 35 | # `TpuReplicator` cannot be constructed without a TPU, however it has 36 | # exactly the same API as `Replicator` so we can run doctests using that 37 | # instead. 38 | snt.distribute.TpuReplicator = snt.distribute.Replicator 39 | 40 | @parameterized.named_parameters(test_utils.find_sonnet_python_modules(snt)) 41 | def test_doctest(self, module): 42 | # `snt` et al import all dependencies from `src`, however doctest does not 43 | # test imported deps so we must manually set `__test__` such that imported 44 | # symbols are tested. 45 | # See: docs.python.org/3/library/doctest.html#which-docstrings-are-examined 46 | if not hasattr(module, "__test__") or not module.__test__: 47 | module.__test__ = {} 48 | for name in module.__all__: 49 | value = getattr(module, name) 50 | if not inspect.ismodule(value): 51 | if (inspect.isclass(value) or isinstance(value, str) or 52 | inspect.isfunction(value) or inspect.ismethod(value)): 53 | module.__test__[name] = value 54 | elif hasattr(value, "__doc__"): 55 | module.__test__[name] = value.__doc__ 56 | 57 | num_failed, num_attempted = doctest.testmod( 58 | module, 59 | optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE, 60 | extraglobs={ 61 | "snt": snt, 62 | "tf": tf, 63 | "tree": tree, 64 | }) 65 | if num_attempted == 0: 66 | self.skipTest("No doctests in %s" % module.__name__) 67 | self.assertEqual(num_failed, 0, "{} doctests failed".format(num_failed)) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /sonnet/src/conformance/goldens_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests goldens cover all modules.""" 16 | 17 | import inspect 18 | 19 | import sonnet as snt 20 | from sonnet.src import test_utils 21 | from sonnet.src.conformance import goldens 22 | import tensorflow as tf 23 | 24 | 25 | class CoverageTest(test_utils.TestCase): 26 | 27 | def test_all_modules_covered(self): 28 | allow_no_checkpoint = set([ 29 | # TODO(petebu): Remove this once optimizer goldens check works. 30 | snt.optimizers.Adam, 31 | snt.optimizers.Momentum, 32 | snt.optimizers.RMSProp, 33 | snt.optimizers.SGD, 34 | 35 | # Stateless or abstract. 36 | snt.BatchApply, 37 | snt.DeepRNN, 38 | snt.Deferred, 39 | snt.Flatten, 40 | snt.Metric, 41 | snt.Module, 42 | snt.Optimizer, 43 | snt.Reshape, 44 | snt.RNNCore, 45 | snt.Sequential, 46 | snt.UnrolledRNN, 47 | 48 | # Tested via snt.nets.ResNet 49 | snt.nets.ResNet50, 50 | snt.nets.resnet.BottleNeckBlockV1, 51 | snt.nets.resnet.BottleNeckBlockV2, 52 | snt.nets.resnet.BlockGroup 53 | ]) 54 | 55 | # Find all the snt.Module types reachable from `import sonnet as snt` 56 | all_sonnet_types = set() 57 | for _, python_module in test_utils.find_sonnet_python_modules(snt): 58 | for _, cls in inspect.getmembers(python_module, inspect.isclass): 59 | if issubclass(cls, snt.Module): 60 | all_sonnet_types.add(cls) 61 | 62 | # Find all the modules that have checkpoint tests. 63 | tested_modules = {module_cls for module_cls, _, _ in goldens.list_goldens()} 64 | 65 | # Make sure we don't leave entries in allow_no_checkpoint if they are 66 | # actually tested. 67 | self.assertEmpty(tested_modules & allow_no_checkpoint) 68 | 69 | # Make sure everything is covered. 70 | self.assertEqual(tested_modules | allow_no_checkpoint, all_sonnet_types) 71 | 72 | 73 | if __name__ == "__main__": 74 | tf.test.main() 75 | -------------------------------------------------------------------------------- /sonnet/src/conformance/optimizer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Conformance tests for models and optimization.""" 16 | 17 | from absl.testing import parameterized 18 | from sonnet.src import test_utils 19 | from sonnet.src.conformance import descriptors 20 | import tensorflow as tf 21 | 22 | BATCH_MODULES = descriptors.BATCH_MODULES 23 | RECURRENT_MODULES = descriptors.RECURRENT_MODULES 24 | 25 | 26 | class OptimizerConformanceTest(test_utils.TestCase, parameterized.TestCase): 27 | 28 | @test_utils.combined_named_parameters( 29 | BATCH_MODULES + RECURRENT_MODULES, 30 | test_utils.named_bools("construct_module_in_function"), 31 | ) 32 | def test_variable_order_is_constant(self, module_fn, input_shape, dtype, 33 | construct_module_in_function): 34 | """Test that variable access order is consistent in built in modules.""" 35 | logged_variables = [] 36 | mod = [None] 37 | if not construct_module_in_function: 38 | mod[0] = module_fn() 39 | 40 | x = tf.zeros(input_shape, dtype=dtype) 41 | 42 | @tf.function(autograph=False) 43 | def f(): 44 | with tf.GradientTape() as tape: 45 | if not mod[0]: 46 | mod[0] = module_fn() 47 | mod[0](x) # pylint: disable=not-callable 48 | 49 | # Leak out the variables that were used. 50 | logged_variables.append( 51 | [(id(v), v.name) for v in tape.watched_variables()]) 52 | 53 | # NOTE: This will run `f` twice iff `f` creates params. 54 | f() 55 | 56 | if len(logged_variables) == 1: 57 | self.skipTest("Module did not create variables in forward pass.") 58 | else: 59 | assert len(logged_variables) == 2 60 | self.assertCountEqual(logged_variables[0], logged_variables[1]) 61 | 62 | if __name__ == "__main__": 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /sonnet/src/conformance/pickle_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests pickling Sonnet modules.""" 16 | 17 | import pickle 18 | 19 | from absl.testing import parameterized 20 | from sonnet.src import test_utils 21 | from sonnet.src.conformance import goldens 22 | import tensorflow as tf 23 | import tree 24 | 25 | 26 | class PickleTest(test_utils.TestCase, parameterized.TestCase): 27 | 28 | # TODO(tomhennigan) Add tests with dill and cloudpickle. 29 | 30 | @goldens.all_goldens 31 | def test_pickle(self, golden): 32 | m1 = golden.create_module() 33 | golden.create_all_variables(m1) 34 | m2 = pickle.loads(pickle.dumps(m1)) 35 | self.assertIsNot(m1, m2) 36 | 37 | # Check that module variables are recreated with equivalent properties. 38 | for v1, v2 in zip(m1.variables, m2.variables): 39 | self.assertIsNot(v1, v2) 40 | self.assertEqual(v1.name, v2.name) 41 | self.assertEqual(v1.device, v2.device) 42 | self.assertAllEqual(v1.read_value(), v2.read_value()) 43 | 44 | if golden.deterministic: 45 | y1 = golden.forward(m1) 46 | y2 = golden.forward(m2) 47 | tree.map_structure(self.assertAllEqual, y1, y2) 48 | 49 | 50 | if __name__ == "__main__": 51 | tf.test.main() 52 | -------------------------------------------------------------------------------- /sonnet/src/conformance/saved_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests using tf.saved_model and Sonnet.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import sonnet as snt 22 | from sonnet.src import test_utils 23 | from sonnet.src.conformance import goldens 24 | import tensorflow as tf 25 | import tree 26 | 27 | 28 | class SavedModelTest(test_utils.TestCase, parameterized.TestCase): 29 | 30 | @goldens.all_goldens 31 | def test_save_restore_cycle(self, golden): 32 | module = golden.create_module() 33 | 34 | # Create all parameters and set them to sequential (but different) values. 35 | variables = golden.create_all_variables(module) 36 | for index, variable in enumerate(variables): 37 | variable.assign(goldens.range_like(variable, start=index)) 38 | 39 | @tf.function(input_signature=[golden.input_spec]) 40 | def inference(x): 41 | return golden.forward(module, x) 42 | 43 | # Create a saved model, add a method for inference and a dependency on our 44 | # module such that it can find dependencies. 45 | saved_model = snt.Module() 46 | saved_model._module = module 47 | saved_model.inference = inference 48 | saved_model.all_variables = list(module.variables) 49 | 50 | # Sample input. 51 | x = goldens.range_like(golden.input_spec) 52 | 53 | # Run the saved model and pull variable values. 54 | saved_model.inference(x) 55 | v1 = saved_model.all_variables 56 | 57 | # Save the model to disk and restore it. 58 | tmp_dir = os.path.join(absltest.get_default_test_tmpdir(), golden.name) 59 | tf.saved_model.save(saved_model, tmp_dir) 60 | restored_model = tf.saved_model.load(tmp_dir) 61 | 62 | # Run the loaded model and pull variable values. 63 | v2 = restored_model.all_variables 64 | y2 = restored_model.inference(x) 65 | 66 | if golden.deterministic: 67 | # The output from both the saved and restored model should be close. 68 | y1 = saved_model.inference(x) 69 | # TODO(b/161972382): The restored model doesn't seem to specialize the 70 | # graph with implementation selector, so the original model uses CuDNN 71 | # calls, whereas the restored model uses the non-specialized graph which 72 | # still contains a regular Tanh op. 73 | tree.map_structure(self.assertAllClose, y1, y2) 74 | 75 | for a, b in zip(v1, v2): 76 | self.assertEqual(a.name, b.name) 77 | self.assertEqual(a.device, b.device) 78 | self.assertAllEqual(a.read_value(), b.read_value()) 79 | 80 | 81 | if __name__ == "__main__": 82 | tf.test.main() 83 | -------------------------------------------------------------------------------- /sonnet/src/conformance/tensorflow1_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests Sonnet 2 with TF1.""" 17 | 18 | import sonnet as snt 19 | from sonnet.src import test_utils 20 | import tensorflow.compat.v1 as tf 21 | 22 | 23 | class TensorFlow1Test(test_utils.TestCase): 24 | 25 | def test_requires_tf2(self): 26 | if tf.version.GIT_VERSION != "unknown": 27 | self.skipTest("This test only runs if testing against TF at head.") 28 | 29 | with self.assertRaisesRegex(AssertionError, "requires TensorFlow 2"): 30 | snt.Module() 31 | 32 | if __name__ == "__main__": 33 | tf.disable_v2_behavior() 34 | tf.test.main() 35 | -------------------------------------------------------------------------------- /sonnet/src/conformance/xla_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests Sonnet and XLA.""" 16 | 17 | import functools 18 | 19 | from absl.testing import parameterized 20 | from sonnet.src import test_utils 21 | from sonnet.src.conformance import goldens 22 | import tensorflow as tf 23 | import tree 24 | 25 | 26 | class XLATest(test_utils.TestCase, parameterized.TestCase): 27 | 28 | @goldens.all_goldens 29 | def test_compile(self, golden): 30 | mod = golden.create_module() 31 | golden.create_all_variables(mod) 32 | 33 | @tf.function 34 | def forward(): 35 | f = lambda: golden.forward(mod) 36 | out = tf.xla.experimental.compile(f) 37 | if len(out) == 1: 38 | return out[0] 39 | else: 40 | return out if out else None 41 | 42 | if self.primary_device == "TPU": 43 | # TODO(b/132329316) Remove when `xla.compile` allows tf.device(TPU). 44 | with tf.device(None): 45 | xla_out = forward() 46 | atol = golden.tpu_atol 47 | else: 48 | xla_out = forward() 49 | atol = 1e-3 50 | 51 | if golden.deterministic and not golden.has_side_effects: 52 | out = golden.forward(mod) 53 | tree.map_structure( 54 | functools.partial(self.assertAllClose, atol=atol), out, xla_out) 55 | 56 | @goldens.all_goldens 57 | def test_jit_scope(self, golden): 58 | mod = golden.create_module() 59 | golden.create_all_variables(mod) 60 | 61 | @tf.function 62 | def forward(): 63 | with tf.xla.experimental.jit_scope(): 64 | return golden.forward(mod) 65 | 66 | xla_out = forward() 67 | if self.primary_device == "TPU": 68 | atol = golden.tpu_atol 69 | else: 70 | atol = 1e-3 71 | 72 | if golden.deterministic and not golden.has_side_effects: 73 | out = golden.forward(mod) 74 | tree.map_structure( 75 | functools.partial(self.assertAllClose, atol=atol), out, xla_out) 76 | 77 | 78 | if __name__ == "__main__": 79 | tf.test.main() 80 | -------------------------------------------------------------------------------- /sonnet/src/custom_getter_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | 17 | import doctest 18 | 19 | from sonnet.src import base 20 | from sonnet.src import custom_getter 21 | from sonnet.src import test_utils 22 | import tensorflow as tf 23 | 24 | 25 | class CustomVariableGetterTest(test_utils.TestCase): 26 | 27 | def testDoesNotModifyNonVariables(self): 28 | 29 | class MyModule(base.Module): 30 | v = tf.Variable(21.) 31 | d = {} 32 | 33 | my_module = MyModule() 34 | self.assertEqual(21., self.evaluate(my_module.v)) 35 | 36 | with custom_getter.custom_variable_getter(lambda v: v * 2): 37 | self.assertEqual(42., self.evaluate(my_module.v)) 38 | my_module.d["foo"] = "bar" 39 | 40 | self.assertEqual(21., self.evaluate(my_module.v)) 41 | self.assertEqual("bar", my_module.d["foo"]) 42 | 43 | 44 | class DoctestTest(test_utils.TestCase): 45 | 46 | def testDoctest(self): 47 | num_failed, num_attempted = doctest.testmod( 48 | custom_getter, extraglobs={"snt": base}) 49 | self.assertGreater(num_attempted, 0, "No doctests found.") 50 | self.assertEqual(num_failed, 0, "{} doctests failed".format(num_failed)) 51 | 52 | 53 | if __name__ == "__main__": 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /sonnet/src/distribute/BUILD: -------------------------------------------------------------------------------- 1 | load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") 2 | 3 | package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"]) 4 | 5 | licenses(["notice"]) 6 | 7 | snt_py_library( 8 | name = "distributed_batch_norm", 9 | srcs = ["distributed_batch_norm.py"], 10 | deps = [ 11 | "//sonnet/src:batch_norm", 12 | "//sonnet/src:initializers", 13 | "//sonnet/src:metrics", 14 | "//sonnet/src:once", 15 | "//sonnet/src:types", 16 | # pip: tensorflow 17 | ], 18 | ) 19 | 20 | snt_py_test( 21 | name = "distributed_batch_norm_test", 22 | srcs = ["distributed_batch_norm_test.py"], 23 | deps = [ 24 | ":distributed_batch_norm", 25 | ":replicator", 26 | # pip: absl/logging 27 | # pip: absl/testing:parameterized 28 | "//sonnet/src:test_utils", 29 | # pip: tensorflow 30 | ], 31 | ) 32 | 33 | snt_py_library( 34 | name = "replicator", 35 | srcs = ["replicator.py"], 36 | deps = [ 37 | # pip: absl/logging 38 | # pip: contextlib2 39 | "//sonnet/src:initializers", 40 | # pip: tensorflow 41 | ], 42 | ) 43 | 44 | snt_py_test( 45 | name = "replicator_test", 46 | srcs = ["replicator_test.py"], 47 | deps = [ 48 | ":replicator", 49 | ":replicator_test_utils", 50 | # pip: absl/logging 51 | # pip: absl/testing:parameterized 52 | "//sonnet/src:initializers", 53 | "//sonnet/src:test_utils", 54 | # pip: tensorflow 55 | ], 56 | ) 57 | 58 | snt_py_library( 59 | name = "replicator_test_utils", 60 | testonly = 1, 61 | srcs = ["replicator_test_utils.py"], 62 | deps = [ 63 | ":replicator", 64 | # pip: absl/logging 65 | # pip: tensorflow 66 | ], 67 | ) 68 | -------------------------------------------------------------------------------- /sonnet/src/distribute/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /sonnet/src/distribute/replicator_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utilities for tests working with replicator.""" 16 | 17 | from typing import Callable, Sequence, Tuple 18 | import unittest 19 | 20 | from absl import logging 21 | from sonnet.src.distribute import replicator as snt_replicator 22 | import tensorflow as tf 23 | 24 | 25 | def _replicator_primary_device() -> snt_replicator.Replicator: 26 | # NOTE: The explicit device list is required since currently Replicator 27 | # only considers CPU and GPU devices. This means on TPU by default we only 28 | # mirror on the local CPU. 29 | for device_type in ("TPU", "GPU", "CPU"): 30 | devices = tf.config.experimental.list_logical_devices( 31 | device_type=device_type) 32 | if devices: 33 | devices = [d.name for d in devices] 34 | logging.info("Replicating over %s", devices) 35 | return snt_replicator.Replicator(devices=devices) 36 | 37 | assert False, "No TPU/GPU or CPU found" 38 | 39 | 40 | def _tpu_replicator_or_skip_test() -> snt_replicator.TpuReplicator: 41 | tpus = tf.config.experimental.list_logical_devices(device_type="TPU") 42 | if not tpus: 43 | raise unittest.SkipTest("No TPU available.") 44 | 45 | logging.info("Using TpuReplicator over %s", [t.name for t in tpus]) 46 | return snt_replicator.TpuReplicator() 47 | 48 | 49 | Strategy = tf.distribute.Strategy 50 | 51 | 52 | def named_replicators() -> Sequence[Tuple[str, Callable[[], Strategy]]]: 53 | return (("TpuReplicator", _tpu_replicator_or_skip_test), 54 | ("Replicator", _replicator_primary_device)) 55 | -------------------------------------------------------------------------------- /sonnet/src/dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Sonnet dropout modules.""" 17 | 18 | from typing import Optional 19 | 20 | from sonnet.src import base 21 | from sonnet.src import types 22 | from sonnet.src import utils 23 | import tensorflow as tf 24 | 25 | 26 | class Dropout(base.Module): 27 | """Randomly drop units in the input at a given rate. 28 | 29 | See: http://www.cs.toronto.edu/~hinton/absps/dropout.pdf 30 | 31 | Dropout was originally described by Hinton et al. TensorFlow deviates slightly 32 | from this paper by scaling activations at training time rather than test time. 33 | """ 34 | 35 | def __init__(self, 36 | rate: types.FloatLike, 37 | noise_shape: Optional[types.ShapeLike] = None, 38 | seed: Optional[int] = None, 39 | name: Optional[str] = None): 40 | """Constructs a Dropout module. 41 | 42 | Args: 43 | rate: Probability that each element of x is discarded. Must be a scalar in 44 | the range `[0, 1)`. 45 | noise_shape: (Optional) Shape vector controlling the shape of the random 46 | noise used to apply dropout. If not set this will be the shape of the 47 | input. If set it should be broadcastable to the input shape. 48 | seed: (Optional) Random seed to be passed to TensorFlow ops when 49 | generating dropout tensor. 50 | name: (Optional) Name for this module. 51 | """ 52 | super().__init__(name=name) 53 | self._rate = rate 54 | self._noise_shape = noise_shape 55 | self._seed = seed 56 | 57 | @utils.smart_autograph 58 | def __call__(self, x: tf.Tensor, is_training: types.BoolLike) -> tf.Tensor: 59 | if not is_training: 60 | return x 61 | 62 | # NOTE: Even if `self._seed` is a constant value (e.g. `2`) this will 63 | # produce a different random dropout each call (the per-op seed is used 64 | # in conjunction with the global seed and some persistent state to produce 65 | # random values). 66 | # c.f. https://www.tensorflow.org/api_docs/python/tf/random/set_random_seed 67 | return tf.nn.dropout( 68 | x, rate=self._rate, noise_shape=self._noise_shape, seed=self._seed) 69 | -------------------------------------------------------------------------------- /sonnet/src/dropout_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.dropout.""" 16 | 17 | from absl.testing import parameterized 18 | import numpy as np 19 | from sonnet.src import dropout 20 | from sonnet.src import test_utils 21 | import tensorflow as tf 22 | 23 | 24 | class DropoutTest(test_utils.TestCase, parameterized.TestCase): 25 | 26 | @parameterized.parameters(np.arange(.0, .9, .1)) 27 | def test_sum_close(self, rate): 28 | mod = dropout.Dropout(rate=rate) 29 | x = tf.ones([1000]) 30 | rtol = 0.3 if "TPU" in self.device_types else 0.1 31 | self.assertAllClose( 32 | tf.reduce_sum(mod(x, is_training=True)), 33 | tf.reduce_sum(mod(x, is_training=False)), 34 | rtol=rtol) 35 | 36 | @parameterized.parameters(np.arange(0, .9, .1)) 37 | def test_dropout_rate(self, rate): 38 | mod = dropout.Dropout(rate=rate) 39 | x = tf.ones([1000]) 40 | x = mod(x, is_training=True) 41 | 42 | # We should have dropped something, test we're within 10% of rate. 43 | # (or 30% on a TPU) 44 | rtol = 0.3 if "TPU" in self.device_types else 0.1 45 | kept = tf.math.count_nonzero(x).numpy() 46 | keep_prob = 1 - rate 47 | self.assertAllClose(kept, 1000 * keep_prob, rtol=rtol) 48 | 49 | def test_dropout_is_actually_random(self): 50 | mod = dropout.Dropout(rate=0.5) 51 | x = tf.ones([1000]) 52 | tf.random.set_seed(1) 53 | y1 = mod(x, is_training=True) 54 | y2 = mod(x, is_training=True) 55 | self.assertNotAllClose(y1, y2) 56 | 57 | @parameterized.parameters(True, False) 58 | def test_with_tf_function_with_booleans(self, autograph): 59 | """tf.function compilation correctly handles if statement.""" 60 | 61 | layer = dropout.Dropout(rate=0.5) 62 | layer = tf.function(layer, autograph=autograph) 63 | 64 | inputs = tf.ones([2, 5, 3, 3, 3]) 65 | expected = tf.zeros_like(inputs) 66 | 67 | for is_training in (True, False): 68 | outputs = layer(inputs, is_training) 69 | self.assertEqual(outputs.shape, expected.shape) 70 | 71 | @parameterized.parameters(True, False) 72 | def test_with_tf_function_with_variables(self, autograph): 73 | """tf.function correctly handles if statement when argument is Variable.""" 74 | 75 | layer = dropout.Dropout(rate=0.5) 76 | layer = tf.function(layer, autograph=autograph) 77 | 78 | inputs = tf.ones([2, 5, 3, 3, 3]) 79 | expected = tf.zeros_like(inputs) 80 | is_training_variable = tf.Variable(False, trainable=False) 81 | 82 | for is_training in (True, False): 83 | is_training_variable.assign(is_training) 84 | outputs = layer(inputs, is_training_variable) 85 | self.assertEqual(outputs.shape, expected.shape) 86 | 87 | 88 | if __name__ == "__main__": 89 | tf.test.main() 90 | -------------------------------------------------------------------------------- /sonnet/src/embed_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.embed.""" 16 | 17 | from absl.testing import parameterized 18 | from sonnet.src import embed 19 | from sonnet.src import initializers 20 | from sonnet.src import test_utils 21 | import tensorflow as tf 22 | 23 | 24 | class EmbedTest(test_utils.TestCase, parameterized.TestCase): 25 | 26 | @parameterized.parameters([1, 10, 100]) 27 | def test_vocab_size(self, vocab_size): 28 | e = embed.Embed(vocab_size=vocab_size) 29 | self.assertEqual(e.vocab_size, vocab_size) 30 | self.assertEqual(e.embeddings.shape[0], vocab_size) 31 | 32 | @parameterized.parameters([1, 10, 100]) 33 | def test_embed_dim(self, embed_dim): 34 | e = embed.Embed(vocab_size=100, embed_dim=embed_dim) 35 | self.assertEqual(e.embed_dim, embed_dim) 36 | self.assertEqual(e.embeddings.shape[1], embed_dim) 37 | 38 | @parameterized.parameters([(1, 1), (10, 10), (100, 100)]) 39 | def test_existing_vocab(self, vocab_size, embed_dim): 40 | existing_vocab = tf.ones([vocab_size, embed_dim]) 41 | e = embed.Embed(existing_vocab=existing_vocab) 42 | self.assertEqual(e.vocab_size, vocab_size) 43 | self.assertEqual(e.embed_dim, embed_dim) 44 | self.assertAllEqual(e.embeddings.read_value(), existing_vocab) 45 | 46 | @parameterized.parameters([True, False]) 47 | def test_densify_gradients(self, densify_gradients): 48 | e = embed.Embed(1, densify_gradients=densify_gradients) 49 | with tf.GradientTape() as tape: 50 | y = e([0]) 51 | dy = tape.gradient(y, e.embeddings) 52 | if densify_gradients: 53 | self.assertIsInstance(dy, tf.Tensor) 54 | else: 55 | self.assertIsInstance(dy, tf.IndexedSlices) 56 | 57 | def test_initializer(self): 58 | e = embed.Embed(1, 1, initializer=initializers.Constant(28.)) 59 | self.assertAllEqual(e.embeddings.read_value(), [[28.]]) 60 | 61 | def test_pinned_to_cpu(self): 62 | with tf.device("CPU"): 63 | e = embed.Embed(1) 64 | spec = tf.DeviceSpec.from_string(e.embeddings.device) 65 | self.assertEqual(spec.device_type, "CPU") 66 | 67 | @parameterized.parameters([True, False]) 68 | def test_trainable(self, trainable): 69 | e = embed.Embed(1, trainable=trainable) 70 | self.assertEqual(e.embeddings.trainable, trainable) 71 | 72 | @parameterized.parameters([tf.float32, tf.float16]) 73 | def test_dtype(self, dtype): 74 | if dtype == tf.float16 and self.primary_device == "TPU": 75 | self.skipTest("float16 embeddings not supported on TPU.") 76 | e = embed.Embed(1, dtype=dtype) 77 | self.assertEqual(e.embeddings.dtype, dtype) 78 | 79 | def test_name(self): 80 | e = embed.Embed(1, name="my_embedding") 81 | self.assertEqual(e.name, "my_embedding") 82 | self.assertEqual(e.embeddings.name, "my_embedding/embeddings:0") 83 | 84 | 85 | if __name__ == "__main__": 86 | tf.test.main() 87 | -------------------------------------------------------------------------------- /sonnet/src/functional/BUILD: -------------------------------------------------------------------------------- 1 | load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") 2 | 3 | package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"]) 4 | 5 | licenses(["notice"]) 6 | 7 | snt_py_library( 8 | name = "haiku", 9 | srcs = ["haiku.py"], 10 | deps = [ 11 | ":utils", 12 | # pip: contextlib2 13 | # pip: tensorflow 14 | ], 15 | ) 16 | 17 | snt_py_library( 18 | name = "jax", 19 | srcs = ["jax.py"], 20 | deps = [ 21 | ":utils", 22 | # pip: tensorflow 23 | # pip: tree 24 | ], 25 | ) 26 | 27 | snt_py_library( 28 | name = "optimizers", 29 | srcs = ["optimizers.py"], 30 | deps = [ 31 | ":haiku", 32 | "//sonnet/src:base", 33 | # pip: tensorflow 34 | # pip: tree 35 | ], 36 | ) 37 | 38 | snt_py_library( 39 | name = "utils", 40 | srcs = ["utils.py"], 41 | deps = [ 42 | "//sonnet/src:utils", 43 | # pip: tensorflow 44 | # pip: tree 45 | ], 46 | ) 47 | 48 | snt_py_test( 49 | name = "haiku_test", 50 | srcs = ["haiku_test.py"], 51 | deps = [ 52 | ":haiku", 53 | # pip: absl/testing:parameterized 54 | "//sonnet", 55 | "//sonnet/src:test_utils", 56 | # pip: tensorflow 57 | # pip: tree 58 | ], 59 | ) 60 | 61 | snt_py_test( 62 | name = "jax_test", 63 | srcs = ["jax_test.py"], 64 | deps = [ 65 | ":jax", 66 | # pip: absl/testing:parameterized 67 | "//sonnet/src:test_utils", 68 | # pip: tensorflow 69 | ], 70 | ) 71 | 72 | snt_py_test( 73 | name = "optimizers_test", 74 | srcs = ["optimizers_test.py"], 75 | deps = [ 76 | ":haiku", 77 | ":optimizers", 78 | # pip: absl/testing:parameterized 79 | "//sonnet", 80 | "//sonnet/src:test_utils", 81 | # pip: tensorflow 82 | # pip: tree 83 | ], 84 | ) 85 | -------------------------------------------------------------------------------- /sonnet/src/functional/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /sonnet/src/functional/jax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A subset of the JAX API in TF2.""" 16 | 17 | import functools 18 | 19 | from sonnet.src.functional import utils 20 | import tensorflow as tf 21 | import tree 22 | 23 | 24 | def device_put(t, device=None): 25 | return tree.map_structure(utils.run_on_device(lambda x: x, device), t) 26 | 27 | 28 | def device_get(t): 29 | return tree.map_structure(lambda x: x.numpy(), t) 30 | 31 | 32 | # TODO(tomhennigan) This should be cached. 33 | def jit(f, device=None): 34 | if device is None: 35 | device = utils.get_first_accelerator() 36 | # TODO(tomhennigan) Enable XLA compilation (experimental_compile=True). 37 | return tf.function(utils.run_on_device(f, device)) 38 | 39 | 40 | def grad(f, argnums=0, has_aux=False): 41 | """Returns the gradient function for `f`.""" 42 | value_and_grad_f = value_and_grad(f, argnums=argnums, has_aux=has_aux) 43 | @functools.wraps(f) 44 | def wrapper(*args, **kwargs): 45 | if has_aux: 46 | (_, aux), g = value_and_grad_f(*args, **kwargs) 47 | return g, aux 48 | else: 49 | _, g = value_and_grad_f(*args, **kwargs) 50 | return g 51 | return wrapper 52 | 53 | 54 | def value_and_grad(f, argnums=0, has_aux=False): 55 | """Returns the gradient function for `f`.""" 56 | @functools.wraps(f) 57 | def wrapper(*args, **kwargs): 58 | """Computes `f` and returns derivatives of the output wrt input(s).""" 59 | params = tree.map_structure(args.__getitem__, argnums) 60 | with tf.GradientTape(watch_accessed_variables=False) as tape: 61 | tree.map_structure(tape.watch, params) 62 | out = f(*args, **kwargs) 63 | if has_aux: 64 | out, aux = out 65 | grads = tape.gradient(out, params) 66 | if has_aux: 67 | return (out, aux), grads 68 | else: 69 | return out, grads 70 | return wrapper 71 | -------------------------------------------------------------------------------- /sonnet/src/functional/jax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for Sonnet JAX interop layer.""" 16 | 17 | from absl.testing import parameterized 18 | from sonnet.src import test_utils 19 | from sonnet.src.functional import jax 20 | import tensorflow as tf 21 | 22 | 23 | class JaxTest(test_utils.TestCase, parameterized.TestCase): 24 | 25 | def test_jit_copies_to_device(self): 26 | accelerators = get_accelerators() 27 | if not accelerators: 28 | self.skipTest("No accelerator.") 29 | 30 | with tf.device("CPU"): 31 | x = tf.ones([]) 32 | 33 | self.assertTrue(x.device.endswith("CPU:0")) 34 | 35 | for device in accelerators: 36 | y = jax.jit(lambda x: x, device=device)(x) 37 | self.assertTrue(y.device, device) 38 | 39 | def test_device_put(self): 40 | accelerators = get_accelerators() 41 | if not accelerators: 42 | self.skipTest("No accelerator.") 43 | 44 | with tf.device("CPU"): 45 | x = tf.ones([]) 46 | 47 | for device in accelerators: 48 | y = jax.device_put(x, device=device) 49 | self.assertTrue(y.device.endswith(device)) 50 | 51 | 52 | class GradTest(test_utils.TestCase, parameterized.TestCase): 53 | 54 | def test_grad(self): 55 | f = lambda x: x ** 2 56 | g = jax.grad(f) 57 | x = tf.constant(4.) 58 | self.assertAllClose(g(x).numpy(), (2 * x).numpy()) 59 | 60 | def test_argnums(self): 61 | f = lambda x, y: (x ** 2 + y ** 2) 62 | g = jax.grad(f, argnums=(0, 1)) 63 | x = tf.constant(4.) 64 | y = tf.constant(5.) 65 | gx, gy = g(x, y) 66 | self.assertAllClose(gx.numpy(), (2 * x).numpy()) 67 | self.assertAllClose(gy.numpy(), (2 * y).numpy(), rtol=1e-3) 68 | 69 | def test_has_aux(self): 70 | f = lambda x: (x ** 2, "aux") 71 | g = jax.grad(f, has_aux=True) 72 | x = tf.constant(2.) 73 | gx, aux = g(x) 74 | self.assertAllClose(gx.numpy(), (2 * x).numpy()) 75 | self.assertEqual(aux, "aux") 76 | 77 | 78 | def get_accelerators(): 79 | gpus = tf.config.experimental.list_logical_devices("GPU") 80 | tpus = tf.config.experimental.list_logical_devices("TPU") 81 | return [tf.DeviceSpec.from_string(d.name).to_string() for d in gpus + tpus] 82 | 83 | if __name__ == "__main__": 84 | tf.test.main() 85 | -------------------------------------------------------------------------------- /sonnet/src/functional/optimizers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for functional optimizers.""" 16 | 17 | from absl.testing import parameterized 18 | import sonnet as snt 19 | from sonnet.src import test_utils 20 | from sonnet.src.functional import haiku 21 | from sonnet.src.functional import optimizers 22 | import tensorflow as tf 23 | import tree 24 | 25 | sgd = optimizers.optimizer(snt.optimizers.SGD) 26 | adam = optimizers.optimizer(snt.optimizers.Adam) 27 | 28 | 29 | class OptimizersTest(test_utils.TestCase, parameterized.TestCase): 30 | 31 | def test_sgd(self): 32 | with haiku.variables(): 33 | params = [tf.Variable(1.)] 34 | params = {p.ref(): tf.ones_like(p) for p in params} 35 | 36 | opt = sgd(learning_rate=0.01) 37 | opt_state = opt.init(params) 38 | grads = tree.map_structure(tf.ones_like, params) 39 | params, opt_state = opt.apply(opt_state, grads, params) 40 | p, = tree.flatten(params) 41 | self.assertAllClose(p.numpy(), 1. - (0.01 * 1)) 42 | 43 | def test_adam(self): 44 | lin = haiku.transform(snt.Linear(1)) 45 | x = tf.ones([1, 1]) 46 | params = lin.init(x) 47 | 48 | optimizer = adam(learning_rate=0.01) 49 | opt_state = optimizer.init(params) 50 | # Step + (m, v) per parameter. 51 | self.assertLen(tree.flatten(opt_state), 5) 52 | 53 | @parameterized.parameters(True, False) 54 | def test_adam_with_variable_lr(self, trainable_lr): 55 | lin = haiku.transform(snt.Linear(1)) 56 | x = tf.ones([1, 1]) 57 | initial_params = lin.init(x) 58 | 59 | with haiku.variables(): 60 | lr = tf.Variable(0.01, trainable=trainable_lr, name="lr") 61 | 62 | optimizer = adam(learning_rate=lr) 63 | initial_opt_state = optimizer.init(initial_params) 64 | # Learning rate, step + (m, v) per parameter. 65 | self.assertLen(tree.flatten(initial_opt_state), 6) 66 | 67 | grads = tree.map_structure(tf.ones_like, initial_params) 68 | params, opt_state = optimizer.apply( 69 | initial_opt_state, grads, initial_params) 70 | 71 | tree.assert_same_structure(initial_opt_state, opt_state) 72 | tree.assert_same_structure(initial_params, params) 73 | 74 | if __name__ == "__main__": 75 | tf.test.main() 76 | -------------------------------------------------------------------------------- /sonnet/src/functional/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utility functions for the JAX API in TF2.""" 16 | 17 | import functools 18 | 19 | from sonnet.src import utils 20 | import tensorflow as tf 21 | import tree 22 | 23 | 24 | def get_first_accelerator(): 25 | tpus = tf.config.experimental.list_logical_devices("TPU") 26 | if tpus: 27 | return tpus[0].name 28 | else: 29 | gpus = tf.config.experimental.list_logical_devices("GPU") 30 | return gpus[0].name if gpus else "/device:CPU:0" 31 | 32 | 33 | def run_on_device(f, device): 34 | """Runs `f` under a tf.device context on the given device.""" 35 | f = utils.smart_autograph(f) 36 | 37 | @tf.autograph.experimental.do_not_convert 38 | @functools.wraps(f) 39 | def wrapper(*args, **kwargs): 40 | with tf.device(device): 41 | args = tree.map_structure(tf.identity, args) 42 | kwargs = tree.map_structure(tf.identity, kwargs) 43 | return f(*args, **kwargs) 44 | return wrapper 45 | 46 | 47 | def get_name_scope(): 48 | with tf.name_scope("x") as ns: 49 | return ns[:-2] 50 | 51 | 52 | def first_non_none(*args): 53 | return next(a for a in args if a is not None) 54 | 55 | 56 | def compose(f0, *fs): 57 | """Composes a sequence of functions. 58 | 59 | >>> f1 = lambda a, b: f"f1({a}, {b})" 60 | >>> f2 = lambda a: f"f2({a})" 61 | >>> f3 = lambda a: f"f3({a})" 62 | >>> f = compose(f1, f2, f3) 63 | >>> f("a", "b") 64 | 'f3(f2(f1(a, b)))' 65 | 66 | Args: 67 | f0: The first function to apply. 68 | *fs: Other functions to apply in sequence. 69 | 70 | Returns: 71 | A function that is the composition of the input functions. 72 | """ 73 | def wrapper(*args, **kwargs): 74 | return functools.reduce(lambda x, f: f(x), fs, f0(*args, **kwargs)) 75 | return wrapper 76 | -------------------------------------------------------------------------------- /sonnet/src/leaky_clip_by_value.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Clipping operation with customized gradients.""" 16 | 17 | from typing import Optional 18 | 19 | import tensorflow as tf 20 | 21 | 22 | @tf.custom_gradient 23 | def leaky_clip_by_value(t: tf.Tensor, 24 | clip_value_min: tf.Tensor, 25 | clip_value_max: tf.Tensor, 26 | name: Optional[str] = None): 27 | """Clips tensor values to a specified min and max. 28 | 29 | The gradient is set to zero when tensor values are already out of bound and 30 | gradient-descent will push them even further away from the valid range. If 31 | gradient-descent pushes the values towards the valid range, the gradient will 32 | pass through without change. 33 | Note that this is assuming a gradient flow for minimization. For 34 | maximization, flip the gradient before it back-propagates to this op. 35 | 36 | Args: 37 | t: A Tensor. 38 | clip_value_min: A 0-D (scalar) Tensor, or a Tensor with the same shape as t. 39 | The minimum value to clip by. 40 | clip_value_max: A 0-D (scalar) Tensor, or a Tensor with the same shape as t. 41 | The maximum value to clip by. 42 | name: A name for the operation (optional). 43 | 44 | Returns: 45 | A clipped Tensor. 46 | 47 | Raises: 48 | ValueError: If the clip tensors would trigger array broadcasting that would 49 | make the returned tensor larger than the input. 50 | """ 51 | clip_t = tf.clip_by_value(t, clip_value_min, clip_value_max, name=name) 52 | 53 | def grad(dy): 54 | """Custom gradient.""" 55 | zeros = tf.zeros_like(dy) 56 | condition = tf.logical_or( 57 | tf.logical_and(t < clip_value_min, dy > 0), 58 | tf.logical_and(t > clip_value_max, dy < 0), 59 | ) 60 | dy = tf.where(condition, zeros, dy) 61 | return dy, None, None 62 | 63 | return clip_t, grad 64 | -------------------------------------------------------------------------------- /sonnet/src/leaky_clip_by_value_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.leaky_clip_by_value.""" 16 | 17 | from absl.testing import parameterized 18 | from sonnet.src import leaky_clip_by_value 19 | from sonnet.src import test_utils 20 | import tensorflow as tf 21 | 22 | 23 | class LeakyClipByValueTest(test_utils.TestCase, parameterized.TestCase): 24 | 25 | def test_leaky_clip_by_value_forward(self): 26 | t = tf.Variable([1.0, 2.0, 3.0]) 27 | # Test when min/max are scalar values. 28 | clip_min = [1.5] 29 | clip_max = [2.5] 30 | clip_t = leaky_clip_by_value.leaky_clip_by_value(t, clip_min, clip_max) 31 | self.assertAllEqual(clip_t.numpy(), [1.5, 2.0, 2.5]) 32 | # Test when min/max are of same sizes as t. 33 | clip_min_array = [0.5, 2.5, 2.5] 34 | clip_max_array = [1.5, 3.0, 3.5] 35 | clip_t_2 = leaky_clip_by_value.leaky_clip_by_value(t, clip_min_array, 36 | clip_max_array) 37 | self.assertAllEqual(clip_t_2.numpy(), [1.0, 2.5, 3.0]) 38 | 39 | @parameterized.parameters([ 40 | (0.5, lambda x: x, [1.0]), 41 | (1.5, lambda x: x, [1.0]), 42 | (1.5, lambda x: -x, [0.0]), 43 | (-.5, lambda x: x, [0.0]), 44 | (-.5, lambda x: -x, [-1.0]), 45 | ]) 46 | def test_leaky_clip_by_value_backward(self, init, fn, expected_grad): 47 | t = tf.Variable([init]) 48 | max_val = 1.0 49 | min_val = 0.0 50 | with tf.GradientTape() as tape: 51 | clip_t = leaky_clip_by_value.leaky_clip_by_value(t, min_val, max_val) 52 | f = fn(clip_t) 53 | grad = tape.gradient(f, t) 54 | clip_t_value = clip_t.numpy() 55 | self.assertAllEqual(grad.numpy(), expected_grad) 56 | self.assertGreaterEqual(clip_t_value, min_val) 57 | self.assertLessEqual(clip_t_value, max_val) 58 | 59 | 60 | if __name__ == "__main__": 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /sonnet/src/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Linear module.""" 16 | 17 | import math 18 | from typing import Optional 19 | 20 | from sonnet.src import base 21 | from sonnet.src import initializers 22 | from sonnet.src import once 23 | from sonnet.src import utils 24 | import tensorflow as tf 25 | 26 | 27 | class Linear(base.Module): 28 | """Linear module, optionally including bias.""" 29 | 30 | def __init__(self, 31 | output_size: int, 32 | with_bias: bool = True, 33 | w_init: Optional[initializers.Initializer] = None, 34 | b_init: Optional[initializers.Initializer] = None, 35 | name: Optional[str] = None): 36 | """Constructs a `Linear` module. 37 | 38 | Args: 39 | output_size: Output dimensionality. 40 | with_bias: Whether to include bias parameters. Default `True`. 41 | w_init: Optional initializer for the weights. By default the weights are 42 | initialized truncated random normal values with a standard deviation of 43 | `1 / sqrt(input_feature_size)`, which is commonly used when the inputs 44 | are zero centered (see https://arxiv.org/abs/1502.03167v3). 45 | b_init: Optional initializer for the bias. By default the bias is 46 | initialized to zero. 47 | name: Name of the module. 48 | """ 49 | super().__init__(name=name) 50 | self.output_size = output_size 51 | self.with_bias = with_bias 52 | self.w_init = w_init 53 | if with_bias: 54 | self.b_init = b_init if b_init is not None else initializers.Zeros() 55 | elif b_init is not None: 56 | raise ValueError("When not using a bias the b_init must be None.") 57 | 58 | @once.once 59 | def _initialize(self, inputs: tf.Tensor): 60 | """Constructs parameters used by this module.""" 61 | utils.assert_minimum_rank(inputs, 2) 62 | 63 | input_size = inputs.shape[-1] 64 | if input_size is None: # Can happen inside an @tf.function. 65 | raise ValueError("Input size must be specified at module build time.") 66 | 67 | self.input_size = input_size 68 | 69 | if self.w_init is None: 70 | # See https://arxiv.org/abs/1502.03167v3. 71 | stddev = 1 / math.sqrt(self.input_size) 72 | self.w_init = initializers.TruncatedNormal(stddev=stddev) 73 | 74 | self.w = tf.Variable( 75 | self.w_init([self.input_size, self.output_size], inputs.dtype), 76 | name="w") 77 | 78 | if self.with_bias: 79 | self.b = tf.Variable( 80 | self.b_init([self.output_size], inputs.dtype), name="b") 81 | 82 | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: 83 | self._initialize(inputs) 84 | 85 | outputs = tf.matmul(inputs, self.w) 86 | if self.with_bias: 87 | outputs = tf.add(outputs, self.b) 88 | return outputs 89 | -------------------------------------------------------------------------------- /sonnet/src/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.metrics.""" 16 | 17 | from sonnet.src import metrics 18 | from sonnet.src import test_utils 19 | import tensorflow as tf 20 | 21 | 22 | class SumTest(test_utils.TestCase): 23 | 24 | def testSimple(self): 25 | acc = metrics.Sum() 26 | self.assertAllEqual([2., 3.], acc(tf.constant([2., 3.]))) 27 | self.assertAllEqual([6., 8.], acc(tf.constant([4., 5.]))) 28 | 29 | def testInitialize(self): 30 | acc = metrics.Sum() 31 | acc.initialize(tf.constant([1., 2.])) 32 | self.assertAllEqual([0., 0.], acc.value) 33 | 34 | def testReset(self): 35 | acc = metrics.Sum() 36 | self.assertAllEqual([2., 3.], acc(tf.constant([2., 3.]))) 37 | self.assertAllEqual([6., 8.], acc(tf.constant([4., 5.]))) 38 | acc.reset() 39 | self.assertAllEqual([7., 8.], acc(tf.constant([7., 8.]))) 40 | 41 | 42 | class MeanTest(test_utils.TestCase): 43 | 44 | def testSimple(self): 45 | mean = metrics.Mean() 46 | self.assertAllEqual([2., 3.], mean(tf.constant([2., 3.]))) 47 | self.assertAllEqual([3., 4.], mean(tf.constant([4., 5.]))) 48 | 49 | def testInitialize(self): 50 | mean = metrics.Mean() 51 | mean.initialize(tf.constant([1., 2.])) 52 | self.assertAllEqual([1., 2.], mean(tf.constant([1., 2.]))) 53 | 54 | def testReset(self): 55 | mean = metrics.Mean() 56 | self.assertAllEqual([2., 3.], mean(tf.constant([2., 3.]))) 57 | self.assertAllEqual([3., 4.], mean(tf.constant([4., 5.]))) 58 | mean.reset() 59 | self.assertAllEqual([7., 8.], mean(tf.constant([7., 8.]))) 60 | 61 | 62 | if __name__ == "__main__": 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /sonnet/src/moving_averages.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Exponential moving average for Sonnet.""" 16 | 17 | from typing import Optional, cast 18 | 19 | from sonnet.src import metrics 20 | from sonnet.src import once 21 | from sonnet.src import types 22 | import tensorflow as tf 23 | 24 | 25 | class ExponentialMovingAverage(metrics.Metric): 26 | """Maintains an exponential moving average for a value. 27 | 28 | Note this module uses debiasing by default. If you don't want this please use 29 | an alternative implementation. 30 | 31 | This module keeps track of a hidden exponential moving average that is 32 | initialized as a vector of zeros which is then normalized to give the average. 33 | This gives us a moving average which isn't biased towards either zero or the 34 | initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf) 35 | 36 | Initially: 37 | 38 | hidden_0 = 0 39 | 40 | Then iteratively: 41 | 42 | hidden_i = (hidden_{i-1} - value) * (1 - decay) 43 | average_i = hidden_i / (1 - decay^i) 44 | 45 | Attributes: 46 | average: Variable holding average. Note that this is None until the first 47 | value is passed. 48 | """ 49 | 50 | def __init__(self, decay: types.FloatLike, name: Optional[str] = None): 51 | """Creates a debiased moving average module. 52 | 53 | Args: 54 | decay: The decay to use. Note values close to 1 result in a slow decay 55 | whereas values close to 0 result in faster decay, tracking the input 56 | values more closely. 57 | name: Name of the module. 58 | """ 59 | super().__init__(name=name) 60 | self._decay = decay 61 | self._counter = tf.Variable( 62 | 0, trainable=False, dtype=tf.int64, name="counter") 63 | 64 | self._hidden: tf.Variable = cast(tf.Variable, None) 65 | self.average: tf.Variable = cast(tf.Variable, None) 66 | 67 | def update(self, value: tf.Tensor): 68 | """Applies EMA to the value given.""" 69 | self.initialize(value) 70 | 71 | self._counter.assign_add(1) 72 | value = tf.convert_to_tensor(value) 73 | counter = tf.cast(self._counter, value.dtype) 74 | self._hidden.assign_sub((self._hidden - value) * (1 - self._decay)) 75 | self.average.assign((self._hidden / (1. - tf.pow(self._decay, counter)))) 76 | 77 | @property 78 | def value(self) -> tf.Tensor: 79 | """Returns the current EMA.""" 80 | return self.average.read_value() 81 | 82 | def reset(self): 83 | """Resets the EMA.""" 84 | self._counter.assign(tf.zeros_like(self._counter)) 85 | if self._hidden is not None: 86 | self._hidden.assign(tf.zeros_like(self._hidden)) 87 | if self.average is not None: 88 | self.average.assign(tf.zeros_like(self.average)) 89 | 90 | @once.once 91 | def initialize(self, value: tf.Tensor): 92 | self._hidden = tf.Variable( 93 | tf.zeros_like(value), trainable=False, name="hidden") 94 | self.average = tf.Variable( 95 | tf.zeros_like(value), trainable=False, name="average") 96 | -------------------------------------------------------------------------------- /sonnet/src/moving_averages_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.moving_averages.""" 16 | 17 | from absl.testing import parameterized 18 | from sonnet.src import moving_averages 19 | from sonnet.src import test_utils 20 | import tensorflow as tf 21 | 22 | 23 | class ExponentialMovingAverageTest(test_utils.TestCase, parameterized.TestCase): 24 | 25 | def testCall(self): 26 | ema = moving_averages.ExponentialMovingAverage(0.50) 27 | 28 | self.assertAllClose(ema(3.0).numpy(), 3.0) 29 | self.assertAllClose(ema(6.0).numpy(), 5.0) 30 | 31 | def testUpdateAndValue(self): 32 | ema = moving_averages.ExponentialMovingAverage(0.50) 33 | ema.update(3.0) 34 | self.assertAllClose(ema.value.numpy(), 3.0, atol=1e-3, rtol=1e-5) 35 | 36 | ema.update(6.0) 37 | self.assertAllClose(ema.value.numpy(), 5.0, atol=1e-3, rtol=1e-5) 38 | 39 | def testReset(self): 40 | ema = moving_averages.ExponentialMovingAverage(0.90) 41 | self.assertAllClose(ema(3.0).numpy(), 3.0, atol=1e-3, rtol=1e-5) 42 | 43 | ema.reset() 44 | self.assertEqual(ema.value.shape, ()) 45 | self.assertEqual(ema.value.numpy(), 0.0) 46 | 47 | self.assertAllClose(ema(3.0).numpy(), 3.0, atol=1e-3, rtol=1e-5) 48 | 49 | def testResetVector(self): 50 | ema = moving_averages.ExponentialMovingAverage(0.90) 51 | random_input = tf.random.normal((1, 5)) 52 | ema(random_input) 53 | ema.reset() 54 | self.assertEqual(ema.value.shape, (1, 5)) 55 | self.assertAllClose(ema.value.numpy(), tf.zeros_like(random_input)) 56 | self.assertEqual(ema._counter.dtype, tf.int64) 57 | 58 | def testValueEqualsLatestUpdate(self): 59 | ema = moving_averages.ExponentialMovingAverage(0.50) 60 | 61 | self.assertAllClose(ema(3.0).numpy(), 3.0, atol=1e-3, rtol=1e-5) 62 | self.assertAllClose(ema.value.numpy(), 3.0, atol=1e-3, rtol=1e-5) 63 | 64 | self.assertAllClose(ema(6.0).numpy(), 5.0, atol=1e-3, rtol=1e-5) 65 | self.assertAllClose(ema.value.numpy(), 5.0, atol=1e-3, rtol=1e-5) 66 | 67 | @parameterized.parameters(True, False) 68 | def testWithTFFunction(self, autograph): 69 | ema_1 = moving_averages.ExponentialMovingAverage(0.95) 70 | ema_2 = moving_averages.ExponentialMovingAverage(0.95) 71 | ema_func = tf.function(ema_2, autograph=autograph) 72 | 73 | for _ in range(10): 74 | x = tf.random.uniform((), 0, 10) 75 | self.assertAllClose( 76 | ema_1(x).numpy(), ema_func(x).numpy(), atol=1e-3, rtol=1e-5) 77 | 78 | @parameterized.parameters(True, False) 79 | def testResetWithTFFunction(self, autograph): 80 | ema = moving_averages.ExponentialMovingAverage(0.90) 81 | ema_func = tf.function(ema, autograph=autograph) 82 | self.assertAllClose(ema_func(3.0).numpy(), 3.0, atol=1e-3, rtol=1e-5) 83 | 84 | ema.reset() 85 | self.assertEqual(ema.value.numpy(), 0.0) 86 | 87 | self.assertAllClose(ema_func(3.0).numpy(), 3.0, atol=1e-3, rtol=1e-5) 88 | 89 | @parameterized.named_parameters(("2D", [2, 2]), ("3D", [1, 1, 3])) 90 | def testAlternativeShape(self, shape): 91 | ema = moving_averages.ExponentialMovingAverage(0.90) 92 | value = tf.random.uniform(shape) 93 | result = ema(value) 94 | self.assertEqual(value.shape, result.shape) 95 | 96 | 97 | if __name__ == "__main__": 98 | tf.test.main() 99 | -------------------------------------------------------------------------------- /sonnet/src/nets/BUILD: -------------------------------------------------------------------------------- 1 | load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") 2 | 3 | package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"]) 4 | 5 | licenses(["notice"]) 6 | 7 | snt_py_library( 8 | name = "mlp", 9 | srcs = ["mlp.py"], 10 | deps = [ 11 | "//sonnet/src:base", 12 | "//sonnet/src:initializers", 13 | "//sonnet/src:linear", 14 | # pip: tensorflow 15 | ], 16 | ) 17 | 18 | snt_py_test( 19 | name = "mlp_test", 20 | srcs = ["mlp_test.py"], 21 | deps = [ 22 | ":mlp", 23 | # pip: absl/testing:parameterized 24 | "//sonnet/src:test_utils", 25 | # pip: tensorflow 26 | ], 27 | ) 28 | 29 | snt_py_library( 30 | name = "cifar10_convnet", 31 | srcs = ["cifar10_convnet.py"], 32 | deps = [ 33 | "//sonnet/src:base", 34 | "//sonnet/src:batch_norm", 35 | "//sonnet/src:conv", 36 | "//sonnet/src:initializers", 37 | "//sonnet/src:linear", 38 | "//sonnet/src:types", 39 | # pip: tensorflow 40 | ], 41 | ) 42 | 43 | snt_py_test( 44 | name = "cifar10_convnet_test", 45 | timeout = "long", 46 | srcs = ["cifar10_convnet_test.py"], 47 | deps = [ 48 | ":cifar10_convnet", 49 | # pip: absl/testing:parameterized 50 | # pip: numpy 51 | "//sonnet/src:test_utils", 52 | # pip: tensorflow 53 | ], 54 | ) 55 | 56 | snt_py_library( 57 | name = "vqvae", 58 | srcs = ["vqvae.py"], 59 | deps = [ 60 | "//sonnet/src:base", 61 | "//sonnet/src:initializers", 62 | "//sonnet/src:moving_averages", 63 | "//sonnet/src:types", 64 | # pip: tensorflow 65 | ], 66 | ) 67 | 68 | snt_py_test( 69 | name = "vqvae_test", 70 | srcs = ["vqvae_test.py"], 71 | deps = [ 72 | ":vqvae", 73 | # pip: absl/testing:parameterized 74 | # pip: numpy 75 | "//sonnet/src:test_utils", 76 | # pip: tensorflow 77 | # pip: tree 78 | ], 79 | ) 80 | 81 | snt_py_library( 82 | name = "resnet", 83 | srcs = ["resnet.py"], 84 | deps = [ 85 | "//sonnet/src:base", 86 | "//sonnet/src:batch_norm", 87 | "//sonnet/src:conv", 88 | "//sonnet/src:initializers", 89 | "//sonnet/src:linear", 90 | "//sonnet/src:pad", 91 | # pip: tensorflow 92 | ], 93 | ) 94 | 95 | snt_py_test( 96 | name = "resnet_test", 97 | srcs = ["resnet_test.py"], 98 | deps = [ 99 | ":resnet", 100 | # pip: absl/testing:parameterized 101 | "//sonnet/src:test_utils", 102 | # pip: tensorflow 103 | ], 104 | ) 105 | -------------------------------------------------------------------------------- /sonnet/src/nets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /sonnet/src/nets/dnc/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Differentiable Neural Computer 3 | 4 | load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") 5 | 6 | package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"]) 7 | 8 | licenses(["notice"]) 9 | 10 | snt_py_library( 11 | name = "control", 12 | srcs = ["control.py"], 13 | deps = [ 14 | "//sonnet/src:linear", 15 | "//sonnet/src:recurrent", 16 | # pip: tensorflow 17 | ], 18 | ) 19 | 20 | snt_py_test( 21 | name = "control_test", 22 | srcs = ["control_test.py"], 23 | main = "control_test.py", 24 | deps = [ 25 | ":control", 26 | # pip: absl/testing:parameterized 27 | # pip: numpy 28 | "//sonnet/src:recurrent", 29 | "//sonnet/src:test_utils", 30 | # pip: tensorflow 31 | # pip: tree 32 | ], 33 | ) 34 | 35 | snt_py_library( 36 | name = "read", 37 | srcs = ["read.py"], 38 | deps = [ 39 | # pip: tensorflow 40 | ], 41 | ) 42 | 43 | snt_py_test( 44 | name = "read_test", 45 | srcs = ["read_test.py"], 46 | main = "read_test.py", 47 | deps = [ 48 | ":read", 49 | # pip: numpy 50 | "//sonnet/src:test_utils", 51 | # pip: tensorflow 52 | ], 53 | ) 54 | 55 | snt_py_library( 56 | name = "util", 57 | srcs = ["util.py"], 58 | deps = [ 59 | # pip: numpy 60 | # pip: tensorflow 61 | # pip: tree 62 | ], 63 | ) 64 | 65 | snt_py_test( 66 | name = "util_test", 67 | srcs = ["util_test.py"], 68 | main = "util_test.py", 69 | deps = [ 70 | ":util", 71 | # pip: absl/testing:parameterized 72 | # pip: numpy 73 | "//sonnet/src:linear", 74 | "//sonnet/src:test_utils", 75 | # pip: tensorflow 76 | # pip: tree 77 | ], 78 | ) 79 | 80 | snt_py_library( 81 | name = "write", 82 | srcs = ["write.py"], 83 | deps = [ 84 | # pip: tensorflow 85 | ], 86 | ) 87 | 88 | snt_py_test( 89 | name = "write_test", 90 | srcs = ["write_test.py"], 91 | main = "write_test.py", 92 | deps = [ 93 | ":write", 94 | # pip: numpy 95 | "//sonnet/src:test_utils", 96 | # pip: tensorflow 97 | ], 98 | ) 99 | -------------------------------------------------------------------------------- /sonnet/src/nets/dnc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /sonnet/src/nets/dnc/read.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Read modules.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def read(memory, 21 | weights, 22 | squash_op=tf.nn.tanh, 23 | squash_before_access=True, 24 | squash_after_access=False): 25 | """Read from the NTM memory. 26 | 27 | Args: 28 | memory: 3D Tensor [batch_size, memory_size, word_size]. 29 | weights: 3D Tensor [batch_size, num_reads, memory_size]. 30 | squash_op: op to perform squashing of memory or read word. 31 | squash_before_access: squash memory before read, default True. 32 | squash_after_access: squash read word, default False. 33 | 34 | Returns: 35 | 3D Tensor [batch_size, num_reads, word_size]. 36 | """ 37 | with tf.name_scope("read_memory"): 38 | if squash_before_access: 39 | squash_op(weights) 40 | read_word = tf.matmul(weights, memory) 41 | if squash_after_access: 42 | read_word = squash_op(read_word) 43 | return read_word 44 | -------------------------------------------------------------------------------- /sonnet/src/nets/dnc/read_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.nets.dnc.read.""" 16 | 17 | import numpy as np 18 | from sonnet.src import test_utils 19 | from sonnet.src.nets.dnc import read 20 | import tensorflow as tf 21 | 22 | 23 | class ReadTest(test_utils.TestCase): 24 | 25 | def testShape(self): 26 | batch_size = 4 27 | num_reads = 2 28 | memory_size = 5 29 | word_size = 3 30 | 31 | mem = tf.random.uniform([batch_size, memory_size, word_size]) 32 | weights = tf.random.uniform([batch_size, num_reads, memory_size]) 33 | values_read = read.read(mem, weights) 34 | self.assertAllEqual(values_read.shape.as_list(), 35 | [batch_size, num_reads, word_size]) 36 | 37 | def testValues(self): 38 | num_reads = 2 39 | memory_size = 5 40 | word_size = 3 41 | 42 | # Random memory and weights (batch_size=1) 43 | mem = tf.random.uniform([1, memory_size, word_size]) 44 | indices = np.random.randint(0, memory_size, size=num_reads) 45 | # One-hot representation 46 | read_weights = tf.constant( 47 | np.expand_dims(np.eye(memory_size)[indices], axis=0), dtype=tf.float32) 48 | 49 | read_values = read.read(mem, read_weights, squash_op=tf.identity) 50 | self.assertAllClose( 51 | mem.numpy()[0, indices, :], read_values.numpy()[0, ...], atol=2e-3) 52 | 53 | 54 | if __name__ == '__main__': 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /sonnet/src/nets/resnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.nets.resnet.""" 16 | 17 | from absl.testing import parameterized 18 | from sonnet.src import test_utils 19 | from sonnet.src.nets import resnet 20 | import tensorflow as tf 21 | 22 | 23 | class ResnetTest(test_utils.TestCase, parameterized.TestCase): 24 | 25 | @parameterized.parameters(True, False) 26 | def test_simple(self, resnet_v2): 27 | image = tf.random.normal([2, 64, 64, 3]) 28 | model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2) 29 | 30 | logits = model(image, is_training=True) 31 | self.assertIsNotNone(logits) 32 | self.assertEqual(logits.shape, [2, 10]) 33 | 34 | @parameterized.parameters(True, False) 35 | def test_tf_function(self, resnet_v2): 36 | image = tf.random.normal([2, 64, 64, 3]) 37 | model = resnet.ResNet( 38 | [1, 1, 1, 1], 39 | 10, 40 | resnet_v2=resnet_v2, 41 | ) 42 | f = tf.function(model) 43 | 44 | logits = f(image, is_training=True) 45 | self.assertIsNotNone(logits) 46 | self.assertEqual(logits.shape, [2, 10]) 47 | self.assertAllEqual(model(image, is_training=True).numpy(), logits.numpy()) 48 | 49 | @parameterized.parameters(3, 5) 50 | def test_error_incorrect_args_block_list(self, list_length): 51 | block_list = [i for i in range(list_length)] 52 | with self.assertRaisesRegex( 53 | ValueError, "blocks_per_group_list` must be of length 4 not {}".format( 54 | list_length)): 55 | resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5}) 56 | 57 | @parameterized.parameters(3, 5) 58 | def test_error_incorrect_args_channel_list(self, list_length): 59 | channel_list = [i for i in range(list_length)] 60 | with self.assertRaisesRegex( 61 | ValueError, 62 | "channels_per_group_list` must be of length 4 not {}".format( 63 | list_length)): 64 | resnet.ResNet([1, 1, 1, 1], 10, {"decay_rate": 0.9, "eps": 1e-5}, 65 | channels_per_group_list=channel_list) 66 | 67 | if __name__ == "__main__": 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /sonnet/src/once.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utility to run functions and methods once.""" 16 | 17 | import uuid 18 | 19 | from sonnet.src import utils 20 | 21 | _ONCE_PROPERTY = "_snt_once" 22 | 23 | 24 | def _check_no_output(output): 25 | if output is not None: 26 | raise ValueError("@snt.once decorated functions cannot return values") 27 | 28 | 29 | def once(f): 30 | """Decorator which ensures a wrapped method is only ever run once. 31 | 32 | >>> @snt.once 33 | ... def f(): 34 | ... print('Hello, world!') 35 | >>> f() 36 | Hello, world! 37 | >>> f() 38 | >>> f() 39 | 40 | If `f` is a method then it will be evaluated once per instance: 41 | 42 | >>> class MyObject: 43 | ... @snt.once 44 | ... def f(self): 45 | ... print('Hello, world!') 46 | 47 | >>> o = MyObject() 48 | >>> o.f() 49 | Hello, world! 50 | >>> o.f() 51 | 52 | >>> o2 = MyObject() 53 | >>> o2.f() 54 | Hello, world! 55 | >>> o.f() 56 | >>> o2.f() 57 | 58 | If an error is raised during execution of `f` it will be raised to the user. 59 | Next time the method is run, it will be treated as not having run before. 60 | 61 | Args: 62 | f: A function to wrap which should only be called once. 63 | 64 | Returns: 65 | Wrapped version of `f` which will only evaluate `f` the first time it is 66 | called. 67 | """ 68 | 69 | # TODO(tomhennigan) Perhaps some more human friendly identifier? 70 | once_id = uuid.uuid4() 71 | 72 | @utils.decorator 73 | def wrapper(wrapped, instance, args, kwargs): 74 | """Decorator which ensures a wrapped method is only ever run once.""" 75 | if instance is None: 76 | # NOTE: We can't use the weakset since you can't weakref None. 77 | if not wrapper.seen_none: 78 | _check_no_output(wrapped(*args, **kwargs)) 79 | wrapper.seen_none = True 80 | return 81 | 82 | # Get or set the `seen` set for this object. 83 | seen = getattr(instance, _ONCE_PROPERTY, None) 84 | if seen is None: 85 | seen = set() 86 | setattr(instance, _ONCE_PROPERTY, seen) 87 | 88 | if once_id not in seen: 89 | _check_no_output(wrapped(*args, **kwargs)) 90 | seen.add(once_id) 91 | 92 | wrapper.seen_none = False 93 | 94 | decorated = wrapper(f) # pylint: disable=no-value-for-parameter,assignment-from-none 95 | decorated.__snt_once_wrapped__ = f 96 | return decorated 97 | -------------------------------------------------------------------------------- /sonnet/src/once_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.once.""" 16 | 17 | import pickle 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from sonnet.src import once 22 | 23 | 24 | class OnceTest(parameterized.TestCase): 25 | 26 | def test_runs_once(self): 27 | r = [] 28 | 29 | @once.once 30 | def f(): 31 | r.append(None) 32 | 33 | for _ in range(3): 34 | f() 35 | 36 | self.assertEqual(r, [None]) 37 | 38 | def test_always_returns_none(self): 39 | f = once.once(lambda: "Hello, world!") 40 | with self.assertRaisesRegex(ValueError, "snt.once .* cannot return"): 41 | f() 42 | 43 | def test_does_not_cache_on_error(self): 44 | 45 | @once.once 46 | def f(): 47 | raise ValueError 48 | 49 | with self.assertRaises(ValueError): 50 | f() 51 | with self.assertRaises(ValueError): 52 | f() 53 | 54 | def test_method(self): 55 | o1 = Counter() 56 | o2 = Counter() 57 | for _ in range(10): 58 | o1.increment() 59 | o2.increment() 60 | 61 | self.assertEqual(o1.call_count, 1) 62 | self.assertEqual(o2.call_count, 1) 63 | 64 | def test_method_does_not_cache_on_error(self): 65 | 66 | class Dummy: 67 | 68 | @once.once 69 | def f(self): 70 | raise ValueError 71 | 72 | o = Dummy() 73 | with self.assertRaises(ValueError): 74 | o.f() 75 | with self.assertRaises(ValueError): 76 | o.f() 77 | 78 | def test_pickle_method_before_evaluation(self): 79 | c1 = Counter() 80 | c2 = pickle.loads(pickle.dumps(c1)) 81 | c1.increment() 82 | self.assertEqual(c1.call_count, 1) 83 | self.assertEqual(c2.call_count, 0) 84 | c2.increment() 85 | self.assertEqual(c1.call_count, 1) 86 | self.assertEqual(c2.call_count, 1) 87 | 88 | def test_pickle_method_already_evaluated(self): 89 | c1 = Counter() 90 | c1.increment() 91 | self.assertEqual(c1.call_count, 1) 92 | c2 = pickle.loads(pickle.dumps(c1)) 93 | self.assertEqual(c2.call_count, 1) 94 | c2.increment() 95 | self.assertEqual(c2.call_count, 1) 96 | 97 | def test_inline(self): 98 | r = [] 99 | f = once.once(lambda: r.append(None)) 100 | for _ in range(10): 101 | f() 102 | self.assertEqual(r, [None]) 103 | 104 | @parameterized.named_parameters( 105 | ("lambda", lambda: lambda: None), ("function", lambda: nop), 106 | ("method", lambda: NoOpCallable().nop), 107 | ("special_method", lambda: NoOpCallable().__call__), 108 | ("object", lambda: NoOpCallable())) # pylint: disable=unnecessary-lambda 109 | def test_adds_property(self, factory): 110 | f = factory() 111 | self.assertIs(once.once(f).__snt_once_wrapped__, f) 112 | 113 | 114 | def nop(): 115 | pass 116 | 117 | 118 | class NoOpCallable: 119 | 120 | def nop(self): 121 | pass 122 | 123 | def __call__(self): 124 | pass 125 | 126 | 127 | class Counter: 128 | call_count = 0 129 | 130 | @once.once 131 | def increment(self): 132 | self.call_count += 1 133 | 134 | 135 | if __name__ == "__main__": 136 | absltest.main() 137 | -------------------------------------------------------------------------------- /sonnet/src/optimizers/BUILD: -------------------------------------------------------------------------------- 1 | load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") 2 | 3 | package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"]) 4 | 5 | licenses(["notice"]) 6 | 7 | snt_py_library( 8 | name = "optimizer_tests", 9 | testonly = 1, 10 | srcs = ["optimizer_tests.py"], 11 | deps = [ 12 | # pip: absl/testing:parameterized 13 | # pip: numpy 14 | "//sonnet/src:base", 15 | "//sonnet/src:test_utils", 16 | # pip: tensorflow 17 | # pip: tree 18 | ], 19 | ) 20 | 21 | snt_py_library( 22 | name = "adam", 23 | srcs = ["adam.py"], 24 | deps = [ 25 | ":optimizer_utils", 26 | "//sonnet/src:base", 27 | "//sonnet/src:once", 28 | "//sonnet/src:types", 29 | "//sonnet/src:utils", 30 | # pip: tensorflow 31 | ], 32 | ) 33 | 34 | snt_py_test( 35 | name = "adam_test", 36 | srcs = ["adam_test.py"], 37 | shard_count = 10, 38 | deps = [ 39 | ":adam", 40 | ":optimizer_tests", 41 | "//sonnet/src:test_utils", 42 | # pip: tensorflow 43 | ], 44 | ) 45 | 46 | snt_py_library( 47 | name = "momentum", 48 | srcs = ["momentum.py"], 49 | deps = [ 50 | ":optimizer_utils", 51 | "//sonnet/src:base", 52 | "//sonnet/src:once", 53 | "//sonnet/src:types", 54 | "//sonnet/src:utils", 55 | # pip: tensorflow 56 | ], 57 | ) 58 | 59 | snt_py_test( 60 | name = "momentum_test", 61 | srcs = ["momentum_test.py"], 62 | shard_count = 10, 63 | deps = [ 64 | ":momentum", 65 | ":optimizer_tests", 66 | "//sonnet/src:test_utils", 67 | # pip: tensorflow 68 | ], 69 | ) 70 | 71 | snt_py_library( 72 | name = "rmsprop", 73 | srcs = ["rmsprop.py"], 74 | deps = [ 75 | ":optimizer_utils", 76 | "//sonnet/src:base", 77 | "//sonnet/src:once", 78 | "//sonnet/src:types", 79 | "//sonnet/src:utils", 80 | # pip: tensorflow 81 | ], 82 | ) 83 | 84 | snt_py_test( 85 | name = "rmsprop_test", 86 | srcs = ["rmsprop_test.py"], 87 | shard_count = 10, 88 | deps = [ 89 | ":optimizer_tests", 90 | ":rmsprop", 91 | "//sonnet/src:test_utils", 92 | # pip: tensorflow 93 | ], 94 | ) 95 | 96 | snt_py_library( 97 | name = "sgd", 98 | srcs = ["sgd.py"], 99 | deps = [ 100 | ":optimizer_utils", 101 | "//sonnet/src:base", 102 | "//sonnet/src:types", 103 | # pip: tensorflow 104 | ], 105 | ) 106 | 107 | snt_py_test( 108 | name = "sgd_test", 109 | srcs = ["sgd_test.py"], 110 | shard_count = 10, 111 | deps = [ 112 | ":optimizer_tests", 113 | ":sgd", 114 | # pip: tensorflow 115 | ], 116 | ) 117 | 118 | snt_py_library( 119 | name = "optimizer_utils", 120 | srcs = ["optimizer_utils.py"], 121 | deps = [ 122 | "//sonnet/src:types", 123 | "//sonnet/src/distribute:replicator", 124 | # pip: tensorflow 125 | ], 126 | ) 127 | -------------------------------------------------------------------------------- /sonnet/src/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /sonnet/src/optimizers/optimizer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utils for Sonnet optimizers.""" 16 | 17 | from typing import Sequence 18 | 19 | from sonnet.src import types 20 | from sonnet.src.distribute import replicator 21 | import tensorflow as tf 22 | 23 | # Sonnet only supports a subset of distribution strategies since it makes use of 24 | # a simplified update model and replica local variables. 25 | # TODO(cjfj,petebu,tomhennigan) Add async parameter server strategy when needed. 26 | # TODO(cjfj,petebu,tomhennigan) Add sync multi-worker GPU strategy when needed. 27 | _SUPPORTED_STRATEGIES = ( 28 | tf.distribute.OneDeviceStrategy, 29 | replicator.Replicator, 30 | replicator.TpuReplicator, 31 | ) 32 | 33 | 34 | def check_distribution_strategy(): 35 | if tf.distribute.has_strategy(): 36 | strategy = tf.distribute.get_strategy() 37 | if not isinstance(strategy, _SUPPORTED_STRATEGIES): 38 | raise ValueError("Sonnet optimizers are not compatible with `{}`. " 39 | "Please use one of `{}` instead.".format( 40 | strategy.__class__.__name__, "`, `".join( 41 | s.__name__ for s in _SUPPORTED_STRATEGIES))) 42 | 43 | 44 | def check_updates_parameters(updates: Sequence[types.ParameterUpdate], 45 | parameters: Sequence[tf.Variable]): 46 | if len(updates) != len(parameters): 47 | raise ValueError("`updates` and `parameters` must be the same length.") 48 | if not parameters: 49 | raise ValueError("`parameters` cannot be empty.") 50 | if all(x is None for x in updates): 51 | raise ValueError("No updates provided for any parameter.") 52 | 53 | 54 | def check_same_dtype(update: types.ParameterUpdate, parameter: tf.Variable): 55 | if update.dtype != parameter.dtype: 56 | raise ValueError( 57 | "DType of update {!r} is not equal to that of parameter {!r}".format( 58 | update, parameter)) 59 | 60 | 61 | def deduplicate_indexed_slices(indexed_slice: tf.IndexedSlices): 62 | """Sums `values` associated with any non-unique `indices`. 63 | 64 | Args: 65 | indexed_slice: An indexed slice with potentially duplicated indices. 66 | 67 | Returns: 68 | A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 69 | de-duplicated version of `indices` and `summed_values` contains the sum of 70 | `values` slices associated with each unique index. 71 | """ 72 | values, indices = indexed_slice.values, indexed_slice.indices 73 | unique_indices, new_index_positions = tf.unique(indices) 74 | summed_values = tf.math.unsorted_segment_sum(values, new_index_positions, 75 | tf.shape(unique_indices)[0]) 76 | return summed_values, unique_indices 77 | -------------------------------------------------------------------------------- /sonnet/src/optimizers/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Stochastic Gradient Descent module.""" 16 | 17 | from typing import Optional, Sequence, Union 18 | 19 | from sonnet.src import base 20 | from sonnet.src import types 21 | from sonnet.src.optimizers import optimizer_utils 22 | import tensorflow as tf 23 | 24 | 25 | class SGD(base.Optimizer): 26 | """Stochastic Gradient Descent (SGD) module. 27 | 28 | Attributes: 29 | learning_rate: Learning rate. 30 | """ 31 | 32 | def __init__(self, 33 | learning_rate: Union[types.FloatLike, tf.Variable], 34 | name: Optional[str] = None): 35 | """Constructs an `SGD` module. 36 | 37 | Args: 38 | learning_rate: Learning rate. 39 | name: Name of the module. 40 | """ 41 | super().__init__(name) 42 | self.learning_rate = learning_rate 43 | 44 | def apply(self, updates: Sequence[types.ParameterUpdate], 45 | parameters: Sequence[tf.Variable]): 46 | """Applies updates to parameters. 47 | 48 | Args: 49 | updates: A list of updates to apply to parameters. Updates are often 50 | gradients as returned by `tf.GradientTape.gradient`. 51 | parameters: A list of parameters. 52 | 53 | Raises: 54 | ValueError: If `updates` and `parameters` are empty, have different 55 | lengths, or have inconsistent types. 56 | """ 57 | optimizer_utils.check_distribution_strategy() 58 | optimizer_utils.check_updates_parameters(updates, parameters) 59 | for update, parameter in zip(updates, parameters): 60 | if update is not None: 61 | optimizer_utils.check_same_dtype(update, parameter) 62 | learning_rate = tf.cast(self.learning_rate, update.dtype) 63 | if isinstance(update, tf.IndexedSlices): 64 | parameter.scatter_sub( 65 | tf.IndexedSlices(update.values * learning_rate, update.indices)) 66 | else: 67 | parameter.assign_sub(update * learning_rate) 68 | -------------------------------------------------------------------------------- /sonnet/src/pad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Padding module for Sonnet.""" 16 | 17 | from typing import Callable, Sequence, Union 18 | 19 | from sonnet.src import utils 20 | 21 | Padding = Callable[[int], Sequence[int]] 22 | Paddings = Union[Padding, Sequence[Padding]] 23 | 24 | 25 | def valid(effective_kernel_size: int): # pylint: disable=unused-argument 26 | """No padding.""" 27 | return [0, 0] 28 | 29 | 30 | def same(effective_kernel_size: int): 31 | """Pads such that the output size matches input size for stride=1.""" 32 | return [(effective_kernel_size - 1) // 2, effective_kernel_size // 2] 33 | 34 | 35 | def full(effective_kernel_size: int): 36 | """Maximal padding whilst not convolving over just padded elements.""" 37 | return [effective_kernel_size - 1, effective_kernel_size - 1] 38 | 39 | 40 | def causal(effective_kernel_size: int): 41 | """Pre-padding such that output has no dependence on the future.""" 42 | return [effective_kernel_size - 1, 0] 43 | 44 | 45 | def reverse_causal(effective_kernel_size: int): 46 | """Post-padding such that output has no dependence on the past.""" 47 | return [0, effective_kernel_size - 1] 48 | 49 | 50 | def create( 51 | padding: Paddings, 52 | kernel: Union[int, Sequence[int]], 53 | rate: Union[int, Sequence[int]], 54 | n: int, 55 | channel_index: int, 56 | ): 57 | """Generates the padding required for a given padding algorithm. 58 | 59 | Args: 60 | padding: callable or list of callables of length n. The callables take an 61 | integer representing the effective kernel size (kernel size when the rate 62 | is 1) and return a list of two integers representing the padding before 63 | and padding after for that dimension. 64 | kernel: int or list of ints of length n. The size of the kernel for each 65 | dimension. If it is an int it will be replicated for the non channel and 66 | batch dimensions. 67 | rate: int or list of ints of length n. The dilation rate for each dimension. 68 | If it is an int it will be replicated for the non channel and batch 69 | dimensions. 70 | n: the number of spatial dimensions. 71 | channel_index: the channel position of the input to which the padding will 72 | be applied. 73 | 74 | Returns: 75 | A list of length n+2 containing the padding for each element. These are of 76 | the form [pad_before, pad_after]. 77 | """ 78 | # The effective kernel size includes any holes/gaps introduced by the 79 | # dilation rate. It's equal to kernel_size when rate == 1. 80 | effective_kernel_size = map( 81 | lambda kernel, rate: (kernel - 1) * rate + 1, 82 | utils.replicate(kernel, n, "kernel"), utils.replicate(rate, n, "rate")) 83 | paddings = map( 84 | lambda x, y: x(y), utils.replicate(padding, n, "padding"), 85 | effective_kernel_size) 86 | if channel_index == 1: # N, C, ... 87 | paddings = [[0, 0], [0, 0]] + list(paddings) 88 | else: # channel_index == -1 N, ..., C 89 | paddings = [[0, 0]] + list(paddings) + [[0, 0]] 90 | 91 | return paddings 92 | -------------------------------------------------------------------------------- /sonnet/src/parallel_linear_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.parallel_linear.""" 16 | 17 | from sonnet.src import linear 18 | from sonnet.src import parallel_linear 19 | from sonnet.src import test_utils 20 | import tensorflow as tf 21 | 22 | 23 | class ParallelLinearTest(test_utils.TestCase): 24 | 25 | def test_output_size_correct(self): 26 | layer = parallel_linear.ParallelLinears(3) 27 | 28 | outputs = layer(tf.ones([4, 2, 6])) 29 | self.assertEqual(outputs.shape, [4, 2, 3]) 30 | 31 | def test_behaves_same_as_stacked_linears(self): 32 | w_init = tf.random.normal((3, 5, 7)) 33 | b_init = tf.random.normal((3, 1, 7)) 34 | inputs = tf.random.normal((3, 2, 5)) 35 | 36 | parallel = parallel_linear.ParallelLinears( 37 | 7, w_init=lambda s, d: w_init, b_init=lambda s, d: b_init) 38 | parallel_outputs = parallel(inputs) 39 | 40 | stacked_outputs = [] 41 | for i in range(3): 42 | layer = linear.Linear( 43 | 7, 44 | w_init=lambda s, d, i=i: w_init[i], 45 | b_init=lambda s, d, i=i: b_init[i]) 46 | stacked_outputs.append(layer(inputs[i])) 47 | stacked_outputs = tf.stack(stacked_outputs, axis=0) 48 | 49 | self.assertAllClose(parallel_outputs.numpy(), stacked_outputs.numpy()) 50 | 51 | 52 | if __name__ == '__main__': 53 | tf.test.main() 54 | -------------------------------------------------------------------------------- /sonnet/src/regularizers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.regularizers.""" 16 | 17 | import numpy as np 18 | from sonnet.src import regularizers 19 | from sonnet.src import test_utils 20 | import tensorflow as tf 21 | 22 | 23 | class L1Test(test_utils.TestCase): 24 | 25 | def testAgainstNumPy(self): 26 | regularizer = regularizers.L1(0.01) 27 | tensors = [tf.random.uniform([42]), tf.random.uniform([24])] 28 | 29 | def l1(scale, t): 30 | return scale * np.abs(t).sum() 31 | 32 | self.assertAllClose( 33 | regularizer(tensors), 34 | sum(l1(regularizer.scale, self.evaluate(t)) for t in tensors)) 35 | 36 | def testNegativeScale(self): 37 | with self.assertRaises(ValueError): 38 | regularizers.L1(-1.0) 39 | 40 | def testEmpty(self): 41 | self.assertAllClose(regularizers.L1(0.01)([]), 0.0) 42 | 43 | 44 | class L2Test(test_utils.TestCase): 45 | 46 | def testAgainstNumPy(self): 47 | regularizer = regularizers.L2(0.01) 48 | tensors = [tf.random.uniform([42]), tf.random.uniform([24])] 49 | 50 | def l2(scale, t): 51 | return scale * np.square(t).sum() 52 | 53 | self.assertAllClose( 54 | regularizer(tensors), 55 | sum(l2(regularizer.scale, self.evaluate(t)) for t in tensors)) 56 | 57 | def testNegativeScale(self): 58 | with self.assertRaises(ValueError): 59 | regularizers.L2(-1.0) 60 | 61 | def testEmpty(self): 62 | self.assertAllClose(regularizers.L2(0.01)([]), 0.0) 63 | 64 | 65 | class OffDiagonalOrthogonalTest(test_utils.TestCase): 66 | 67 | def testAgainstNumPy(self): 68 | regularizer = regularizers.OffDiagonalOrthogonal(0.01) 69 | tensors = [tf.random.uniform([4, 2]), tf.random.uniform([2, 4])] 70 | 71 | def odo(scale, t): 72 | t2 = np.square(np.dot(t.T, t)) 73 | return scale * (t2.sum() - np.trace(t2)) 74 | 75 | atol = 1e-3 if self.primary_device == "TPU" else 1e-6 76 | self.assertAllClose( 77 | regularizer(tensors), 78 | sum(odo(regularizer.scale, self.evaluate(t)) for t in tensors), 79 | atol=atol) 80 | 81 | def testNegativeScale(self): 82 | with self.assertRaises(ValueError): 83 | regularizers.OffDiagonalOrthogonal(-1.0) 84 | 85 | def testEmpty(self): 86 | self.assertAllClose(regularizers.OffDiagonalOrthogonal(0.01)([]), 0.0) 87 | 88 | 89 | if __name__ == "__main__": 90 | tf.test.main() 91 | -------------------------------------------------------------------------------- /sonnet/src/scale_gradient.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """TensorFlow op that scales gradient for backwards pass.""" 16 | 17 | from typing import Tuple 18 | 19 | from sonnet.src import types 20 | import tensorflow as tf 21 | 22 | 23 | @tf.custom_gradient 24 | def scale_gradient( 25 | t: tf.Tensor, scale: types.FloatLike 26 | ) -> Tuple[tf.Tensor, types.GradFn]: 27 | """Scales gradients for the backwards pass. 28 | 29 | Args: 30 | t: A Tensor. 31 | scale: The scale factor for the gradient on the backwards pass. 32 | 33 | Returns: 34 | A Tensor same as input, with scaled backward gradient. 35 | """ 36 | 37 | def grad(dy: tf.Tensor) -> Tuple[tf.Tensor, None]: 38 | """Scaled gradient.""" 39 | return scale * dy, None 40 | 41 | return t, grad 42 | -------------------------------------------------------------------------------- /sonnet/src/scale_gradient_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.scale_gradient.""" 16 | 17 | import itertools 18 | 19 | from absl.testing import parameterized 20 | from sonnet.src import scale_gradient 21 | from sonnet.src import test_utils 22 | import tensorflow as tf 23 | 24 | 25 | class ScaleGradientTest(test_utils.TestCase, parameterized.TestCase): 26 | 27 | @parameterized.parameters( 28 | *itertools.product([-1.0, 0.0, 1.0], [-0.5, 0.0, 0.5, 2.0])) 29 | def test_scale(self, t_, scale): 30 | t = tf.Variable([t_]) 31 | with tf.GradientTape() as tape: 32 | y = scale_gradient.scale_gradient(t, scale) 33 | output = y * y 34 | grad = tape.gradient(output, t) 35 | self.assertAllEqual(grad.numpy(), [2 * t_ * scale]) 36 | self.assertAllEqual(output.numpy(), [t_**2]) 37 | 38 | 39 | if __name__ == "__main__": 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /sonnet/src/sequential.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Sequential applies a linear sequence of layers.""" 16 | 17 | from typing import Any, Callable, Iterable, Optional 18 | 19 | from sonnet.src import base 20 | 21 | 22 | class Sequential(base.Module): 23 | """Sequential applies a linear chain of modules / callables. 24 | 25 | >>> mlp = snt.Sequential([ 26 | ... snt.Linear(1024), 27 | ... tf.nn.relu, 28 | ... snt.Linear(10), 29 | ... ]) 30 | >>> mlp(tf.random.normal([8, 100])) 31 | 32 | 33 | Note that `Sequential` is limited in the range of possible architectures 34 | it can handle. This is a deliberate design decision; `Sequential` is only 35 | meant to be used for the simple case of fusing together modules/ops where 36 | the input of a particular module/op is the output of the previous one. 37 | 38 | Another restriction is that it is not possible to have extra arguments in the 39 | `__call__` method that are passed to the constituents of the module - for 40 | example, if there is a `BatchNorm` module in `Sequential` and the user wishes 41 | to switch the `is_training` flag. If this is the desired use case, the 42 | recommended solution is to subclass `snt.Module` and implement `__call__`: 43 | 44 | >>> class CustomModule(snt.Module): 45 | ... def __init__(self, name=None): 46 | ... super(CustomModule, self).__init__(name=name) 47 | ... self.conv2d = snt.Conv2D(32, 4, 2) 48 | ... self.bn = snt.BatchNorm() 49 | ... 50 | ... def __call__(self, inputs, is_training): 51 | ... outputs = self.conv2d(inputs) 52 | ... outputs = self.bn(outputs, is_training=is_training) 53 | ... outputs = tf.nn.relu(outputs) 54 | ... return outputs 55 | """ 56 | 57 | def __init__(self, 58 | layers: Optional[Iterable[Callable[..., Any]]] = None, 59 | name: Optional[str] = None): 60 | super().__init__(name=name) 61 | self._layers = list(layers) if layers is not None else [] 62 | 63 | def __call__(self, inputs, *args, **kwargs): 64 | outputs = inputs 65 | for i, mod in enumerate(self._layers): 66 | if i == 0: 67 | # Pass additional arguments to the first layer. 68 | outputs = mod(outputs, *args, **kwargs) 69 | else: 70 | outputs = mod(outputs) 71 | return outputs 72 | -------------------------------------------------------------------------------- /sonnet/src/sequential_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for sonnet.v2.src.sequential.""" 16 | 17 | from absl.testing import parameterized 18 | from sonnet.src import sequential 19 | from sonnet.src import test_utils 20 | import tensorflow as tf 21 | 22 | input_parameters = parameterized.parameters(object(), ([[[1.]]],), ({1, 2, 3},), 23 | None, "str", 1) 24 | 25 | 26 | class SequentialTest(test_utils.TestCase, parameterized.TestCase): 27 | 28 | @input_parameters 29 | def test_empty(self, value): 30 | net = sequential.Sequential() 31 | self.assertIs(net(value), value) 32 | 33 | @input_parameters 34 | def test_empty_drops_varargs_varkwargs(self, value): 35 | net = sequential.Sequential() 36 | self.assertIs(net(value, object(), keyword=object()), value) 37 | 38 | @input_parameters 39 | def test_identity_chain(self, value): 40 | net = sequential.Sequential([identity, identity, identity]) 41 | self.assertIs(net(value), value) 42 | 43 | def test_call(self): 44 | seq = sequential.Sequential([append_character(ch) for ch in "rocks!"]) 45 | self.assertEqual(seq("Sonnet "), "Sonnet rocks!") 46 | 47 | def test_varargs_varkwargs_to_call(self): 48 | layer1 = lambda a, b, c: ((a + b + c), (c + b + a)) 49 | layer2 = lambda a: a[0] + "," + a[1] 50 | net = sequential.Sequential([layer1, layer2]) 51 | self.assertEqual(net("a", "b", c="c"), "abc,cba") 52 | 53 | 54 | def identity(v): 55 | return v 56 | 57 | 58 | def append_character(c): 59 | return lambda v: v + c 60 | 61 | 62 | if __name__ == "__main__": 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /sonnet/src/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Type aliases for Sonnet.""" 16 | 17 | from typing import Callable, Iterable, Mapping, Optional, Sequence, Tuple, Union 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | # Parameter update type, used by optimizers. 23 | ParameterUpdate = Optional[Union[tf.Tensor, tf.IndexedSlices]] 24 | 25 | # Objects that can be treated like tensors (in TF2). 26 | TensorLike = Union[np.ndarray, tf.Tensor, tf.Variable] 27 | 28 | # Note that we have no way of statically verifying the tensor's shape. 29 | BoolLike = Union[bool, np.bool_, TensorLike] 30 | IntegerLike = Union[int, np.integer, TensorLike] 31 | FloatLike = Union[float, np.floating, TensorLike] 32 | 33 | ShapeLike = Union[int, Sequence[int], tf.TensorShape] 34 | 35 | # Note that this is effectively treated as `Any`; see b/109648354. 36 | TensorNest = Union[TensorLike, Iterable['TensorNest'], 37 | Mapping[str, 'TensorNest'],] # pytype: disable=not-supported-yet 38 | 39 | ActivationFn = Callable[[TensorLike], TensorLike] 40 | Axis = Union[int, slice, Sequence[int]] 41 | GradFn = Callable[[tf.Tensor], Tuple[tf.Tensor, Optional[tf.Tensor]]] 42 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2019 The Sonnet Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | 17 | # Pip installs the relevant dependencies and runs the Sonnet tests on CPU 18 | 19 | set -e 20 | set -x 21 | 22 | if command -v use_bazel.sh > /dev/null ; then 23 | # When running internally ensure the correct version of Bazel is used 24 | use_bazel.sh 0.26.1 25 | fi 26 | 27 | virtualenv -p python3 . 28 | source bin/activate 29 | python3 --version 30 | 31 | # Run setup.py, install dependencies first to use pip install 32 | python3 -m pip install -r requirements.txt 33 | python3 setup.py install 34 | 35 | # CPU count on macos or linux 36 | if [ "$(uname)" == "Darwin" ]; then 37 | N_JOBS=$(sysctl -n hw.logicalcpu) 38 | else 39 | N_JOBS=$(grep -c ^processor /proc/cpuinfo) 40 | fi 41 | 42 | echo "" 43 | echo "Bazel will use ${N_JOBS} concurrent job(s)." 44 | echo "" 45 | 46 | # Python test dependencies. 47 | python3 -m pip install -r requirements-test.txt 48 | python3 -m pip install -r requirements-tf.txt 49 | python3 -c 'import tensorflow as tf; print(tf.__version__)' 50 | 51 | # Run bazel test command. Double test timeouts to avoid flakes. 52 | bazel test --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ 53 | --build_tests_only --test_output=errors \ 54 | --cache_test_results=no \ 55 | -- //... 56 | 57 | # Test docs still build. 58 | cd docs/ 59 | pip install -r requirements.txt 60 | make doctest html 61 | 62 | deactivate 63 | --------------------------------------------------------------------------------