├── .bazeliskrc ├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .dockerignore ├── .flake8 ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug-performance-report.md │ └── feature-request.md ├── boring-cyborg.yml ├── pull_request_template.md ├── release-template.yml └── workflows │ ├── backport.yml │ ├── ci_test.yml │ ├── github_build_dev_container.sh │ ├── make_wheel_Linux_x86.sh │ ├── make_wheel_Windows_x86.sh │ ├── make_wheel_macOS_arm64.sh │ ├── make_wheel_macOS_x86.sh │ ├── notify_codeowners.py │ ├── notify_codeowners.yml │ ├── release-drafter.yml │ ├── release.yml │ └── validate_codeowners.yml ├── .gitignore ├── BUILD ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── MIGRATION_TO_CORE.md ├── README.md ├── STYLE_GUIDE.md ├── WORKSPACE ├── build_deps ├── build_pip_pkg.sh ├── tf_dependency │ ├── BUILD │ ├── BUILD.tpl │ ├── build_defs.bzl.tpl │ ├── tf.patch │ └── tf_configure.bzl └── toolchains │ └── gpu │ ├── BUILD │ ├── crosstool │ ├── BUILD │ ├── BUILD.tpl │ ├── CROSSTOOL.tpl │ ├── cc_toolchain_config.bzl.tpl │ ├── clang │ │ └── bin │ │ │ └── crosstool_wrapper_driver_is_not_gcc.tpl │ └── windows │ │ └── msvc_wrapper_for_nvcc.py.tpl │ ├── cub.BUILD │ ├── cuda │ ├── BUILD │ ├── BUILD.tpl │ ├── BUILD.windows.tpl │ ├── build_defs.bzl.tpl │ └── cuda_config.h.tpl │ ├── cuda_configure.bzl │ └── find_cuda_config.py ├── configure.py ├── docs ├── README.md ├── overview.md └── tutorials │ ├── README.md │ ├── _template.ipynb │ ├── _toc.yaml │ ├── average_optimizers_callback.ipynb │ ├── image_ops.ipynb │ ├── layers_normalizations.ipynb │ ├── layers_weightnormalization.ipynb │ ├── losses_triplet.ipynb │ ├── networks_seq2seq_nmt.ipynb │ ├── optimizers_conditionalgradient.ipynb │ ├── optimizers_cyclicallearningrate.ipynb │ ├── optimizers_lazyadam.ipynb │ ├── time_stopping.ipynb │ └── tqdm_progress_bar.ipynb ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── setup.py ├── tensorflow_addons ├── BUILD ├── __init__.py ├── activations │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── gelu.py │ ├── hardshrink.py │ ├── lisht.py │ ├── mish.py │ ├── rrelu.py │ ├── snake.py │ ├── softshrink.py │ ├── sparsemax.py │ ├── tanhshrink.py │ └── tests │ │ ├── __init__.py │ │ ├── activations_test.py │ │ ├── gelu_test.py │ │ ├── hardshrink_test.py │ │ ├── lisht_test.py │ │ ├── mish_test.py │ │ ├── rrelu_test.py │ │ ├── run_all_test.py │ │ ├── snake_test.py │ │ ├── softshrink_test.py │ │ ├── sparsemax_test.py │ │ └── tanhshrink_test.py ├── callbacks │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── average_model_checkpoint.py │ ├── tests │ │ ├── __init__.py │ │ ├── avg_model_checkpoint_test.py │ │ ├── run_all_test.py │ │ ├── time_stopping_test.py │ │ └── tqdm_progress_bar_test.py │ ├── time_stopping.py │ └── tqdm_progress_bar.py ├── conftest.py ├── custom_ops │ ├── README.md │ ├── image │ │ ├── BUILD │ │ └── cc │ │ │ ├── kernels │ │ │ ├── adjust_hsv_in_yiq_op.cc │ │ │ ├── adjust_hsv_in_yiq_op.h │ │ │ ├── adjust_hsv_in_yiq_op_gpu.cu.cc │ │ │ ├── connected_components.cc │ │ │ ├── connected_components.h │ │ │ ├── euclidean_distance_transform_op.cc │ │ │ ├── euclidean_distance_transform_op.h │ │ │ ├── euclidean_distance_transform_op_gpu.cu.cc │ │ │ ├── resampler_ops.cc │ │ │ ├── resampler_ops.h │ │ │ └── resampler_ops_gpu.cu.cc │ │ │ └── ops │ │ │ ├── distort_image_ops.cc │ │ │ ├── image_ops.cc │ │ │ └── resampler_ops.cc │ ├── layers │ │ ├── BUILD │ │ └── cc │ │ │ ├── kernels │ │ │ ├── correlation_cost_op.cc │ │ │ ├── correlation_cost_op.h │ │ │ ├── correlation_cost_op_gpu.cu.cc │ │ │ ├── embedding_bag_backward_kernels.cu.cc │ │ │ ├── embedding_bag_ops.cc │ │ │ ├── embedding_bag_ops.h │ │ │ └── embedding_bag_ops_gpu.cu.cc │ │ │ └── ops │ │ │ ├── correlation_cost_op.cc │ │ │ └── embedding_bag_ops.cc │ ├── seq2seq │ │ ├── BUILD │ │ └── cc │ │ │ ├── kernels │ │ │ ├── beam_search_ops.cc │ │ │ ├── beam_search_ops.h │ │ │ └── beam_search_ops_gpu.cu.cc │ │ │ └── ops │ │ │ └── beam_search_ops.cc │ └── text │ │ ├── BUILD │ │ └── cc │ │ ├── kernels │ │ ├── parse_time_kernel.cc │ │ └── skip_gram_kernels.cc │ │ └── ops │ │ ├── parse_time_op.cc │ │ └── skip_gram_ops.cc ├── image │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── color_ops.py │ ├── compose_ops.py │ ├── connected_components.py │ ├── cutout_ops.py │ ├── dense_image_warp.py │ ├── distance_transform.py │ ├── distort_image_ops.py │ ├── filters.py │ ├── interpolate_spline.py │ ├── resampler_ops.py │ ├── sparse_image_warp.py │ ├── tests │ │ ├── __init__.py │ │ ├── color_ops_test.py │ │ ├── compose_ops_test.py │ │ ├── connected_components_test.py │ │ ├── cutout_ops_test.py │ │ ├── dense_image_warp_test.py │ │ ├── distance_transform_test.py │ │ ├── distort_image_ops_test.py │ │ ├── filters_test.py │ │ ├── interpolate_spline_test.py │ │ ├── resampler_ops_test.py │ │ ├── run_all_test.py │ │ ├── sparse_image_warp_test.py │ │ ├── test_data │ │ │ ├── Yellow_Smiley_Face.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-1-clamp-0.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-1-clamp-1.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-1-clamp-4.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-2-clamp-0.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-2-clamp-1.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-2-clamp-4.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-3-clamp-0.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-3-clamp-1.png │ │ │ └── Yellow_Smiley_Face_Warp-interp-3-clamp-4.png │ │ ├── transform_ops_test.py │ │ ├── translate_ops_test.py │ │ └── utils_test.py │ ├── transform_ops.py │ ├── translate_ops.py │ └── utils.py ├── layers │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── adaptive_pooling.py │ ├── crf.py │ ├── embedding_bag.py │ ├── esn.py │ ├── gelu.py │ ├── max_unpooling_2d.py │ ├── max_unpooling_2d_v2.py │ ├── maxout.py │ ├── multihead_attention.py │ ├── netvlad.py │ ├── noisy_dense.py │ ├── normalizations.py │ ├── optical_flow.py │ ├── poincare.py │ ├── polynomial.py │ ├── snake.py │ ├── sparsemax.py │ ├── spatial_pyramid_pooling.py │ ├── spectral_normalization.py │ ├── stochastic_depth.py │ ├── tests │ │ ├── __init__.py │ │ ├── adaptive_pooling_test.py │ │ ├── crf_test.py │ │ ├── embedding_bag_test.py │ │ ├── esn_test.py │ │ ├── gelu_test.py │ │ ├── max_unpooling_2d_test.py │ │ ├── max_unpooling_2d_v2_test.py │ │ ├── maxout_test.py │ │ ├── multihead_attention_test.py │ │ ├── netvlad_test.py │ │ ├── noisy_dense_test.py │ │ ├── normalizations_test.py │ │ ├── optical_flow_test.py │ │ ├── poincare_test.py │ │ ├── polynomial_test.py │ │ ├── run_all_test.py │ │ ├── snake_test.py │ │ ├── sparsemax_test.py │ │ ├── spatial_pyramid_pooling_test.py │ │ ├── spectral_normalization_test.py │ │ ├── stochastic_depth_test.py │ │ ├── tlu_test.py │ │ └── wrappers_test.py │ ├── tlu.py │ └── wrappers.py ├── losses │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── contrastive.py │ ├── focal_loss.py │ ├── giou_loss.py │ ├── kappa_loss.py │ ├── lifted.py │ ├── metric_learning.py │ ├── npairs.py │ ├── quantiles.py │ ├── sparsemax_loss.py │ ├── tests │ │ ├── __init__.py │ │ ├── contrastive_test.py │ │ ├── focal_loss_test.py │ │ ├── giou_loss_test.py │ │ ├── kappa_loss_test.py │ │ ├── lifted_test.py │ │ ├── metric_test.py │ │ ├── npairs_test.py │ │ ├── quantiles_test.py │ │ ├── run_all_test.py │ │ ├── sparsemax_loss_test.py │ │ └── triplet_test.py │ └── triplet.py ├── metrics │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── cohens_kappa.py │ ├── f_scores.py │ ├── geometric_mean.py │ ├── hamming.py │ ├── harmonic_mean.py │ ├── matthews_correlation_coefficient.py │ ├── multilabel_confusion_matrix.py │ ├── r_square.py │ ├── streaming_correlations.py │ ├── tests │ │ ├── __init__.py │ │ ├── cohens_kappa_test.py │ │ ├── f_scores_test.py │ │ ├── geometric_mean_test.py │ │ ├── hamming_test.py │ │ ├── harmonic_mean_test.py │ │ ├── matthews_correlation_coefficient_test.py │ │ ├── metrics_test.py │ │ ├── multilabel_confusion_matrix_test.py │ │ ├── r_square_test.py │ │ ├── run_all_test.py │ │ └── streaming_correlations_test.py │ └── utils.py ├── optimizers │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── adabelief.py │ ├── average_wrapper.py │ ├── cocob.py │ ├── conditional_gradient.py │ ├── constants.py │ ├── cyclical_learning_rate.py │ ├── discriminative_layer_training.py │ ├── lamb.py │ ├── lazy_adam.py │ ├── lookahead.py │ ├── moving_average.py │ ├── novograd.py │ ├── proximal_adagrad.py │ ├── rectified_adam.py │ ├── stochastic_weight_averaging.py │ ├── tests │ │ ├── __init__.py │ │ ├── adabelief_test.py │ │ ├── cocob_test.py │ │ ├── conditional_gradient_test.py │ │ ├── cyclical_learning_rate_test.py │ │ ├── discriminative_layer_training_test.py │ │ ├── lamb_test.py │ │ ├── lazy_adam_test.py │ │ ├── lookahead_test.py │ │ ├── moving_average_test.py │ │ ├── novograd_test.py │ │ ├── proximal_adagrad_test.py │ │ ├── rectified_adam_test.py │ │ ├── run_all_test.py │ │ ├── standard_test.py │ │ ├── stochastic_weight_averaging_test.py │ │ ├── weight_decay_optimizers_test.py │ │ └── yogi_test.py │ ├── utils.py │ ├── weight_decay_optimizers.py │ └── yogi.py ├── options.py ├── register.py ├── rnn │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── abstract_rnn_cell.py │ ├── esn_cell.py │ ├── layer_norm_lstm_cell.py │ ├── layer_norm_simple_rnn_cell.py │ ├── nas_cell.py │ ├── peephole_lstm_cell.py │ └── tests │ │ ├── __init__.py │ │ ├── esn_cell_test.py │ │ ├── layer_norm_lstm_cell_test.py │ │ ├── layer_norm_simple_rnn_cell_test.py │ │ ├── nas_cell_test.py │ │ ├── peephole_lstm_cell_test.py │ │ └── run_all_test.py ├── seq2seq │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── attention_wrapper.py │ ├── basic_decoder.py │ ├── beam_search_decoder.py │ ├── decoder.py │ ├── loss.py │ ├── sampler.py │ └── tests │ │ ├── __init__.py │ │ ├── attention_wrapper_test.py │ │ ├── basic_decoder_test.py │ │ ├── beam_search_decoder_test.py │ │ ├── beam_search_ops_test.py │ │ ├── decoder_test.py │ │ ├── loss_test.py │ │ └── run_all_test.py ├── tensorflow_addons.bzl ├── testing │ ├── BUILD │ ├── __init__.py │ ├── serialization.py │ └── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ └── serialization_test.py ├── tests │ ├── __init__.py │ ├── register_test.py │ └── run_all_test.py ├── text │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── crf.py │ ├── crf_wrapper.py │ ├── parse_time_op.py │ ├── skip_gram_ops.py │ └── tests │ │ ├── __init__.py │ │ ├── crf_test.py │ │ ├── crf_wrapper_test.py │ │ ├── parse_time_op_test.py │ │ ├── run_all_test.py │ │ └── skip_gram_ops_test.py ├── utils │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── ensure_tf_install.py │ ├── keras_utils.py │ ├── resource_loader.py │ ├── test_utils.py │ ├── tests │ │ ├── __init__.py │ │ ├── keras_utils_test.py │ │ ├── run_all_test.py │ │ └── test_utils_test.py │ ├── tf_inspect.py │ ├── tf_test_utils.py │ ├── tfa_eol_msg.py │ └── types.py └── version.py └── tools ├── build_dev_container.sh ├── docker ├── build_wheel.Dockerfile ├── cpu_tests.Dockerfile ├── dev_container.Dockerfile ├── pre-commit.Dockerfile └── sanity_check.Dockerfile ├── docs ├── BUILD ├── Readme.md └── build_docs.py ├── format.py ├── install_deps ├── black.txt ├── buildifier.sh ├── clang-format.sh ├── doc_requirements.txt ├── flake8.txt ├── install_bazelisk.sh ├── pytest.txt ├── tensorflow-cpu.txt ├── tensorflow.txt └── typedapi.txt ├── install_so_files.sh ├── pre-commit.sh ├── releases └── tf_auditwheel_patch.sh ├── run_build.sh ├── run_cpu_tests.sh ├── run_google_cloud_tests.sh ├── run_gpu_tests.sh ├── run_sanity_check.sh ├── testing ├── build_and_run_tests.sh ├── parallel_gpu_execute.sh └── source_code_test.py └── update_release_version.sh /.bazeliskrc: -------------------------------------------------------------------------------- 1 | USE_BAZEL_VERSION=6.1.0 2 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG IMAGE_TYPE=latest-cpu 2 | FROM tfaddons/dev_container:$IMAGE_TYPE 3 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tensorflow Addons SIG Dev Container", 3 | 4 | // Uncomment this for GPU image 5 | // "build": { 6 | // "args": { 7 | // "IMAGE_TYPE": "latest" 8 | // } 9 | // }, 10 | 11 | "dockerFile": "Dockerfile", 12 | 13 | // Set *default* container specific settings.json values on container create. 14 | "settings": { 15 | "terminal.integrated.shell.linux": null, 16 | "python.formatting.provider": "black", 17 | "python.linting.flake8Enabled": true, 18 | "python.testing.pytestEnabled": true, 19 | "python.testing.pytestArgs": [ 20 | "./tensorflow_addons" 21 | ], 22 | "C_Cpp.clang_format_style": "{BasedOnStyle: Google}", 23 | "C_Cpp.default.includePath": [ 24 | "${workspaceFolder}/**", 25 | "/usr/local/lib/python3.6/dist-packages/tensorflow/include/" 26 | ], 27 | }, 28 | "remoteEnv": { 29 | "TF_CPP_MIN_LOG_LEVEL": "1", 30 | "TF_NEED_CUDA":"0" 31 | 32 | }, 33 | // Add the IDs of extensions you want installed when the container is created. 34 | "extensions": [ 35 | "ms-python.python", 36 | "ms-vscode.cpptools", 37 | "mine.cpplint" 38 | ], 39 | 40 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 41 | // "forwardPorts": [], 42 | 43 | // Use 'postCreateCommand' to run commands after the container is created. 44 | // "postCreateCommand": "uname -a", 45 | 46 | // Uncomment to use Docker from inside the container. See https://aka.ms/vscode-remote/samples/docker-in-docker. 47 | // "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ], 48 | 49 | // Uncomment when using a ptrace-based debugger like C++, Go, and Rust 50 | "runArgs": [ 51 | "--cap-add=SYS_PTRACE", 52 | // Uncomment this to enable Nvidia support 53 | //"--runtime=nvidia", 54 | "--security-opt", 55 | "seccomp=unconfined" ], 56 | 57 | // Uncomment to connect as a non-root user. See https://aka.ms/vscode-remote/pytest ./tensorflow_addons/layerscontainers/non-root. 58 | //"remoteUser": "vscode" 59 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .github 3 | *.Dockerfile 4 | .coverage* 5 | # C extensions 6 | *.so 7 | wheelhouse/ 8 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | 3 | ignore = 4 | # defaults flake8 ignores 5 | E121,E123,E126,E226,E24,E704,W503,W504 6 | # whitespace before ':' 7 | # https://black.readthedocs.io/en/stable/the_black_code_style.html#slices 8 | E203 9 | # missing whitespace after ',' 10 | # black takes care of that. Sometimes it may 11 | # add a comma at the end of lists. 12 | E231 13 | # Line too long 14 | # We use black, no need to enforce line length 15 | E501 16 | # lowercase ... imported as non lowercase 17 | # Useful to ignore for "import keras.backend as K" 18 | N812 19 | 20 | per-file-ignores = 21 | # imported but unused in __init__.py, that's ok. 22 | **/__init__.py:F401 23 | # import not at top okay due to TF installation check 24 | tensorflow_addons/__init__.py:F401,E402 25 | # function name should be lowercase 26 | tensorflow_addons/image/utils.py:N802 27 | tensorflow_addons/image/tests/utils_test.py:N802 28 | tensorflow_addons/image/tests/color_ops_test.py:N802 29 | tensorflow_addons/optimizers/tests/conditional_gradient_test.py:N802 30 | # variable ... in function should be lowercase 31 | tensorflow_addons/callbacks/tests/time_stopping_test.py:N806 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-performance-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug/Performance Issue 3 | about: Use this template for reporting a bug or performance issue. 4 | 5 | --- 6 | 7 | **System information** 8 | - OS Platform and Distribution (e.g., Linux Ubuntu 20.04): 9 | - TensorFlow version and how it was installed (source or binary): 10 | - TensorFlow-Addons version and how it was installed (source or binary): 11 | - Python version: 12 | - Is GPU used? (yes/no): 13 | 14 | **Describe the bug** 15 | 16 | A clear and concise description of what the bug is. 17 | 18 | **Code to reproduce the issue** 19 | 20 | Provide a reproducible test case that is the bare minimum necessary to generate the problem. 21 | 22 | **Other info / logs** 23 | 24 | Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Use this template for raising a feature request 4 | 5 | --- 6 | 7 | :exclamation: :exclamation: :exclamation: 8 | 9 | TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision: 10 | [TensorFlow Addons Wind Down](https://github.com/tensorflow/addons/issues/2807) 11 | 12 | Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA: 13 | [Keras](https://github.com/keras-team/keras) 14 | [Keras-CV](https://github.com/keras-team/keras-cv) 15 | [Keras-NLP](https://github.com/keras-team/keras-nlp) 16 | 17 | 18 | :exclamation: :exclamation: :exclamation: -------------------------------------------------------------------------------- /.github/boring-cyborg.yml: -------------------------------------------------------------------------------- 1 | labelPRBasedOnFilePath: 2 | # Subpackages 3 | activations: 4 | - tensorflow_addons/activations/**/* 5 | 6 | callbacks: 7 | - tensorflow_addons/callbacks/**/* 8 | 9 | custom-ops: 10 | - tensorflow_addons/custom_ops/**/* 11 | 12 | image: 13 | - tensorflow_addons/image/**/* 14 | 15 | layers: 16 | - tensorflow_addons/layers/**/* 17 | 18 | losses: 19 | - tensorflow_addons/losses/**/* 20 | 21 | metrics: 22 | - tensorflow_addons/metrics/**/* 23 | 24 | optimizers: 25 | - tensorflow_addons/optimizers/**/* 26 | 27 | seq2seq: 28 | - tensorflow_addons/seq2seq/**/* 29 | 30 | text: 31 | - tensorflow_addons/text/**/* 32 | 33 | # Other labels 34 | build: 35 | - build_deps/**/* 36 | - tools/releases/**/* 37 | 38 | documentation: 39 | - docs/**/* 40 | 41 | tutorials: 42 | - docs/tutorials/**/* 43 | 44 | test-cases: 45 | - tools/testing/**/ 46 | 47 | style: 48 | - STYLE_GUIDE.md 49 | 50 | github: 51 | - .github/**/* 52 | -------------------------------------------------------------------------------- /.github/release-template.yml: -------------------------------------------------------------------------------- 1 | template: | 2 | ## Release Notes 3 | 4 | $CHANGES 5 | 6 | ## Thanks to our Contributors 7 | 8 | $CONTRIBUTORS 9 | -------------------------------------------------------------------------------- /.github/workflows/backport.yml: -------------------------------------------------------------------------------- 1 | name: Backport 2 | on: 3 | pull_request: 4 | types: 5 | - closed 6 | - labeled 7 | 8 | permissions: {} 9 | 10 | jobs: 11 | backport: 12 | runs-on: ubuntu-20.04 13 | name: Backport 14 | permissions: 15 | contents: write 16 | steps: 17 | - name: Backport Bot 18 | if: github.event.pull_request.merged && ( ( github.event.action == 'closed' && contains( join( github.event.pull_request.labels.*.name ), 'backport') ) || contains( github.event.label.name, 'backport' ) ) 19 | uses: Gaurav0/backport@d69fd1d2469762a7b4007f671857e4f94deed0af # Version 1.0.24 20 | with: 21 | bot_username: bot-of-gabrieldemarmiesse 22 | bot_token: 1353d990cdb8b8ceb1b73d301dce83cc0da3db29 23 | bot_token_key: a1b2c3d47311f8e29e204f85a81b4df4a44e252c 24 | github_token: ${{ secrets.GITHUB_TOKEN }} 25 | -------------------------------------------------------------------------------- /.github/workflows/github_build_dev_container.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x -e 4 | 5 | df -h 6 | docker info 7 | # to get more disk space 8 | rm -rf /usr/share/dotnet & 9 | 10 | tools/build_dev_container.sh 11 | -------------------------------------------------------------------------------- /.github/workflows/make_wheel_Linux_x86.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | df -h 4 | docker info 5 | 6 | # to get more disk space 7 | rm -rf "$AGENT_TOOLSDIRECTORY" & 8 | rm -rf /opt/ghc & 9 | rm -rf "/usr/local/share/boost" & 10 | rm -rf /usr/share/dotnet 11 | 12 | # Tests are ran as part of make_wheel target 13 | DOCKER_BUILDKIT=1 docker build \ 14 | -f tools/docker/build_wheel.Dockerfile \ 15 | --output type=local,dest=wheelhouse \ 16 | --build-arg PY_VERSION \ 17 | --build-arg TF_VERSION \ 18 | --build-arg NIGHTLY_FLAG \ 19 | --build-arg NIGHTLY_TIME \ 20 | --build-arg SKIP_CUSTOM_OP_TESTS \ 21 | ./ 22 | -------------------------------------------------------------------------------- /.github/workflows/make_wheel_Windows_x86.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | export TF_NEED_CUDA=0 4 | export PYTHON_BIN_PATH=$(which python) 5 | export BAZEL_VC="C:/Program Files (x86)/Microsoft Visual Studio/2019/Enterprise/VC/" 6 | 7 | # Install Deps 8 | python --version 9 | python -m pip install --default-timeout=1000 wheel setuptools tensorflow==$TF_VERSION 10 | 11 | # Test 12 | bash ./tools/testing/build_and_run_tests.sh $SKIP_CUSTOM_OP_TESTS 13 | 14 | # Clean 15 | bazel clean 16 | 17 | # Build 18 | python configure.py 19 | 20 | bazel.exe build \ 21 | --noshow_progress \ 22 | --noshow_loading_progress \ 23 | --verbose_failures \ 24 | --test_output=errors \ 25 | build_pip_pkg 26 | bazel-bin/build_pip_pkg wheelhouse $NIGHTLY_FLAG 27 | -------------------------------------------------------------------------------- /.github/workflows/make_wheel_macOS_arm64.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | export TF_NEED_CUDA=0 4 | 5 | python --version 6 | python -m pip install --default-timeout=1000 delocate==0.10.3 wheel setuptools tensorflow==$TF_VERSION 7 | 8 | python configure.py 9 | # Setting DYLD_LIBRARY_PATH to help delocate finding tensorflow after the rpath invalidation 10 | export DYLD_LIBRARY_PATH=$DYLD_LIBRARY_PATH:$(python -c 'import configure; print(configure.get_tf_shared_lib_dir())') 11 | 12 | # For dynamic linking, we want the ARM version of TensorFlow. 13 | # Since we cannot run it on x86 so we need to force pip to install it regardless 14 | python -m pip install \ 15 | --platform=macosx_12_0_arm64 \ 16 | --no-deps \ 17 | --target=$(python -c 'import site; print(site.getsitepackages()[0])') \ 18 | --upgrade \ 19 | tensorflow-macos==$TF_VERSION 20 | 21 | bazel build \ 22 | --cpu=darwin_arm64 \ 23 | --copt -mmacosx-version-min=12.0 \ 24 | --linkopt -mmacosx-version-min=12.0 \ 25 | --noshow_progress \ 26 | --noshow_loading_progress \ 27 | --verbose_failures \ 28 | --test_output=errors \ 29 | build_pip_pkg 30 | 31 | bazel-bin/build_pip_pkg artifacts "--plat-name macosx_11_0_arm64 $NIGHTLY_FLAG" 32 | delocate-wheel -w wheelhouse -v --ignore-missing-dependencies artifacts/*.whl 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/make_wheel_macOS_x86.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | export TF_NEED_CUDA=0 4 | 5 | # Install Deps 6 | python --version 7 | python -m pip install --default-timeout=1000 delocate==0.10.3 wheel setuptools tensorflow==$TF_VERSION 8 | 9 | # Test 10 | bash ./tools/testing/build_and_run_tests.sh $SKIP_CUSTOM_OP_TESTS 11 | 12 | # Clean 13 | bazel clean 14 | 15 | # Build 16 | python configure.py 17 | 18 | bazel build \ 19 | --copt=-mmacosx-version-min=10.14 \ 20 | --linkopt=-mmacosx-version-min=10.14 \ 21 | --noshow_progress \ 22 | --noshow_loading_progress \ 23 | --verbose_failures \ 24 | --test_output=errors \ 25 | build_pip_pkg 26 | 27 | bazel-bin/build_pip_pkg artifacts $NIGHTLY_FLAG 28 | 29 | # Setting DYLD_LIBRARY_PATH to help delocate finding tensorflow after the rpath invalidation 30 | export DYLD_LIBRARY_PATH=$DYLD_LIBRARY_PATH:$(python -c 'import configure; print(configure.get_tf_shared_lib_dir())') 31 | delocate-wheel -w wheelhouse -v --ignore-missing-dependencies artifacts/*.whl 32 | 33 | -------------------------------------------------------------------------------- /.github/workflows/notify_codeowners.yml: -------------------------------------------------------------------------------- 1 | name: Notify codeowners 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened] 6 | 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | notify-codeowners: 13 | name: Notify codeowners 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | - uses: actions/setup-python@v1 18 | with: 19 | python-version: 3.9 20 | - run: pip install pygithub click 21 | - name: Drop a message for codeowners 22 | env: 23 | PR: ${{ steps.findPr.outputs.pr }} 24 | BOT_TOKEN: ${{ secrets.BOT_TOKEN }} 25 | run: | 26 | python .github/workflows/notify_codeowners.py \ 27 | --pull-request-id=auto \ 28 | --no-dry-run 29 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: release-drafter 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - r* 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | update_release_draft: 14 | permissions: 15 | contents: write # for release-drafter/release-drafter to create a github release 16 | pull-requests: write # for release-drafter/release-drafter to add label to PR 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: release-drafter/release-drafter@74e7c423dafbb406c9c18b1638334f67a7c891c3 # Version 5.7.0 20 | with: 21 | config-name: release-template.yml 22 | env: 23 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 24 | -------------------------------------------------------------------------------- /.github/workflows/validate_codeowners.yml: -------------------------------------------------------------------------------- 1 | name: Validate codeowners 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - r* 8 | 9 | # pull_request_target: 10 | # branches: 11 | # - master 12 | # - r* 13 | # Enable pull_request_target when notify_codeowners.py can validate the codeowners file 14 | # of the commit that triggered the workflow, not the commit the workflow is runnng on. 15 | # Otherwise, it's useless, it just check the codeowners file from the latest commit in master 16 | 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | validate-codeowners: 23 | name: Check that the CODEOWNERS is valid 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v2 27 | - uses: actions/setup-python@v1 28 | with: 29 | python-version: 3.9 30 | - run: pip install pygithub click 31 | - name: Check that the CODEOWNERS is valid 32 | env: 33 | BOT_TOKEN: ${{ secrets.BOT_TOKEN }} 34 | run: python .github/workflows/notify_codeowners.py -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | pip-wheel-metadata 27 | 28 | # Jupyter Notebook 29 | .ipynb_checkpoints 30 | 31 | # IDE 32 | .vscode/ 33 | .idea/ 34 | *.iml 35 | 36 | # Build 37 | /.bazelrc 38 | /bazel-* 39 | /artifacts 40 | .bazelrc 41 | 42 | .coverage* 43 | htmlcov 44 | 45 | wheelhouse/ 46 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | sh_binary( 2 | name = "build_pip_pkg", 3 | srcs = ["build_deps/build_pip_pkg.sh"], 4 | data = [ 5 | "LICENSE", 6 | "MANIFEST.in", 7 | "requirements.txt", 8 | "setup.py", 9 | "//tensorflow_addons", 10 | ], 11 | ) 12 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include tensorflow_addons *.so 2 | include docs/* -------------------------------------------------------------------------------- /MIGRATION_TO_CORE.md: -------------------------------------------------------------------------------- 1 | # Migration From TF-Addons To TensorFlow Core / Keras 2 | 3 | **Given the challenges of external SIG coordinating with internal roadmaps, a new 4 | process has been put in place for the core TF and Keras teams to handle migration 5 | and deprecation of Addons components. If you believe there is a strong candidate for 6 | migration please post an issue and we'll escalate it to the respective team members.** 7 | 8 | ### Criteria for Migration 9 | * The addition is widely used throughout the community 10 | * The addition is unlikely to have API changes as time progresses 11 | * The addition is well written / tested 12 | -------------------------------------------------------------------------------- /STYLE_GUIDE.md: -------------------------------------------------------------------------------- 1 | #### C++ 2 | C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). 3 | 4 | Addons uses [clang-format](https://clang.llvm.org/docs/ClangFormat.html) 5 | to check your C/C++ changes. Sometimes you have some manually formatted 6 | code that you don’t want clang-format to touch. 7 | You can disable formatting like this: 8 | 9 | ```cpp 10 | int formatted_code; 11 | // clang-format off 12 | void unformatted_code ; 13 | // clang-format on 14 | void formatted_code_again; 15 | ``` 16 | 17 | Install Clang-format 9 with: 18 | 19 | ```bash 20 | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - 21 | sudo add-apt-repository -u 'http://apt.llvm.org/bionic/ llvm-toolchain-bionic-9 main' 22 | sudo apt install clang-format-9 23 | ``` 24 | 25 | format all with: 26 | ```bash 27 | clang-format-9 -i --style=google **/*.cc **/*.h 28 | ``` 29 | 30 | #### Python 31 | 32 | Addons uses [flake8](http://flake8.pycqa.org/en/latest/) to check pep8 compliance and 33 | code analysis. 34 | 35 | Addons use [Black](https://black.readthedocs.io/en/stable/) to format our code. 36 | The continuous integration check will fail if you do not use it. 37 | 38 | Install them with: 39 | ``` 40 | pip install flake8 black 41 | ``` 42 | 43 | Be sure to run them both before you push your commits, otherwise the CI will fail! 44 | 45 | ``` 46 | python -m black ./ 47 | python -m flake8 48 | ``` 49 | 50 | #### TensorFlow Conventions 51 | 52 | Follow the guidance in the [TensorFlow Style Guide - Conventions](https://www.tensorflow.org/community/contribute/code_style#tensorflow_conventions_and_special_uses). 53 | 54 | Please note that Addons follows the conventions of the TensorFlow library, but formats our code using [PEP8](https://www.python.org/dev/peps/pep-0008/) guidelines. 55 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | load("//build_deps/tf_dependency:tf_configure.bzl", "tf_configure") 3 | load("//build_deps/toolchains/gpu:cuda_configure.bzl", "cuda_configure") 4 | 5 | tf_configure( 6 | name = "local_config_tf", 7 | ) 8 | 9 | http_archive( 10 | name = "org_tensorflow", 11 | patches = [ 12 | "//build_deps/tf_dependency:tf.patch", 13 | ], 14 | sha256 = "9cec5acb0ecf2d47b16891f8bc5bc6fbfdffe1700bdadc0d9ebe27ea34f0c220", 15 | strip_prefix = "tensorflow-2.15.0", 16 | urls = [ 17 | "https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.15.0.tar.gz", 18 | ], 19 | ) 20 | 21 | load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") 22 | 23 | tf_workspace3() 24 | 25 | load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") 26 | 27 | tf_workspace2() 28 | 29 | load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") 30 | 31 | tf_workspace1() 32 | 33 | load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") 34 | 35 | tf_workspace0() 36 | -------------------------------------------------------------------------------- /build_deps/tf_dependency/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/build_deps/tf_dependency/BUILD -------------------------------------------------------------------------------- /build_deps/tf_dependency/BUILD.tpl: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "tf_header_lib", 5 | hdrs = [":tf_header_include"], 6 | includes = ["include"], 7 | visibility = ["//visibility:public"], 8 | ) 9 | 10 | 11 | cc_library( 12 | name = "libtensorflow_framework", 13 | srcs = ["%{TF_SHARED_LIBRARY_NAME}"], 14 | visibility = ["//visibility:public"], 15 | ) 16 | 17 | %{TF_HEADER_GENRULE} 18 | %{TF_SHARED_LIBRARY_GENRULE} -------------------------------------------------------------------------------- /build_deps/tf_dependency/build_defs.bzl.tpl: -------------------------------------------------------------------------------- 1 | # Addons Build Definitions inherited from TensorFlow Core 2 | 3 | D_GLIBCXX_USE_CXX11_ABI = "%{tf_cx11_abi}" 4 | CPLUSPLUS_VERSION = "%{tf_cplusplus_ver}" 5 | -------------------------------------------------------------------------------- /build_deps/toolchains/gpu/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/build_deps/toolchains/gpu/BUILD -------------------------------------------------------------------------------- /build_deps/toolchains/gpu/crosstool/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/build_deps/toolchains/gpu/crosstool/BUILD -------------------------------------------------------------------------------- /build_deps/toolchains/gpu/crosstool/BUILD.tpl: -------------------------------------------------------------------------------- 1 | licenses(["restricted"]) 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load(":cc_toolchain_config.bzl", "cc_toolchain_config") 6 | 7 | 8 | toolchain( 9 | name = "toolchain-linux-x86_64", 10 | exec_compatible_with = [ 11 | "@platforms//os:linux", 12 | "@platforms//cpu:x86_64", 13 | ], 14 | target_compatible_with = [ 15 | "@platforms//os:linux", 16 | "@platforms//cpu:x86_64", 17 | ], 18 | toolchain = ":cc-compiler-local", 19 | toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", 20 | ) 21 | 22 | cc_toolchain_suite( 23 | name = "toolchain", 24 | toolchains = { 25 | "local|compiler": ":cc-compiler-local", 26 | "k8": ":cc-compiler-local", 27 | "ppc": ":cc-compiler-local", 28 | "aarch64": ":cc-compiler-local", 29 | }, 30 | ) 31 | 32 | cc_toolchain( 33 | name = "cc-compiler-local", 34 | all_files = "%{linker_files}", 35 | compiler_files = ":empty", 36 | dwp_files = ":empty", 37 | linker_files = "%{linker_files}", 38 | objcopy_files = ":empty", 39 | strip_files = ":empty", 40 | # To support linker flags that need to go to the start of command line 41 | # we need the toolchain to support parameter files. Parameter files are 42 | # last on the command line and contain all shared libraries to link, so all 43 | # regular options will be left of them. 44 | supports_param_files = 1, 45 | toolchain_config = ":cc-compiler-local-config", 46 | toolchain_identifier = "local_linux", 47 | ) 48 | 49 | cc_toolchain_config( 50 | name = "cc-compiler-local-config", 51 | cpu = "local", 52 | builtin_include_directories = "%{cxx_builtin_include_directories}".split(","), 53 | extra_no_canonical_prefixes_flags = ["-fno-canonical-system-headers"], 54 | host_compiler_path = "clang/bin/crosstool_wrapper_driver_is_not_gcc", 55 | host_compiler_prefix = "/usr/bin", 56 | host_compiler_warnings = [], 57 | host_unfiltered_compile_flags = [], 58 | linker_bin_path = "/usr/bin", 59 | ) 60 | 61 | filegroup( 62 | name = "empty", 63 | srcs = [], 64 | ) 65 | 66 | filegroup( 67 | name = "crosstool_wrapper_driver_is_not_gcc", 68 | srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], 69 | ) 70 | -------------------------------------------------------------------------------- /build_deps/toolchains/gpu/cub.BUILD: -------------------------------------------------------------------------------- 1 | # Description: CUB library which is a set of primitives for GPU programming. 2 | 3 | load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda") 4 | 5 | package( 6 | default_visibility = ["//visibility:public"], 7 | ) 8 | 9 | licenses(["notice"]) # BSD 10 | 11 | filegroup( 12 | name = "cub_header_files", 13 | srcs = glob([ 14 | "cub/**", 15 | ]), 16 | ) 17 | 18 | cc_library( 19 | name = "cub", 20 | hdrs = if_cuda([":cub_header_files"]), 21 | include_prefix = "gpu", 22 | deps = [ 23 | "@local_config_cuda//cuda:cuda_headers", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /build_deps/toolchains/gpu/cuda/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/build_deps/toolchains/gpu/cuda/BUILD -------------------------------------------------------------------------------- /build_deps/toolchains/gpu/cuda/build_defs.bzl.tpl: -------------------------------------------------------------------------------- 1 | # Macros for building CUDA code. 2 | def if_cuda(if_true, if_false = []): 3 | """Shorthand for select()'ing on whether we're building with CUDA. 4 | 5 | Returns a select statement which evaluates to if_true if we're building 6 | with CUDA enabled. Otherwise, the select statement evaluates to if_false. 7 | 8 | """ 9 | return select({ 10 | "@local_config_cuda//cuda:using_nvcc": if_true, 11 | "@local_config_cuda//cuda:using_clang": if_true, 12 | "//conditions:default": if_false 13 | }) 14 | 15 | 16 | def cuda_default_copts(): 17 | """Default options for all CUDA compilations.""" 18 | return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + %{cuda_extra_copts}) 19 | 20 | 21 | def cuda_is_configured(): 22 | """Returns true if CUDA was enabled during the configure process.""" 23 | return %{cuda_is_configured} 24 | 25 | def if_cuda_is_configured(x): 26 | """Tests if the CUDA was enabled during the configure process. 27 | 28 | Unlike if_cuda(), this does not require that we are building with 29 | --config=cuda. Used to allow non-CUDA code to depend on CUDA libraries. 30 | """ 31 | if cuda_is_configured(): 32 | return x 33 | return [] 34 | 35 | def cuda_header_library( 36 | name, 37 | hdrs, 38 | include_prefix = None, 39 | strip_include_prefix = None, 40 | deps = [], 41 | **kwargs): 42 | """Generates a cc_library containing both virtual and system include paths. 43 | 44 | Generates both a header-only target with virtual includes plus the full 45 | target without virtual includes. This works around the fact that bazel can't 46 | mix 'includes' and 'include_prefix' in the same target.""" 47 | 48 | native.cc_library( 49 | name = name + "_virtual", 50 | hdrs = hdrs, 51 | include_prefix = include_prefix, 52 | strip_include_prefix = strip_include_prefix, 53 | deps = deps, 54 | visibility = ["//visibility:private"], 55 | ) 56 | 57 | native.cc_library( 58 | name = name, 59 | textual_hdrs = hdrs, 60 | deps = deps + [":%s_virtual" % name], 61 | **kwargs 62 | ) 63 | -------------------------------------------------------------------------------- /build_deps/toolchains/gpu/cuda/cuda_config.h.tpl: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef CUDA_CUDA_CONFIG_H_ 17 | #define CUDA_CUDA_CONFIG_H_ 18 | 19 | #define TF_CUDA_CAPABILITIES %{cuda_compute_capabilities} 20 | 21 | #define TF_CUDA_VERSION "%{cuda_version}" 22 | #define TF_CUDNN_VERSION "%{cudnn_version}" 23 | 24 | #define TF_CUDA_TOOLKIT_PATH "%{cuda_toolkit_path}" 25 | 26 | #endif // CUDA_CUDA_CONFIG_H_ 27 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # [tensorflow.org/addons](https://tensorflow.org/addons) 2 | 3 | This directory contains the source for [tensorflow.org/addons](https://tensorflow.org/addons). 4 | 5 | It comprises two main components: 6 | 7 | ## 1. Narrative Docs 8 | 9 | Any markdown or notebook files in this directory will be published to tensorflow.org/addons. 10 | 11 | `tutorials/_toc.yaml` controls the left-nav on the tutorials tab. Make sure to keep that file up to date. 12 | Notify the tensorflow/docs team if you need to major changes. 13 | 14 | The preferred formatting for TensorFlow notebooks is to use the [tensorflow/docs](https://github.com/tensorflow/docs) [`nbfmt` tool](https://github.com/tensorflow/docs/tree/master/tools/tensorflow_docs/tools). If modifying a tutorial gives you 15 | an unreadable diff use the following commands to re-apply the standard formatting: 16 | 17 | ``` 18 | pip install git+https://github.com/tensorflow/docs 19 | python -m tensorflow_docs.tools.nbfmt {path to notebook file or directory} 20 | ``` 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/tutorials/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Addons Tutorials 2 | 3 | TensorFlow Addons welcomes and highly encourages tutorial contributions. 4 | 5 | 6 | ## How To Contribute 7 | 8 | Addons tutorials are created using [Google Colab](https://colab.research.google.com/) 9 | and the jupyter notebooks are saved to this directory in the repository. To do 10 | this, follow the below steps: 11 | 12 | 1. Create a new branch on your fork of TensorFlow Addons 13 | 2. Goto [Google Colab](https://colab.research.google.com/) and start a new 14 | notebook using addons example template: 15 | [docs/tutorials/_template.ipynb](_template.ipynb) 16 | 3. Edit the links for the "View source on GitHub" and "Run in Google Colab" 17 | URL boxes so that they match the name of your new example notebook 18 | 4. Follow the guidelines of the template 19 | 5. "Save a copy in GitHub" and select your new branch. The notebook should be 20 | named `subpackage_submodule` 21 | 6. Submit the branch as a PR on the TF-Addons GitHub 22 | -------------------------------------------------------------------------------- /docs/tutorials/_toc.yaml: -------------------------------------------------------------------------------- 1 | toc: 2 | - title: Overview 3 | path: /addons/overview 4 | - heading: Tutorials 5 | - title: Triplet loss 6 | path: /addons/tutorials/losses_triplet 7 | - title: Image Ops 8 | path: /addons/tutorials/image_ops 9 | - title: Normalization layers 10 | path: /addons/tutorials/layers_normalizations 11 | - title: Weight normalization layer 12 | path: /addons/tutorials/layers_weightnormalization 13 | - title: LazyAdam optimizer 14 | path: /addons/tutorials/optimizers_lazyadam 15 | - title: ConditionalGradient Optimizer 16 | path: /addons/tutorials/optimizers_conditionalgradient 17 | - title: CyclicalLearningRate Schedule 18 | path: /addons/tutorials/optimizers_cyclicallearningrate 19 | - title: TQDM Progress Bar 20 | path: /addons/tutorials/tqdm_progress_bar 21 | - title: Seq2Seq for Translation 22 | path: /addons/tutorials/networks_seq2seq_nmt 23 | - title: Moving Average Optimizer Checkpoint 24 | path: /addons/tutorials/average_optimizers_callback 25 | - title: Time Stopping Callback 26 | path: /addons/tutorials/time_stopping 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | target-version = ['py37', 'py38', 'py39', 'py310'] 3 | exclude = ''' 4 | ( 5 | /( 6 | \.eggs # exclude a few common directories in the 7 | | \.git # root of the project 8 | | \.hg 9 | | \.mypy_cache 10 | | \.tox 11 | | \.venv 12 | | _build 13 | | buck-out 14 | | build 15 | | dist 16 | )/ 17 | ) 18 | ''' 19 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -ra 3 | doctest_optionflags = ELLIPSIS NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL DONT_ACCEPT_BLANKLINE 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | typeguard>=2.7,<3.0.0 2 | packaging 3 | -------------------------------------------------------------------------------- /tensorflow_addons/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | config_setting( 6 | name = "windows", 7 | constraint_values = ["@platforms//os:windows"], 8 | ) 9 | 10 | py_library( 11 | name = "tensorflow_addons", 12 | srcs = glob(["*.py"]), 13 | deps = [ 14 | "//tensorflow_addons/activations", 15 | "//tensorflow_addons/callbacks", 16 | "//tensorflow_addons/image", 17 | "//tensorflow_addons/layers", 18 | "//tensorflow_addons/losses", 19 | "//tensorflow_addons/metrics", 20 | "//tensorflow_addons/optimizers", 21 | "//tensorflow_addons/rnn", 22 | "//tensorflow_addons/seq2seq", 23 | "//tensorflow_addons/testing", 24 | "//tensorflow_addons/text", 25 | "//tensorflow_addons/utils", 26 | ], 27 | ) 28 | 29 | py_test( 30 | name = "tensorflow_addons_test", 31 | size = "small", 32 | srcs = glob(["tests/*"]), 33 | main = "tests/run_all_test.py", 34 | deps = [ 35 | ":tensorflow_addons", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /tensorflow_addons/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Useful extra functionality for TensorFlow maintained by SIG-addons.""" 16 | from tensorflow_addons.utils.ensure_tf_install import _check_tf_version 17 | from tensorflow_addons.utils.tfa_eol_msg import _print_eol_warning 18 | 19 | _print_eol_warning() 20 | _check_tf_version() 21 | 22 | # Local project imports 23 | from tensorflow_addons import activations 24 | from tensorflow_addons import callbacks 25 | from tensorflow_addons import image 26 | from tensorflow_addons import layers 27 | from tensorflow_addons import losses 28 | from tensorflow_addons import metrics 29 | from tensorflow_addons import optimizers 30 | from tensorflow_addons import rnn 31 | from tensorflow_addons import seq2seq 32 | from tensorflow_addons import text 33 | from tensorflow_addons import options 34 | from tensorflow_addons.register import register_all 35 | from tensorflow_addons.utils import types 36 | 37 | from tensorflow_addons.version import __version__ 38 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "activations", 7 | srcs = glob(["*.py"]), 8 | data = [ 9 | "//tensorflow_addons:options.py", 10 | "//tensorflow_addons/testing", 11 | "//tensorflow_addons/utils", 12 | ], 13 | ) 14 | 15 | py_test( 16 | name = "activations_test", 17 | size = "small", 18 | srcs = glob(["tests/*"]), 19 | main = "run_all_test.py", 20 | deps = [ 21 | ":activations", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Activations 2 | 3 | ## Contents 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/activations 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all activations 9 | must: 10 | * Be a `tf.function` unless it is a straightforward call to a custom op or likely to be retraced. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * Simple unittests that demonstrate the layer is behaving as expected. 15 | * To run your `tf.functions` in eager mode and graph mode in the tests, 16 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 17 | decorator. This will run the tests twice, once normally, and once 18 | with `tf.config.run_functions_eagerly(True)`. 19 | * Add activation name to [activations_test.py](https://github.com/tensorflow/addons/tree/master/tensorflow_addons/activations/tests/activations_test.py) to test serialization. 20 | 21 | #### Documentation Requirements 22 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional activation functions.""" 16 | 17 | from tensorflow_addons.activations.gelu import gelu 18 | from tensorflow_addons.activations.hardshrink import hardshrink 19 | from tensorflow_addons.activations.lisht import lisht 20 | from tensorflow_addons.activations.mish import mish 21 | from tensorflow_addons.activations.softshrink import softshrink 22 | from tensorflow_addons.activations.rrelu import rrelu 23 | from tensorflow_addons.activations.snake import snake 24 | from tensorflow_addons.activations.sparsemax import sparsemax 25 | from tensorflow_addons.activations.tanhshrink import tanhshrink 26 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | import warnings 18 | 19 | from tensorflow_addons.utils.types import TensorLike 20 | 21 | 22 | @tf.keras.utils.register_keras_serializable(package="Addons") 23 | def gelu(x: TensorLike, approximate: bool = True) -> tf.Tensor: 24 | r"""Gaussian Error Linear Unit. 25 | 26 | Computes gaussian error linear: 27 | 28 | $$ 29 | \mathrm{gelu}(x) = x \Phi(x), 30 | $$ 31 | 32 | where 33 | 34 | $$ 35 | \Phi(x) = \frac{1}{2} \left[ 1 + \mathrm{erf}(\frac{x}{\sqrt{2}}) \right]$ 36 | $$ 37 | 38 | when `approximate` is `False`; or 39 | 40 | $$ 41 | \Phi(x) = \frac{x}{2} \left[ 1 + \tanh(\sqrt{\frac{2}{\pi}} \cdot (x + 0.044715 \cdot x^3)) \right] 42 | $$ 43 | 44 | when `approximate` is `True`. 45 | 46 | See [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415) 47 | and [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805). 48 | 49 | Consider using `tf.nn.gelu` instead. 50 | Note that the default of `approximate` changed to `False` in `tf.nn.gelu`. 51 | 52 | Usage: 53 | 54 | >>> x = tf.constant([0.0, 0.0, 1.0]) 55 | >>> tfa.activations.gelu(x, approximate=False) 56 | 57 | >>> tfa.activations.gelu(x, approximate=True) 58 | 59 | 60 | Args: 61 | x: A `Tensor`. Must be one of the following types: 62 | `float16`, `float32`, `float64`. 63 | approximate: bool, whether to enable approximation. 64 | Returns: 65 | A `Tensor`. Has the same type as `x`. 66 | """ 67 | warnings.warn( 68 | "gelu activation has been migrated to core TensorFlow, " 69 | "and will be deprecated in Addons 0.13. " 70 | "Note that the default of `approximate` changed to `False` in `tf.nn.gelu`.", 71 | DeprecationWarning, 72 | ) 73 | 74 | return tf.nn.gelu(x, approximate) 75 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/hardshrink.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | from tensorflow_addons.utils.types import Number, TensorLike 18 | 19 | 20 | @tf.keras.utils.register_keras_serializable(package="Addons") 21 | def hardshrink(x: TensorLike, lower: Number = -0.5, upper: Number = 0.5) -> tf.Tensor: 22 | r"""Hard shrink function. 23 | 24 | Computes hard shrink function: 25 | 26 | $$ 27 | \mathrm{hardshrink}(x) = 28 | \begin{cases} 29 | x & \text{if } x < \text{lower} \\ 30 | x & \text{if } x > \text{upper} \\ 31 | 0 & \text{otherwise} 32 | \end{cases}. 33 | $$ 34 | 35 | Usage: 36 | 37 | >>> x = tf.constant([1.0, 0.0, 1.0]) 38 | >>> tfa.activations.hardshrink(x) 39 | 40 | 41 | Args: 42 | x: A `Tensor`. Must be one of the following types: 43 | `bfloat16`, `float16`, `float32`, `float64`. 44 | lower: `float`, lower bound for setting values to zeros. 45 | upper: `float`, upper bound for setting values to zeros. 46 | Returns: 47 | A `Tensor`. Has the same type as `x`. 48 | """ 49 | if lower > upper: 50 | raise ValueError( 51 | "The value of lower is {} and should" 52 | " not be higher than the value " 53 | "variable upper, which is {} .".format(lower, upper) 54 | ) 55 | x = tf.convert_to_tensor(x) 56 | mask_lower = x < lower 57 | mask_upper = upper < x 58 | mask = tf.logical_or(mask_lower, mask_upper) 59 | mask = tf.cast(mask, x.dtype) 60 | return x * mask 61 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/lisht.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_addons.utils.types import TensorLike 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | def lisht(x: TensorLike) -> tf.Tensor: 23 | r"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function. 24 | 25 | Computes linearly scaled hyperbolic tangent (LiSHT): 26 | 27 | $$ 28 | \mathrm{lisht}(x) = x * \tanh(x). 29 | $$ 30 | 31 | See [LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function for Neural Networks](https://arxiv.org/abs/1901.05894). 32 | 33 | Usage: 34 | 35 | >>> x = tf.constant([1.0, 0.0, 1.0]) 36 | >>> tfa.activations.lisht(x) 37 | 38 | 39 | Args: 40 | x: A `Tensor`. Must be one of the following types: 41 | `bfloat16`, `float16`, `float32`, `float64`. 42 | Returns: 43 | A `Tensor`. Has the same type as `x`. 44 | """ 45 | x = tf.convert_to_tensor(x) 46 | return x * tf.math.tanh(x) 47 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/mish.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_addons.utils.types import TensorLike 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | def mish(x: TensorLike) -> tf.Tensor: 23 | r"""Mish: A Self Regularized Non-Monotonic Neural Activation Function. 24 | 25 | Computes mish activation: 26 | 27 | $$ 28 | \mathrm{mish}(x) = x \cdot \tanh(\mathrm{softplus}(x)). 29 | $$ 30 | 31 | See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). 32 | 33 | Usage: 34 | 35 | >>> x = tf.constant([1.0, 0.0, 1.0]) 36 | >>> tfa.activations.mish(x) 37 | 38 | 39 | Args: 40 | x: A `Tensor`. Must be one of the following types: 41 | `bfloat16`, `float16`, `float32`, `float64`. 42 | Returns: 43 | A `Tensor`. Has the same type as `x`. 44 | """ 45 | x = tf.convert_to_tensor(x) 46 | return x * tf.math.tanh(tf.math.softplus(x)) 47 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/snake.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_addons.utils import types 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | def snake(x: types.TensorLike, frequency: types.Number = 1) -> tf.Tensor: 23 | r"""Snake activation to learn periodic functions. 24 | 25 | Computes snake activation: 26 | 27 | $$ 28 | \mathrm{snake}(x) = \mathrm{x} + \frac{1 - \cos(2 \cdot \mathrm{frequency} \cdot x)}{2 \cdot \mathrm{frequency}}. 29 | $$ 30 | 31 | See [Neural Networks Fail to Learn Periodic Functions and How to Fix It](https://arxiv.org/abs/2006.08195). 32 | 33 | Usage: 34 | 35 | >>> x = tf.constant([-1.0, 0.0, 1.0]) 36 | >>> tfa.activations.snake(x) 37 | 38 | 39 | Args: 40 | x: A `Tensor`. 41 | frequency: A scalar, frequency of the periodic part. 42 | Returns: 43 | A `Tensor`. Has the same type as `x`. 44 | """ 45 | x = tf.convert_to_tensor(x) 46 | frequency = tf.cast(frequency, x.dtype) 47 | 48 | return x + (1 - tf.cos(2 * frequency * x)) / (2 * frequency) 49 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/softshrink.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | from tensorflow_addons.utils.types import Number, TensorLike 18 | 19 | 20 | @tf.keras.utils.register_keras_serializable(package="Addons") 21 | def softshrink(x: TensorLike, lower: Number = -0.5, upper: Number = 0.5) -> tf.Tensor: 22 | r"""Soft shrink function. 23 | 24 | Computes soft shrink function: 25 | 26 | $$ 27 | \mathrm{softshrink}(x) = 28 | \begin{cases} 29 | x - \mathrm{lower} & \text{if } x < \mathrm{lower} \\ 30 | x - \mathrm{upper} & \text{if } x > \mathrm{upper} \\ 31 | 0 & \text{otherwise} 32 | \end{cases}. 33 | $$ 34 | 35 | Usage: 36 | 37 | >>> x = tf.constant([-1.0, 0.0, 1.0]) 38 | >>> tfa.activations.softshrink(x) 39 | 40 | 41 | Args: 42 | x: A `Tensor`. Must be one of the following types: 43 | `bfloat16`, `float16`, `float32`, `float64`. 44 | lower: `float`, lower bound for setting values to zeros. 45 | upper: `float`, upper bound for setting values to zeros. 46 | Returns: 47 | A `Tensor`. Has the same type as `x`. 48 | """ 49 | if lower > upper: 50 | raise ValueError( 51 | "The value of lower is {} and should" 52 | " not be higher than the value " 53 | "variable upper, which is {} .".format(lower, upper) 54 | ) 55 | x = tf.convert_to_tensor(x) 56 | values_below_lower = tf.where(x < lower, x - lower, 0) 57 | values_above_upper = tf.where(upper < x, x - upper, 0) 58 | return values_below_lower + values_above_upper 59 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tanhshrink.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_addons.utils.types import TensorLike 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | def tanhshrink(x: TensorLike) -> tf.Tensor: 23 | r"""Tanh shrink function. 24 | 25 | Applies the element-wise function: 26 | 27 | $$ 28 | \mathrm{tanhshrink}(x) = x - \tanh(x). 29 | $$ 30 | 31 | Usage: 32 | 33 | >>> x = tf.constant([-1.0, 0.0, 1.0]) 34 | >>> tfa.activations.tanhshrink(x) 35 | 36 | 37 | Args: 38 | x: A `Tensor`. Must be one of the following types: 39 | `bfloat16`, `float16`, `float32`, `float64`. 40 | Returns: 41 | A `Tensor`. Has the same type as `x`. 42 | """ 43 | x = tf.convert_to_tensor(x) 44 | return x - tf.math.tanh(x) 45 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/activations/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/activations_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | import tensorflow as tf 18 | from tensorflow_addons import activations 19 | 20 | 21 | ALL_ACTIVATIONS = [ 22 | "gelu", 23 | "hardshrink", 24 | "lisht", 25 | "mish", 26 | "rrelu", 27 | "softshrink", 28 | "sparsemax", 29 | "tanhshrink", 30 | "snake", 31 | ] 32 | 33 | 34 | @pytest.mark.parametrize("name", ALL_ACTIVATIONS) 35 | def test_serialization(name): 36 | fn = tf.keras.activations.get("Addons>" + name) 37 | ref_fn = getattr(activations, name) 38 | assert fn.__class__ == ref_fn.__class__ 39 | config = tf.keras.activations.serialize(fn) 40 | fn = tf.keras.activations.deserialize(config) 41 | assert fn.__class__ == ref_fn.__class__ 42 | 43 | 44 | @pytest.mark.parametrize("name", ALL_ACTIVATIONS) 45 | def test_serialization_with_layers(name): 46 | layer = tf.keras.layers.Dense(3, activation=getattr(activations, name)) 47 | config = tf.keras.layers.serialize(layer) 48 | deserialized_layer = tf.keras.layers.deserialize(config) 49 | assert deserialized_layer.__class__.__name__ == layer.__class__.__name__ 50 | assert deserialized_layer.activation.__name__ == name 51 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/gelu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import gelu 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_gelu(dtype): 26 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 27 | expected_result = tf.constant( 28 | [-0.04540229, -0.158808, 0.0, 0.841192, 1.9545977], dtype=dtype 29 | ) 30 | test_utils.assert_allclose_according_to_type(gelu(x), expected_result) 31 | 32 | expected_result = tf.constant( 33 | [-0.04550028, -0.15865526, 0.0, 0.8413447, 1.9544997], dtype=dtype 34 | ) 35 | test_utils.assert_allclose_according_to_type(gelu(x, False), expected_result) 36 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/hardshrink_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations.hardshrink import hardshrink 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_hardshrink(dtype): 26 | x = tf.constant([-2.0, -0.5, 0.0, 0.5, 2.0], dtype=dtype) 27 | expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype) 28 | test_utils.assert_allclose_according_to_type(hardshrink(x), expected_result) 29 | 30 | expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype) 31 | test_utils.assert_allclose_according_to_type( 32 | hardshrink(x, lower=-1.0, upper=1.0), expected_result 33 | ) 34 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/lisht_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import lisht 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_lisht(dtype): 26 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 27 | expected_result = tf.constant( 28 | [1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype 29 | ) 30 | test_utils.assert_allclose_according_to_type(lisht(x), expected_result) 31 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/mish_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import mish 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_mish(dtype): 26 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 27 | expected_result = tf.constant( 28 | [-0.2525015, -0.30340144, 0.0, 0.86509836, 1.943959], dtype=dtype 29 | ) 30 | test_utils.assert_allclose_according_to_type(mish(x), expected_result) 31 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/snake_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | from tensorflow_addons.activations import snake 20 | from tensorflow_addons.utils import test_utils 21 | 22 | 23 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_activation(dtype): 26 | x = dtype(np.random.rand(2, 5)) 27 | a = dtype(np.random.randn()) 28 | expected_result = x + np.power(np.sin(a * x), 2) / a 29 | test_utils.assert_allclose_according_to_type(snake(x, a), expected_result) 30 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/softshrink_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import softshrink 21 | 22 | from tensorflow_addons.utils import test_utils 23 | 24 | 25 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 26 | def test_softshrink(dtype): 27 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 28 | expected_result = tf.constant([-1.5, -0.5, 0.0, 0.5, 1.5], dtype=dtype) 29 | test_utils.assert_allclose_according_to_type(softshrink(x), expected_result) 30 | 31 | expected_result = tf.constant([-1.0, 0.0, 0.0, 0.0, 1.0], dtype=dtype) 32 | test_utils.assert_allclose_according_to_type( 33 | softshrink(x, lower=-1.0, upper=1.0), expected_result 34 | ) 35 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/tanhshrink_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import tanhshrink 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_tanh(dtype): 26 | x = tf.constant([-1.0, 0.0, 1.0], dtype=dtype) 27 | expected_result = tf.constant([-0.23840582, 0.0, 0.238405825], dtype=dtype) 28 | test_utils.assert_allclose_according_to_type(tanhshrink(x), expected_result) 29 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "callbacks", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/optimizers", 10 | "//tensorflow_addons/testing", 11 | "//tensorflow_addons/utils", 12 | ], 13 | ) 14 | 15 | py_test( 16 | name = "callbacks_test", 17 | size = "small", 18 | srcs = glob(["tests/*"]), 19 | main = "tests/run_all_test.py", 20 | deps = [ 21 | ":callbacks", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Callbacks 2 | 3 | ## Contents 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/callbacks 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all callbacks 9 | must: 10 | * Inherit from `tf.keras.callbacks.Callback`. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * Simple unittests that demonstrate the callback is behaving as expected. 15 | * To run your `tf.functions` in eager mode and graph mode in the tests, 16 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 17 | decorator. This will run the tests twice, once normally, and once 18 | with `tf.config.run_functions_eagerly(True)`. 19 | 20 | #### Documentation Requirements 21 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 22 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional callbacks that conform to Keras API.""" 16 | 17 | from tensorflow_addons.callbacks.average_model_checkpoint import AverageModelCheckpoint 18 | from tensorflow_addons.callbacks.time_stopping import TimeStopping 19 | from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar 20 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/callbacks/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/tests/time_stopping_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.keras.models import Sequential 7 | from tensorflow.keras.layers import Dense 8 | 9 | from tensorflow_addons.callbacks.time_stopping import TimeStopping 10 | 11 | 12 | class SleepLayer(tf.keras.layers.Layer): 13 | def __init__(self, secs): 14 | self.secs = secs 15 | super().__init__(dynamic=True) 16 | 17 | def call(self, inputs): 18 | time.sleep(self.secs) 19 | return inputs 20 | 21 | 22 | def get_data_and_model(secs): 23 | X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) 24 | y = np.array([[0], [1], [1], [0]]) 25 | 26 | model = Sequential() 27 | model.add(SleepLayer(secs)) 28 | model.add(Dense(1)) 29 | model.compile(loss="mean_squared_error") 30 | 31 | # In case there is some initialization going on. 32 | model.fit(X, y, epochs=1, verbose=0) 33 | return X, y, model 34 | 35 | 36 | def test_stop_at_the_right_time(): 37 | X, y, model = get_data_and_model(0.1) 38 | 39 | time_stopping = TimeStopping(2, verbose=0) 40 | history = model.fit(X, y, epochs=30, verbose=0, callbacks=[time_stopping]) 41 | 42 | assert len(history.epoch) <= 20 43 | 44 | 45 | def test_default_value(): 46 | X, y, model = get_data_and_model(0.1) 47 | 48 | time_stopping = TimeStopping() 49 | history = model.fit(X, y, epochs=15, verbose=0, callbacks=[time_stopping]) 50 | 51 | assert len(history.epoch) == 15 52 | 53 | 54 | @pytest.mark.parametrize("verbose", [0, 1]) 55 | def test_time_stopping_verbose(capsys, verbose): 56 | X, y, model = get_data_and_model(0.25) 57 | 58 | time_stopping = TimeStopping(1, verbose=verbose) 59 | 60 | capsys.readouterr() # flush the stdout/stderr buffer. 61 | history = model.fit(X, y, epochs=10, verbose=0, callbacks=[time_stopping]) 62 | fit_stdout = capsys.readouterr().out 63 | nb_epochs_run = len(history.epoch) 64 | message = "Timed stopping at epoch " + str(nb_epochs_run) 65 | if verbose: 66 | assert message in fit_stdout 67 | else: 68 | assert message not in fit_stdout 69 | assert len(history.epoch) <= 4 70 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/time_stopping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Callback that stops training when a specified amount of time has passed.""" 16 | 17 | import datetime 18 | import time 19 | from typeguard import typechecked 20 | 21 | import tensorflow as tf 22 | from tensorflow.keras.callbacks import Callback 23 | 24 | 25 | @tf.keras.utils.register_keras_serializable(package="Addons") 26 | class TimeStopping(Callback): 27 | """Stop training when a specified amount of time has passed. 28 | 29 | Args: 30 | seconds: maximum amount of time before stopping. 31 | Defaults to 86400 (1 day). 32 | verbose: verbosity mode. Defaults to 0. 33 | """ 34 | 35 | @typechecked 36 | def __init__(self, seconds: int = 86400, verbose: int = 0): 37 | super().__init__() 38 | 39 | self.seconds = seconds 40 | self.verbose = verbose 41 | self.stopped_epoch = None 42 | 43 | def on_train_begin(self, logs=None): 44 | self.stopping_time = time.time() + self.seconds 45 | 46 | def on_epoch_end(self, epoch, logs={}): 47 | if time.time() >= self.stopping_time: 48 | self.model.stop_training = True 49 | self.stopped_epoch = epoch 50 | 51 | def on_train_end(self, logs=None): 52 | if self.stopped_epoch is not None and self.verbose > 0: 53 | formatted_time = datetime.timedelta(seconds=self.seconds) 54 | msg = "Timed stopping at epoch {} after training for {}".format( 55 | self.stopped_epoch + 1, formatted_time 56 | ) 57 | print(msg) 58 | 59 | def get_config(self): 60 | config = { 61 | "seconds": self.seconds, 62 | "verbose": self.verbose, 63 | } 64 | 65 | base_config = super().get_config() 66 | return {**base_config, **config} 67 | -------------------------------------------------------------------------------- /tensorflow_addons/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | import tensorflow_addons as tfa 6 | 7 | from tensorflow_addons.utils.test_utils import ( # noqa: F401 8 | maybe_run_functions_eagerly, 9 | only_run_functions_eagerly, 10 | run_custom_and_py_ops, 11 | run_with_mixed_precision_policy, 12 | pytest_make_parametrize_id, 13 | data_format, 14 | set_seeds, 15 | pytest_addoption, 16 | set_global_variables, 17 | pytest_configure, 18 | device, 19 | pytest_generate_tests, 20 | pytest_collection_modifyitems, 21 | ) 22 | 23 | # fixtures present in this file will be available 24 | # when running tests and can be referenced with strings 25 | # https://docs.pytest.org/en/latest/fixture.html#conftest-py-sharing-fixture-functions 26 | 27 | 28 | @pytest.fixture(autouse=True) 29 | def add_doctest_namespace(doctest_namespace): 30 | doctest_namespace["np"] = np 31 | doctest_namespace["tf"] = tf 32 | doctest_namespace["tfa"] = tfa 33 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Custom Ops 2 | 3 | ## Contents 4 | | Sub-Package | Description | 5 | |:----------------------- |:-----------------------------| 6 | | Image | Ops for image manipulation | 7 | | Seq2seq | Ops for seq2seq encoder-decoder framework | 8 | | Text | Ops for text processing | 9 | | Layers | Ops for model layers | 10 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/image/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load("//tensorflow_addons:tensorflow_addons.bzl", "custom_op_library") 6 | 7 | custom_op_library( 8 | name = "_distort_image_ops.so", 9 | srcs = [ 10 | "cc/kernels/adjust_hsv_in_yiq_op.cc", 11 | "cc/kernels/adjust_hsv_in_yiq_op.h", 12 | "cc/ops/distort_image_ops.cc", 13 | ], 14 | cuda_srcs = [ 15 | "cc/kernels/adjust_hsv_in_yiq_op.h", 16 | "cc/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc", 17 | ], 18 | ) 19 | 20 | custom_op_library( 21 | name = "_image_ops.so", 22 | srcs = [ 23 | "cc/kernels/connected_components.cc", 24 | "cc/kernels/connected_components.h", 25 | "cc/kernels/euclidean_distance_transform_op.cc", 26 | "cc/kernels/euclidean_distance_transform_op.h", 27 | "cc/ops/image_ops.cc", 28 | ], 29 | cuda_srcs = [ 30 | "cc/kernels/euclidean_distance_transform_op.h", 31 | "cc/kernels/euclidean_distance_transform_op_gpu.cu.cc", 32 | ], 33 | ) 34 | 35 | custom_op_library( 36 | name = "_resampler_ops.so", 37 | srcs = [ 38 | "cc/kernels/resampler_ops.cc", 39 | "cc/kernels/resampler_ops.h", 40 | "cc/ops/resampler_ops.cc", 41 | ], 42 | cuda_srcs = [ 43 | "cc/kernels/resampler_ops.h", 44 | "cc/kernels/resampler_ops_gpu.cu.cc", 45 | ], 46 | ) 47 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | #ifndef TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ 17 | #define TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ 18 | 19 | #if PLATFORM_WINDOWS 20 | #define __restrict__ __restrict 21 | #endif 22 | 23 | #include "tensorflow/core/framework/op_kernel.h" 24 | 25 | namespace tensorflow { 26 | namespace addons { 27 | namespace functor { 28 | 29 | // Helper functor for the Resampler Op in 2D 30 | template 31 | struct Resampler2DFunctor { 32 | void operator()(OpKernelContext* ctx, const Device& d, 33 | const T* __restrict__ data, const T* __restrict__ warp, 34 | T* __restrict__ output, const int batch_size, 35 | const int data_height, const int data_width, 36 | const int data_channels, const int num_sampling_points); 37 | }; 38 | 39 | // Helper functor for the Resampler Gradient Op in 2D 40 | template 41 | struct ResamplerGrad2DFunctor { 42 | void operator()(OpKernelContext* ctx, const Device& d, 43 | const T* __restrict__ data, const T* __restrict__ warp, 44 | const T* __restrict__ grad_output, T* __restrict__ grad_data, 45 | T* __restrict__ grad_warp, const int batch_size, 46 | const int data_height, const int data_width, 47 | const int data_channels, const int num_sampling_points); 48 | }; 49 | 50 | } // namespace functor 51 | } // namespace addons 52 | } // namespace tensorflow 53 | #endif // TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow/core/framework/common_shape_fns.h" 17 | #include "tensorflow/core/framework/op.h" 18 | #include "tensorflow/core/framework/shape_inference.h" 19 | 20 | namespace tensorflow { 21 | namespace addons { 22 | 23 | using shape_inference::DimensionHandle; 24 | using shape_inference::InferenceContext; 25 | using shape_inference::ShapeHandle; 26 | 27 | // -------------------------------------------------------------------------- 28 | REGISTER_OP("Addons>Resampler") 29 | .Input("data: T") 30 | .Input("warp: T") 31 | .Output("output: T") 32 | .Attr("T: {bfloat16, half, float, double}") 33 | .SetShapeFn([](InferenceContext* c) { 34 | ShapeHandle data; 35 | ShapeHandle warp; 36 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data)); 37 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &warp)); 38 | 39 | ShapeHandle output; // will be warp[:-1] + [data[-1]] 40 | TF_RETURN_IF_ERROR(c->Subshape(warp, 0, -1, &output)); 41 | TF_RETURN_IF_ERROR( 42 | c->Concatenate(output, c->Vector(c->Dim(data, -1)), &output)); 43 | 44 | c->set_output(0, output); 45 | return Status(); 46 | }) 47 | .Doc(R"doc(Resampler op.)doc"); 48 | 49 | // -------------------------------------------------------------------------- 50 | REGISTER_OP("Addons>ResamplerGrad") 51 | .Input("data: T") 52 | .Input("warp: T") 53 | .Input("grad_output: T") 54 | .Output("grad_data: T") 55 | .Output("grad_warp: T") 56 | .Attr("T: {bfloat16, half, float, double}") 57 | .SetShapeFn([](InferenceContext* c) { 58 | c->set_output(0, c->input(0)); 59 | c->set_output(1, c->input(1)); 60 | return Status(); 61 | }) 62 | .Doc(R"doc(Resampler Grad op.)doc"); 63 | 64 | } // namespace addons 65 | } // namespace tensorflow 66 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/layers/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load("//tensorflow_addons:tensorflow_addons.bzl", "custom_op_library") 6 | 7 | custom_op_library( 8 | name = "_correlation_cost_ops.so", 9 | srcs = [ 10 | "cc/kernels/correlation_cost_op.cc", 11 | "cc/kernels/correlation_cost_op.h", 12 | "cc/ops/correlation_cost_op.cc", 13 | ], 14 | cuda_deps = [ 15 | "@cub_archive//:cub", 16 | ], 17 | cuda_srcs = [ 18 | "cc/kernels/correlation_cost_op.h", 19 | "cc/kernels/correlation_cost_op_gpu.cu.cc", 20 | ], 21 | ) 22 | 23 | custom_op_library( 24 | name = "_embedding_bag_ops.so", 25 | srcs = [ 26 | "cc/kernels/embedding_bag_ops.cc", 27 | "cc/kernels/embedding_bag_ops.h", 28 | "cc/ops/embedding_bag_ops.cc", 29 | ], 30 | cuda_srcs = [ 31 | "cc/kernels/embedding_bag_ops.h", 32 | "cc/kernels/embedding_bag_ops_gpu.cu.cc", 33 | "cc/kernels/embedding_bag_backward_kernels.cu.cc", 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_ADDONS_LAYERS_KERNELS_CORRELATION_COST_OP_H_ 17 | #define TENSORFLOW_ADDONS_LAYERS_KERNELS_CORRELATION_COST_OP_H_ 18 | 19 | #include "tensorflow/core/framework/op_kernel.h" 20 | #include "tensorflow/core/util/tensor_format.h" 21 | 22 | namespace tensorflow { 23 | namespace addons { 24 | namespace functor { 25 | 26 | template 27 | struct CorrelationCostFunctor { 28 | Status operator()(OpKernelContext* context, const Tensor& input_a_t, 29 | const Tensor& input_b_t, Tensor* output_t, 30 | /* params */ 31 | int kernel_size, int max_displacement, int stride_1, 32 | int stride_2, int pad, TensorFormat data_format); 33 | }; 34 | 35 | template 36 | struct CorrelationCostGradFunctor { 37 | Status operator()(OpKernelContext* context, const Tensor& input_a_t, 38 | const Tensor& input_b_t, const Tensor& topdiff_t, 39 | Tensor* output_a_gradient_t, Tensor* output_b_gradient_t, 40 | /* params */ 41 | int kernel_size, int max_displacement, int stride_1, 42 | int stride_2, int pad, TensorFormat data_format); 43 | }; 44 | 45 | } // namespace functor 46 | } // namespace addons 47 | } // namespace tensorflow 48 | 49 | #endif // TENSORFLOW_ADDONS_LAYERS_KERNELS_CORRELATION_COST_OP_H_ 50 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/layers/cc/kernels/embedding_bag_ops.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_ADDONS_LAYERS_KERNELS_EMBEDDING_BAG_OPS_H_ 17 | #define TENSORFLOW_ADDONS_LAYERS_KERNELS_EMBEDDING_BAG_OPS_H_ 18 | 19 | #include "tensorflow/core/framework/op_kernel.h" 20 | #include "tensorflow/core/framework/tensor_types.h" 21 | 22 | namespace tensorflow { 23 | namespace addons { 24 | 25 | enum class Combiner { 26 | kSum, 27 | kMean, 28 | }; 29 | 30 | namespace functor { 31 | 32 | template 33 | struct EmbeddingBagFunctor { 34 | void operator()(const Device &device, 35 | typename TTypes::ConstTensor indices, 36 | typename TTypes::ConstTensor params, 37 | typename TTypes::ConstTensor weights, 38 | typename TTypes::Tensor output, Combiner combiner); 39 | }; 40 | 41 | template 42 | struct EmbeddingBagBackwardFunctor { 43 | void operator()(const Device &device, 44 | typename TTypes::ConstTensor indices, 45 | typename TTypes::ConstTensor params, 46 | typename TTypes::ConstTensor weights, 47 | typename TTypes::ConstTensor grads, 48 | typename TTypes::Tensor params_grads, 49 | typename TTypes::Tensor weights_grads, 50 | Combiner combiner, OpKernelContext *context); 51 | }; 52 | 53 | } // namespace functor 54 | } // namespace addons 55 | } // namespace tensorflow 56 | 57 | #endif // TENSORFLOW_ADDONS_LAYERS_KERNELS_EMBEDDING_BAG_OPS_H_ 58 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/seq2seq/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load("//tensorflow_addons:tensorflow_addons.bzl", "custom_op_library") 6 | 7 | custom_op_library( 8 | name = "_beam_search_ops.so", 9 | srcs = [ 10 | "cc/kernels/beam_search_ops.cc", 11 | "cc/kernels/beam_search_ops.h", 12 | "cc/ops/beam_search_ops.cc", 13 | ], 14 | cuda_srcs = [ 15 | "cc/kernels/beam_search_ops.h", 16 | "cc/kernels/beam_search_ops_gpu.cu.cc", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_ADDONS_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 17 | #define TENSORFLOW_ADDONS_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 18 | 19 | #include "tensorflow/core/framework/tensor_types.h" 20 | #include "tensorflow/core/platform/types.h" 21 | #include "tensorflow/core/public/version.h" 22 | #if TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 16 23 | #include "unsupported/Eigen/CXX11/Tensor" 24 | #else 25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 | #endif 27 | 28 | namespace tensorflow { 29 | class OpKernelContext; 30 | 31 | namespace addons { 32 | 33 | namespace functor { 34 | 35 | template 36 | struct GatherTree { 37 | void operator()(OpKernelContext* ctx, const Device& d, 38 | typename TTypes::ConstTensor step_ids, 39 | typename TTypes::ConstTensor parent_ids, 40 | TTypes::ConstVec max_sequence_lengths, 41 | const T end_token, typename TTypes::Tensor beams); 42 | }; 43 | 44 | } // namespace functor 45 | } // end namespace addons 46 | } // namespace tensorflow 47 | 48 | #endif // TENSORFLOW_ADDONS_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 49 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/text/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load("//tensorflow_addons:tensorflow_addons.bzl", "custom_op_library") 6 | 7 | custom_op_library( 8 | name = "_skip_gram_ops.so", 9 | srcs = [ 10 | "cc/kernels/skip_gram_kernels.cc", 11 | "cc/ops/skip_gram_ops.cc", 12 | ], 13 | ) 14 | 15 | custom_op_library( 16 | name = "_parse_time_op.so", 17 | srcs = select({ 18 | "//tensorflow_addons:windows": [], 19 | "//conditions:default": [ 20 | "cc/kernels/parse_time_kernel.cc", 21 | "cc/ops/parse_time_op.cc", 22 | ], 23 | }), 24 | ) 25 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/text/cc/ops/skip_gram_ops.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow/core/framework/op.h" 17 | #include "tensorflow/core/framework/shape_inference.h" 18 | 19 | namespace tensorflow { 20 | namespace addons { 21 | REGISTER_OP("Addons>SkipGramGenerateCandidates") 22 | .Input("input_tensor: T") 23 | .Input("min_skips: int32") 24 | .Input("max_skips: int32") 25 | .Input("start: int32") 26 | .Input("limit: int32") 27 | .Input("emit_self_as_target: bool") 28 | .Output("tokens: T") 29 | .Output("labels: T") 30 | .Attr("T: type") 31 | // The seed attributes are needed by GuardedPhiloxRandom 32 | .Attr("seed: int = 0") 33 | .Attr("seed2: int = 0") 34 | .SetIsStateful() 35 | .SetShapeFn([](shape_inference::InferenceContext* c) { 36 | shape_inference::ShapeHandle unused; 37 | // input_tensor must be of rank-1. 38 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 39 | // All other args must be scalar. 40 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 41 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 42 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 43 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 44 | 45 | // Due to possible randomness in selecting skips, we only know that the 46 | // outputs will be of rank-1, but not their sizes. 47 | c->set_output(0, c->Vector(c->UnknownDim())); 48 | c->set_output(1, c->Vector(c->UnknownDim())); 49 | return Status(); 50 | }) 51 | .Doc(R"doc( 52 | Generates skip-gram token and label paired Tensors from the input tensor. 53 | See docs for the public-facing skip_gram_sample() Python op for more details. 54 | )doc"); 55 | } // end namespace addons 56 | } // namespace tensorflow -------------------------------------------------------------------------------- /tensorflow_addons/image/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "image", 7 | srcs = glob(["*.py"]), 8 | data = [ 9 | ":sparse_image_warp_test_data", 10 | "//tensorflow_addons/custom_ops/image:_distort_image_ops.so", 11 | "//tensorflow_addons/custom_ops/image:_image_ops.so", 12 | "//tensorflow_addons/custom_ops/image:_resampler_ops.so", 13 | "//tensorflow_addons/testing", 14 | "//tensorflow_addons/utils", 15 | ], 16 | ) 17 | 18 | filegroup( 19 | name = "sparse_image_warp_test_data", 20 | srcs = glob(["tests/test_data/*.png"]), 21 | ) 22 | 23 | py_test( 24 | name = "image_test", 25 | size = "small", 26 | srcs = glob(["tests/*"]), 27 | main = "tests/run_all_test.py", 28 | deps = [ 29 | ":image", 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /tensorflow_addons/image/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Image 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/image 5 | 6 | 7 | ## Contribution Guidelines 8 | #### Standard API 9 | In order to conform with the current API standard, all image ops 10 | must: 11 | * Be a standard image processing technique 12 | * Must be impossible to implement in one of the other API 13 | standards (Layers, Losses, etc.). 14 | 15 | #### Testing Requirements 16 | * Simple unittests that demonstrate the image op is behaving as 17 | expected. 18 | * To run your `tf.functions` in eager mode and graph mode in the tests, 19 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 20 | decorator. This will run the tests twice, once normally, and once 21 | with `tf.config.run_functions_eagerly(True)`. 22 | 23 | #### Documentation Requirements 24 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 25 | -------------------------------------------------------------------------------- /tensorflow_addons/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional image manipulation ops.""" 16 | 17 | from tensorflow_addons.image.distort_image_ops import adjust_hsv_in_yiq 18 | from tensorflow_addons.image.compose_ops import blend 19 | from tensorflow_addons.image.color_ops import equalize 20 | from tensorflow_addons.image.color_ops import sharpness 21 | from tensorflow_addons.image.connected_components import connected_components 22 | from tensorflow_addons.image.cutout_ops import cutout 23 | from tensorflow_addons.image.dense_image_warp import dense_image_warp 24 | from tensorflow_addons.image.distance_transform import euclidean_dist_transform 25 | from tensorflow_addons.image.dense_image_warp import interpolate_bilinear 26 | from tensorflow_addons.image.interpolate_spline import interpolate_spline 27 | from tensorflow_addons.image.filters import gaussian_filter2d 28 | from tensorflow_addons.image.filters import mean_filter2d 29 | from tensorflow_addons.image.filters import median_filter2d 30 | from tensorflow_addons.image.cutout_ops import random_cutout 31 | from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq 32 | from tensorflow_addons.image.resampler_ops import resampler 33 | from tensorflow_addons.image.transform_ops import rotate 34 | from tensorflow_addons.image.transform_ops import shear_x 35 | from tensorflow_addons.image.transform_ops import shear_y 36 | from tensorflow_addons.image.sparse_image_warp import sparse_image_warp 37 | from tensorflow_addons.image.transform_ops import compose_transforms 38 | from tensorflow_addons.image.transform_ops import angles_to_projective_transforms 39 | from tensorflow_addons.image.transform_ops import transform 40 | from tensorflow_addons.image.translate_ops import translate 41 | from tensorflow_addons.image.translate_ops import translate_xy 42 | from tensorflow_addons.image.translate_ops import translations_to_projective_transforms 43 | -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png -------------------------------------------------------------------------------- /tensorflow_addons/layers/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], # Apache 2.0 4 | ) 5 | 6 | py_library( 7 | name = "layers", 8 | srcs = glob(["*.py"]), 9 | data = [ 10 | "//tensorflow_addons/custom_ops/layers:_correlation_cost_ops.so", 11 | "//tensorflow_addons/custom_ops/layers:_embedding_bag_ops.so", 12 | ], 13 | deps = [ 14 | "//tensorflow_addons/activations", 15 | "//tensorflow_addons/rnn", 16 | "//tensorflow_addons/testing", 17 | "//tensorflow_addons/text", 18 | "//tensorflow_addons/utils", 19 | ], 20 | ) 21 | 22 | py_test( 23 | name = "layers_test", 24 | size = "large", 25 | srcs = glob(["tests/*"]), 26 | main = "tests/run_all_test.py", 27 | deps = [ 28 | ":layers", 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Layers 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/layers 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all layers 9 | must: 10 | * Inherit from either `keras.layers.Layer` or its subclasses. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * Simple unittests that demonstrate the layer is behaving as expected. 15 | * To run your `tf.functions` in eager mode and graph mode in the tests, 16 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 17 | decorator. This will run the tests twice, once normally, and once 18 | with `tf.config.run_functions_eagerly(True)`. 19 | * Run `layer_test` on the layer. 20 | 21 | #### Documentation Requirements 22 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional layers that conform to Keras API.""" 16 | 17 | from tensorflow_addons.layers.adaptive_pooling import ( 18 | AdaptiveAveragePooling1D, 19 | AdaptiveMaxPooling1D, 20 | AdaptiveAveragePooling2D, 21 | AdaptiveMaxPooling2D, 22 | AdaptiveAveragePooling3D, 23 | AdaptiveMaxPooling3D, 24 | ) 25 | 26 | from tensorflow_addons.layers.embedding_bag import EmbeddingBag 27 | from tensorflow_addons.layers.gelu import GELU 28 | from tensorflow_addons.layers.max_unpooling_2d import MaxUnpooling2D 29 | from tensorflow_addons.layers.max_unpooling_2d_v2 import MaxUnpooling2DV2 30 | from tensorflow_addons.layers.maxout import Maxout 31 | from tensorflow_addons.layers.multihead_attention import MultiHeadAttention 32 | from tensorflow_addons.layers.normalizations import FilterResponseNormalization 33 | from tensorflow_addons.layers.normalizations import GroupNormalization 34 | from tensorflow_addons.layers.normalizations import InstanceNormalization 35 | from tensorflow_addons.layers.optical_flow import CorrelationCost 36 | from tensorflow_addons.layers.poincare import PoincareNormalize 37 | from tensorflow_addons.layers.polynomial import PolynomialCrossing 38 | from tensorflow_addons.layers.snake import Snake 39 | from tensorflow_addons.layers.sparsemax import Sparsemax 40 | from tensorflow_addons.layers.spectral_normalization import SpectralNormalization 41 | from tensorflow_addons.layers.spatial_pyramid_pooling import SpatialPyramidPooling2D 42 | from tensorflow_addons.layers.tlu import TLU 43 | from tensorflow_addons.layers.wrappers import WeightNormalization 44 | from tensorflow_addons.layers.esn import ESN 45 | from tensorflow_addons.layers.stochastic_depth import StochasticDepth 46 | from tensorflow_addons.layers.noisy_dense import NoisyDense 47 | from tensorflow_addons.layers.crf import CRF 48 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implements GELU activation.""" 16 | 17 | import tensorflow as tf 18 | from tensorflow_addons.activations import gelu 19 | from typeguard import typechecked 20 | 21 | 22 | @tf.keras.utils.register_keras_serializable(package="Addons") 23 | class GELU(tf.keras.layers.Layer): 24 | """Gaussian Error Linear Unit. 25 | 26 | A smoother version of ReLU generally used 27 | in the BERT or BERT architecture based models. 28 | Original paper: https://arxiv.org/abs/1606.08415 29 | 30 | Input shape: 31 | Arbitrary. Use the keyword argument `input_shape` 32 | (tuple of integers, does not include the samples axis) 33 | when using this layer as the first layer in a model. 34 | 35 | Output shape: 36 | Same shape as the input. 37 | """ 38 | 39 | @typechecked 40 | def __init__(self, approximate: bool = True, **kwargs): 41 | super().__init__(**kwargs) 42 | self.approximate = approximate 43 | self.supports_masking = True 44 | 45 | def call(self, inputs): 46 | return gelu(inputs, approximate=self.approximate) 47 | 48 | def get_config(self): 49 | config = {"approximate": self.approximate} 50 | base_config = super().get_config() 51 | return {**base_config, **config} 52 | 53 | def compute_output_shape(self, input_shape): 54 | return input_shape 55 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/poincare.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementing PoincareNormalize layer.""" 16 | 17 | import tensorflow as tf 18 | from typeguard import typechecked 19 | from typing import Union, List 20 | 21 | 22 | @tf.keras.utils.register_keras_serializable(package="Addons") 23 | class PoincareNormalize(tf.keras.layers.Layer): 24 | """Project into the Poincare ball with `norm <= 1.0 - epsilon`. 25 | 26 | See [Poincaré Embeddings for Learning Hierarchical Representations](https://arxiv.org/pdf/1705.08039.pdf), 27 | and [wiki](https://en.wikipedia.org/wiki/Poincare_ball_model). 28 | 29 | For a 1-D tensor with `axis = 0`, computes 30 | 31 | (x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon 32 | output = 33 | x otherwise 34 | 35 | For `x` with more dimensions, independently normalizes each 1-D slice along 36 | dimension `axis`. 37 | 38 | Args: 39 | axis: Axis along which to normalize. A scalar or a vector of integers. 40 | epsilon: A small deviation from the edge of the unit sphere for 41 | numerical stability. 42 | """ 43 | 44 | @typechecked 45 | def __init__( 46 | self, axis: Union[None, int, List[int]] = 1, epsilon: float = 1e-5, **kwargs 47 | ): 48 | super().__init__(**kwargs) 49 | self.axis = axis 50 | self.epsilon = epsilon 51 | 52 | def call(self, inputs): 53 | x = tf.convert_to_tensor(inputs) 54 | square_sum = tf.math.reduce_sum(tf.math.square(x), self.axis, keepdims=True) 55 | x_inv_norm = tf.math.rsqrt(square_sum) 56 | x_inv_norm = tf.math.minimum((1.0 - self.epsilon) * x_inv_norm, 1.0) 57 | outputs = tf.math.multiply(x, x_inv_norm) 58 | return outputs 59 | 60 | def compute_output_shape(self, input_shape): 61 | return input_shape 62 | 63 | def get_config(self): 64 | config = {"axis": self.axis, "epsilon": self.epsilon} 65 | base_config = super().get_config() 66 | return {**base_config, **config} 67 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/snake.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implements Snake layer.""" 16 | 17 | import tensorflow as tf 18 | from typeguard import typechecked 19 | 20 | from tensorflow_addons.activations.snake import snake 21 | 22 | from tensorflow_addons.utils import types 23 | 24 | 25 | @tf.keras.utils.register_keras_serializable(package="Addons") 26 | class Snake(tf.keras.layers.Layer): 27 | """Snake layer to learn periodic functions with the trainable `frequency` scalar. 28 | 29 | See [Neural Networks Fail to Learn Periodic Functions and How to Fix It](https://arxiv.org/abs/2006.08195). 30 | 31 | Args: 32 | frequency_initializer: Initializer for the `frequency` scalar. 33 | """ 34 | 35 | @typechecked 36 | def __init__(self, frequency_initializer: types.Initializer = "ones", **kwargs): 37 | super().__init__(**kwargs) 38 | self.frequency_initializer = tf.keras.initializers.get(frequency_initializer) 39 | self.frequency = self.add_weight( 40 | initializer=frequency_initializer, trainable=True 41 | ) 42 | 43 | def call(self, inputs): 44 | return snake(inputs, self.frequency) 45 | 46 | def get_config(self): 47 | config = { 48 | "frequency_initializer": tf.keras.initializers.serialize( 49 | self.frequency_initializer 50 | ), 51 | } 52 | base_config = super().get_config() 53 | return {**base_config, **config} 54 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/sparsemax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | from tensorflow_addons.activations.sparsemax import sparsemax 18 | from typeguard import typechecked 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | class Sparsemax(tf.keras.layers.Layer): 23 | """Sparsemax activation function. 24 | 25 | The output shape is the same as the input shape. 26 | 27 | See [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification](https://arxiv.org/abs/1602.02068). 28 | 29 | Args: 30 | axis: Integer, axis along which the sparsemax normalization is applied. 31 | """ 32 | 33 | @typechecked 34 | def __init__(self, axis: int = -1, **kwargs): 35 | super().__init__(**kwargs) 36 | self.supports_masking = True 37 | self.axis = axis 38 | 39 | def call(self, inputs): 40 | return sparsemax(inputs, axis=self.axis) 41 | 42 | def get_config(self): 43 | config = {"axis": self.axis} 44 | base_config = super().get_config() 45 | return {**base_config, **config} 46 | 47 | def compute_output_shape(self, input_shape): 48 | return input_shape 49 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/layers/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/esn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Echo State recurrent Network (ESN).""" 16 | 17 | import pytest 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.layers.esn import ESN 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 25 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 26 | def layer_test_esn(dtype): 27 | inp = np.asanyarray( 28 | [[[1.0, 1.0, 1.0, 1.0]], [[2.0, 2.0, 2.0, 2.0]], [[3.0, 3.0, 3.0, 3.0]]] 29 | ).astype(dtype) 30 | out = np.asarray([[2.5, 2.5, 2.5], [4.5, 4.5, 4.5], [6.5, 6.5, 6.5]]).astype(dtype) 31 | 32 | const_initializer = tf.constant_initializer(0.5) 33 | kwargs = { 34 | "units": 3, 35 | "connectivity": 1, 36 | "leaky": 1, 37 | "spectral_radius": 0.9, 38 | "use_norm2": True, 39 | "use_bias": True, 40 | "activation": None, 41 | "kernel_initializer": const_initializer, 42 | "recurrent_initializer": const_initializer, 43 | "bias_initializer": const_initializer, 44 | "dtype": dtype, 45 | } 46 | 47 | test_utils.layer_test(ESN, kwargs=kwargs, input_data=inp, expected_output=out) 48 | 49 | 50 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 51 | def test_serialization(dtype): 52 | esn = ESN( 53 | units=3, 54 | connectivity=1, 55 | leaky=1, 56 | spectral_radius=0.9, 57 | use_norm2=False, 58 | use_bias=True, 59 | activation=None, 60 | kernel_initializer="ones", 61 | recurrent_initializer="ones", 62 | bias_initializer="ones", 63 | ) 64 | serialized_esn = tf.keras.layers.serialize(esn) 65 | new_layer = tf.keras.layers.deserialize(serialized_esn) 66 | assert esn.get_config() == new_layer.get_config() 67 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/gelu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for GELU activation.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | from tensorflow_addons.layers.gelu import GELU 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_random(dtype): 26 | x = np.array([[0.5, 1.2, -0.3]]).astype(dtype) 27 | val = np.array([[0.345714, 1.0617027, -0.11462909]]).astype(dtype) 28 | test_utils.layer_test( 29 | GELU, kwargs={"dtype": dtype}, input_data=x, expected_output=val 30 | ) 31 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/maxout_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Maxout layer.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | 21 | from tensorflow_addons.layers.maxout import Maxout 22 | from tensorflow_addons.utils import test_utils 23 | 24 | 25 | pytestmark = pytest.mark.usefixtures("maybe_run_functions_eagerly") 26 | 27 | 28 | def test_simple(): 29 | test_utils.layer_test(Maxout, kwargs={"num_units": 3}, input_shape=(5, 4, 2, 18)) 30 | 31 | 32 | def test_nchw(): 33 | test_utils.layer_test( 34 | Maxout, kwargs={"num_units": 4, "axis": 1}, input_shape=(2, 20, 3, 6) 35 | ) 36 | 37 | test_utils.layer_test( 38 | Maxout, kwargs={"num_units": 4, "axis": -3}, input_shape=(2, 20, 3, 6) 39 | ) 40 | 41 | 42 | def test_unknown(): 43 | inputs = np.random.random((5, 4, 2, 18)).astype("float32") 44 | test_utils.layer_test( 45 | Maxout, kwargs={"num_units": 3}, input_shape=(5, 4, 2, None), input_data=inputs 46 | ) 47 | 48 | test_utils.layer_test( 49 | Maxout, 50 | kwargs={"num_units": 3}, 51 | input_shape=(None, None, None, None), 52 | input_data=inputs, 53 | ) 54 | 55 | 56 | def test_invalid_shape(): 57 | with pytest.raises(ValueError, match="number of features"): 58 | test_utils.layer_test(Maxout, kwargs={"num_units": 3}, input_shape=(5, 4, 2, 7)) 59 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/netvlad_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for NetVLAD layer.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | from tensorflow_addons.layers.netvlad import NetVLAD 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | pytestmark = pytest.mark.usefixtures("maybe_run_functions_eagerly") 25 | 26 | 27 | @pytest.mark.parametrize("num_clusters", [1, 4]) 28 | def test_simple(num_clusters): 29 | test_utils.layer_test( 30 | NetVLAD, 31 | kwargs={"num_clusters": num_clusters}, 32 | input_shape=(5, 4, 100), 33 | expected_output_shape=(None, num_clusters * 100), 34 | ) 35 | 36 | 37 | def test_unknown(): 38 | inputs = np.random.random((5, 4, 100)).astype("float32") 39 | test_utils.layer_test( 40 | NetVLAD, 41 | kwargs={"num_clusters": 3}, 42 | input_shape=(None, None, 100), 43 | input_data=inputs, 44 | expected_output_shape=(None, 3 * 100), 45 | ) 46 | 47 | 48 | def test_invalid_shape(): 49 | with pytest.raises(ValueError) as exception_info: 50 | test_utils.layer_test( 51 | NetVLAD, kwargs={"num_clusters": 0}, input_shape=(5, 4, 20) 52 | ) 53 | assert "`num_clusters` must be greater than 1" in str(exception_info.value) 54 | 55 | with pytest.raises(ValueError) as exception_info: 56 | test_utils.layer_test( 57 | NetVLAD, kwargs={"num_clusters": 2}, input_shape=(5, 4, 4, 20) 58 | ) 59 | assert "must have rank 3" in str(exception_info.value) 60 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/snake_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Snake layer.""" 16 | 17 | import pytest 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from tensorflow_addons.layers.snake import Snake 23 | from tensorflow_addons.activations.snake import snake 24 | 25 | from tensorflow_addons.utils import test_utils 26 | 27 | 28 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 29 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 30 | def test_layer(dtype): 31 | x = np.random.rand(2, 5).astype(dtype) 32 | a = np.random.randn() 33 | val = snake(x, a) 34 | test_utils.layer_test( 35 | Snake, 36 | kwargs={"frequency_initializer": tf.constant_initializer(a), "dtype": dtype}, 37 | input_data=x, 38 | expected_output=val, 39 | ) 40 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/sparsemax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | import pytest 18 | import numpy as np 19 | 20 | from tensorflow_addons.layers import Sparsemax 21 | from tensorflow_addons.utils import test_utils 22 | 23 | test_obs = 17 24 | 25 | 26 | def _np_sparsemax(z): 27 | z = z - np.mean(z, axis=1)[:, np.newaxis] 28 | 29 | # sort z 30 | z_sorted = np.sort(z, axis=1)[:, ::-1] 31 | 32 | # calculate k(z) 33 | z_cumsum = np.cumsum(z_sorted, axis=1) 34 | k = np.arange(1, z.shape[1] + 1) 35 | z_check = 1 + k * z_sorted > z_cumsum 36 | # use argmax to get the index by row as .nonzero() doesn't 37 | # take an axis argument. np.argmax return the first index, but the last 38 | # index is required here, use np.flip to get the last index and 39 | # `z.shape[axis]` to compensate for np.flip afterwards. 40 | k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1) 41 | 42 | # calculate tau(z) 43 | tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1] 44 | tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1) 45 | 46 | # calculate p 47 | return np.maximum(0, z - tau_z) 48 | 49 | 50 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 51 | @pytest.mark.parametrize("dtype", [np.float32, np.float64]) 52 | def test_sparsemax_layer_against_numpy(dtype): 53 | """check sparsemax kernel against numpy.""" 54 | random = np.random.RandomState(1) 55 | 56 | z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) 57 | 58 | test_utils.layer_test( 59 | Sparsemax, 60 | kwargs={"dtype": dtype}, 61 | input_data=z, 62 | expected_output=_np_sparsemax(z).astype(dtype), 63 | ) 64 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/stochastic_depth_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow_addons.layers.stochastic_depth import StochasticDepth 6 | from tensorflow_addons.utils import test_utils 7 | 8 | _KEEP_SEED = 1111 9 | _DROP_SEED = 2222 10 | 11 | 12 | @pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED]) 13 | @pytest.mark.parametrize("training", [True, False]) 14 | def stochastic_depth_test(seed, training): 15 | np.random.seed(seed) 16 | tf.random.set_seed(seed) 17 | 18 | survival_probability = 0.5 19 | 20 | shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) 21 | residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32) 22 | 23 | if training: 24 | if seed == _KEEP_SEED: 25 | # shortcut + residual 26 | expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32) 27 | elif seed == _DROP_SEED: 28 | # shortcut 29 | expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) 30 | else: 31 | # shortcut + p_l * residual 32 | expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32) 33 | 34 | test_utils.layer_test( 35 | StochasticDepth, 36 | kwargs={"survival_probability": survival_probability}, 37 | input_data=[shortcut, residual], 38 | expected_output=expected_output, 39 | ) 40 | 41 | 42 | @pytest.mark.usefixtures("run_with_mixed_precision_policy") 43 | def test_with_mixed_precision_policy(): 44 | policy = tf.keras.mixed_precision.global_policy() 45 | 46 | shortcut = np.asarray([[0.2, 0.1, 0.4]]) 47 | residual = np.asarray([[0.2, 0.4, 0.5]]) 48 | 49 | output = StochasticDepth()([shortcut, residual]) 50 | assert output.dtype == policy.compute_dtype 51 | 52 | output = StochasticDepth()([shortcut, residual], training=True) 53 | assert output.dtype == policy.compute_dtype 54 | 55 | 56 | def test_serialization(): 57 | stoch_depth = StochasticDepth(survival_probability=0.5) 58 | serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) 59 | new_layer = tf.keras.layers.deserialize(serialized_stoch_depth) 60 | assert stoch_depth.get_config() == new_layer.get_config() 61 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/tlu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for TLU activation.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from tensorflow_addons.layers.tlu import TLU 23 | from tensorflow_addons.utils import test_utils 24 | 25 | 26 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 27 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 28 | def test_random(dtype): 29 | x = np.array([[-2.5, 0.0, 0.3]]).astype(dtype) 30 | val = np.array([[0.0, 0.0, 0.3]]).astype(dtype) 31 | test_utils.layer_test( 32 | TLU, kwargs={"dtype": dtype}, input_data=x, expected_output=val 33 | ) 34 | 35 | 36 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 37 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 38 | def test_affine(dtype): 39 | x = np.array([[-2.5, 0.0, 0.3]]).astype(dtype) 40 | val = np.array([[-1.5, 1.0, 1.3]]).astype(dtype) 41 | test_utils.layer_test( 42 | TLU, 43 | kwargs={ 44 | "affine": True, 45 | "dtype": dtype, 46 | "alpha_initializer": "ones", 47 | "tau_initializer": "ones", 48 | }, 49 | input_data=x, 50 | expected_output=val, 51 | ) 52 | 53 | 54 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 55 | def test_serialization(dtype): 56 | tlu = TLU( 57 | affine=True, alpha_initializer="ones", tau_initializer="ones", dtype=dtype 58 | ) 59 | serialized_tlu = tf.keras.layers.serialize(tlu) 60 | new_layer = tf.keras.layers.deserialize(serialized_tlu) 61 | assert tlu.get_config() == new_layer.get_config() 62 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "losses", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/activations", 10 | "//tensorflow_addons/testing", 11 | "//tensorflow_addons/utils", 12 | ], 13 | ) 14 | 15 | py_test( 16 | name = "losses_test", 17 | size = "small", 18 | srcs = glob(["tests/*"]), 19 | main = "tests/run_all_test.py", 20 | deps = [ 21 | ":losses", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Losses 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/losses 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all losses 9 | must: 10 | * Inherit from `keras.losses.Loss`. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * Simple unittests that demonstrate the loss is behaving as expected on 15 | some set of known inputs and outputs. 16 | * To run your `tf.functions` in eager mode and graph mode in the tests, 17 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 18 | decorator. This will run the tests twice, once normally, and once 19 | with `tf.config.run_functions_eagerly(True)`. 20 | 21 | #### Documentation Requirements 22 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 23 | 24 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional losses that conform to Keras API.""" 16 | 17 | from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss 18 | from tensorflow_addons.losses.focal_loss import ( 19 | sigmoid_focal_crossentropy, 20 | SigmoidFocalCrossEntropy, 21 | ) 22 | from tensorflow_addons.losses.giou_loss import giou_loss, GIoULoss 23 | from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss 24 | from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss 25 | from tensorflow_addons.losses.triplet import ( 26 | triplet_semihard_loss, 27 | triplet_hard_loss, 28 | TripletSemiHardLoss, 29 | TripletHardLoss, 30 | ) 31 | from tensorflow_addons.losses.quantiles import pinball_loss, PinballLoss 32 | 33 | 34 | from tensorflow_addons.losses.npairs import ( 35 | npairs_loss, 36 | NpairsLoss, 37 | npairs_multilabel_loss, 38 | NpairsMultilabelLoss, 39 | ) 40 | from tensorflow_addons.losses.kappa_loss import WeightedKappaLoss 41 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/losses/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/losses/tests/metric_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for metric learning.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow_addons.losses.metric_learning import pairwise_distance 22 | 23 | 24 | def test_zero_distance(): 25 | """Test that equal embeddings have a pairwise distance of 0.""" 26 | equal_embeddings = tf.constant([[1.0, 0.5], [1.0, 0.5]]) 27 | 28 | distances = pairwise_distance(equal_embeddings, squared=False) 29 | np.testing.assert_allclose(tf.math.reduce_sum(distances), 0, 1e-6, 1e-6) 30 | 31 | 32 | def test_positive_distances(): 33 | """Test that the pairwise distances are always positive.""" 34 | 35 | # Create embeddings very close to each other in [1.0 - 2e-7, 1.0 + 2e-7] 36 | # This will encourage errors in the computation 37 | embeddings = 1.0 + 2e-7 * tf.random.uniform([64, 6], dtype=tf.float32) 38 | distances = pairwise_distance(embeddings, squared=False) 39 | assert np.all(distances >= 0) 40 | 41 | 42 | def test_correct_distance(): 43 | """Compare against numpy caluclation.""" 44 | tf_embeddings = tf.constant([[0.5, 0.5], [1.0, 1.0]]) 45 | 46 | expected_distance = np.array([[0, np.sqrt(2) / 2], [np.sqrt(2) / 2, 0]]) 47 | 48 | distances = pairwise_distance(tf_embeddings, squared=False) 49 | np.testing.assert_allclose(expected_distance, distances, 1e-6, 1e-6) 50 | 51 | 52 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 53 | def test_correct_distance_squared(): 54 | """Compare against numpy caluclation for squared distances.""" 55 | tf_embeddings = tf.constant([[0.5, 0.5], [1.0, 1.0]]) 56 | 57 | expected_distance = np.array([[0, 0.5], [0.5, 0]]) 58 | 59 | distances = pairwise_distance(tf_embeddings, squared=True) 60 | np.testing.assert_allclose(expected_distance, distances, 1e-6, 1e-6) 61 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "metrics", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/testing", 10 | "//tensorflow_addons/utils", 11 | ], 12 | ) 13 | 14 | py_test( 15 | name = "metrics_test", 16 | size = "small", 17 | srcs = glob(["tests/*"]), 18 | main = "tests/run_all_test.py", 19 | deps = [ 20 | ":metrics", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional metrics that conform to Keras API.""" 16 | 17 | from tensorflow_addons.metrics.cohens_kappa import CohenKappa 18 | from tensorflow_addons.metrics.f_scores import F1Score, FBetaScore 19 | from tensorflow_addons.metrics.hamming import ( 20 | HammingLoss, 21 | hamming_distance, 22 | hamming_loss_fn, 23 | ) 24 | from tensorflow_addons.metrics.utils import MeanMetricWrapper 25 | from tensorflow_addons.metrics.matthews_correlation_coefficient import ( 26 | MatthewsCorrelationCoefficient, 27 | ) 28 | from tensorflow_addons.metrics.multilabel_confusion_matrix import ( 29 | MultiLabelConfusionMatrix, 30 | ) 31 | from tensorflow_addons.metrics.r_square import RSquare 32 | from tensorflow_addons.metrics.geometric_mean import GeometricMean 33 | from tensorflow_addons.metrics.harmonic_mean import HarmonicMean 34 | from tensorflow_addons.metrics.streaming_correlations import ( 35 | KendallsTauB, 36 | KendallsTauC, 37 | PearsonsCorrelation, 38 | SpearmansRank, 39 | ) 40 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/harmonic_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implements HarmonicMean.""" 16 | 17 | import tensorflow as tf 18 | 19 | from typeguard import typechecked 20 | from tensorflow_addons.utils.types import AcceptableDTypes 21 | 22 | 23 | @tf.keras.utils.register_keras_serializable(package="Addons") 24 | class HarmonicMean(tf.keras.metrics.Mean): 25 | """Compute Harmonic Mean 26 | The harmonic mean is a kind of mean. It can be expressed as the reciprocal of 27 | the arithmetic mean of the reciprocals of the given set of numbers. 28 | Note: `tfa.metrics.HarmonicMean` can be used the same as `tf.keras.metrics.Mean`. 29 | Args: 30 | name: (Optional) String name of the metric instance. 31 | dtype: (Optional) Data type of the metric result. 32 | Usage: 33 | >>> metric = tfa.metrics.HarmonicMean() 34 | >>> metric.update_state([1, 4, 4]) 35 | >>> metric.result().numpy() 36 | 2.0 37 | """ 38 | 39 | @typechecked 40 | def __init__( 41 | self, name: str = "harmonic_mean", dtype: AcceptableDTypes = None, **kwargs 42 | ): 43 | super().__init__(name=name, dtype=dtype, **kwargs) 44 | 45 | def update_state(self, values, sample_weight=None) -> None: 46 | values = tf.cast(values, dtype=self.dtype) 47 | super().update_state(tf.math.reciprocal(values), sample_weight) 48 | 49 | def result(self) -> tf.Tensor: 50 | return tf.math.reciprocal_no_nan(super().result()) 51 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/metrics/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/metrics/tests/harmonic_mean_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests HarmonicMean metrics.""" 16 | 17 | import numpy as np 18 | import pytest 19 | import tensorflow as tf 20 | 21 | from tensorflow_addons.metrics import HarmonicMean 22 | 23 | 24 | def get_test_data(): 25 | return [ 26 | ([np.inf] * 2, 0), 27 | ([0, 0, 0, 0], 0), 28 | ([1, 4, 4], 2.0), 29 | ([0, 0, 0, 0, 0, 0, 0, 1, 2, 6], 0), 30 | ([0.2, 0.5, 0.3, 0.6, 0.1, 0.7], 0.25609756), 31 | ([8, 4, 1, 7, 2, 11, 9, 22, 52], 3.9394846), 32 | ([8.2, 9.7, 9.1, 2.7, 1.1, 2.0], 2.8376906), 33 | ([0.6666666, 0.215213, 0.15167], 0.23548213), 34 | ] 35 | 36 | 37 | def assert_result(expected, result): 38 | np.testing.assert_allclose(expected, result, atol=1e-6) 39 | 40 | 41 | def check_result(obj, expected_result, expected_count): 42 | result = obj.result().numpy() 43 | count = obj.count.numpy() 44 | assert_result(expected_result, result) 45 | np.testing.assert_equal(expected_count, count) 46 | 47 | 48 | @pytest.mark.parametrize("values, expected", get_test_data()) 49 | def test_vector_update_state_hmean(values, expected): 50 | obj = HarmonicMean() 51 | values = tf.constant(values, tf.float32) 52 | obj.update_state(values) 53 | check_result(obj, expected, len(values)) 54 | 55 | 56 | @pytest.mark.parametrize("values, expected", get_test_data()) 57 | def test_call_hmean(values, expected): 58 | obj = HarmonicMean() 59 | result = obj(tf.constant(values, tf.float32)) 60 | count = obj.count.numpy() 61 | assert_result(expected, result) 62 | np.testing.assert_equal(len(values), count) 63 | 64 | 65 | @pytest.mark.parametrize( 66 | "values, sample_weight, expected", 67 | [ 68 | ([1, 2, 3, 4, 5], 1, 2.1897807), 69 | ([2.1, 4.6, 7.1], [1, 2, 3], 4.499409), 70 | ([9.6, 1.8, 8.2], [0.2, 0.5, 0.3], 2.9833248), 71 | ], 72 | ) 73 | def test_sample_weight_hmean(values, sample_weight, expected): 74 | obj = HarmonicMean() 75 | obj.update_state(values, sample_weight=sample_weight) 76 | assert_result(expected, obj.result().numpy()) 77 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/tests/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import inspect 16 | 17 | from tensorflow.keras.metrics import Metric 18 | from tensorflow_addons import metrics 19 | 20 | 21 | def test_update_state_signature(): 22 | public_params = ["sample_weight"] 23 | params_comb = [["y_true", "y_pred"], ["values"]] 24 | for name, obj in inspect.getmembers(metrics): 25 | if inspect.isclass(obj) and issubclass(obj, Metric): 26 | check_update_state_signature(obj, public_params, params_comb) 27 | 28 | 29 | def check_update_state_signature(metric_class, public_params, case_list): 30 | error_msg = ( 31 | "Class {} is missing the parameter {} in the `update_state` " 32 | "method. If the method doesn't use this argument, declare " 33 | "it anyway and raise a UserWarning if it is " 34 | "not None." 35 | ) 36 | 37 | update_state_signature = inspect.signature(metric_class.update_state) 38 | 39 | for expected_param in public_params: 40 | if expected_param not in update_state_signature.parameters.keys(): 41 | raise ValueError(error_msg.format(metric_class.__name__, expected_param)) 42 | 43 | missing_params = [] 44 | for case in case_list: 45 | case_miss_params = [] 46 | case_check = True 47 | for expected_param in case: 48 | if expected_param not in update_state_signature.parameters.keys(): 49 | case_miss_params.append(expected_param) 50 | case_check = False 51 | break 52 | if case_check: 53 | return 54 | missing_params.append(case_miss_params) 55 | missing_params = [", ".join(p) for p in missing_params] 56 | missing_params = " or ".join(missing_params) 57 | raise ValueError(error_msg.format(metric_class.__name__, missing_params)) 58 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "optimizers", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/testing", 10 | "//tensorflow_addons/utils", 11 | ], 12 | ) 13 | 14 | py_test( 15 | name = "optimizers_test", 16 | size = "small", 17 | srcs = glob(["tests/*"]), 18 | main = "tests/run_all_test.py", 19 | deps = [ 20 | ":optimizers", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Optimizers 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all optimizers 9 | must: 10 | * Inherit from either `keras.optimizer_v2.OptimizerV2` or its subclasses. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * To run your `tf.functions` in eager mode and graph mode in the tests, 15 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 16 | decorator. This will run the tests twice, once normally, and once 17 | with `tf.config.run_functions_eagerly(True)`. 18 | 19 | #### Documentation Requirements 20 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 21 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional optimizers that conform to Keras API.""" 16 | 17 | from tensorflow_addons.optimizers.constants import KerasLegacyOptimizer 18 | from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper 19 | from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient 20 | from tensorflow_addons.optimizers.cyclical_learning_rate import CyclicalLearningRate 21 | from tensorflow_addons.optimizers.cyclical_learning_rate import ( 22 | TriangularCyclicalLearningRate, 23 | ) 24 | from tensorflow_addons.optimizers.cyclical_learning_rate import ( 25 | Triangular2CyclicalLearningRate, 26 | ) 27 | from tensorflow_addons.optimizers.cyclical_learning_rate import ( 28 | ExponentialCyclicalLearningRate, 29 | ) 30 | from tensorflow_addons.optimizers.discriminative_layer_training import ( 31 | MultiOptimizer, 32 | ) 33 | from tensorflow_addons.optimizers.lamb import LAMB 34 | from tensorflow_addons.optimizers.lazy_adam import LazyAdam 35 | from tensorflow_addons.optimizers.lookahead import Lookahead 36 | from tensorflow_addons.optimizers.moving_average import MovingAverage 37 | from tensorflow_addons.optimizers.novograd import NovoGrad 38 | from tensorflow_addons.optimizers.proximal_adagrad import ProximalAdagrad 39 | from tensorflow_addons.optimizers.rectified_adam import RectifiedAdam 40 | from tensorflow_addons.optimizers.stochastic_weight_averaging import SWA 41 | from tensorflow_addons.optimizers.weight_decay_optimizers import AdamW 42 | from tensorflow_addons.optimizers.adabelief import AdaBelief 43 | from tensorflow_addons.optimizers.weight_decay_optimizers import SGDW 44 | from tensorflow_addons.optimizers.weight_decay_optimizers import ( 45 | extend_with_decoupled_weight_decay, 46 | ) 47 | from tensorflow_addons.optimizers.weight_decay_optimizers import ( 48 | DecoupledWeightDecayExtension, 49 | ) 50 | from tensorflow_addons.optimizers.yogi import Yogi 51 | from tensorflow_addons.optimizers.cocob import COCOB 52 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow as tf 16 | 17 | if ( 18 | hasattr(tf.keras.optimizers, "experimental") 19 | and tf.keras.optimizers.Optimizer.__module__ 20 | == tf.keras.optimizers.experimental.Optimizer.__module__ 21 | ): 22 | # If the default optimizer points to new Keras optimizer, addon optimizers 23 | # should use the legacy path. 24 | KerasLegacyOptimizer = tf.keras.optimizers.legacy.Optimizer 25 | else: 26 | KerasLegacyOptimizer = tf.keras.optimizers.Optimizer 27 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/optimizers/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/tests/cocob_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for COntinuos COin Betting (COCOB) Backprop optimizer""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | from tensorflow_addons.optimizers import COCOB 20 | 21 | 22 | def run_dense_sample(iterations, expected, optimizer): 23 | var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32) 24 | var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32) 25 | 26 | grad_0 = tf.constant([0.1, 0.2], dtype=tf.dtypes.float32) 27 | grad_1 = tf.constant([0.03, 0.04], dtype=tf.dtypes.float32) 28 | 29 | grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) 30 | 31 | for _ in range(iterations): 32 | optimizer.apply_gradients(grads_and_vars) 33 | 34 | np.testing.assert_allclose(var_0.read_value(), expected[0], atol=2e-4) 35 | np.testing.assert_allclose(var_1.read_value(), expected[1], atol=2e-4) 36 | 37 | 38 | def test_dense_sample_with_default_alpha(): 39 | run_dense_sample( 40 | iterations=10, 41 | expected=[[0.84528893, 1.845289], [2.845289, 3.845289]], 42 | optimizer=COCOB(), 43 | ) 44 | 45 | 46 | def test_dense_sample_with_custom_int_alpha(): 47 | run_dense_sample( 48 | iterations=7, 49 | expected=[[0.09346926, 1.0934693], [2.0934694, 3.0934694]], 50 | optimizer=COCOB(20), 51 | ) 52 | 53 | 54 | def test_dense_sample_with_custom_float_alpha(): 55 | run_dense_sample( 56 | iterations=5, 57 | expected=[[0.89307845, 1.8930784], [2.8930783, 3.8930783]], 58 | optimizer=COCOB(55.7), 59 | ) 60 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/tests/standard_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from tensorflow_addons import optimizers 21 | from tensorflow_addons.optimizers import KerasLegacyOptimizer 22 | from tensorflow_addons.utils.test_utils import discover_classes 23 | 24 | class_exceptions = [ 25 | "MultiOptimizer", # is wrapper 26 | "SGDW", # is wrapper 27 | "AdamW", # is wrapper 28 | "SWA", # is wrapper 29 | "AveragedOptimizerWrapper", # is wrapper 30 | "ConditionalGradient", # is wrapper 31 | "Lookahead", # is wrapper 32 | "MovingAverage", # is wrapper 33 | "KerasLegacyOptimizer", # is a constantc 34 | ] 35 | 36 | classes_to_test = discover_classes(optimizers, KerasLegacyOptimizer, class_exceptions) 37 | 38 | 39 | @pytest.mark.parametrize("optimizer", classes_to_test) 40 | @pytest.mark.parametrize("serialize", [True, False]) 41 | def test_optimizer_minimize_serialize(optimizer, serialize, tmpdir): 42 | """ 43 | Purpose of this test is to confirm that the optimizer can minimize the loss in toy conditions. 44 | It also tests for serialization as a parameter. 45 | """ 46 | model = tf.keras.Sequential([tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1)]) 47 | 48 | x = np.array(np.ones([1])) 49 | y = np.array(np.zeros([1])) 50 | 51 | opt = optimizer() 52 | loss = tf.keras.losses.MSE 53 | 54 | model.compile(optimizer=opt, loss=loss) 55 | 56 | # serialize whole model including optimizer, clear the session, then reload the whole model. 57 | # successfully serialized optimizers should not require a compile before training 58 | if serialize: 59 | model.save(str(tmpdir), save_format="tf") 60 | tf.keras.backend.clear_session() 61 | model = tf.keras.models.load_model(str(tmpdir)) 62 | 63 | history = model.fit(x, y, batch_size=1, epochs=10) 64 | 65 | loss_values = history.history["loss"] 66 | 67 | np.testing.assert_array_less(loss_values[-1], loss_values[0]) 68 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional Utilities used for tfa.optimizers.""" 16 | 17 | import re 18 | import tensorflow as tf 19 | from typing import List 20 | 21 | 22 | def fit_bn(model, *args, **kwargs): 23 | """Resets batch normalization layers of model, and recalculates the 24 | statistics for each batchnorm layer by running a pass on the data. 25 | 26 | Args: 27 | model: An instance of tf.keras.Model 28 | *args, **kwargs: Params that'll be passed to `.fit` method of model 29 | """ 30 | kwargs["epochs"] = 1 31 | if not isinstance(model, tf.keras.Model): 32 | raise TypeError("model must be an instance of tf.keras.Model") 33 | 34 | if not model.built: 35 | raise ValueError("Call `fit_bn` after the model is built and trained") 36 | 37 | assign_ops = [] 38 | for layer in model.layers: 39 | if isinstance(layer, tf.keras.layers.BatchNormalization): 40 | assign_ops.extend( 41 | [ 42 | layer.moving_mean.assign(tf.zeros_like(layer.moving_mean)), 43 | layer.moving_variance.assign(tf.ones_like(layer.moving_variance)), 44 | ] 45 | ) 46 | 47 | _trainable = model.trainable 48 | _metrics = model._metrics 49 | model.trainable = False 50 | model._metrics = [] 51 | 52 | model.fit(*args, **kwargs) 53 | 54 | model.trainable = _trainable 55 | model._metrics = _metrics 56 | 57 | 58 | def get_variable_name(variable) -> str: 59 | """Get the variable name from the variable tensor.""" 60 | param_name = variable.name 61 | m = re.match("^(.*):\\d+$", param_name) 62 | if m is not None: 63 | param_name = m.group(1) 64 | return param_name 65 | 66 | 67 | def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool: 68 | """Whether variable is matched in regexes list by its name.""" 69 | if regexes: 70 | # var_name = get_variable_name(variable) 71 | var_name = variable.name 72 | for r in regexes: 73 | if re.search(r, var_name): 74 | return True 75 | return False 76 | -------------------------------------------------------------------------------- /tensorflow_addons/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import warnings 4 | import traceback 5 | 6 | try: 7 | _TF_ADDONS_PY_OPS = bool(int(os.environ["TF_ADDONS_PY_OPS"])) 8 | except KeyError: 9 | if platform.system() == "Linux": 10 | _TF_ADDONS_PY_OPS = False 11 | else: 12 | _TF_ADDONS_PY_OPS = True 13 | 14 | _FALLBACK_WARNING_TEMPLATE = """{} 15 | 16 | The {} C++/CUDA custom op could not be loaded. 17 | For this reason, Addons will fallback to an implementation written 18 | in Python with public TensorFlow ops. There worst you might experience with 19 | this is a moderate slowdown on GPU. There can be multiple 20 | reason for this loading error, one of them may be an ABI incompatibility between 21 | the TensorFlow installed on your system and the TensorFlow used to compile 22 | TensorFlow Addons' custom ops. The stacktrace generated when loading the 23 | shared object file was displayed above. 24 | 25 | If you want this warning to disappear, either make sure the TensorFlow installed 26 | is compatible with this version of Addons, or tell TensorFlow Addons to 27 | prefer using Python implementations and not custom C++/CUDA ones. You can do that 28 | by setting the enviornment variable `TF_ADDONS_PY_OPS=1`: 29 | ```bash 30 | TF_ADDONS_PY_OPS=1 python my_script.py 31 | ``` 32 | or run `tfa.options.disable_custom_kernel()` in your code, after your imports: 33 | ```python 34 | import tensorflow_addons as tfa 35 | import ... 36 | import ... 37 | 38 | tfa.options.disable_custom_kernel() 39 | ``` 40 | """ 41 | 42 | 43 | def warn_fallback(op_name): 44 | warning_msg = _FALLBACK_WARNING_TEMPLATE.format(traceback.format_exc(), op_name) 45 | warnings.warn(warning_msg, RuntimeWarning) 46 | disable_custom_kernel() 47 | 48 | 49 | def enable_custom_kernel(): 50 | """Prefer custom C++/CUDA kernel to pure python operations. 51 | 52 | Enable using custom C++/CUDA kernel instead of pure python operations. 53 | It has the same effect as setting environment variable `TF_ADDONS_PY_OPS=0`. 54 | """ 55 | global _TF_ADDONS_PY_OPS 56 | _TF_ADDONS_PY_OPS = False 57 | 58 | 59 | def disable_custom_kernel(): 60 | """Prefer pure python operations to custom C++/CUDA kernel. 61 | 62 | Disable using custom C++/CUDA kernel instead of pure python operations. 63 | It has the same effect as setting environment variable `TF_ADDONS_PY_OPS=1`. 64 | """ 65 | global _TF_ADDONS_PY_OPS 66 | _TF_ADDONS_PY_OPS = True 67 | 68 | 69 | def is_custom_kernel_disabled(): 70 | """Return whether custom C++/CUDA kernel is disabled.""" 71 | return _TF_ADDONS_PY_OPS 72 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "rnn", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/testing", 10 | "//tensorflow_addons/utils", 11 | ], 12 | ) 13 | 14 | py_test( 15 | name = "rnn_test", 16 | size = "small", 17 | srcs = glob(["tests/*"]), 18 | main = "tests/run_all_test.py", 19 | deps = [ 20 | ":rnn", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/README.md: -------------------------------------------------------------------------------- 1 | # Addons - RNN 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/rnn 5 | 6 | ## Contribution Guidelines 7 | #### Prerequisites 8 | * For any cell based on research paper, the original paper has to be well recognized. 9 | The criteria here is >= 100 citation based on Google scholar. If the contributor feels 10 | this requirement need to be overruled, please specify the detailed justification in the 11 | PR. 12 | 13 | #### Standard API 14 | In order to conform with the current API standard, all cells must: 15 | * Inherit from either `keras.layers.AbstractRNNCell` or `keras.layers.Layer` with 16 | required properties. 17 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 18 | 19 | #### Testing Requirements 20 | * To run your `tf.functions` in eager mode and graph mode in the tests, 21 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 22 | decorator. This will run the tests twice, once normally, and once 23 | with `tf.config.run_functions_eagerly(True)`. 24 | 25 | #### Documentation Requirements 26 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 27 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional RNN cells that corform to Keras API.""" 16 | 17 | from tensorflow_addons.rnn.nas_cell import NASCell 18 | from tensorflow_addons.rnn.layer_norm_lstm_cell import LayerNormLSTMCell 19 | from tensorflow_addons.rnn.layer_norm_simple_rnn_cell import LayerNormSimpleRNNCell 20 | from tensorflow_addons.rnn.esn_cell import ESNCell 21 | from tensorflow_addons.rnn.peephole_lstm_cell import PeepholeLSTMCell 22 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/rnn/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/rnn/tests/peephole_lstm_cell_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Peephole Cell.""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from tensorflow_addons.rnn import PeepholeLSTMCell 21 | 22 | 23 | def test_peephole_lstm_cell(): 24 | def _run_cell(cell_fn, **kwargs): 25 | inputs = tf.one_hot([1, 2, 3, 4], 4) 26 | cell = cell_fn(5, **kwargs) 27 | cell.build(inputs.shape) 28 | initial_state = cell.get_initial_state( 29 | inputs=inputs, batch_size=4, dtype=tf.float32 30 | ) 31 | output, _ = cell(inputs, initial_state) 32 | return output 33 | 34 | tf.random.set_seed(12345) 35 | first_implementation_output = _run_cell( 36 | PeepholeLSTMCell, 37 | kernel_initializer="ones", 38 | recurrent_activation="sigmoid", 39 | implementation=1, 40 | ) 41 | second_implementation_output = _run_cell( 42 | PeepholeLSTMCell, 43 | kernel_initializer="ones", 44 | recurrent_activation="sigmoid", 45 | implementation=2, 46 | ) 47 | expected_output = np.asarray( 48 | [ 49 | [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], 50 | [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], 51 | [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], 52 | [0.0, 0.0, 0.0, 0.0, 0.0], 53 | ], 54 | dtype=np.float32, 55 | ) 56 | np.testing.assert_allclose( 57 | first_implementation_output, second_implementation_output 58 | ) 59 | np.testing.assert_allclose( 60 | first_implementation_output, expected_output, rtol=1e-6, atol=1e-6 61 | ) 62 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/seq2seq/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "seq2seq", 7 | srcs = glob(["*.py"]), 8 | data = [ 9 | "//tensorflow_addons:options.py", 10 | "//tensorflow_addons/custom_ops/seq2seq:_beam_search_ops.so", 11 | ], 12 | deps = [ 13 | "//tensorflow_addons/rnn", 14 | "//tensorflow_addons/testing", 15 | "//tensorflow_addons/utils", 16 | ], 17 | ) 18 | 19 | py_test( 20 | name = "seq2seq_test", 21 | size = "medium", 22 | srcs = glob(["tests/*"]), 23 | main = "tests/run_all_test.py", 24 | deps = [ 25 | ":seq2seq", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /tensorflow_addons/seq2seq/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/seq2seq/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/seq2seq/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/tensorflow_addons.bzl: -------------------------------------------------------------------------------- 1 | load("@local_config_tf//:build_defs.bzl", "CPLUSPLUS_VERSION", "D_GLIBCXX_USE_CXX11_ABI") 2 | load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda", "if_cuda_is_configured") 3 | 4 | def custom_op_library( 5 | name, 6 | srcs = [], 7 | cuda_srcs = [], 8 | deps = [], 9 | cuda_deps = [], 10 | copts = [], 11 | **kwargs): 12 | deps = deps + [ 13 | "@local_config_tf//:libtensorflow_framework", 14 | "@local_config_tf//:tf_header_lib", 15 | ] 16 | 17 | if cuda_srcs: 18 | copts = copts + if_cuda(["-DGOOGLE_CUDA=1"]) 19 | cuda_copts = copts + if_cuda_is_configured([ 20 | "-x cuda", 21 | "-nvcc_options=relaxed-constexpr", 22 | "-nvcc_options=ftz=true", 23 | ]) 24 | cuda_deps = deps + if_cuda_is_configured(cuda_deps) + if_cuda_is_configured([ 25 | "@local_config_cuda//cuda:cuda_headers", 26 | "@local_config_cuda//cuda:cudart_static", 27 | ]) 28 | basename = name.split(".")[0] 29 | native.cc_library( 30 | name = basename + "_gpu", 31 | srcs = cuda_srcs, 32 | deps = cuda_deps, 33 | copts = cuda_copts, 34 | alwayslink = 1, 35 | **kwargs 36 | ) 37 | deps = deps + if_cuda_is_configured([":" + basename + "_gpu"]) 38 | 39 | copts = copts + select({ 40 | "//tensorflow_addons:windows": [ 41 | "/DEIGEN_STRONG_INLINE=inline", 42 | "-DTENSORFLOW_MONOLITHIC_BUILD", 43 | "/D_USE_MATH_DEFINES", 44 | "/DPLATFORM_WINDOWS", 45 | "/DEIGEN_HAS_C99_MATH", 46 | "/DTENSORFLOW_USE_EIGEN_THREADPOOL", 47 | "/DEIGEN_AVOID_STL_ARRAY", 48 | "/Iexternal/gemmlowp", 49 | "/wd4018", 50 | "/wd4577", 51 | "/DNOGDI", 52 | "/UTF_COMPILE_LIBRARY", 53 | ], 54 | "//conditions:default": ["-pthread", CPLUSPLUS_VERSION, D_GLIBCXX_USE_CXX11_ABI], 55 | }) 56 | 57 | native.cc_binary( 58 | name = name, 59 | srcs = srcs, 60 | copts = copts, 61 | linkshared = 1, 62 | features = select({ 63 | "//tensorflow_addons:windows": ["windows_export_all_symbols"], 64 | "//conditions:default": [], 65 | }), 66 | deps = deps, 67 | **kwargs 68 | ) 69 | -------------------------------------------------------------------------------- /tensorflow_addons/testing/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "testing", 7 | srcs = glob(["*.py"]), 8 | ) 9 | 10 | py_test( 11 | name = "serialization_test", 12 | size = "small", 13 | srcs = glob(["tests/*"]), 14 | main = "tests/run_all_test.py", 15 | deps = [ 16 | ":testing", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /tensorflow_addons/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/testing/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/testing/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/testing/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/testing/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/testing/tests/serialization_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tensorflow.keras.metrics import MeanAbsoluteError, TrueNegatives, Metric 6 | from tensorflow_addons.testing.serialization import check_metric_serialization 7 | 8 | 9 | def test_check_metric_serialization_mae(): 10 | check_metric_serialization(MeanAbsoluteError(), (2, 2), (2, 2)) 11 | check_metric_serialization(MeanAbsoluteError(name="hello"), (2, 2), (2, 2)) 12 | check_metric_serialization(MeanAbsoluteError(), (2, 2, 2), (2, 2, 2)) 13 | check_metric_serialization(MeanAbsoluteError(), (2, 2, 2), (2, 2, 2), (2, 2, 1)) 14 | 15 | 16 | def get_random_booleans(): 17 | return np.random.uniform(0, 2, size=(2, 2)) 18 | 19 | 20 | def test_check_metric_serialization_true_negative(): 21 | check_metric_serialization( 22 | TrueNegatives(0.8), 23 | np.random.uniform(0, 2, size=(2, 2)).astype(bool), 24 | np.random.uniform(0, 1, size=(2, 2)).astype(np.float32), 25 | ) 26 | 27 | 28 | class MyDummyMetric(Metric): 29 | def __init__(self, stuff, name): 30 | super().__init__(name) 31 | self.stuff = stuff 32 | 33 | def update_state(self, y_true, y_pred, sample_weights): 34 | pass 35 | 36 | def get_config(self): 37 | return super().get_config() 38 | 39 | def result(self): 40 | return 3 41 | 42 | 43 | def test_missing_arg(): 44 | with pytest.raises(KeyError) as exception_info: 45 | check_metric_serialization(MyDummyMetric("dodo", "dada"), (2,), (2,)) 46 | 47 | assert "stuff" in str(exception_info.value) 48 | 49 | 50 | class MyOtherDummyMetric(Metric): 51 | def __init__(self, to_add, name=None, dtype=None): 52 | super().__init__(name, dtype) 53 | self.to_add = to_add 54 | self.sum_of_y_pred = self.add_weight(name="my_sum", initializer="zeros") 55 | 56 | def update_state(self, y_true, y_pred, sample_weights=None): 57 | self.sum_of_y_pred.assign_add(tf.math.reduce_sum(y_pred) + self.to_add) 58 | 59 | def get_config(self): 60 | config = {"to_add": self.to_add + 1} 61 | config.update(super().get_config()) 62 | return config 63 | 64 | def result(self): 65 | return self.sum_of_y_pred 66 | 67 | 68 | def test_wrong_serialization(): 69 | with pytest.raises(AssertionError): 70 | check_metric_serialization(MyOtherDummyMetric(5), (2,), (2,)) 71 | -------------------------------------------------------------------------------- /tensorflow_addons/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/tests/register_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | import tensorflow as tf 5 | from tensorflow_addons.register import register_all, _get_all_shared_objects 6 | from tensorflow_addons.utils import resource_loader 7 | 8 | 9 | def test_multiple_register(): 10 | if resource_loader.SKIP_CUSTOM_OPS: 11 | pytest.skip( 12 | "Skipping the test because a custom ops " 13 | "was being loaded while --skip-custom-ops was set." 14 | ) 15 | register_all() 16 | register_all() 17 | 18 | 19 | def test_get_all_shared_objects(): 20 | if resource_loader.SKIP_CUSTOM_OPS: 21 | pytest.skip( 22 | "Skipping the test because a custom ops " 23 | "was being loaded while --skip-custom-ops was set." 24 | ) 25 | all_shared_objects = _get_all_shared_objects() 26 | assert len(all_shared_objects) >= 4 27 | 28 | for file in all_shared_objects: 29 | tf.load_op_library(file) 30 | 31 | 32 | if __name__ == "__main__": 33 | sys.exit(pytest.main([__file__])) 34 | -------------------------------------------------------------------------------- /tensorflow_addons/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/text/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | # TODO: Once TF exports symbols in a DLL we can enable parse_time_op for windows 6 | # https://github.com/tensorflow/addons/issues/782 7 | py_library( 8 | name = "text", 9 | srcs = glob(["*.py"]), 10 | data = [ 11 | "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so", 12 | "//tensorflow_addons/rnn", 13 | "//tensorflow_addons/testing", 14 | "//tensorflow_addons/utils", 15 | ] + select({ 16 | "//tensorflow_addons:windows": [], 17 | "//conditions:default": [ 18 | "//tensorflow_addons/custom_ops/text:_parse_time_op.so", 19 | ], 20 | }), 21 | ) 22 | 23 | py_test( 24 | name = "text_test", 25 | size = "small", 26 | srcs = glob(["tests/*"]), 27 | main = "tests/run_all_test.py", 28 | deps = [ 29 | ":text", 30 | "//tensorflow_addons/layers", 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /tensorflow_addons/text/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Text 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/text 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all text ops 9 | must: 10 | * Be impossible to implement in one of the other API 11 | standards (Layers, Losses, etc.). 12 | * Be related to text processing. 13 | 14 | #### Testing Requirements 15 | * Simple unittests that demonstrate the text op is behaving as 16 | expected. 17 | * To run your `tf.functions` in eager mode and graph mode in the tests, 18 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 19 | decorator. This will run the tests twice, once normally, and once 20 | with `tf.config.run_functions_eagerly(True)`. 21 | 22 | #### Documentation Requirements 23 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 24 | -------------------------------------------------------------------------------- /tensorflow_addons/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional text-processing ops.""" 16 | 17 | # Conditional Random Field 18 | from tensorflow_addons.text import crf 19 | from tensorflow_addons.text.crf import CrfDecodeForwardRnnCell 20 | from tensorflow_addons.text.crf import crf_binary_score 21 | from tensorflow_addons.text.crf import crf_constrained_decode 22 | from tensorflow_addons.text.crf import crf_decode 23 | from tensorflow_addons.text.crf import crf_decode_backward 24 | from tensorflow_addons.text.crf import crf_decode_forward 25 | from tensorflow_addons.text.crf import crf_filtered_inputs 26 | from tensorflow_addons.text.crf import crf_forward 27 | from tensorflow_addons.text.crf import crf_log_likelihood 28 | from tensorflow_addons.text.crf import crf_log_norm 29 | from tensorflow_addons.text.crf import crf_multitag_sequence_score 30 | from tensorflow_addons.text.crf import crf_sequence_score 31 | from tensorflow_addons.text.crf import crf_unary_score 32 | from tensorflow_addons.text.crf import viterbi_decode 33 | from tensorflow_addons.text.crf_wrapper import CRFModelWrapper 34 | from tensorflow_addons.text.parse_time_op import parse_time 35 | 36 | # Skip Gram Sampling 37 | from tensorflow_addons.text.skip_gram_ops import skip_gram_sample 38 | from tensorflow_addons.text.skip_gram_ops import skip_gram_sample_with_text_vocab 39 | -------------------------------------------------------------------------------- /tensorflow_addons/text/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/text/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/text/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "utils", 7 | srcs = glob(["*.py"]), 8 | data = [ 9 | "//tensorflow_addons:conftest.py", 10 | "//tensorflow_addons:options.py", 11 | ], 12 | ) 13 | 14 | py_test( 15 | name = "keras_utils_test", 16 | size = "small", 17 | srcs = glob(["tests/*"]), 18 | main = "tests/run_all_test.py", 19 | deps = [ 20 | ":utils", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/README.md: -------------------------------------------------------------------------------- 1 | # Addons Utils 2 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/utils/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/utils/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/addons/d208d752e98c310280938efa939117bf635a60a8/tensorflow_addons/utils/tests/__init__.py -------------------------------------------------------------------------------- /tensorflow_addons/utils/tests/keras_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Keras utils.""" 16 | 17 | import sys 18 | 19 | import pytest 20 | import tensorflow as tf 21 | 22 | from tensorflow_addons.utils import keras_utils 23 | 24 | 25 | def test_normalize_data_format(): 26 | assert keras_utils.normalize_data_format("Channels_Last") == "channels_last" 27 | assert keras_utils.normalize_data_format("CHANNELS_FIRST") == "channels_first" 28 | 29 | with pytest.raises(ValueError, match="The `data_format` argument must be one of"): 30 | keras_utils.normalize_data_format("invalid") 31 | 32 | 33 | def test_normalize_tuple(): 34 | assert (2, 2, 2) == keras_utils.normalize_tuple(2, n=3, name="strides") 35 | assert (2, 1, 2) == keras_utils.normalize_tuple((2, 1, 2), n=3, name="strides") 36 | 37 | with pytest.raises(ValueError): 38 | keras_utils.normalize_tuple((2, 1), n=3, name="strides") 39 | 40 | with pytest.raises(TypeError): 41 | keras_utils.normalize_tuple(None, n=3, name="strides") 42 | 43 | 44 | def test_standard_cell(): 45 | keras_utils.assert_like_rnncell("cell", tf.keras.layers.LSTMCell(10)) 46 | 47 | 48 | def test_non_cell(): 49 | with pytest.raises(TypeError): 50 | keras_utils.assert_like_rnncell("cell", tf.keras.layers.Dense(10)) 51 | 52 | 53 | def test_custom_cell(): 54 | class CustomCell(tf.keras.layers.AbstractRNNCell): 55 | @property 56 | def output_size(self): 57 | raise ValueError("assert_like_rnncell should not run code") 58 | 59 | keras_utils.assert_like_rnncell("cell", CustomCell()) 60 | 61 | 62 | if __name__ == "__main__": 63 | sys.exit(pytest.main([__file__])) 64 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/tests/test_utils_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | from packaging.version import Version 3 | 4 | import numpy as np 5 | import pytest 6 | import tensorflow as tf 7 | from tensorflow_addons.utils import test_utils 8 | 9 | 10 | def test_seed_is_set(): 11 | assert random.randint(0, 10000) == 6311 12 | assert np.random.randint(0, 10000) == 2732 13 | assert tf.random.uniform([], 0, 10000, dtype=tf.int64).numpy() == 9457 14 | 15 | 16 | @pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy]) 17 | def test_all_scopes(device): 18 | assert isinstance(device, str) or isinstance(device, tf.distribute.Strategy) 19 | 20 | 21 | def train_small_model(): 22 | model_input = tf.keras.layers.Input((3,)) 23 | model_output = tf.keras.layers.Dense(4)(model_input) 24 | model = tf.keras.Model(model_input, model_output) 25 | model.compile(loss="mse") 26 | 27 | x = np.random.uniform(size=(5, 3)) 28 | y = np.random.uniform(size=(5, 4)) 29 | model.fit(x, y, epochs=1) 30 | 31 | 32 | @pytest.mark.skipif( 33 | Version(tf.__version__) >= Version("2.13"), 34 | reason="TF2.13 breakage: https://github.com/tensorflow/addons/pull/2835#issuecomment-1629772331", 35 | ) 36 | @pytest.mark.with_device([tf.distribute.MirroredStrategy]) 37 | def test_distributed_strategy(device): 38 | assert isinstance(device, tf.distribute.Strategy) 39 | train_small_model() 40 | 41 | 42 | @pytest.mark.skipif( 43 | Version(tf.__version__) >= Version("2.13"), 44 | reason="TF2.13 breakage: https://github.com/tensorflow/addons/pull/2835#issuecomment-1629772331", 45 | ) 46 | @pytest.mark.with_device(["no_device"]) 47 | @pytest.mark.needs_gpu 48 | def test_custom_device_placement(): 49 | with tf.device(test_utils.gpus_for_testing()[0]): 50 | train_small_model() 51 | 52 | strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing()) 53 | with strategy.scope(): 54 | train_small_model() 55 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/tfa_eol_msg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import warnings 17 | 18 | 19 | def _print_eol_warning(): 20 | """ 21 | Prints TensorFlow Addons End of Life Warning 22 | """ 23 | warnings.warn( 24 | "\n\nTensorFlow Addons (TFA) has ended development and introduction of new features.\n" 25 | "TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.\n" 26 | "Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community " 27 | "(e.g. Keras, Keras-CV, and Keras-NLP). \n\n" 28 | "For more information see: https://github.com/tensorflow/addons/issues/2807 \n", 29 | UserWarning, 30 | ) 31 | 32 | 33 | _print_eol_warning() 34 | -------------------------------------------------------------------------------- /tensorflow_addons/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Define TensorFlow Addons version information.""" 16 | 17 | # Required TensorFlow version [min, max) 18 | INCLUSIVE_MIN_TF_VERSION = "2.13.0" 19 | EXCLUSIVE_MAX_TF_VERSION = "2.16.0" 20 | 21 | # We follow Semantic Versioning (https://semver.org/) 22 | _MAJOR_VERSION = "0" 23 | _MINOR_VERSION = "23" 24 | _PATCH_VERSION = "0" 25 | 26 | # When building releases, we can update this value on the release branch to 27 | # reflect the current release candidate ('rc0', 'rc1') or, finally, the official 28 | # stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a 29 | # release branch, the current version is by default assumed to be a 30 | # 'development' version, labeled 'dev'. 31 | _VERSION_SUFFIX = "dev" 32 | 33 | # Example, '0.1.0-dev' 34 | __version__ = ".".join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) 35 | if _VERSION_SUFFIX: 36 | __version__ = "{}-{}".format(__version__, _VERSION_SUFFIX) 37 | -------------------------------------------------------------------------------- /tools/build_dev_container.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x -e 4 | 5 | docker build \ 6 | -f tools/docker/dev_container.Dockerfile \ 7 | --build-arg TF_VERSION=2.15.0 \ 8 | --build-arg TF_PACKAGE=tensorflow \ 9 | --build-arg PY_VERSION=$PY_VERSION \ 10 | --no-cache \ 11 | --target dev_container \ 12 | -t tfaddons/dev_container:latest-gpu ./ 13 | -------------------------------------------------------------------------------- /tools/docker/build_wheel.Dockerfile: -------------------------------------------------------------------------------- 1 | #syntax=docker/dockerfile:1.1.5-experimental 2 | ARG PY_VERSION 3 | FROM tensorflow/build:2.15-python$PY_VERSION as base_install 4 | 5 | ENV TF_NEED_CUDA="1" 6 | ARG PY_VERSION 7 | ARG TF_VERSION 8 | 9 | # TODO: Temporary due to build bug https://github.com/pypa/pip/issues/11770 10 | RUN python -m pip install pip==22.3.1 11 | 12 | # TODO: Remove this if tensorflow/build container removes their keras-nightly install 13 | # https://github.com/tensorflow/build/issues/78 14 | RUN python -m pip uninstall -y keras-nightly 15 | 16 | RUN python -m pip install --default-timeout=1000 tensorflow==$TF_VERSION 17 | 18 | COPY tools/install_deps/ /install_deps 19 | RUN python -m pip install -r /install_deps/pytest.txt 20 | 21 | COPY requirements.txt . 22 | RUN python -m pip install -r requirements.txt 23 | 24 | COPY ./ /addons 25 | WORKDIR /addons 26 | 27 | # ------------------------------------------------------------------- 28 | FROM base_install as tfa_gpu_tests 29 | CMD ["bash", "tools/testing/build_and_run_tests.sh"] 30 | 31 | # ------------------------------------------------------------------- 32 | FROM base_install as make_wheel 33 | ARG NIGHTLY_FLAG 34 | ARG NIGHTLY_TIME 35 | ARG SKIP_CUSTOM_OP_TESTS 36 | 37 | RUN python configure.py 38 | 39 | # Test Before Building 40 | RUN bash tools/testing/build_and_run_tests.sh $SKIP_CUSTOM_OP_TESTS 41 | 42 | # Build 43 | RUN bazel build \ 44 | --noshow_progress \ 45 | --noshow_loading_progress \ 46 | --verbose_failures \ 47 | --test_output=errors \ 48 | --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain \ 49 | build_pip_pkg && \ 50 | # Package Whl 51 | bazel-bin/build_pip_pkg artifacts $NIGHTLY_FLAG 52 | 53 | RUN bash tools/releases/tf_auditwheel_patch.sh 54 | RUN python -m auditwheel repair --plat manylinux2014_x86_64 artifacts/*.whl 55 | RUN ls -al wheelhouse/ 56 | 57 | # ------------------------------------------------------------------- 58 | 59 | FROM python:$PY_VERSION as test_wheel_in_fresh_environment 60 | 61 | ARG TF_VERSION 62 | ARG SKIP_CUSTOM_OP_TESTS 63 | 64 | RUN python -m pip install --default-timeout=1000 tensorflow==$TF_VERSION 65 | 66 | COPY --from=make_wheel /addons/wheelhouse/ /addons/wheelhouse/ 67 | RUN pip install /addons/wheelhouse/*.whl 68 | 69 | RUN if [[ -z "$SKIP_CUSTOM_OP_TESTS" ]] ; then python -c "import tensorflow_addons as tfa; print(tfa.register_all())" ; else python -c "import tensorflow_addons as tfa; print(tfa.register_all(custom_kernels=False))" ; fi 70 | 71 | # ------------------------------------------------------------------- 72 | FROM scratch as output 73 | 74 | COPY --from=test_wheel_in_fresh_environment /addons/wheelhouse/ . 75 | -------------------------------------------------------------------------------- /tools/docker/cpu_tests.Dockerfile: -------------------------------------------------------------------------------- 1 | #syntax=docker/dockerfile:1.1.5-experimental 2 | FROM python:3.9 as build_wheel 3 | 4 | ARG TF_VERSION=2.15.0 5 | RUN pip install --default-timeout=1000 tensorflow-cpu==$TF_VERSION 6 | 7 | RUN apt-get update && apt-get install -y sudo rsync 8 | COPY tools/install_deps/install_bazelisk.sh .bazeliskrc ./ 9 | RUN bash install_bazelisk.sh 10 | 11 | COPY requirements.txt ./ 12 | RUN pip install -r requirements.txt 13 | 14 | COPY tools/install_deps/pytest.txt ./ 15 | RUN pip install -r pytest.txt pytest-cov 16 | 17 | COPY ./ /addons 18 | WORKDIR addons 19 | RUN python configure.py 20 | RUN pip install -e ./ 21 | RUN --mount=type=cache,id=cache_bazel,target=/root/.cache/bazel \ 22 | bash tools/install_so_files.sh 23 | RUN pytest -v -n auto --durations=25 --doctest-modules ./tensorflow_addons \ 24 | --cov=tensorflow_addons ./tensorflow_addons/ 25 | 26 | RUN bazel build --enable_runfiles build_pip_pkg 27 | RUN bazel-bin/build_pip_pkg artifacts 28 | 29 | 30 | FROM python:3.9 31 | 32 | COPY tools/install_deps/tensorflow-cpu.txt ./ 33 | RUN pip install --default-timeout=1000 -r tensorflow-cpu.txt 34 | 35 | COPY --from=0 /addons/artifacts /artifacts 36 | 37 | RUN pip install /artifacts/tensorflow_addons-*.whl 38 | 39 | # check that we didnd't forget to add a py file to 40 | # The corresponding BUILD file. 41 | # Also test that the wheel works in a fresh environment 42 | RUN python -c "import tensorflow_addons as tfa; print(tfa.register_all())" 43 | -------------------------------------------------------------------------------- /tools/docker/dev_container.Dockerfile: -------------------------------------------------------------------------------- 1 | #syntax=docker/dockerfile:1.1.5-experimental 2 | ARG PY_VERSION 3 | ARG IMAGE_TYPE 4 | 5 | # Currenly all of our dev images are GPU capable but at a cost of being quite large. 6 | # See https://github.com/tensorflow/build/pull/47 7 | FROM tensorflow/build:latest-python$PY_VERSION as dev_container 8 | ARG TF_PACKAGE 9 | ARG TF_VERSION 10 | 11 | RUN pip install --default-timeout=1000 $TF_PACKAGE==$TF_VERSION 12 | 13 | COPY tools/install_deps /install_deps 14 | COPY requirements.txt /tmp/requirements.txt 15 | RUN pip install -r /install_deps/black.txt \ 16 | -r /install_deps/flake8.txt \ 17 | -r /install_deps/pytest.txt \ 18 | -r /install_deps/typedapi.txt \ 19 | -r /tmp/requirements.txt 20 | 21 | RUN bash /install_deps/buildifier.sh 22 | RUN bash /install_deps/clang-format.sh 23 | 24 | ENV ADDONS_DEV_CONTAINER="1" 25 | 26 | # Clean up 27 | RUN apt-get autoremove -y \ 28 | && apt-get clean -y \ 29 | && rm -rf /var/lib/apt/lists/* 30 | -------------------------------------------------------------------------------- /tools/docker/pre-commit.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9 2 | 3 | COPY tools/install_deps /install_deps 4 | RUN pip install -r /install_deps/black.txt -r /install_deps/flake8.txt 5 | 6 | COPY tools/install_deps/buildifier.sh ./buildifier.sh 7 | RUN bash buildifier.sh 8 | 9 | COPY tools/install_deps/clang-format.sh ./clang-format.sh 10 | RUN bash clang-format.sh 11 | 12 | WORKDIR /addons 13 | 14 | 15 | CMD ["python", "tools/format.py"] 16 | -------------------------------------------------------------------------------- /tools/docs/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Doc generator 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | package( 9 | default_visibility = ["//tensorflow_addons:__subpackages__"], 10 | ) 11 | 12 | py_binary( 13 | name = "build_docs", 14 | srcs = ["build_docs.py"], 15 | deps = [ 16 | "//tensorflow_addons", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /tools/docs/Readme.md: -------------------------------------------------------------------------------- 1 | ## 1. Generated API docs 2 | 3 | [tensorflow.org/addons/api_docs/python/tfa](https://tensorflow.org/addons/api_docs/python/tfa) 4 | 5 | `build_docs.py` controls executed this docs generation. To test-run it: 6 | 7 | ```bash 8 | # Install dependencies: 9 | pip install -r tools/install_deps/doc_requirements.txt 10 | 11 | # Build tool: 12 | bazel build //tools/docs:build_docs 13 | 14 | # Generate API doc: 15 | # Use current branch 16 | bazel-bin/tools/docs/build_docs --git_branch=$(git rev-parse --abbrev-ref HEAD) 17 | # or specified explicitly 18 | bazel-bin/tools/docs/build_docs --git_branch=master --output_dir=/tmp/tfa_api 19 | ``` 20 | -------------------------------------------------------------------------------- /tools/format.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from subprocess import check_call, CalledProcessError 3 | 4 | 5 | def check_bash_call(string): 6 | check_call(["bash", "-c", string]) 7 | 8 | 9 | def _run_format_and_flake8(): 10 | files_changed = False 11 | 12 | try: 13 | check_bash_call("python -m black --check ./") 14 | except CalledProcessError: 15 | check_bash_call("python -m black ./") 16 | files_changed = True 17 | 18 | try: 19 | check_bash_call("buildifier -mode=check -r .") 20 | except CalledProcessError: 21 | check_bash_call("buildifier -r .") 22 | files_changed = True 23 | 24 | # todo: find a way to check if files changed 25 | # see https://github.com/DoozyX/clang-format-lint-action for inspiration 26 | check_bash_call( 27 | "shopt -s globstar && clang-format-9 -i --style=google **/*.cc **/*.h", 28 | ) 29 | 30 | if files_changed: 31 | print("Some files have changed.") 32 | print("Please do git add and git commit again") 33 | else: 34 | print("No formatting needed.") 35 | 36 | print("Running flake8.") 37 | check_bash_call("flake8") 38 | print("Done") 39 | 40 | if files_changed: 41 | exit(1) 42 | 43 | 44 | def run_format_and_flake8(): 45 | try: 46 | _run_format_and_flake8() 47 | except CalledProcessError as error: 48 | print("Pre-commit returned exit code", error.returncode) 49 | exit(error.returncode) 50 | 51 | 52 | if __name__ == "__main__": 53 | run_format_and_flake8() 54 | -------------------------------------------------------------------------------- /tools/install_deps/black.txt: -------------------------------------------------------------------------------- 1 | black==22.3.0 2 | -------------------------------------------------------------------------------- /tools/install_deps/buildifier.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | wget -O /usr/local/bin/buildifier https://github.com/bazelbuild/buildtools/releases/download/0.29.0/buildifier 18 | chmod +x /usr/local/bin/buildifier 19 | -------------------------------------------------------------------------------- /tools/install_deps/clang-format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | wget -O /usr/local/bin/clang-format-9 https://github.com/DoozyX/clang-format-lint-action/raw/master/clang-format/clang-format9 19 | chmod +x /usr/local/bin/clang-format-9 20 | ln -s /usr/local/bin/clang-format-9 /usr/local/bin/clang-format 21 | -------------------------------------------------------------------------------- /tools/install_deps/doc_requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/tensorflow/docs@99113f26039f6c042df7f2898e05019dbcdf3675 2 | pyyaml 3 | -------------------------------------------------------------------------------- /tools/install_deps/flake8.txt: -------------------------------------------------------------------------------- 1 | flake8~=4.0 2 | pep8-naming~=0.12.1 3 | -------------------------------------------------------------------------------- /tools/install_deps/install_bazelisk.sh: -------------------------------------------------------------------------------- 1 | # Downloads bazelisk to ${output_dir} as `bazel`. 2 | date 3 | 4 | output_dir=${1:-"/usr/local/bin"} 5 | 6 | case "$(uname -s)" in 7 | Darwin) name=bazelisk-darwin-amd64 ;; 8 | Linux) name=bazelisk-linux-amd64 ;; 9 | *) name=bazelisk-windows-amd64 ;; 10 | esac 11 | 12 | mkdir -p "${output_dir}" 13 | curl -LO "https://github.com/bazelbuild/bazelisk/releases/download/v1.18.0/${name}" 14 | 15 | mv "${name}" "${output_dir}/bazel" 16 | chmod u+x "${output_dir}/bazel" 17 | 18 | if [[ ! ":$PATH:" =~ :${output_dir}/?: ]]; then 19 | PATH="${output_dir}:$PATH" 20 | fi 21 | 22 | which bazel 23 | date 24 | -------------------------------------------------------------------------------- /tools/install_deps/pytest.txt: -------------------------------------------------------------------------------- 1 | pytest~=6.2.5 2 | pytest-xdist~=1.31 3 | pytest-extra-durations~=0.1.3 4 | scikit-learn~=1.2.2 5 | scikit-image~=0.20.0 6 | Pillow~=9.4.0 7 | tqdm>=4.36.1 8 | -------------------------------------------------------------------------------- /tools/install_deps/tensorflow-cpu.txt: -------------------------------------------------------------------------------- 1 | tensorflow-cpu~=2.15.0 2 | -------------------------------------------------------------------------------- /tools/install_deps/tensorflow.txt: -------------------------------------------------------------------------------- 1 | tensorflow~=2.15.0 -------------------------------------------------------------------------------- /tools/install_deps/typedapi.txt: -------------------------------------------------------------------------------- 1 | typedapi~=0.2.0 2 | -------------------------------------------------------------------------------- /tools/install_so_files.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | if [ "$TF_NEED_CUDA" == "1" ]; then 4 | CUDA_FLAG="--crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain" 5 | fi 6 | 7 | bazel build $CUDA_FLAG //tensorflow_addons/... 8 | cp ./bazel-bin/tensorflow_addons/custom_ops/image/_*_ops.so ./tensorflow_addons/custom_ops/image/ 9 | cp ./bazel-bin/tensorflow_addons/custom_ops/layers/_*_ops.so ./tensorflow_addons/custom_ops/layers/ 10 | cp ./bazel-bin/tensorflow_addons/custom_ops/seq2seq/_*_ops.so ./tensorflow_addons/custom_ops/seq2seq/ 11 | cp ./bazel-bin/tensorflow_addons/custom_ops/text/_*_ops.so ./tensorflow_addons/custom_ops/text/ 12 | cp ./bazel-bin/tensorflow_addons/custom_ops/text/_parse_time_op.so ./tensorflow_addons/custom_ops/text/ 13 | -------------------------------------------------------------------------------- /tools/pre-commit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # usage: bash tools/pre-commit.sh 3 | 4 | 5 | set -e 6 | 7 | if [ -z "${ADDONS_DEV_CONTAINER}" ]; then 8 | export DOCKER_BUILDKIT=1 9 | docker build -t tf_addons_formatting -f tools/docker/pre-commit.Dockerfile . 10 | 11 | export MSYS_NO_PATHCONV=1 12 | docker run --rm -t -v "$(pwd -P):/addons" tf_addons_formatting 13 | else 14 | python tools/format.py 15 | fi 16 | -------------------------------------------------------------------------------- /tools/releases/tf_auditwheel_patch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | set -e -x 18 | 19 | SITE_PKG_LOCATION=$(python -c "import site; print(site.getsitepackages()[0])") 20 | TF_SHARED_LIBRARY_NAME=$(grep -r TF_SHARED_LIBRARY_NAME .bazelrc | awk -F= '{print$2}') 21 | POLICY_JSON="${SITE_PKG_LOCATION}/auditwheel/policy/manylinux-policy.json" 22 | sed -i "s/libresolv.so.2\"/libresolv.so.2\", $TF_SHARED_LIBRARY_NAME/g" $POLICY_JSON 23 | -------------------------------------------------------------------------------- /tools/run_build.sh: -------------------------------------------------------------------------------- 1 | # usage: bash tools/run_build.sh 2 | # by default uses docker buildkit. 3 | # to disable it: 4 | # DOCKER_BUILDKIT=0 bash tools/run_build.sh 5 | set -e 6 | 7 | export DOCKER_BUILDKIT=1 8 | docker build -f tools/docker/sanity_check.Dockerfile --target=${1} ./ 9 | -------------------------------------------------------------------------------- /tools/run_cpu_tests.sh: -------------------------------------------------------------------------------- 1 | # usage: bash tools/run_cpu_tests.sh 2 | 3 | set -e 4 | 5 | export DOCKER_BUILDKIT=1 6 | docker build --progress=plain -f tools/docker/cpu_tests.Dockerfile ./ 7 | -------------------------------------------------------------------------------- /tools/run_google_cloud_tests.sh: -------------------------------------------------------------------------------- 1 | set -x -e 2 | 3 | bash tools/run_gpu_tests.sh 4 | -------------------------------------------------------------------------------- /tools/run_gpu_tests.sh: -------------------------------------------------------------------------------- 1 | # usage: bash tools/run_gpu_tests.sh 2 | 3 | set -x -e 4 | 5 | export DOCKER_BUILDKIT=1 6 | docker build \ 7 | -f tools/docker/build_wheel.Dockerfile \ 8 | --target tfa_gpu_tests \ 9 | --build-arg TF_VERSION=2.15.0 \ 10 | --build-arg PY_VERSION=3.9 \ 11 | -t tfa_gpu_tests ./ 12 | docker run --rm -t --gpus=all --shm-size=512m tfa_gpu_tests 13 | -------------------------------------------------------------------------------- /tools/run_sanity_check.sh: -------------------------------------------------------------------------------- 1 | # usage: bash tools/run_sanity_check.sh 2 | 3 | set -e 4 | 5 | export DOCKER_BUILDKIT=1 6 | docker build -f tools/docker/sanity_check.Dockerfile ./ 7 | -------------------------------------------------------------------------------- /tools/testing/build_and_run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # ============================================================================== 17 | # usage: bash tools/testing/build_and_run_tests.sh 18 | 19 | set -x -e 20 | 21 | SKIP_CUSTOM_OP_TESTS_FLAG=${1} 22 | 23 | python -m pip install -r tools/install_deps/pytest.txt -e ./ 24 | python ./configure.py 25 | bash tools/install_so_files.sh 26 | python -c "import tensorflow as tf; print(tf.config.list_physical_devices())" 27 | 28 | bazel clean 29 | python -m pytest -v --functions-durations=20 --modules-durations=5 $SKIP_CUSTOM_OP_TESTS_FLAG ./tensorflow_addons 30 | -------------------------------------------------------------------------------- /tools/update_release_version.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | # Usage 19 | if [ $# -lt 1 ]; then 20 | echo "Usage: bash tools/update_release_version.sh " 21 | echo "e.g. bash tools/update_release_version.sh 2.3.0 2.3.1" 22 | exit 1 23 | fi 24 | 25 | last_version=${BASH_ARGV[0]} 26 | tf_version='' 27 | for ver in $@ 28 | do 29 | if [ -z $tf_version ]; then 30 | tf_version="'$ver'" 31 | else 32 | tf_version="$tf_version, '$ver'" 33 | fi 34 | done 35 | echo $tf_version 36 | echo $last_version 37 | sed -ri "s/(tf-version: \[)'.+'/\1$tf_version/g" \ 38 | .github/workflows/release.yml 39 | sed -ri "s/(tensorflow(-cpu)*(~|=)=)[0-9]+[a-zA-Z0-9_.-]+/\1$1/g" \ 40 | CONTRIBUTING.md \ 41 | tools/install_deps/tensorflow-cpu.txt \ 42 | tools/install_deps/tensorflow.txt 43 | sed -ri "s/(TF_VERSION=)\S+/\1$last_version/g" \ 44 | tools/docker/cpu_tests.Dockerfile \ 45 | tools/run_gpu_tests.sh \ 46 | tools/build_dev_container.sh 47 | --------------------------------------------------------------------------------