├── .venv ├── lib64 ├── bin │ ├── python │ ├── python3 │ ├── python3.8 │ ├── pip │ ├── pip3 │ ├── pip3.8 │ ├── easy_install │ ├── easy_install-3.8 │ ├── activate.csh │ ├── activate │ └── activate.fish └── pyvenv.cfg ├── .bazelversion ├── requirements.txt ├── tensorflow_addons ├── testing │ ├── __init__.py │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ └── serialization_test.py │ └── BUILD ├── tests │ ├── __init__.py │ ├── run_all_test.py │ └── register_test.py ├── utils │ ├── __init__.py │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ ├── test_utils_test.py │ │ └── keras_utils_test.py │ ├── README.md │ ├── BUILD │ └── types.py ├── image │ ├── tests │ │ ├── __init__.py │ │ ├── test_data │ │ │ ├── Yellow_Smiley_Face.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-1-clamp-0.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-1-clamp-1.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-1-clamp-4.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-2-clamp-0.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-2-clamp-1.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-2-clamp-4.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-3-clamp-0.png │ │ │ ├── Yellow_Smiley_Face_Warp-interp-3-clamp-1.png │ │ │ └── Yellow_Smiley_Face_Warp-interp-3-clamp-4.png │ │ ├── run_all_test.py │ │ └── compose_ops_test.py │ ├── BUILD │ ├── README.md │ ├── __init__.py │ └── compose_ops.py ├── layers │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ ├── gelu_test.py │ │ ├── snake_test.py │ │ ├── maxout_test.py │ │ ├── stochastic_depth_test.py │ │ ├── netvlad_test.py │ │ ├── sparsemax_test.py │ │ ├── tlu_test.py │ │ └── esn_test.py │ ├── BUILD │ ├── README.md │ ├── sparsemax.py │ ├── gelu.py │ ├── snake.py │ ├── __init__.py │ └── poincare.py ├── losses │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ └── metric_test.py │ ├── BUILD │ ├── README.md │ ├── __init__.py │ └── giou_loss.ipynb ├── metrics │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ ├── metrics_test.py │ │ └── harmonic_mean_test.py │ ├── BUILD │ ├── __init__.py │ └── harmonic_mean.py ├── rnn │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ └── peephole_lstm_cell_test.py │ ├── BUILD │ ├── __init__.py │ └── README.md ├── seq2seq │ ├── tests │ │ ├── __init__.py │ │ └── run_all_test.py │ └── BUILD ├── text │ ├── tests │ │ ├── __init__.py │ │ └── run_all_test.py │ ├── README.md │ ├── BUILD │ └── __init__.py ├── activations │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ ├── tanhshrink_test.py │ │ ├── snake_test.py │ │ ├── lisht_test.py │ │ ├── mish_test.py │ │ ├── hardshrink_test.py │ │ ├── softshrink_test.py │ │ ├── gelu_test.py │ │ ├── activations_test.py │ │ └── rrelu_test.py │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── tanhshrink.py │ ├── mish.py │ ├── lisht.py │ ├── snake.py │ ├── hardshrink.py │ └── softshrink.py ├── callbacks │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ └── time_stopping_test.py │ ├── BUILD │ ├── README.md │ ├── __init__.py │ └── time_stopping.py ├── optimizers │ ├── tests │ │ ├── __init__.py │ │ ├── run_all_test.py │ │ ├── cocob_test.py │ │ └── standard_test.py │ ├── BUILD │ ├── README.md │ ├── utils.py │ └── __init__.py ├── custom_ops │ ├── README.md │ ├── seq2seq │ │ ├── BUILD │ │ └── cc │ │ │ └── kernels │ │ │ └── beam_search_ops.h │ ├── layers │ │ ├── BUILD │ │ └── cc │ │ │ └── kernels │ │ │ └── correlation_cost_op.h │ ├── text │ │ ├── BUILD │ │ └── cc │ │ │ └── ops │ │ │ └── skip_gram_ops.cc │ └── image │ │ ├── BUILD │ │ └── cc │ │ ├── kernels │ │ └── resampler_ops.h │ │ └── ops │ │ └── resampler_ops.cc ├── conftest.py ├── BUILD ├── __init__.py ├── version.py ├── tensorflow_addons.bzl └── options.py ├── tools ├── install_deps │ ├── black.txt │ ├── tensorflow.txt │ ├── typedapi.txt │ ├── tensorflow-cpu.txt │ ├── flake8.txt │ ├── doc_requirements.txt │ ├── pytest.txt │ ├── install_bazelisk.sh │ ├── buildifier.sh │ └── clang-format.sh ├── run_google_cloud_tests.sh ├── run_sanity_check.sh ├── run_cpu_tests.sh ├── run_build.sh ├── build_dev_container.sh ├── run_gpu_tests.sh ├── docs │ ├── BUILD │ └── Readme.md ├── pre-commit.sh ├── docker │ ├── pre-commit.Dockerfile │ ├── dev_container.Dockerfile │ ├── cpu_tests.Dockerfile │ └── build_wheel.Dockerfile ├── install_so_files.sh ├── releases │ └── tf_auditwheel_patch.sh ├── format.py ├── testing │ └── build_and_run_tests.sh └── update_release_version.sh ├── MANIFEST.in ├── .dockerignore ├── work-items.md ├── pytest.ini ├── .github ├── release-template.yml ├── workflows │ ├── github_build_dev_container.sh │ ├── make_wheel_Linux.sh │ ├── release-drafter.yml │ ├── make_wheel_Windows.sh │ ├── make_wheel_macOS.sh │ ├── notify_codeowners.yml │ └── backport.yml ├── ISSUE_TEMPLATE │ ├── bug-performance-report.md │ └── feature-request.md ├── boring-cyborg.yml └── pull_request_template.md ├── BUILD ├── pyproject.toml ├── .gitignore ├── MIGRATION_TO_CORE.md ├── WORKSPACE ├── docs ├── tutorials │ ├── _toc.yaml │ └── README.md └── README.md ├── .flake8 ├── STYLE_GUIDE.md └── .tours └── giou-loss-overview.tour /.venv/lib64: -------------------------------------------------------------------------------- 1 | lib -------------------------------------------------------------------------------- /.bazelversion: -------------------------------------------------------------------------------- 1 | 3.7.2 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | typeguard>=2.7 2 | -------------------------------------------------------------------------------- /tensorflow_addons/testing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/seq2seq/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/testing/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/text/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/install_deps/black.txt: -------------------------------------------------------------------------------- 1 | black==20.8b1 2 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/README.md: -------------------------------------------------------------------------------- 1 | # Addons Utils 2 | -------------------------------------------------------------------------------- /tools/install_deps/tensorflow.txt: -------------------------------------------------------------------------------- 1 | tensorflow~=2.5.0 -------------------------------------------------------------------------------- /tools/install_deps/typedapi.txt: -------------------------------------------------------------------------------- 1 | typedapi~=0.2.0 2 | -------------------------------------------------------------------------------- /tools/install_deps/tensorflow-cpu.txt: -------------------------------------------------------------------------------- 1 | tensorflow-cpu~=2.5.0 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include tensorflow_addons *.so 2 | include docs/* -------------------------------------------------------------------------------- /tools/install_deps/flake8.txt: -------------------------------------------------------------------------------- 1 | flake8~=3.7.9 2 | pep8-naming~=0.10.0 3 | -------------------------------------------------------------------------------- /tools/run_google_cloud_tests.sh: -------------------------------------------------------------------------------- 1 | set -x -e 2 | 3 | bash tools/run_gpu_tests.sh 4 | -------------------------------------------------------------------------------- /tools/install_deps/doc_requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/tensorflow/docs 2 | pyyaml -------------------------------------------------------------------------------- /.venv/bin/python: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/.venv/bin/python -------------------------------------------------------------------------------- /.venv/bin/python3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/.venv/bin/python3 -------------------------------------------------------------------------------- /.venv/bin/python3.8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/.venv/bin/python3.8 -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .github 3 | *.Dockerfile 4 | .coverage* 5 | # C extensions 6 | *.so 7 | wheelhouse/ 8 | -------------------------------------------------------------------------------- /.venv/pyvenv.cfg: -------------------------------------------------------------------------------- 1 | home = /opt/python/3.8.6/bin 2 | include-system-site-packages = true 3 | version = 3.8.6 4 | -------------------------------------------------------------------------------- /work-items.md: -------------------------------------------------------------------------------- 1 | https://giou.stanford.edu/ 2 | 3 | Figure out how to condense this into an understandable CodeTour. 4 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -ra 3 | doctest_optionflags = ELLIPSIS NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL DONT_ACCEPT_BLANKLINE 4 | -------------------------------------------------------------------------------- /.github/release-template.yml: -------------------------------------------------------------------------------- 1 | template: | 2 | ## Release Notes 3 | 4 | $CHANGES 5 | 6 | ## Thanks to our Contributors 7 | 8 | $CONTRIBUTORS 9 | -------------------------------------------------------------------------------- /tools/run_sanity_check.sh: -------------------------------------------------------------------------------- 1 | # usage: bash tools/run_sanity_check.sh 2 | 3 | set -e 4 | 5 | export DOCKER_BUILDKIT=1 6 | docker build -f tools/docker/sanity_check.Dockerfile ./ 7 | -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face.png -------------------------------------------------------------------------------- /tools/install_deps/pytest.txt: -------------------------------------------------------------------------------- 1 | pytest~=5.3 2 | pytest-xdist~=1.31 3 | pytest-extra-durations~=0.1.3 4 | scikit-learn~=0.22 5 | scikit-image~=0.17.0 6 | Pillow~=8.0.0 7 | tqdm>=4.36.1 8 | -------------------------------------------------------------------------------- /tools/run_cpu_tests.sh: -------------------------------------------------------------------------------- 1 | # usage: bash tools/run_cpu_tests.sh 2 | 3 | set -e 4 | 5 | export DOCKER_BUILDKIT=1 6 | docker build --progress=plain -f tools/docker/cpu_tests.Dockerfile ./ 7 | -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamicwebpaige/addons/master/tensorflow_addons/image/tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png -------------------------------------------------------------------------------- /.github/workflows/github_build_dev_container.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x -e 4 | 5 | df -h 6 | docker info 7 | # to get more disk space 8 | rm -rf /usr/share/dotnet & 9 | 10 | tools/build_dev_container.sh 11 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/seq2seq/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/testing/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/text/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/tests/run_all_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | import pytest 5 | 6 | if __name__ == "__main__": 7 | dirname = Path(__file__).absolute().parent 8 | sys.exit(pytest.main([str(dirname)])) 9 | -------------------------------------------------------------------------------- /tools/run_build.sh: -------------------------------------------------------------------------------- 1 | # usage: bash tools/run_build.sh 2 | # by default uses docker buildkit. 3 | # to disable it: 4 | # DOCKER_BUILDKIT=0 bash tools/run_build.sh 5 | set -e 6 | 7 | export DOCKER_BUILDKIT=1 8 | docker build -f tools/docker/sanity_check.Dockerfile --target=${1} ./ 9 | -------------------------------------------------------------------------------- /.venv/bin/pip: -------------------------------------------------------------------------------- 1 | #!/workspaces/addons/.venv/bin/python3.8 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from pip._internal.cli.main import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | sh_binary( 2 | name = "build_pip_pkg", 3 | srcs = ["build_deps/build_pip_pkg.sh"], 4 | data = [ 5 | "LICENSE", 6 | "MANIFEST.in", 7 | "requirements.txt", 8 | "setup.py", 9 | "//tensorflow_addons", 10 | ], 11 | ) 12 | -------------------------------------------------------------------------------- /.venv/bin/pip3: -------------------------------------------------------------------------------- 1 | #!/workspaces/addons/.venv/bin/python3.8 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from pip._internal.cli.main import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /.venv/bin/pip3.8: -------------------------------------------------------------------------------- 1 | #!/workspaces/addons/.venv/bin/python3.8 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from pip._internal.cli.main import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /.venv/bin/easy_install: -------------------------------------------------------------------------------- 1 | #!/workspaces/addons/.venv/bin/python3.8 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from setuptools.command.easy_install import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /.venv/bin/easy_install-3.8: -------------------------------------------------------------------------------- 1 | #!/workspaces/addons/.venv/bin/python3.8 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from setuptools.command.easy_install import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /tools/build_dev_container.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x -e 4 | 5 | DOCKER_BUILDKIT=1 docker build \ 6 | -f tools/docker/dev_container.Dockerfile \ 7 | --build-arg TF_VERSION=2.5.0 \ 8 | --build-arg TF_PACKAGE=tensorflow-cpu \ 9 | --target dev_container_cpu \ 10 | -t tfaddons/dev_container:latest-cpu ./ 11 | -------------------------------------------------------------------------------- /tools/run_gpu_tests.sh: -------------------------------------------------------------------------------- 1 | # usage: bash tools/run_gpu_tests.sh 2 | 3 | set -x -e 4 | 5 | export DOCKER_BUILDKIT=1 6 | docker build \ 7 | -f tools/docker/build_wheel.Dockerfile \ 8 | --target tfa_gpu_tests \ 9 | --build-arg TF_VERSION=2.5.0 \ 10 | --build-arg PY_VERSION=3.6 \ 11 | -t tfa_gpu_tests ./ 12 | docker run --rm -t --gpus=all tfa_gpu_tests 13 | -------------------------------------------------------------------------------- /tools/docs/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Doc generator 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | package( 9 | default_visibility = ["//tensorflow_addons:__subpackages__"], 10 | ) 11 | 12 | py_binary( 13 | name = "build_docs", 14 | srcs = ["build_docs.py"], 15 | deps = [ 16 | "//tensorflow_addons", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | target-version = ['py36', 'py37', 'py38'] 3 | exclude = ''' 4 | ( 5 | /( 6 | \.eggs # exclude a few common directories in the 7 | | \.git # root of the project 8 | | \.hg 9 | | \.mypy_cache 10 | | \.tox 11 | | \.venv 12 | | _build 13 | | buck-out 14 | | build 15 | | dist 16 | )/ 17 | ) 18 | ''' 19 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Custom Ops 2 | 3 | ## Contents 4 | | Sub-Package | Description | 5 | |:----------------------- |:-----------------------------| 6 | | Image | Ops for image manipulation | 7 | | Seq2seq | Ops for seq2seq encoder-decoder framework | 8 | | Text | Ops for text processing | 9 | | Layers | Ops for model layers | 10 | -------------------------------------------------------------------------------- /tools/pre-commit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # usage: bash tools/pre-commit.sh 3 | 4 | 5 | set -e 6 | 7 | if [ -z "${ADDONS_DEV_CONTAINER}" ]; then 8 | export DOCKER_BUILDKIT=1 9 | docker build -t tf_addons_formatting -f tools/docker/pre-commit.Dockerfile . 10 | 11 | export MSYS_NO_PATHCONV=1 12 | docker run --rm -t -v "$(pwd -P):/addons" tf_addons_formatting 13 | else 14 | python tools/format.py 15 | fi 16 | -------------------------------------------------------------------------------- /.github/workflows/make_wheel_Linux.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | df -h 4 | docker info 5 | # to get more disk space 6 | rm -rf /usr/share/dotnet & 7 | 8 | DOCKER_BUILDKIT=1 docker build \ 9 | -f tools/docker/build_wheel.Dockerfile \ 10 | --output type=local,dest=wheelhouse \ 11 | --build-arg PY_VERSION \ 12 | --build-arg TF_VERSION \ 13 | --build-arg NIGHTLY_FLAG \ 14 | --build-arg NIGHTLY_TIME \ 15 | ./ 16 | -------------------------------------------------------------------------------- /tensorflow_addons/testing/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "testing", 7 | srcs = glob(["*.py"]), 8 | ) 9 | 10 | py_test( 11 | name = "serialization_test", 12 | size = "small", 13 | srcs = glob(["tests/*"]), 14 | main = "tests/run_all_test.py", 15 | deps = [ 16 | ":testing", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /tools/docker/pre-commit.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6 2 | 3 | COPY tools/install_deps /install_deps 4 | RUN pip install -r /install_deps/black.txt -r /install_deps/flake8.txt 5 | 6 | COPY tools/install_deps/buildifier.sh ./buildifier.sh 7 | RUN bash buildifier.sh 8 | 9 | COPY tools/install_deps/clang-format.sh ./clang-format.sh 10 | RUN bash clang-format.sh 11 | 12 | WORKDIR /addons 13 | 14 | 15 | CMD ["python", "tools/format.py"] 16 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: release-drafter 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - r* 8 | 9 | jobs: 10 | update_release_draft: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: release-drafter/release-drafter@74e7c423dafbb406c9c18b1638334f67a7c891c3 # Version 5.7.0 14 | with: 15 | config-name: release-template.yml 16 | env: 17 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 18 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "rnn", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/testing", 10 | "//tensorflow_addons/utils", 11 | ], 12 | ) 13 | 14 | py_test( 15 | name = "rnn_test", 16 | size = "small", 17 | srcs = glob(["tests/*"]), 18 | main = "tests/run_all_test.py", 19 | deps = [ 20 | ":rnn", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "metrics", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/testing", 10 | "//tensorflow_addons/utils", 11 | ], 12 | ) 13 | 14 | py_test( 15 | name = "metrics_test", 16 | size = "small", 17 | srcs = glob(["tests/*"]), 18 | main = "tests/run_all_test.py", 19 | deps = [ 20 | ":metrics", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "utils", 7 | srcs = glob(["*.py"]), 8 | data = [ 9 | "//tensorflow_addons:conftest.py", 10 | "//tensorflow_addons:options.py", 11 | ], 12 | ) 13 | 14 | py_test( 15 | name = "keras_utils_test", 16 | size = "small", 17 | srcs = glob(["tests/*"]), 18 | main = "tests/run_all_test.py", 19 | deps = [ 20 | ":utils", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "optimizers", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/testing", 10 | "//tensorflow_addons/utils", 11 | ], 12 | ) 13 | 14 | py_test( 15 | name = "optimizers_test", 16 | size = "small", 17 | srcs = glob(["tests/*"]), 18 | main = "tests/run_all_test.py", 19 | deps = [ 20 | ":optimizers", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /.github/workflows/make_wheel_Windows.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | export TF_NEED_CUDA=0 4 | export BAZEL_VC="C:/Program Files (x86)/Microsoft Visual Studio/2019/Enterprise/VC/" 5 | 6 | python -m pip install --default-timeout=1000 wheel setuptools tensorflow==$TF_VERSION 7 | bash ./tools/testing/build_and_run_tests.sh 8 | 9 | python configure.py 10 | 11 | bazel.exe build \ 12 | --noshow_progress \ 13 | --noshow_loading_progress \ 14 | --verbose_failures \ 15 | --test_output=errors \ 16 | build_pip_pkg 17 | bazel-bin/build_pip_pkg wheelhouse $NIGHTLY_FLAG 18 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/seq2seq/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load("//tensorflow_addons:tensorflow_addons.bzl", "custom_op_library") 6 | 7 | custom_op_library( 8 | name = "_beam_search_ops.so", 9 | srcs = [ 10 | "cc/kernels/beam_search_ops.cc", 11 | "cc/kernels/beam_search_ops.h", 12 | "cc/ops/beam_search_ops.cc", 13 | ], 14 | cuda_srcs = [ 15 | "cc/kernels/beam_search_ops.h", 16 | "cc/kernels/beam_search_ops_gpu.cu.cc", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "losses", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/activations", 10 | "//tensorflow_addons/testing", 11 | "//tensorflow_addons/utils", 12 | ], 13 | ) 14 | 15 | py_test( 16 | name = "losses_test", 17 | size = "small", 18 | srcs = glob(["tests/*"]), 19 | main = "tests/run_all_test.py", 20 | deps = [ 21 | ":losses", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "callbacks", 7 | srcs = glob(["*.py"]), 8 | deps = [ 9 | "//tensorflow_addons/optimizers", 10 | "//tensorflow_addons/testing", 11 | "//tensorflow_addons/utils", 12 | ], 13 | ) 14 | 15 | py_test( 16 | name = "callbacks_test", 17 | size = "small", 18 | srcs = glob(["tests/*"]), 19 | main = "tests/run_all_test.py", 20 | deps = [ 21 | ":callbacks", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /.github/workflows/make_wheel_macOS.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | export TF_NEED_CUDA=0 4 | 5 | python --version 6 | python -m pip install --default-timeout=1000 delocate wheel setuptools tensorflow==$TF_VERSION 7 | 8 | python configure.py 9 | 10 | bazel build \ 11 | --copt -mmacosx-version-min=10.13 \ 12 | --linkopt -mmacosx-version-min=10.13 \ 13 | --noshow_progress \ 14 | --noshow_loading_progress \ 15 | --verbose_failures \ 16 | --test_output=errors \ 17 | build_pip_pkg 18 | 19 | bazel-bin/build_pip_pkg artifacts $NIGHTLY_FLAG 20 | delocate-wheel -w wheelhouse artifacts/*.whl 21 | 22 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "activations", 7 | srcs = glob(["*.py"]), 8 | data = [ 9 | "//tensorflow_addons:options.py", 10 | "//tensorflow_addons/testing", 11 | "//tensorflow_addons/utils", 12 | ], 13 | ) 14 | 15 | py_test( 16 | name = "activations_test", 17 | size = "small", 18 | srcs = glob(["tests/*"]), 19 | main = "run_all_test.py", 20 | deps = [ 21 | ":activations", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | pip-wheel-metadata 27 | 28 | # Jupyter Notebook 29 | .ipynb_checkpoints 30 | 31 | # IDE 32 | .vscode/ 33 | .idea/ 34 | *.iml 35 | 36 | # Build 37 | /.bazelrc 38 | /bazel-* 39 | /artifacts 40 | .bazelrc 41 | 42 | .coverage* 43 | htmlcov 44 | 45 | wheelhouse/ 46 | -------------------------------------------------------------------------------- /MIGRATION_TO_CORE.md: -------------------------------------------------------------------------------- 1 | # Migration From TF-Addons To TensorFlow Core / Keras 2 | 3 | **Given the challenges of external SIG coordinating with internal roadmaps, a new 4 | process has been put in place for the core TF and Keras teams to handle migration 5 | and deprecation of Addons components. If you believe there is a strong candidate for 6 | migration please post an issue and we'll escalate it to the respective team members.** 7 | 8 | ### Criteria for Migration 9 | * The addition is widely used throughout the community 10 | * The addition is unlikely to have API changes as time progresses 11 | * The addition is well written / tested 12 | -------------------------------------------------------------------------------- /tools/install_deps/install_bazelisk.sh: -------------------------------------------------------------------------------- 1 | # Downloads bazelisk to ${output_dir} as `bazel`. 2 | date 3 | 4 | output_dir=${1:-"/usr/local/bin"} 5 | 6 | case "$(uname -s)" in 7 | Darwin) name=bazelisk-darwin-amd64 ;; 8 | Linux) name=bazelisk-linux-amd64 ;; 9 | *) name=bazelisk-windows-amd64 ;; 10 | esac 11 | 12 | mkdir -p "${output_dir}" 13 | curl -LO "https://github.com/bazelbuild/bazelisk/releases/download/v1.3.0/${name}" 14 | 15 | mv "${name}" "${output_dir}/bazel" 16 | chmod u+x "${output_dir}/bazel" 17 | 18 | if [[ ! ":$PATH:" =~ :${output_dir}/?: ]]; then 19 | PATH="${output_dir}:$PATH" 20 | fi 21 | 22 | which bazel 23 | date 24 | -------------------------------------------------------------------------------- /tools/docs/Readme.md: -------------------------------------------------------------------------------- 1 | ## 1. Generated API docs 2 | 3 | [tensorflow.org/addons/api_docs/python/tfa](https://tensorflow.org/addons/api_docs/python/tfa) 4 | 5 | `build_docs.py` controls executed this docs generation. To test-run it: 6 | 7 | ```bash 8 | # Install dependencies: 9 | pip install -r tools/install_deps/doc_requirements.txt 10 | 11 | # Build tool: 12 | bazel build //tools/docs:build_docs 13 | 14 | # Generate API doc: 15 | # Use current branch 16 | bazel-bin/tools/docs/build_docs --git_branch=$(git rev-parse --abbrev-ref HEAD) 17 | # or specified explicitly 18 | bazel-bin/tools/docs/build_docs --git_branch=master --output_dir=/tmp/tfa_api 19 | ``` 20 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/layers/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load("//tensorflow_addons:tensorflow_addons.bzl", "custom_op_library") 6 | 7 | custom_op_library( 8 | name = "_correlation_cost_ops.so", 9 | srcs = [ 10 | "cc/kernels/correlation_cost_op.cc", 11 | "cc/kernels/correlation_cost_op.h", 12 | "cc/ops/correlation_cost_op.cc", 13 | ], 14 | cuda_deps = [ 15 | "@cub_archive//:cub", 16 | ], 17 | cuda_srcs = [ 18 | "cc/kernels/correlation_cost_op.h", 19 | "cc/kernels/correlation_cost_op_gpu.cu.cc", 20 | ], 21 | ) 22 | -------------------------------------------------------------------------------- /tensorflow_addons/seq2seq/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "seq2seq", 7 | srcs = glob(["*.py"]), 8 | data = [ 9 | "//tensorflow_addons:options.py", 10 | "//tensorflow_addons/custom_ops/seq2seq:_beam_search_ops.so", 11 | ], 12 | deps = [ 13 | "//tensorflow_addons/testing", 14 | "//tensorflow_addons/utils", 15 | ], 16 | ) 17 | 18 | py_test( 19 | name = "seq2seq_test", 20 | size = "medium", 21 | srcs = glob(["tests/*"]), 22 | main = "tests/run_all_test.py", 23 | deps = [ 24 | ":seq2seq", 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/text/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load("//tensorflow_addons:tensorflow_addons.bzl", "custom_op_library") 6 | 7 | custom_op_library( 8 | name = "_skip_gram_ops.so", 9 | srcs = [ 10 | "cc/kernels/skip_gram_kernels.cc", 11 | "cc/ops/skip_gram_ops.cc", 12 | ], 13 | ) 14 | 15 | custom_op_library( 16 | name = "_parse_time_op.so", 17 | srcs = select({ 18 | "//tensorflow_addons:windows": [], 19 | "//conditions:default": [ 20 | "cc/kernels/parse_time_kernel.cc", 21 | "cc/ops/parse_time_op.cc", 22 | ], 23 | }), 24 | ) 25 | -------------------------------------------------------------------------------- /tools/install_so_files.sh: -------------------------------------------------------------------------------- 1 | set -e -x 2 | 3 | if [ "$TF_NEED_CUDA" == "1" ]; then 4 | CUDA_FLAG="--crosstool_top=//build_deps/toolchains/gcc7_manylinux2010-nvcc-cuda11:toolchain" 5 | fi 6 | 7 | bazel build $CUDA_FLAG //tensorflow_addons/... 8 | cp ./bazel-bin/tensorflow_addons/custom_ops/image/_*_ops.so ./tensorflow_addons/custom_ops/image/ 9 | cp ./bazel-bin/tensorflow_addons/custom_ops/layers/_*_ops.so ./tensorflow_addons/custom_ops/layers/ 10 | cp ./bazel-bin/tensorflow_addons/custom_ops/seq2seq/_*_ops.so ./tensorflow_addons/custom_ops/seq2seq/ 11 | cp ./bazel-bin/tensorflow_addons/custom_ops/text/_*_ops.so ./tensorflow_addons/custom_ops/text/ 12 | cp ./bazel-bin/tensorflow_addons/custom_ops/text/_parse_time_op.so ./tensorflow_addons/custom_ops/text/ 13 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | load("//build_deps/tf_dependency:tf_configure.bzl", "tf_configure") 3 | load("//build_deps/toolchains/gpu:cuda_configure.bzl", "cuda_configure") 4 | 5 | http_archive( 6 | name = "cub_archive", 7 | build_file = "//build_deps/toolchains/gpu:cub.BUILD", 8 | sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3", 9 | strip_prefix = "cub-1.8.0", 10 | urls = [ 11 | "https://storage.googleapis.com/mirror.tensorflow.org/github.com/NVlabs/cub/archive/1.8.0.zip", 12 | "https://github.com/NVlabs/cub/archive/1.8.0.zip", 13 | ], 14 | ) 15 | 16 | tf_configure( 17 | name = "local_config_tf", 18 | ) 19 | 20 | cuda_configure(name = "local_config_cuda") 21 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], # Apache 2.0 4 | ) 5 | 6 | py_library( 7 | name = "layers", 8 | srcs = glob(["*.py"]), 9 | data = [ 10 | "//tensorflow_addons/custom_ops/layers:_correlation_cost_ops.so", 11 | ], 12 | deps = [ 13 | "//tensorflow_addons/activations", 14 | "//tensorflow_addons/rnn", 15 | "//tensorflow_addons/testing", 16 | "//tensorflow_addons/text", 17 | "//tensorflow_addons/utils", 18 | ], 19 | ) 20 | 21 | py_test( 22 | name = "layers_test", 23 | size = "large", 24 | srcs = glob(["tests/*"]), 25 | main = "tests/run_all_test.py", 26 | deps = [ 27 | ":layers", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /.github/workflows/notify_codeowners.yml: -------------------------------------------------------------------------------- 1 | name: Notify codeowners 2 | 3 | on: 4 | pull_request: 5 | types: [opened] 6 | 7 | 8 | jobs: 9 | notify-codeowners: 10 | name: Notify codeowners 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.7 17 | - run: pip install pygithub click 18 | - name: Drop a message for codeowners 19 | env: 20 | PR: ${{ steps.findPr.outputs.pr }} 21 | run: | 22 | python .github/workflows/notify_codeowners.py \ 23 | --pull-request-id=auto \ 24 | --no-dry-run \ 25 | https://raw.githubusercontent.com/tensorflow/addons/master/.github/CODEOWNERS 26 | -------------------------------------------------------------------------------- /.github/workflows/backport.yml: -------------------------------------------------------------------------------- 1 | name: Backport 2 | on: 3 | pull_request: 4 | types: 5 | - closed 6 | - labeled 7 | 8 | jobs: 9 | backport: 10 | runs-on: ubuntu-18.04 11 | name: Backport 12 | steps: 13 | - name: Backport Bot 14 | if: github.event.pull_request.merged && ( ( github.event.action == 'closed' && contains( join( github.event.pull_request.labels.*.name ), 'backport') ) || contains( github.event.label.name, 'backport' ) ) 15 | uses: Gaurav0/backport@d69fd1d2469762a7b4007f671857e4f94deed0af # Version 1.0.24 16 | with: 17 | bot_username: bot-of-gabrieldemarmiesse 18 | bot_token: 1353d990cdb8b8ceb1b73d301dce83cc0da3db29 19 | bot_token_key: a1b2c3d47311f8e29e204f85a81b4df4a44e252c 20 | github_token: ${{ secrets.GITHUB_TOKEN }} 21 | -------------------------------------------------------------------------------- /tensorflow_addons/image/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | py_library( 6 | name = "image", 7 | srcs = glob(["*.py"]), 8 | data = [ 9 | ":sparse_image_warp_test_data", 10 | "//tensorflow_addons/custom_ops/image:_distort_image_ops.so", 11 | "//tensorflow_addons/custom_ops/image:_image_ops.so", 12 | "//tensorflow_addons/custom_ops/image:_resampler_ops.so", 13 | "//tensorflow_addons/testing", 14 | "//tensorflow_addons/utils", 15 | ], 16 | ) 17 | 18 | filegroup( 19 | name = "sparse_image_warp_test_data", 20 | srcs = glob(["tests/test_data/*.png"]), 21 | ) 22 | 23 | py_test( 24 | name = "image_test", 25 | size = "small", 26 | srcs = glob(["tests/*"]), 27 | main = "tests/run_all_test.py", 28 | deps = [ 29 | ":image", 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-performance-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug/Performance Issue 3 | about: Use this template for reporting a bug or a performance issue. 4 | 5 | --- 6 | 7 | **System information** 8 | - OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 9 | - TensorFlow version and how it was installed (source or binary): 10 | - TensorFlow-Addons version and how it was installed (source or binary): 11 | - Python version: 12 | - Is GPU used? (yes/no): 13 | 14 | **Describe the bug** 15 | 16 | A clear and concise description of what the bug is. 17 | 18 | **Code to reproduce the issue** 19 | 20 | Provide a reproducible test case that is the bare minimum necessary to generate the problem. 21 | 22 | **Other info / logs** 23 | 24 | Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. 25 | -------------------------------------------------------------------------------- /tools/install_deps/buildifier.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | wget -O /usr/local/bin/buildifier https://github.com/bazelbuild/buildtools/releases/download/0.29.0/buildifier 18 | chmod +x /usr/local/bin/buildifier 19 | -------------------------------------------------------------------------------- /docs/tutorials/_toc.yaml: -------------------------------------------------------------------------------- 1 | toc: 2 | - title: Overview 3 | path: /addons/overview 4 | - heading: Tutorials 5 | - title: Triplet loss 6 | path: /addons/tutorials/losses_triplet 7 | - title: Image Ops 8 | path: /addons/tutorials/image_ops 9 | - title: Normalization layers 10 | path: /addons/tutorials/layers_normalizations 11 | - title: Weight normalization layer 12 | path: /addons/tutorials/layers_weightnormalization 13 | - title: LazyAdam optimizer 14 | path: /addons/tutorials/optimizers_lazyadam 15 | - title: ConditionalGradient Optimizer 16 | path: /addons/tutorials/optimizers_conditionalgradient 17 | - title: TQDM Progress Bar 18 | path: /addons/tutorials/tqdm_progress_bar 19 | - title: Seq2Seq for Translation 20 | path: /addons/tutorials/networks_seq2seq_nmt 21 | - title: Moving Average Optimizer Checkpoint 22 | path: /addons/tutorials/average_optimizers_callback 23 | - title: Time Stopping Callback 24 | path: /addons/tutorials/time_stopping -------------------------------------------------------------------------------- /tensorflow_addons/text/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Text 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/text 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all text ops 9 | must: 10 | * Be impossible to implement in one of the other API 11 | standards (Layers, Losses, etc.). 12 | * Be related to text processing. 13 | 14 | #### Testing Requirements 15 | * Simple unittests that demonstrate the text op is behaving as 16 | expected. 17 | * To run your `tf.functions` in eager mode and graph mode in the tests, 18 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 19 | decorator. This will run the tests twice, once normally, and once 20 | with `tf.config.run_functions_eagerly(True)`. 21 | 22 | #### Documentation Requirements 23 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 24 | -------------------------------------------------------------------------------- /tools/docker/dev_container.Dockerfile: -------------------------------------------------------------------------------- 1 | #syntax=docker/dockerfile:1.1.5-experimental 2 | FROM tensorflow/tensorflow:2.1.0-custom-op-ubuntu16 as dev_container_cpu 3 | ARG TF_PACKAGE 4 | ARG TF_VERSION 5 | 6 | # Temporary until custom-op container is updated 7 | RUN ln -sf /usr/bin/python3 /usr/bin/python 8 | RUN ln -sf /usr/local/bin/pip3 /usr/local/bin/pip 9 | RUN pip install --default-timeout=1000 $TF_PACKAGE==$TF_VERSION 10 | 11 | COPY tools/install_deps /install_deps 12 | COPY requirements.txt /tmp/requirements.txt 13 | RUN pip install -r /install_deps/black.txt \ 14 | -r /install_deps/flake8.txt \ 15 | -r /install_deps/pytest.txt \ 16 | -r /install_deps/typedapi.txt \ 17 | -r /tmp/requirements.txt 18 | 19 | RUN bash /install_deps/buildifier.sh 20 | RUN bash /install_deps/clang-format.sh 21 | 22 | ENV ADDONS_DEV_CONTAINER="1" 23 | 24 | # Clean up 25 | RUN apt-get autoremove -y \ 26 | && apt-get clean -y \ 27 | && rm -rf /var/lib/apt/lists/* 28 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Optimizers 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all optimizers 9 | must: 10 | * Inherit from either `keras.optimizer_v2.OptimizerV2` or its subclasses. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * To run your `tf.functions` in eager mode and graph mode in the tests, 15 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 16 | decorator. This will run the tests twice, once normally, and once 17 | with `tf.config.run_functions_eagerly(True)`. 18 | 19 | #### Documentation Requirements 20 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 21 | -------------------------------------------------------------------------------- /tensorflow_addons/image/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Image 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/image 5 | 6 | 7 | ## Contribution Guidelines 8 | #### Standard API 9 | In order to conform with the current API standard, all image ops 10 | must: 11 | * Be a standard image processing technique 12 | * Must be impossible to implement in one of the other API 13 | standards (Layers, Losses, etc.). 14 | 15 | #### Testing Requirements 16 | * Simple unittests that demonstrate the image op is behaving as 17 | expected. 18 | * To run your `tf.functions` in eager mode and graph mode in the tests, 19 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 20 | decorator. This will run the tests twice, once normally, and once 21 | with `tf.config.run_functions_eagerly(True)`. 22 | 23 | #### Documentation Requirements 24 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Use this template for raising a feature request 4 | 5 | --- 6 | 7 | **Describe the feature and the current behavior/state.** 8 | 9 | **Relevant information** 10 | - Are you willing to contribute it (yes/no): 11 | *If you wish to contribute, then read the requirements for new contributions in [`CONTRIBUTING.md`](https://github.com/tensorflow/addons/blob/master/CONTRIBUTING.md#requirements-for-new-contributions-to-the-repository)* 12 | - Are you willing to maintain it going forward? (yes/no): 13 | - Is there a relevant academic paper? (if so, where): 14 | - Does the relavent academic paper exceed 50 citations? (yes/no): 15 | - Is there already an implementation in another framework? (if so, where): 16 | - Was it part of tf.contrib? (if so, where): 17 | 18 | **Which API type would this fall under (layer, metric, optimizer, etc.)** 19 | 20 | **Who will benefit with this feature?** 21 | 22 | **Any other info.** 23 | -------------------------------------------------------------------------------- /docs/tutorials/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Addons Tutorials 2 | 3 | TensorFlow Addons welcomes and highly encourages tutorial contributions. 4 | 5 | 6 | ## How To Contribute 7 | 8 | Addons tutorials are created using [Google Colab](https://colab.research.google.com/) 9 | and the jupyter notebooks are saved to this directory in the repository. To do 10 | this, follow the below steps: 11 | 12 | 1. Create a new branch on your fork of TensorFlow Addons 13 | 2. Goto [Google Colab](https://colab.research.google.com/) and start a new 14 | notebook using addons example template: 15 | [docs/tutorials/_template.ipynb](_template.ipynb) 16 | 3. Edit the links for the "View source on GitHub" and "Run in Google Colab" 17 | URL boxes so that they match the name of your new example notebook 18 | 4. Follow the guidelines of the template 19 | 5. "Save a copy in GitHub" and select your new branch. The notebook should be 20 | named `subpackage_submodule` 21 | 6. Submit the branch as a PR on the TF-Addons GitHub 22 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Callbacks 2 | 3 | ## Contents 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/callbacks 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all callbacks 9 | must: 10 | * Inherit from `tf.keras.callbacks.Callback`. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * Simple unittests that demonstrate the callback is behaving as expected. 15 | * To run your `tf.functions` in eager mode and graph mode in the tests, 16 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 17 | decorator. This will run the tests twice, once normally, and once 18 | with `tf.config.run_functions_eagerly(True)`. 19 | 20 | #### Documentation Requirements 21 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 22 | -------------------------------------------------------------------------------- /tensorflow_addons/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | import tensorflow_addons as tfa 6 | 7 | from tensorflow_addons.utils.test_utils import ( # noqa: F401 8 | maybe_run_functions_eagerly, 9 | only_run_functions_eagerly, 10 | run_custom_and_py_ops, 11 | run_with_mixed_precision_policy, 12 | pytest_make_parametrize_id, 13 | data_format, 14 | set_seeds, 15 | pytest_addoption, 16 | set_global_variables, 17 | pytest_configure, 18 | device, 19 | pytest_generate_tests, 20 | pytest_collection_modifyitems, 21 | ) 22 | 23 | # fixtures present in this file will be available 24 | # when running tests and can be referenced with strings 25 | # https://docs.pytest.org/en/latest/fixture.html#conftest-py-sharing-fixture-functions 26 | 27 | 28 | @pytest.fixture(autouse=True) 29 | def add_doctest_namespace(doctest_namespace): 30 | doctest_namespace["np"] = np 31 | doctest_namespace["tf"] = tf 32 | doctest_namespace["tfa"] = tfa 33 | -------------------------------------------------------------------------------- /.github/boring-cyborg.yml: -------------------------------------------------------------------------------- 1 | labelPRBasedOnFilePath: 2 | # Subpackages 3 | activations: 4 | - tensorflow_addons/activations/**/* 5 | 6 | callbacks: 7 | - tensorflow_addons/callbacks/**/* 8 | 9 | custom-ops: 10 | - tensorflow_addons/custom_ops/**/* 11 | 12 | image: 13 | - tensorflow_addons/image/**/* 14 | 15 | layers: 16 | - tensorflow_addons/layers/**/* 17 | 18 | losses: 19 | - tensorflow_addons/losses/**/* 20 | 21 | metrics: 22 | - tensorflow_addons/metrics/**/* 23 | 24 | optimizers: 25 | - tensorflow_addons/optimizers/**/* 26 | 27 | seq2seq: 28 | - tensorflow_addons/seq2seq/**/* 29 | 30 | text: 31 | - tensorflow_addons/text/**/* 32 | 33 | # Other labels 34 | build: 35 | - build_deps/**/* 36 | - tools/releases/**/* 37 | 38 | documentation: 39 | - docs/**/* 40 | 41 | tutorials: 42 | - docs/tutorials/**/* 43 | 44 | test-cases: 45 | - tools/testing/**/ 46 | 47 | style: 48 | - STYLE_GUIDE.md 49 | 50 | github: 51 | - .github/**/* 52 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Losses 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/losses 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all losses 9 | must: 10 | * Inherit from `keras.losses.Loss`. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * Simple unittests that demonstrate the loss is behaving as expected on 15 | some set of known inputs and outputs. 16 | * To run your `tf.functions` in eager mode and graph mode in the tests, 17 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 18 | decorator. This will run the tests twice, once normally, and once 19 | with `tf.config.run_functions_eagerly(True)`. 20 | 21 | #### Documentation Requirements 22 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 23 | 24 | -------------------------------------------------------------------------------- /tools/install_deps/clang-format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | wget -O /usr/local/bin/clang-format-9 https://github.com/DoozyX/clang-format-lint-action/raw/master/clang-format/clang-format9 19 | chmod +x /usr/local/bin/clang-format-9 20 | ln -s /usr/local/bin/clang-format-9 /usr/local/bin/clang-format 21 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Layers 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/layers 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all layers 9 | must: 10 | * Inherit from either `keras.layers.Layer` or its subclasses. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * Simple unittests that demonstrate the layer is behaving as expected. 15 | * To run your `tf.functions` in eager mode and graph mode in the tests, 16 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 17 | decorator. This will run the tests twice, once normally, and once 18 | with `tf.config.run_functions_eagerly(True)`. 19 | * Run `layer_test` on the layer. 20 | 21 | #### Documentation Requirements 22 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/tests/register_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | import tensorflow as tf 5 | from tensorflow_addons.register import register_all, _get_all_shared_objects 6 | from tensorflow_addons.utils import resource_loader 7 | 8 | 9 | def test_multiple_register(): 10 | if resource_loader.SKIP_CUSTOM_OPS: 11 | pytest.skip( 12 | "Skipping the test because a custom ops " 13 | "was being loaded while --skip-custom-ops was set." 14 | ) 15 | register_all() 16 | register_all() 17 | 18 | 19 | def test_get_all_shared_objects(): 20 | if resource_loader.SKIP_CUSTOM_OPS: 21 | pytest.skip( 22 | "Skipping the test because a custom ops " 23 | "was being loaded while --skip-custom-ops was set." 24 | ) 25 | all_shared_objects = _get_all_shared_objects() 26 | assert len(all_shared_objects) >= 4 27 | 28 | for file in all_shared_objects: 29 | tf.load_op_library(file) 30 | 31 | 32 | if __name__ == "__main__": 33 | sys.exit(pytest.main([__file__])) 34 | -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional callbacks that conform to Keras API.""" 16 | 17 | from tensorflow_addons.callbacks.average_model_checkpoint import AverageModelCheckpoint 18 | from tensorflow_addons.callbacks.time_stopping import TimeStopping 19 | from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar 20 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # [tensorflow.org/addons](https://tensorflow.org/addons) 2 | 3 | This directory contains the source for [tensorflow.org/addons](https://tensorflow.org/addons). 4 | 5 | It comprises two main components: 6 | 7 | ## 1. Narrative Docs 8 | 9 | Any markdown or notebook files in this directory will be published to tensorflow.org/addons. 10 | 11 | `tutorials/_toc.yaml` controls the left-nav on the tutorials tab. Make sure to keep that file up to date. 12 | Notify the tensorflow/docs team if you need to major changes. 13 | 14 | The preferred formatting for TensorFlow notebooks is to use the [tensorflow/docs](https://github.com/tensorflow/docs) [`nbfmt` tool](https://github.com/tensorflow/docs/tree/master/tools/tensorflow_docs/tools). If modifying a tutorial gives you 15 | an unreadable diff use the following commands to re-apply the standard formatting: 16 | 17 | ``` 18 | pip install git+https://github.com/tensorflow/docs 19 | python -m tensorflow_docs.tools.nbfmt {path to notebook file or directory} 20 | ``` 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /tensorflow_addons/text/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | # TODO: Once TF exports symbols in a DLL we can enable parse_time_op for windows 6 | # https://github.com/tensorflow/addons/issues/782 7 | py_library( 8 | name = "text", 9 | srcs = glob(["*.py"]), 10 | data = select({ 11 | "//tensorflow_addons:windows": [ 12 | "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so", 13 | "//tensorflow_addons/testing", 14 | "//tensorflow_addons/utils", 15 | ], 16 | "//conditions:default": [ 17 | "//tensorflow_addons/custom_ops/text:_parse_time_op.so", 18 | "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so", 19 | "//tensorflow_addons/testing", 20 | "//tensorflow_addons/utils", 21 | ], 22 | }), 23 | ) 24 | 25 | py_test( 26 | name = "text_test", 27 | size = "small", 28 | srcs = glob(["tests/*"]), 29 | main = "tests/run_all_test.py", 30 | deps = [ 31 | ":text", 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /tensorflow_addons/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | config_setting( 6 | name = "windows", 7 | constraint_values = ["@bazel_tools//platforms:windows"], 8 | ) 9 | 10 | py_library( 11 | name = "tensorflow_addons", 12 | srcs = glob(["*.py"]), 13 | deps = [ 14 | "//tensorflow_addons/activations", 15 | "//tensorflow_addons/callbacks", 16 | "//tensorflow_addons/image", 17 | "//tensorflow_addons/layers", 18 | "//tensorflow_addons/losses", 19 | "//tensorflow_addons/metrics", 20 | "//tensorflow_addons/optimizers", 21 | "//tensorflow_addons/rnn", 22 | "//tensorflow_addons/seq2seq", 23 | "//tensorflow_addons/testing", 24 | "//tensorflow_addons/text", 25 | "//tensorflow_addons/utils", 26 | ], 27 | ) 28 | 29 | py_test( 30 | name = "tensorflow_addons_test", 31 | size = "small", 32 | srcs = glob(["tests/*"]), 33 | main = "tests/run_all_test.py", 34 | deps = [ 35 | ":tensorflow_addons", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /tools/releases/tf_auditwheel_patch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | set -e -x 18 | 19 | SITE_PKG_LOCATION=$(python -c "import site; print(site.getsitepackages()[0])") 20 | TF_SHARED_LIBRARY_NAME=$(grep -r TF_SHARED_LIBRARY_NAME .bazelrc | awk -F= '{print$2}') 21 | POLICY_JSON="${SITE_PKG_LOCATION}/auditwheel/policy/policy.json" 22 | sed -i "s/libresolv.so.2\"/libresolv.so.2\", $TF_SHARED_LIBRARY_NAME/g" $POLICY_JSON 23 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional RNN cells that corform to Keras API.""" 16 | 17 | from tensorflow_addons.rnn.nas_cell import NASCell 18 | from tensorflow_addons.rnn.layer_norm_lstm_cell import LayerNormLSTMCell 19 | from tensorflow_addons.rnn.layer_norm_simple_rnn_cell import LayerNormSimpleRNNCell 20 | from tensorflow_addons.rnn.esn_cell import ESNCell 21 | from tensorflow_addons.rnn.peephole_lstm_cell import PeepholeLSTMCell 22 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | 3 | ignore = 4 | # defaults flake8 ignores 5 | E121,E123,E126,E226,E24,E704,W503,W504 6 | # whitespace before ':' 7 | # https://black.readthedocs.io/en/stable/the_black_code_style.html#slices 8 | E203 9 | # missing whitespace after ',' 10 | # black takes care of that. Sometimes it may 11 | # add a comma at the end of lists. 12 | E231 13 | # Line too long 14 | # We use black, no need to enforce line length 15 | E501 16 | # lowercase ... imported as non lowercase 17 | # Useful to ignore for "import keras.backend as K" 18 | N812 19 | 20 | per-file-ignores = 21 | # imported but unused in __init__.py, that's ok. 22 | **/__init__.py:F401 23 | # import not at top okay due to TF installation check 24 | tensorflow_addons/__init__.py:F401,E402 25 | # function name should be lowercase 26 | tensorflow_addons/image/utils.py:N802 27 | tensorflow_addons/image/tests/utils_test.py:N802 28 | tensorflow_addons/image/tests/color_ops_test.py:N802 29 | tensorflow_addons/optimizers/tests/conditional_gradient_test.py:N802 30 | # variable ... in function should be lowercase 31 | tensorflow_addons/callbacks/tests/time_stopping_test.py:N806 32 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/README.md: -------------------------------------------------------------------------------- 1 | # Addons - Activations 2 | 3 | ## Contents 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/activations 5 | 6 | ## Contribution Guidelines 7 | #### Standard API 8 | In order to conform with the current API standard, all activations 9 | must: 10 | * Be a `tf.function` unless it is a straightforward call to a custom op or likely to be retraced. 11 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 12 | 13 | #### Testing Requirements 14 | * Simple unittests that demonstrate the layer is behaving as expected. 15 | * To run your `tf.functions` in eager mode and graph mode in the tests, 16 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 17 | decorator. This will run the tests twice, once normally, and once 18 | with `tf.config.run_functions_eagerly(True)`. 19 | * Add activation name to [activations_test.py](https://github.com/tensorflow/addons/tree/master/tensorflow_addons/activations/tests/activations_test.py) to test serialization. 20 | 21 | #### Documentation Requirements 22 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 23 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/README.md: -------------------------------------------------------------------------------- 1 | # Addons - RNN 2 | 3 | ## Components 4 | https://www.tensorflow.org/addons/api_docs/python/tfa/rnn 5 | 6 | ## Contribution Guidelines 7 | #### Prerequisites 8 | * For any cell based on research paper, the original paper has to be well recognized. 9 | The criteria here is >= 100 citation based on Google scholar. If the contributor feels 10 | this requirement need to be overruled, please specify the detailed justification in the 11 | PR. 12 | 13 | #### Standard API 14 | In order to conform with the current API standard, all cells must: 15 | * Inherit from either `keras.layers.AbstractRNNCell` or `keras.layers.Layer` with 16 | required properties. 17 | * Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')` 18 | 19 | #### Testing Requirements 20 | * To run your `tf.functions` in eager mode and graph mode in the tests, 21 | you can use the `@pytest.mark.usefixtures("maybe_run_functions_eagerly")` 22 | decorator. This will run the tests twice, once normally, and once 23 | with `tf.config.run_functions_eagerly(True)`. 24 | 25 | #### Documentation Requirements 26 | * Update the [CODEOWNERS file](https://github.com/tensorflow/addons/blob/master/.github/CODEOWNERS) 27 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/tanhshrink_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import tanhshrink 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_tanh(dtype): 26 | x = tf.constant([-1.0, 0.0, 1.0], dtype=dtype) 27 | expected_result = tf.constant([-0.23840582, 0.0, 0.238405825], dtype=dtype) 28 | test_utils.assert_allclose_according_to_type(tanhshrink(x), expected_result) 29 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/gelu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for GELU activation.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | from tensorflow_addons.layers.gelu import GELU 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_random(dtype): 26 | x = np.array([[0.5, 1.2, -0.3]]).astype(dtype) 27 | val = np.array([[0.345714, 1.0617027, -0.11462909]]).astype(dtype) 28 | test_utils.layer_test( 29 | GELU, kwargs={"dtype": dtype}, input_data=x, expected_output=val 30 | ) 31 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/snake_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | from tensorflow_addons.activations import snake 20 | from tensorflow_addons.utils import test_utils 21 | 22 | 23 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_activation(dtype): 26 | x = dtype(np.random.rand(2, 5)) 27 | a = dtype(np.random.randn()) 28 | expected_result = x + np.power(np.sin(a * x), 2) / a 29 | test_utils.assert_allclose_according_to_type(snake(x, a), expected_result) 30 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/lisht_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import lisht 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_lisht(dtype): 26 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 27 | expected_result = tf.constant( 28 | [1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype 29 | ) 30 | test_utils.assert_allclose_according_to_type(lisht(x), expected_result) 31 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/mish_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import mish 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_mish(dtype): 26 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 27 | expected_result = tf.constant( 28 | [-0.2525015, -0.30340144, 0.0, 0.86509836, 1.943959], dtype=dtype 29 | ) 30 | test_utils.assert_allclose_according_to_type(mish(x), expected_result) 31 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional activation functions.""" 16 | 17 | from tensorflow_addons.activations.gelu import gelu 18 | from tensorflow_addons.activations.hardshrink import hardshrink 19 | from tensorflow_addons.activations.lisht import lisht 20 | from tensorflow_addons.activations.mish import mish 21 | from tensorflow_addons.activations.softshrink import softshrink 22 | from tensorflow_addons.activations.rrelu import rrelu 23 | from tensorflow_addons.activations.snake import snake 24 | from tensorflow_addons.activations.sparsemax import sparsemax 25 | from tensorflow_addons.activations.tanhshrink import tanhshrink 26 | -------------------------------------------------------------------------------- /.venv/bin/activate.csh: -------------------------------------------------------------------------------- 1 | # This file must be used with "source bin/activate.csh" *from csh*. 2 | # You cannot run it directly. 3 | # Created by Davide Di Blasi . 4 | # Ported to Python 3.3 venv by Andrew Svetlov 5 | 6 | alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; test "\!:*" != "nondestructive" && unalias deactivate' 7 | 8 | # Unset irrelevant variables. 9 | deactivate nondestructive 10 | 11 | setenv VIRTUAL_ENV "/workspaces/addons/.venv" 12 | 13 | set _OLD_VIRTUAL_PATH="$PATH" 14 | setenv PATH "$VIRTUAL_ENV/bin:$PATH" 15 | 16 | 17 | set _OLD_VIRTUAL_PROMPT="$prompt" 18 | 19 | if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then 20 | if (".venv" != "") then 21 | set env_name = ".venv" 22 | else 23 | if (`basename "VIRTUAL_ENV"` == "__") then 24 | # special case for Aspen magic directories 25 | # see https://aspen.io/ 26 | set env_name = `basename \`dirname "$VIRTUAL_ENV"\`` 27 | else 28 | set env_name = `basename "$VIRTUAL_ENV"` 29 | endif 30 | endif 31 | set prompt = "[$env_name] $prompt" 32 | unset env_name 33 | endif 34 | 35 | alias pydoc python -m pydoc 36 | 37 | rehash 38 | -------------------------------------------------------------------------------- /tools/docker/cpu_tests.Dockerfile: -------------------------------------------------------------------------------- 1 | #syntax=docker/dockerfile:1.1.5-experimental 2 | FROM python:3.6 as build_wheel 3 | 4 | ARG TF_VERSION=2.5.0 5 | RUN pip install --default-timeout=1000 tensorflow-cpu==$TF_VERSION 6 | 7 | RUN apt-get update && apt-get install -y sudo rsync 8 | COPY tools/install_deps/install_bazelisk.sh .bazelversion ./ 9 | RUN bash install_bazelisk.sh 10 | 11 | COPY requirements.txt ./ 12 | RUN pip install -r requirements.txt 13 | 14 | COPY tools/install_deps/pytest.txt ./ 15 | RUN pip install -r pytest.txt pytest-cov 16 | 17 | COPY ./ /addons 18 | WORKDIR addons 19 | RUN python configure.py 20 | RUN pip install -e ./ 21 | RUN --mount=type=cache,id=cache_bazel,target=/root/.cache/bazel \ 22 | bash tools/install_so_files.sh 23 | RUN pytest -v -n auto --durations=25 --doctest-modules ./tensorflow_addons \ 24 | --cov=tensorflow_addons ./tensorflow_addons/ 25 | 26 | RUN bazel build --enable_runfiles build_pip_pkg 27 | RUN bazel-bin/build_pip_pkg artifacts 28 | 29 | 30 | FROM python:3.6 31 | 32 | COPY tools/install_deps/tensorflow-cpu.txt ./ 33 | RUN pip install --default-timeout=1000 -r tensorflow-cpu.txt 34 | 35 | COPY --from=0 /addons/artifacts /artifacts 36 | 37 | RUN pip install /artifacts/tensorflow_addons-*.whl 38 | 39 | # check that we didnd't forget to add a py file to 40 | # The corresponding BUILD file. 41 | # Also test that the wheel works in a fresh environment 42 | RUN python -c "import tensorflow_addons" 43 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/image/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | load("//tensorflow_addons:tensorflow_addons.bzl", "custom_op_library") 6 | 7 | custom_op_library( 8 | name = "_distort_image_ops.so", 9 | srcs = [ 10 | "cc/kernels/adjust_hsv_in_yiq_op.cc", 11 | "cc/kernels/adjust_hsv_in_yiq_op.h", 12 | "cc/ops/distort_image_ops.cc", 13 | ], 14 | cuda_srcs = [ 15 | "cc/kernels/adjust_hsv_in_yiq_op.h", 16 | "cc/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc", 17 | ], 18 | ) 19 | 20 | custom_op_library( 21 | name = "_image_ops.so", 22 | srcs = [ 23 | "cc/kernels/connected_components.cc", 24 | "cc/kernels/connected_components.h", 25 | "cc/kernels/euclidean_distance_transform_op.cc", 26 | "cc/kernels/euclidean_distance_transform_op.h", 27 | "cc/ops/image_ops.cc", 28 | ], 29 | cuda_srcs = [ 30 | "cc/kernels/euclidean_distance_transform_op.h", 31 | "cc/kernels/euclidean_distance_transform_op_gpu.cu.cc", 32 | ], 33 | ) 34 | 35 | custom_op_library( 36 | name = "_resampler_ops.so", 37 | srcs = [ 38 | "cc/kernels/resampler_ops.cc", 39 | "cc/kernels/resampler_ops.h", 40 | "cc/ops/resampler_ops.cc", 41 | ], 42 | cuda_srcs = [ 43 | "cc/kernels/resampler_ops.h", 44 | "cc/kernels/resampler_ops_gpu.cu.cc", 45 | ], 46 | ) 47 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/tests/test_utils_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import pytest 5 | import tensorflow as tf 6 | from tensorflow_addons.utils import test_utils 7 | 8 | 9 | def test_seed_is_set(): 10 | assert random.randint(0, 10000) == 6311 11 | assert np.random.randint(0, 10000) == 2732 12 | assert tf.random.uniform([], 0, 10000, dtype=tf.int64).numpy() == 9457 13 | 14 | 15 | @pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy]) 16 | def test_all_scopes(device): 17 | assert isinstance(device, str) or isinstance(device, tf.distribute.Strategy) 18 | 19 | 20 | def train_small_model(): 21 | model_input = tf.keras.layers.Input((3,)) 22 | model_output = tf.keras.layers.Dense(4)(model_input) 23 | model = tf.keras.Model(model_input, model_output) 24 | model.compile(loss="mse") 25 | 26 | x = np.random.uniform(size=(5, 3)) 27 | y = np.random.uniform(size=(5, 4)) 28 | model.fit(x, y, epochs=1) 29 | 30 | 31 | @pytest.mark.with_device([tf.distribute.MirroredStrategy]) 32 | def test_distributed_strategy(device): 33 | assert isinstance(device, tf.distribute.Strategy) 34 | train_small_model() 35 | 36 | 37 | @pytest.mark.with_device(["no_device"]) 38 | @pytest.mark.needs_gpu 39 | def test_custom_device_placement(): 40 | with tf.device(test_utils.gpus_for_testing()[0]): 41 | train_small_model() 42 | 43 | strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing()) 44 | with strategy.scope(): 45 | train_small_model() 46 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/hardshrink_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations.hardshrink import hardshrink 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_hardshrink(dtype): 26 | x = tf.constant([-2.0, -0.5, 0.0, 0.5, 2.0], dtype=dtype) 27 | expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype) 28 | test_utils.assert_allclose_according_to_type(hardshrink(x), expected_result) 29 | 30 | expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype) 31 | test_utils.assert_allclose_according_to_type( 32 | hardshrink(x, lower=-1.0, upper=1.0), expected_result 33 | ) 34 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/softshrink_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import softshrink 21 | 22 | from tensorflow_addons.utils import test_utils 23 | 24 | 25 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 26 | def test_softshrink(dtype): 27 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 28 | expected_result = tf.constant([-1.5, -0.5, 0.0, 0.5, 1.5], dtype=dtype) 29 | test_utils.assert_allclose_according_to_type(softshrink(x), expected_result) 30 | 31 | expected_result = tf.constant([-1.0, 0.0, 0.0, 0.0, 1.0], dtype=dtype) 32 | test_utils.assert_allclose_according_to_type( 33 | softshrink(x, lower=-1.0, upper=1.0), expected_result 34 | ) 35 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/snake_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Snake layer.""" 16 | 17 | import pytest 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from tensorflow_addons.layers.snake import Snake 23 | from tensorflow_addons.activations.snake import snake 24 | 25 | from tensorflow_addons.utils import test_utils 26 | 27 | 28 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 29 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 30 | def test_layer(dtype): 31 | x = np.random.rand(2, 5).astype(dtype) 32 | a = np.random.randn() 33 | val = snake(x, a) 34 | test_utils.layer_test( 35 | Snake, 36 | kwargs={"frequency_initializer": tf.constant_initializer(a), "dtype": dtype}, 37 | input_data=x, 38 | expected_output=val, 39 | ) 40 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/gelu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import gelu 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 25 | def test_gelu(dtype): 26 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 27 | expected_result = tf.constant( 28 | [-0.04540229, -0.158808, 0.0, 0.841192, 1.9545977], dtype=dtype 29 | ) 30 | test_utils.assert_allclose_according_to_type(gelu(x), expected_result) 31 | 32 | expected_result = tf.constant( 33 | [-0.04550028, -0.15865526, 0.0, 0.8413447, 1.9544997], dtype=dtype 34 | ) 35 | test_utils.assert_allclose_according_to_type(gelu(x, False), expected_result) 36 | -------------------------------------------------------------------------------- /tools/format.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from subprocess import check_call, CalledProcessError 3 | 4 | 5 | def check_bash_call(string): 6 | check_call(["bash", "-c", string]) 7 | 8 | 9 | def _run_format_and_flake8(): 10 | files_changed = False 11 | 12 | try: 13 | check_bash_call("python -m black --check ./") 14 | except CalledProcessError: 15 | check_bash_call("python -m black ./") 16 | files_changed = True 17 | 18 | try: 19 | check_bash_call("buildifier -mode=check -r .") 20 | except CalledProcessError: 21 | check_bash_call("buildifier -r .") 22 | files_changed = True 23 | 24 | # todo: find a way to check if files changed 25 | # see https://github.com/DoozyX/clang-format-lint-action for inspiration 26 | check_bash_call( 27 | "shopt -s globstar && clang-format-9 -i --style=google **/*.cc **/*.h", 28 | ) 29 | 30 | if files_changed: 31 | print("Some files have changed.") 32 | print("Please do git add and git commit again") 33 | else: 34 | print("No formatting needed.") 35 | 36 | print("Running flake8.") 37 | check_bash_call("flake8") 38 | print("Done") 39 | 40 | if files_changed: 41 | exit(1) 42 | 43 | 44 | def run_format_and_flake8(): 45 | try: 46 | _run_format_and_flake8() 47 | except CalledProcessError as error: 48 | print("Pre-commit returned exit code", error.returncode) 49 | exit(error.returncode) 50 | 51 | 52 | if __name__ == "__main__": 53 | run_format_and_flake8() 54 | -------------------------------------------------------------------------------- /tools/testing/build_and_run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # ============================================================================== 17 | # usage: bash tools/testing/build_and_run_tests.sh 18 | 19 | set -x -e 20 | 21 | python -m pip install -r tools/install_deps/pytest.txt -e ./ 22 | python ./configure.py 23 | bash tools/install_so_files.sh 24 | python -c "import tensorflow as tf; print(tf.config.list_physical_devices())" 25 | 26 | # use 10 workers if a gpu is available, otherwise, 27 | # one worker per cpu core. Kokoro has 38 cores, that'd be too much 28 | # for the gpu memory, until we change the device placement to 29 | # use multiple gpus when they are available. 30 | EXTRA_ARGS="-n 10" 31 | if ! [ -x "$(command -v nvidia-smi)" ]; then 32 | EXTRA_ARGS="-n auto" 33 | fi 34 | 35 | bazel clean 36 | python -m pytest -v --functions-durations=20 --modules-durations=5 $EXTRA_ARGS ./tensorflow_addons 37 | -------------------------------------------------------------------------------- /tensorflow_addons/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Useful extra functionality for TensorFlow maintained by SIG-addons.""" 16 | from tensorflow_addons.utils.ensure_tf_install import _check_tf_version 17 | 18 | _check_tf_version() 19 | 20 | # Local project imports 21 | from tensorflow_addons import activations 22 | from tensorflow_addons import callbacks 23 | from tensorflow_addons import image 24 | from tensorflow_addons import layers 25 | from tensorflow_addons import losses 26 | from tensorflow_addons import metrics 27 | from tensorflow_addons import optimizers 28 | from tensorflow_addons import rnn 29 | from tensorflow_addons import seq2seq 30 | from tensorflow_addons import text 31 | from tensorflow_addons import options 32 | from tensorflow_addons.register import register_all 33 | from tensorflow_addons.utils import types 34 | 35 | from tensorflow_addons.version import __version__ 36 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional metrics that conform to Keras API.""" 16 | 17 | from tensorflow_addons.metrics.cohens_kappa import CohenKappa 18 | from tensorflow_addons.metrics.f_scores import F1Score, FBetaScore 19 | from tensorflow_addons.metrics.hamming import ( 20 | HammingLoss, 21 | hamming_distance, 22 | hamming_loss_fn, 23 | ) 24 | from tensorflow_addons.metrics.utils import MeanMetricWrapper 25 | from tensorflow_addons.metrics.matthews_correlation_coefficient import ( 26 | MatthewsCorrelationCoefficient, 27 | ) 28 | from tensorflow_addons.metrics.multilabel_confusion_matrix import ( 29 | MultiLabelConfusionMatrix, 30 | ) 31 | from tensorflow_addons.metrics.r_square import RSquare 32 | from tensorflow_addons.metrics.geometric_mean import GeometricMean 33 | from tensorflow_addons.metrics.harmonic_mean import HarmonicMean 34 | from tensorflow_addons.metrics.kendalls_tau import KendallsTau 35 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tanhshrink.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_addons.utils.types import TensorLike 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | def tanhshrink(x: TensorLike) -> tf.Tensor: 23 | r"""Tanh shrink function. 24 | 25 | Applies the element-wise function: 26 | 27 | $$ 28 | \mathrm{tanhshrink}(x) = x - \tanh(x). 29 | $$ 30 | 31 | Usage: 32 | 33 | >>> x = tf.constant([-1.0, 0.0, 1.0]) 34 | >>> tfa.activations.tanhshrink(x) 35 | 36 | 37 | Args: 38 | x: A `Tensor`. Must be one of the following types: 39 | `bfloat16`, `float16`, `float32`, `float64`. 40 | Returns: 41 | A `Tensor`. Has the same type as `x`. 42 | """ 43 | x = tf.convert_to_tensor(x) 44 | return x - tf.math.tanh(x) 45 | -------------------------------------------------------------------------------- /tensorflow_addons/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Define TensorFlow Addons version information.""" 16 | 17 | # Required TensorFlow version [min, max) 18 | INCLUSIVE_MIN_TF_VERSION = "2.3.0" 19 | EXCLUSIVE_MAX_TF_VERSION = "2.6.0" 20 | 21 | # We follow Semantic Versioning (https://semver.org/) 22 | _MAJOR_VERSION = "0" 23 | _MINOR_VERSION = "14" 24 | _PATCH_VERSION = "0" 25 | 26 | # When building releases, we can update this value on the release branch to 27 | # reflect the current release candidate ('rc0', 'rc1') or, finally, the official 28 | # stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a 29 | # release branch, the current version is by default assumed to be a 30 | # 'development' version, labeled 'dev'. 31 | _VERSION_SUFFIX = "dev" 32 | 33 | # Example, '0.1.0-dev' 34 | __version__ = ".".join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) 35 | if _VERSION_SUFFIX: 36 | __version__ = "{}-{}".format(__version__, _VERSION_SUFFIX) 37 | -------------------------------------------------------------------------------- /tools/update_release_version.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | # Usage 19 | if [ $# -lt 1 ]; then 20 | echo "Usage: bash tools/update_release_version.sh " 21 | echo "e.g. bash tools/update_release_version.sh 2.3.0 2.3.1" 22 | exit 1 23 | fi 24 | 25 | last_version=${BASH_ARGV[0]} 26 | tf_version='' 27 | for ver in $@ 28 | do 29 | if [ -z $tf_version ]; then 30 | tf_version="'$ver'" 31 | else 32 | tf_version="$tf_version, '$ver'" 33 | fi 34 | done 35 | echo $tf_version 36 | echo $last_version 37 | sed -ri "s/(tf-version: \[)'.+'/\1$tf_version/g" \ 38 | .github/workflows/release.yml 39 | sed -ri "s/(tensorflow(-cpu)*(~|=)=)[0-9]+[a-zA-Z0-9_.-]+/\1$1/g" \ 40 | CONTRIBUTING.md \ 41 | tools/install_deps/tensorflow-cpu.txt \ 42 | tools/install_deps/tensorflow.txt 43 | sed -ri "s/(TF_VERSION=)\S+/\1$last_version/g" \ 44 | tools/docker/cpu_tests.Dockerfile \ 45 | tools/run_gpu_tests.sh \ 46 | tools/build_dev_container.sh 47 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional losses that conform to Keras API.""" 16 | 17 | from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss 18 | from tensorflow_addons.losses.focal_loss import ( 19 | sigmoid_focal_crossentropy, 20 | SigmoidFocalCrossEntropy, 21 | ) 22 | from tensorflow_addons.losses.giou_loss import giou_loss, GIoULoss 23 | from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss 24 | from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss 25 | from tensorflow_addons.losses.triplet import ( 26 | triplet_semihard_loss, 27 | triplet_hard_loss, 28 | TripletSemiHardLoss, 29 | TripletHardLoss, 30 | ) 31 | from tensorflow_addons.losses.quantiles import pinball_loss, PinballLoss 32 | 33 | 34 | from tensorflow_addons.losses.npairs import ( 35 | npairs_loss, 36 | NpairsLoss, 37 | npairs_multilabel_loss, 38 | NpairsMultilabelLoss, 39 | ) 40 | from tensorflow_addons.losses.kappa_loss import WeightedKappaLoss 41 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_ADDONS_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 17 | #define TENSORFLOW_ADDONS_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 18 | 19 | #include "tensorflow/core/framework/tensor_types.h" 20 | #include "tensorflow/core/platform/types.h" 21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 | 23 | namespace tensorflow { 24 | class OpKernelContext; 25 | 26 | namespace addons { 27 | 28 | namespace functor { 29 | 30 | template 31 | struct GatherTree { 32 | void operator()(OpKernelContext* ctx, const Device& d, 33 | typename TTypes::ConstTensor step_ids, 34 | typename TTypes::ConstTensor parent_ids, 35 | TTypes::ConstVec max_sequence_lengths, 36 | const T end_token, typename TTypes::Tensor beams); 37 | }; 38 | 39 | } // namespace functor 40 | } // end namespace addons 41 | } // namespace tensorflow 42 | 43 | #endif // TENSORFLOW_ADDONS_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 44 | -------------------------------------------------------------------------------- /STYLE_GUIDE.md: -------------------------------------------------------------------------------- 1 | #### C++ 2 | C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). 3 | 4 | Addons uses [clang-format](https://clang.llvm.org/docs/ClangFormat.html) 5 | to check your C/C++ changes. Sometimes you have some manually formatted 6 | code that you don’t want clang-format to touch. 7 | You can disable formatting like this: 8 | 9 | ```cpp 10 | int formatted_code; 11 | // clang-format off 12 | void unformatted_code ; 13 | // clang-format on 14 | void formatted_code_again; 15 | ``` 16 | 17 | Install Clang-format 9 with: 18 | 19 | ```bash 20 | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - 21 | sudo add-apt-repository -u 'http://apt.llvm.org/bionic/ llvm-toolchain-bionic-9 main' 22 | sudo apt install clang-format-9 23 | ``` 24 | 25 | format all with: 26 | ```bash 27 | clang-format-9 -i --style=google **/*.cc **/*.h 28 | ``` 29 | 30 | #### Python 31 | 32 | Addons uses [flake8](http://flake8.pycqa.org/en/latest/) to check pep8 compliance and 33 | code analysis. 34 | 35 | Addons use [Black](https://black.readthedocs.io/en/stable/) to format our code. 36 | The continuous integration check will fail if you do not use it. 37 | 38 | Install them with: 39 | ``` 40 | pip install flake8 black 41 | ``` 42 | 43 | Be sure to run them both before you push your commits, otherwise the CI will fail! 44 | 45 | ``` 46 | python -m black ./ 47 | python -m flake8 48 | ``` 49 | 50 | #### TensorFlow Conventions 51 | 52 | Follow the guidance in the [TensorFlow Style Guide - Conventions](https://www.tensorflow.org/community/contribute/code_style#tensorflow_conventions_and_special_uses). 53 | 54 | Please note that Addons follows the conventions of the TensorFlow library, but formats our code using [PEP8](https://www.python.org/dev/peps/pep-0008/) guidelines. 55 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/mish.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_addons.utils.types import TensorLike 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | def mish(x: TensorLike) -> tf.Tensor: 23 | r"""Mish: A Self Regularized Non-Monotonic Neural Activation Function. 24 | 25 | Computes mish activation: 26 | 27 | $$ 28 | \mathrm{mish}(x) = x \cdot \tanh(\mathrm{softplus}(x)). 29 | $$ 30 | 31 | See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). 32 | 33 | Usage: 34 | 35 | >>> x = tf.constant([1.0, 0.0, 1.0]) 36 | >>> tfa.activations.mish(x) 37 | 38 | 39 | Args: 40 | x: A `Tensor`. Must be one of the following types: 41 | `bfloat16`, `float16`, `float32`, `float64`. 42 | Returns: 43 | A `Tensor`. Has the same type as `x`. 44 | """ 45 | x = tf.convert_to_tensor(x) 46 | return x * tf.math.tanh(tf.math.softplus(x)) 47 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/lisht.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_addons.utils.types import TensorLike 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | def lisht(x: TensorLike) -> tf.Tensor: 23 | r"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function. 24 | 25 | Computes linearly scaled hyperbolic tangent (LiSHT): 26 | 27 | $$ 28 | \mathrm{lisht}(x) = x * \tanh(x). 29 | $$ 30 | 31 | See [LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function for Neural Networks](https://arxiv.org/abs/1901.05894). 32 | 33 | Usage: 34 | 35 | >>> x = tf.constant([1.0, 0.0, 1.0]) 36 | >>> tfa.activations.lisht(x) 37 | 38 | 39 | Args: 40 | x: A `Tensor`. Must be one of the following types: 41 | `bfloat16`, `float16`, `float32`, `float64`. 42 | Returns: 43 | A `Tensor`. Has the same type as `x`. 44 | """ 45 | x = tf.convert_to_tensor(x) 46 | return x * tf.math.tanh(x) 47 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/activations_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | import tensorflow as tf 18 | from tensorflow_addons import activations 19 | 20 | 21 | ALL_ACTIVATIONS = [ 22 | "gelu", 23 | "hardshrink", 24 | "lisht", 25 | "mish", 26 | "rrelu", 27 | "softshrink", 28 | "sparsemax", 29 | "tanhshrink", 30 | "snake", 31 | ] 32 | 33 | 34 | @pytest.mark.parametrize("name", ALL_ACTIVATIONS) 35 | def test_serialization(name): 36 | fn = tf.keras.activations.get("Addons>" + name) 37 | ref_fn = getattr(activations, name) 38 | assert fn == ref_fn 39 | config = tf.keras.activations.serialize(fn) 40 | fn = tf.keras.activations.deserialize(config) 41 | assert fn == ref_fn 42 | 43 | 44 | @pytest.mark.parametrize("name", ALL_ACTIVATIONS) 45 | def test_serialization_with_layers(name): 46 | layer = tf.keras.layers.Dense(3, activation=getattr(activations, name)) 47 | config = tf.keras.layers.serialize(layer) 48 | deserialized_layer = tf.keras.layers.deserialize(config) 49 | assert deserialized_layer.__class__.__name__ == layer.__class__.__name__ 50 | assert deserialized_layer.activation.__name__ == name 51 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/sparsemax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | from tensorflow_addons.activations.sparsemax import sparsemax 18 | from typeguard import typechecked 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | class Sparsemax(tf.keras.layers.Layer): 23 | """Sparsemax activation function. 24 | 25 | The output shape is the same as the input shape. 26 | 27 | See [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification](https://arxiv.org/abs/1602.02068). 28 | 29 | Args: 30 | axis: Integer, axis along which the sparsemax normalization is applied. 31 | """ 32 | 33 | @typechecked 34 | def __init__(self, axis: int = -1, **kwargs): 35 | super().__init__(**kwargs) 36 | self.supports_masking = True 37 | self.axis = axis 38 | 39 | def call(self, inputs): 40 | return sparsemax(inputs, axis=self.axis) 41 | 42 | def get_config(self): 43 | config = {"axis": self.axis} 44 | base_config = super().get_config() 45 | return {**base_config, **config} 46 | 47 | def compute_output_shape(self, input_shape): 48 | return input_shape 49 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/snake.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from tensorflow_addons.utils import types 19 | 20 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") 22 | def snake(x: types.TensorLike, frequency: types.Number = 1) -> tf.Tensor: 23 | r"""Snake activation to learn periodic functions. 24 | 25 | Computes snake activation: 26 | 27 | $$ 28 | \mathrm{snake}(x) = \frac{x + (1 - \cos(2 \cdot \mathrm{frequency} \cdot x))}{2 \cdot \mathrm{frequency}}. 29 | $$ 30 | 31 | See [Neural Networks Fail to Learn Periodic Functions and How to Fix It](https://arxiv.org/abs/2006.08195). 32 | 33 | Usage: 34 | 35 | >>> x = tf.constant([-1.0, 0.0, 1.0]) 36 | >>> tfa.activations.snake(x) 37 | 38 | 39 | Args: 40 | x: A `Tensor`. 41 | frequency: A scalar, frequency of the periodic part. 42 | Returns: 43 | A `Tensor`. Has the same type as `x`. 44 | """ 45 | x = tf.convert_to_tensor(x) 46 | frequency = tf.cast(frequency, x.dtype) 47 | 48 | return x + (1 - tf.cos(2 * frequency * x)) / (2 * frequency) 49 | -------------------------------------------------------------------------------- /tensorflow_addons/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional text-processing ops.""" 16 | 17 | # Conditional Random Field 18 | from tensorflow_addons.text import crf 19 | from tensorflow_addons.text.crf import CrfDecodeForwardRnnCell 20 | from tensorflow_addons.text.crf import crf_binary_score 21 | from tensorflow_addons.text.crf import crf_constrained_decode 22 | from tensorflow_addons.text.crf import crf_decode 23 | from tensorflow_addons.text.crf import crf_decode_backward 24 | from tensorflow_addons.text.crf import crf_decode_forward 25 | from tensorflow_addons.text.crf import crf_filtered_inputs 26 | from tensorflow_addons.text.crf import crf_forward 27 | from tensorflow_addons.text.crf import crf_log_likelihood 28 | from tensorflow_addons.text.crf import crf_log_norm 29 | from tensorflow_addons.text.crf import crf_multitag_sequence_score 30 | from tensorflow_addons.text.crf import crf_sequence_score 31 | from tensorflow_addons.text.crf import crf_unary_score 32 | from tensorflow_addons.text.crf import viterbi_decode 33 | from tensorflow_addons.text.parse_time_op import parse_time 34 | 35 | # Skip Gram Sampling 36 | from tensorflow_addons.text.skip_gram_ops import skip_gram_sample 37 | from tensorflow_addons.text.skip_gram_ops import skip_gram_sample_with_text_vocab 38 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/maxout_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Maxout layer.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | 21 | from tensorflow_addons.layers.maxout import Maxout 22 | from tensorflow_addons.utils import test_utils 23 | 24 | 25 | pytestmark = pytest.mark.usefixtures("maybe_run_functions_eagerly") 26 | 27 | 28 | def test_simple(): 29 | test_utils.layer_test(Maxout, kwargs={"num_units": 3}, input_shape=(5, 4, 2, 18)) 30 | 31 | 32 | def test_nchw(): 33 | test_utils.layer_test( 34 | Maxout, kwargs={"num_units": 4, "axis": 1}, input_shape=(2, 20, 3, 6) 35 | ) 36 | 37 | test_utils.layer_test( 38 | Maxout, kwargs={"num_units": 4, "axis": -3}, input_shape=(2, 20, 3, 6) 39 | ) 40 | 41 | 42 | def test_unknown(): 43 | inputs = np.random.random((5, 4, 2, 18)).astype("float32") 44 | test_utils.layer_test( 45 | Maxout, kwargs={"num_units": 3}, input_shape=(5, 4, 2, None), input_data=inputs 46 | ) 47 | 48 | test_utils.layer_test( 49 | Maxout, 50 | kwargs={"num_units": 3}, 51 | input_shape=(None, None, None, None), 52 | input_data=inputs, 53 | ) 54 | 55 | 56 | def test_invalid_shape(): 57 | with pytest.raises(ValueError, match="number of features"): 58 | test_utils.layer_test(Maxout, kwargs={"num_units": 3}, input_shape=(5, 4, 2, 7)) 59 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Types for typing functions signatures.""" 16 | 17 | from typing import Union, Callable, List 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | # TODO: Remove once https://github.com/tensorflow/tensorflow/issues/44613 is resolved 23 | from tensorflow.python.keras.engine import keras_tensor 24 | 25 | 26 | Number = Union[ 27 | float, 28 | int, 29 | np.float16, 30 | np.float32, 31 | np.float64, 32 | np.int8, 33 | np.int16, 34 | np.int32, 35 | np.int64, 36 | np.uint8, 37 | np.uint16, 38 | np.uint32, 39 | np.uint64, 40 | ] 41 | 42 | Initializer = Union[None, dict, str, Callable, tf.keras.initializers.Initializer] 43 | Regularizer = Union[None, dict, str, Callable, tf.keras.regularizers.Regularizer] 44 | Constraint = Union[None, dict, str, Callable, tf.keras.constraints.Constraint] 45 | Activation = Union[None, str, Callable] 46 | Optimizer = Union[tf.keras.optimizers.Optimizer, str] 47 | 48 | TensorLike = Union[ 49 | List[Union[Number, list]], 50 | tuple, 51 | Number, 52 | np.ndarray, 53 | tf.Tensor, 54 | tf.SparseTensor, 55 | tf.Variable, 56 | keras_tensor.KerasTensor, 57 | ] 58 | FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64] 59 | AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None] 60 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implements GELU activation.""" 16 | 17 | import tensorflow as tf 18 | from tensorflow_addons.activations import gelu 19 | from typeguard import typechecked 20 | 21 | 22 | @tf.keras.utils.register_keras_serializable(package="Addons") 23 | class GELU(tf.keras.layers.Layer): 24 | """Gaussian Error Linear Unit. 25 | 26 | A smoother version of ReLU generally used 27 | in the BERT or BERT architecture based models. 28 | Original paper: https://arxiv.org/abs/1606.08415 29 | 30 | Input shape: 31 | Arbitrary. Use the keyword argument `input_shape` 32 | (tuple of integers, does not include the samples axis) 33 | when using this layer as the first layer in a model. 34 | 35 | Output shape: 36 | Same shape as the input. 37 | """ 38 | 39 | @typechecked 40 | def __init__(self, approximate: bool = True, **kwargs): 41 | super().__init__(**kwargs) 42 | self.approximate = approximate 43 | self.supports_masking = True 44 | 45 | def call(self, inputs): 46 | return gelu(inputs, approximate=self.approximate) 47 | 48 | def get_config(self): 49 | config = {"approximate": self.approximate} 50 | base_config = super().get_config() 51 | return {**base_config, **config} 52 | 53 | def compute_output_shape(self, input_shape): 54 | return input_shape 55 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/stochastic_depth_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow_addons.layers.stochastic_depth import StochasticDepth 6 | from tensorflow_addons.utils import test_utils 7 | 8 | _KEEP_SEED = 1111 9 | _DROP_SEED = 2222 10 | 11 | 12 | @pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED]) 13 | @pytest.mark.parametrize("training", [True, False]) 14 | def stochastic_depth_test(seed, training): 15 | np.random.seed(seed) 16 | tf.random.set_seed(seed) 17 | 18 | survival_probability = 0.5 19 | 20 | shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) 21 | residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32) 22 | 23 | if training: 24 | if seed == _KEEP_SEED: 25 | # shortcut + residual 26 | expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32) 27 | elif seed == _DROP_SEED: 28 | # shortcut 29 | expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) 30 | else: 31 | # shortcut + p_l * residual 32 | expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32) 33 | 34 | test_utils.layer_test( 35 | StochasticDepth, 36 | kwargs={"survival_probability": survival_probability}, 37 | input_data=[shortcut, residual], 38 | expected_output=expected_output, 39 | ) 40 | 41 | 42 | @pytest.mark.usefixtures("run_with_mixed_precision_policy") 43 | def test_with_mixed_precision_policy(): 44 | policy = tf.keras.mixed_precision.experimental.global_policy() 45 | 46 | shortcut = np.asarray([[0.2, 0.1, 0.4]]) 47 | residual = np.asarray([[0.2, 0.4, 0.5]]) 48 | 49 | output = StochasticDepth()([shortcut, residual]) 50 | 51 | assert output.dtype == policy.compute_dtype 52 | 53 | 54 | def test_serialization(): 55 | stoch_depth = StochasticDepth(survival_probability=0.5) 56 | serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) 57 | new_layer = tf.keras.layers.deserialize(serialized_stoch_depth) 58 | assert stoch_depth.get_config() == new_layer.get_config() 59 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional Utilities used for tfa.optimizers.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def fit_bn(model, *args, **kwargs): 21 | """Resets batch normalization layers of model, and recalculates the 22 | statistics for each batchnorm layer by running a pass on the data. 23 | 24 | Args: 25 | model: An instance of tf.keras.Model 26 | *args, **kwargs: Params that'll be passed to `.fit` method of model 27 | """ 28 | kwargs["epochs"] = 1 29 | if not isinstance(model, tf.keras.Model): 30 | raise TypeError("model must be an instance of tf.keras.Model") 31 | 32 | if not model.built: 33 | raise ValueError("Call `fit_bn` after the model is built and trained") 34 | 35 | assign_ops = [] 36 | for layer in model.layers: 37 | if isinstance(layer, tf.keras.layers.BatchNormalization): 38 | assign_ops.extend( 39 | [ 40 | layer.moving_mean.assign(tf.zeros_like(layer.moving_mean)), 41 | layer.moving_variance.assign(tf.ones_like(layer.moving_variance)), 42 | ] 43 | ) 44 | 45 | _trainable = model.trainable 46 | _metrics = model._metrics 47 | model.trainable = False 48 | model._metrics = [] 49 | 50 | model.fit(*args, **kwargs) 51 | 52 | model.trainable = _trainable 53 | model._metrics = _metrics 54 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/snake.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implements Snake layer.""" 16 | 17 | import tensorflow as tf 18 | from typeguard import typechecked 19 | 20 | from tensorflow_addons.activations.snake import snake 21 | 22 | from tensorflow_addons.utils import types 23 | 24 | 25 | @tf.keras.utils.register_keras_serializable(package="Addons") 26 | class Snake(tf.keras.layers.Layer): 27 | """Snake layer to learn periodic functions with the trainable `frequency` scalar. 28 | 29 | See [Neural Networks Fail to Learn Periodic Functions and How to Fix It](https://arxiv.org/abs/2006.08195). 30 | 31 | Args: 32 | frequency_initializer: Initializer for the `frequency` scalar. 33 | """ 34 | 35 | @typechecked 36 | def __init__(self, frequency_initializer: types.Initializer = "ones", **kwargs): 37 | super().__init__(**kwargs) 38 | self.frequency_initializer = tf.keras.initializers.get(frequency_initializer) 39 | self.frequency = self.add_weight( 40 | initializer=frequency_initializer, trainable=True 41 | ) 42 | 43 | def call(self, inputs): 44 | return snake(inputs, self.frequency) 45 | 46 | def get_config(self): 47 | config = { 48 | "frequency_initializer": tf.keras.initializers.serialize( 49 | self.frequency_initializer 50 | ), 51 | } 52 | base_config = super().get_config() 53 | return {**base_config, **config} 54 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/harmonic_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implements HarmonicMean.""" 16 | 17 | import tensorflow as tf 18 | 19 | from typeguard import typechecked 20 | from tensorflow_addons.utils.types import AcceptableDTypes 21 | 22 | 23 | @tf.keras.utils.register_keras_serializable(package="Addons") 24 | class HarmonicMean(tf.keras.metrics.Mean): 25 | """Compute Harmonic Mean 26 | The harmonic mean is a kind of mean. It can be expressed as the reciprocal of 27 | the arithmetic mean of the reciprocals of the given set of numbers. 28 | Note: `tfa.metrics.HarmonicMean` can be used the same as `tf.keras.metrics.Mean`. 29 | Args: 30 | name: (Optional) String name of the metric instance. 31 | dtype: (Optional) Data type of the metric result. 32 | Usage: 33 | >>> metric = tfa.metrics.HarmonicMean() 34 | >>> metric.update_state([1, 4, 4]) 35 | >>> metric.result().numpy() 36 | 2.0 37 | """ 38 | 39 | @typechecked 40 | def __init__( 41 | self, name: str = "harmonic_mean", dtype: AcceptableDTypes = None, **kwargs 42 | ): 43 | super().__init__(name=name, dtype=dtype, **kwargs) 44 | 45 | def update_state(self, values, sample_weight=None) -> None: 46 | values = tf.cast(values, dtype=self.dtype) 47 | super().update_state(tf.math.reciprocal(values), sample_weight) 48 | 49 | def result(self) -> tf.Tensor: 50 | return tf.math.reciprocal_no_nan(super().result()) 51 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_ADDONS_LAYERS_KERNELS_CORRELATION_COST_OP_H_ 17 | #define TENSORFLOW_ADDONS_LAYERS_KERNELS_CORRELATION_COST_OP_H_ 18 | 19 | #include "tensorflow/core/framework/op_kernel.h" 20 | #include "tensorflow/core/util/tensor_format.h" 21 | 22 | namespace tensorflow { 23 | namespace addons { 24 | namespace functor { 25 | 26 | template 27 | struct CorrelationCostFunctor { 28 | Status operator()(OpKernelContext* context, const Tensor& input_a_t, 29 | const Tensor& input_b_t, Tensor* output_t, 30 | /* params */ 31 | int kernel_size, int max_displacement, int stride_1, 32 | int stride_2, int pad, TensorFormat data_format); 33 | }; 34 | 35 | template 36 | struct CorrelationCostGradFunctor { 37 | Status operator()(OpKernelContext* context, const Tensor& input_a_t, 38 | const Tensor& input_b_t, const Tensor& topdiff_t, 39 | Tensor* output_a_gradient_t, Tensor* output_b_gradient_t, 40 | /* params */ 41 | int kernel_size, int max_displacement, int stride_1, 42 | int stride_2, int pad, TensorFormat data_format); 43 | }; 44 | 45 | } // namespace functor 46 | } // namespace addons 47 | } // namespace tensorflow 48 | 49 | #endif // TENSORFLOW_ADDONS_LAYERS_KERNELS_CORRELATION_COST_OP_H_ 50 | -------------------------------------------------------------------------------- /tensorflow_addons/losses/giou_loss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "giou_loss.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "metadata": { 22 | "id": "XNGF7YQBloHY" 23 | }, 24 | "source": [ 25 | "import tensorflow as tf\n", 26 | "import tensorflow_addons as tfa" 27 | ], 28 | "execution_count": null, 29 | "outputs": [] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "metadata": { 34 | "id": "oD9HV-tWlQDr" 35 | }, 36 | "source": [ 37 | "# Initialize the loss function.\n", 38 | "gl = tfa.losses.GIoULoss()" 39 | ], 40 | "execution_count": null, 41 | "outputs": [] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "metadata": { 46 | "id": "rY8tF2vTltZ3" 47 | }, 48 | "source": [ 49 | "# Define the bounding box coordinates.\n", 50 | "boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])\n", 51 | "boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])" 52 | ], 53 | "execution_count": null, 54 | "outputs": [] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "9HX8Tn1HlvfR" 60 | }, 61 | "source": [ 62 | "# Calculate and print the GIoU loss.\n", 63 | "loss = gl(boxes1, boxes2)\n", 64 | "loss" 65 | ], 66 | "execution_count": null, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "metadata": { 72 | "id": "u0xvrpbWlw6Z" 73 | }, 74 | "source": [ 75 | "# Utilization with a Keras model.\n", 76 | "model = tf.keras.Model()\n", 77 | "model.compile('sgd', loss=tfa.losses.GIoULoss())" 78 | ], 79 | "execution_count": null, 80 | "outputs": [] 81 | } 82 | ] 83 | } -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/tests/time_stopping_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.keras.models import Sequential 7 | from tensorflow.keras.layers import Dense 8 | 9 | from tensorflow_addons.callbacks.time_stopping import TimeStopping 10 | 11 | 12 | class SleepLayer(tf.keras.layers.Layer): 13 | def __init__(self, secs): 14 | self.secs = secs 15 | super().__init__(dynamic=True) 16 | 17 | def call(self, inputs): 18 | time.sleep(self.secs) 19 | return inputs 20 | 21 | 22 | def get_data_and_model(secs): 23 | X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) 24 | y = np.array([[0], [1], [1], [0]]) 25 | 26 | model = Sequential() 27 | model.add(SleepLayer(secs)) 28 | model.add(Dense(1)) 29 | model.compile(loss="mean_squared_error") 30 | 31 | # In case there is some initialization going on. 32 | model.fit(X, y, epochs=1, verbose=0) 33 | return X, y, model 34 | 35 | 36 | def test_stop_at_the_right_time(): 37 | X, y, model = get_data_and_model(0.1) 38 | 39 | time_stopping = TimeStopping(2, verbose=0) 40 | history = model.fit(X, y, epochs=30, verbose=0, callbacks=[time_stopping]) 41 | 42 | assert len(history.epoch) <= 20 43 | 44 | 45 | def test_default_value(): 46 | X, y, model = get_data_and_model(0.1) 47 | 48 | time_stopping = TimeStopping() 49 | history = model.fit(X, y, epochs=15, verbose=0, callbacks=[time_stopping]) 50 | 51 | assert len(history.epoch) == 15 52 | 53 | 54 | @pytest.mark.parametrize("verbose", [0, 1]) 55 | def test_time_stopping_verbose(capsys, verbose): 56 | X, y, model = get_data_and_model(0.25) 57 | 58 | time_stopping = TimeStopping(1, verbose=verbose) 59 | 60 | capsys.readouterr() # flush the stdout/stderr buffer. 61 | history = model.fit(X, y, epochs=10, verbose=0, callbacks=[time_stopping]) 62 | fit_stdout = capsys.readouterr().out 63 | nb_epochs_run = len(history.epoch) 64 | message = "Timed stopping at epoch " + str(nb_epochs_run) 65 | if verbose: 66 | assert message in fit_stdout 67 | else: 68 | assert message not in fit_stdout 69 | assert len(history.epoch) <= 4 70 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/netvlad_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for NetVLAD layer.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | from tensorflow_addons.layers.netvlad import NetVLAD 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | pytestmark = pytest.mark.usefixtures("maybe_run_functions_eagerly") 25 | 26 | 27 | @pytest.mark.parametrize("num_clusters", [1, 4]) 28 | def test_simple(num_clusters): 29 | test_utils.layer_test( 30 | NetVLAD, 31 | kwargs={"num_clusters": num_clusters}, 32 | input_shape=(5, 4, 100), 33 | expected_output_shape=(None, num_clusters * 100), 34 | ) 35 | 36 | 37 | def test_unknown(): 38 | inputs = np.random.random((5, 4, 100)).astype("float32") 39 | test_utils.layer_test( 40 | NetVLAD, 41 | kwargs={"num_clusters": 3}, 42 | input_shape=(None, None, 100), 43 | input_data=inputs, 44 | expected_output_shape=(None, 3 * 100), 45 | ) 46 | 47 | 48 | def test_invalid_shape(): 49 | with pytest.raises(ValueError) as exception_info: 50 | test_utils.layer_test( 51 | NetVLAD, kwargs={"num_clusters": 0}, input_shape=(5, 4, 20) 52 | ) 53 | assert "`num_clusters` must be greater than 1" in str(exception_info.value) 54 | 55 | with pytest.raises(ValueError) as exception_info: 56 | test_utils.layer_test( 57 | NetVLAD, kwargs={"num_clusters": 2}, input_shape=(5, 4, 4, 20) 58 | ) 59 | assert "must have rank 3" in str(exception_info.value) 60 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/tests/cocob_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for COntinuos COin Betting (COCOB) Backprop optimizer""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | from tensorflow_addons.optimizers import COCOB 20 | 21 | 22 | def run_dense_sample(iterations, expected, optimizer): 23 | var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32) 24 | var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32) 25 | 26 | grad_0 = tf.constant([0.1, 0.2], dtype=tf.dtypes.float32) 27 | grad_1 = tf.constant([0.03, 0.04], dtype=tf.dtypes.float32) 28 | 29 | grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) 30 | 31 | for _ in range(iterations): 32 | optimizer.apply_gradients(grads_and_vars) 33 | 34 | np.testing.assert_allclose(var_0.read_value(), expected[0], atol=2e-4) 35 | np.testing.assert_allclose(var_1.read_value(), expected[1], atol=2e-4) 36 | 37 | 38 | def test_dense_sample_with_default_alpha(): 39 | run_dense_sample( 40 | iterations=10, 41 | expected=[[0.84528893, 1.845289], [2.845289, 3.845289]], 42 | optimizer=COCOB(), 43 | ) 44 | 45 | 46 | def test_dense_sample_with_custom_int_alpha(): 47 | run_dense_sample( 48 | iterations=7, 49 | expected=[[0.09346926, 1.0934693], [2.0934694, 3.0934694]], 50 | optimizer=COCOB(20), 51 | ) 52 | 53 | 54 | def test_dense_sample_with_custom_float_alpha(): 55 | run_dense_sample( 56 | iterations=5, 57 | expected=[[0.89307845, 1.8930784], [2.8930783, 3.8930783]], 58 | optimizer=COCOB(55.7), 59 | ) 60 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/hardshrink.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | from tensorflow_addons.utils.types import Number, TensorLike 18 | 19 | 20 | @tf.keras.utils.register_keras_serializable(package="Addons") 21 | def hardshrink(x: TensorLike, lower: Number = -0.5, upper: Number = 0.5) -> tf.Tensor: 22 | r"""Hard shrink function. 23 | 24 | Computes hard shrink function: 25 | 26 | $$ 27 | \mathrm{hardshrink}(x) = 28 | \begin{cases} 29 | x & \text{if } x < \text{lower} \\ 30 | x & \text{if } x > \text{upper} \\ 31 | 0 & \text{otherwise} 32 | \end{cases}. 33 | $$ 34 | 35 | Usage: 36 | 37 | >>> x = tf.constant([1.0, 0.0, 1.0]) 38 | >>> tfa.activations.hardshrink(x) 39 | 40 | 41 | Args: 42 | x: A `Tensor`. Must be one of the following types: 43 | `bfloat16`, float16`, `float32`, `float64`. 44 | lower: `float`, lower bound for setting values to zeros. 45 | upper: `float`, upper bound for setting values to zeros. 46 | Returns: 47 | A `Tensor`. Has the same type as `x`. 48 | """ 49 | if lower > upper: 50 | raise ValueError( 51 | "The value of lower is {} and should" 52 | " not be higher than the value " 53 | "variable upper, which is {} .".format(lower, upper) 54 | ) 55 | x = tf.convert_to_tensor(x) 56 | mask_lower = x < lower 57 | mask_upper = upper < x 58 | mask = tf.logical_or(mask_lower, mask_upper) 59 | mask = tf.cast(mask, x.dtype) 60 | return x * mask 61 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | #ifndef TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ 17 | #define TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ 18 | 19 | #if PLATFORM_WINDOWS 20 | #define __restrict__ __restrict 21 | #endif 22 | 23 | #include "tensorflow/core/framework/op_kernel.h" 24 | 25 | namespace tensorflow { 26 | namespace addons { 27 | namespace functor { 28 | 29 | // Helper functor for the Resampler Op in 2D 30 | template 31 | struct Resampler2DFunctor { 32 | void operator()(OpKernelContext* ctx, const Device& d, 33 | const T* __restrict__ data, const T* __restrict__ warp, 34 | T* __restrict__ output, const int batch_size, 35 | const int data_height, const int data_width, 36 | const int data_channels, const int num_sampling_points); 37 | }; 38 | 39 | // Helper functor for the Resampler Gradient Op in 2D 40 | template 41 | struct ResamplerGrad2DFunctor { 42 | void operator()(OpKernelContext* ctx, const Device& d, 43 | const T* __restrict__ data, const T* __restrict__ warp, 44 | const T* __restrict__ grad_output, T* __restrict__ grad_data, 45 | T* __restrict__ grad_warp, const int batch_size, 46 | const int data_height, const int data_width, 47 | const int data_channels, const int num_sampling_points); 48 | }; 49 | 50 | } // namespace functor 51 | } // namespace addons 52 | } // namespace tensorflow 53 | #endif // TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/sparsemax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | import pytest 18 | import numpy as np 19 | 20 | from tensorflow_addons.layers import Sparsemax 21 | from tensorflow_addons.utils import test_utils 22 | 23 | test_obs = 17 24 | 25 | 26 | def _np_sparsemax(z): 27 | z = z - np.mean(z, axis=1)[:, np.newaxis] 28 | 29 | # sort z 30 | z_sorted = np.sort(z, axis=1)[:, ::-1] 31 | 32 | # calculate k(z) 33 | z_cumsum = np.cumsum(z_sorted, axis=1) 34 | k = np.arange(1, z.shape[1] + 1) 35 | z_check = 1 + k * z_sorted > z_cumsum 36 | # use argmax to get the index by row as .nonzero() doesn't 37 | # take an axis argument. np.argmax return the first index, but the last 38 | # index is required here, use np.flip to get the last index and 39 | # `z.shape[axis]` to compensate for np.flip afterwards. 40 | k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1) 41 | 42 | # calculate tau(z) 43 | tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1] 44 | tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1) 45 | 46 | # calculate p 47 | return np.maximum(0, z - tau_z) 48 | 49 | 50 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 51 | @pytest.mark.parametrize("dtype", [np.float32, np.float64]) 52 | def test_sparsemax_layer_against_numpy(dtype): 53 | """check sparsemax kernel against numpy.""" 54 | random = np.random.RandomState(1) 55 | 56 | z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) 57 | 58 | test_utils.layer_test( 59 | Sparsemax, 60 | kwargs={"dtype": dtype}, 61 | input_data=z, 62 | expected_output=_np_sparsemax(z).astype(dtype), 63 | ) 64 | -------------------------------------------------------------------------------- /tensorflow_addons/rnn/tests/peephole_lstm_cell_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Peephole Cell.""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from tensorflow_addons.rnn import PeepholeLSTMCell 21 | 22 | 23 | def test_peephole_lstm_cell(): 24 | def _run_cell(cell_fn, **kwargs): 25 | inputs = tf.one_hot([1, 2, 3, 4], 4) 26 | cell = cell_fn(5, **kwargs) 27 | cell.build(inputs.shape) 28 | initial_state = cell.get_initial_state( 29 | inputs=inputs, batch_size=4, dtype=tf.float32 30 | ) 31 | output, _ = cell(inputs, initial_state) 32 | return output 33 | 34 | tf.random.set_seed(12345) 35 | first_implementation_output = _run_cell( 36 | PeepholeLSTMCell, 37 | kernel_initializer="ones", 38 | recurrent_activation="sigmoid", 39 | implementation=1, 40 | ) 41 | second_implementation_output = _run_cell( 42 | PeepholeLSTMCell, 43 | kernel_initializer="ones", 44 | recurrent_activation="sigmoid", 45 | implementation=2, 46 | ) 47 | expected_output = np.asarray( 48 | [ 49 | [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], 50 | [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], 51 | [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], 52 | [0.0, 0.0, 0.0, 0.0, 0.0], 53 | ], 54 | dtype=np.float32, 55 | ) 56 | np.testing.assert_allclose( 57 | first_implementation_output, second_implementation_output 58 | ) 59 | np.testing.assert_allclose( 60 | first_implementation_output, expected_output, rtol=1e-6, atol=1e-6 61 | ) 62 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/tlu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for TLU activation.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from tensorflow_addons.layers.tlu import TLU 23 | from tensorflow_addons.utils import test_utils 24 | 25 | 26 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 27 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 28 | def test_random(dtype): 29 | x = np.array([[-2.5, 0.0, 0.3]]).astype(dtype) 30 | val = np.array([[0.0, 0.0, 0.3]]).astype(dtype) 31 | test_utils.layer_test( 32 | TLU, kwargs={"dtype": dtype}, input_data=x, expected_output=val 33 | ) 34 | 35 | 36 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 37 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 38 | def test_affine(dtype): 39 | x = np.array([[-2.5, 0.0, 0.3]]).astype(dtype) 40 | val = np.array([[-1.5, 1.0, 1.3]]).astype(dtype) 41 | test_utils.layer_test( 42 | TLU, 43 | kwargs={ 44 | "affine": True, 45 | "dtype": dtype, 46 | "alpha_initializer": "ones", 47 | "tau_initializer": "ones", 48 | }, 49 | input_data=x, 50 | expected_output=val, 51 | ) 52 | 53 | 54 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 55 | def test_serialization(dtype): 56 | tlu = TLU( 57 | affine=True, alpha_initializer="ones", tau_initializer="ones", dtype=dtype 58 | ) 59 | serialized_tlu = tf.keras.layers.serialize(tlu) 60 | new_layer = tf.keras.layers.deserialize(serialized_tlu) 61 | assert tlu.get_config() == new_layer.get_config() 62 | -------------------------------------------------------------------------------- /tensorflow_addons/utils/tests/keras_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Keras utils.""" 16 | 17 | import sys 18 | 19 | import pytest 20 | import tensorflow as tf 21 | 22 | from tensorflow_addons.utils import keras_utils 23 | 24 | 25 | def test_normalize_data_format(): 26 | assert keras_utils.normalize_data_format("Channels_Last") == "channels_last" 27 | assert keras_utils.normalize_data_format("CHANNELS_FIRST") == "channels_first" 28 | 29 | with pytest.raises(ValueError, match="The `data_format` argument must be one of"): 30 | keras_utils.normalize_data_format("invalid") 31 | 32 | 33 | def test_normalize_tuple(): 34 | assert (2, 2, 2) == keras_utils.normalize_tuple(2, n=3, name="strides") 35 | assert (2, 1, 2) == keras_utils.normalize_tuple((2, 1, 2), n=3, name="strides") 36 | 37 | with pytest.raises(ValueError): 38 | keras_utils.normalize_tuple((2, 1), n=3, name="strides") 39 | 40 | with pytest.raises(TypeError): 41 | keras_utils.normalize_tuple(None, n=3, name="strides") 42 | 43 | 44 | def test_standard_cell(): 45 | keras_utils.assert_like_rnncell("cell", tf.keras.layers.LSTMCell(10)) 46 | 47 | 48 | def test_non_cell(): 49 | with pytest.raises(TypeError): 50 | keras_utils.assert_like_rnncell("cell", tf.keras.layers.Dense(10)) 51 | 52 | 53 | def test_custom_cell(): 54 | class CustomCell(tf.keras.layers.AbstractRNNCell): 55 | @property 56 | def output_size(self): 57 | raise ValueError("assert_like_rnncell should not run code") 58 | 59 | keras_utils.assert_like_rnncell("cell", CustomCell()) 60 | 61 | 62 | if __name__ == "__main__": 63 | sys.exit(pytest.main([__file__])) 64 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional layers that conform to Keras API.""" 16 | 17 | from tensorflow_addons.layers.adaptive_pooling import ( 18 | AdaptiveAveragePooling1D, 19 | AdaptiveMaxPooling1D, 20 | AdaptiveAveragePooling2D, 21 | AdaptiveMaxPooling2D, 22 | AdaptiveAveragePooling3D, 23 | AdaptiveMaxPooling3D, 24 | ) 25 | from tensorflow_addons.layers.gelu import GELU 26 | from tensorflow_addons.layers.max_unpooling_2d import MaxUnpooling2D 27 | from tensorflow_addons.layers.maxout import Maxout 28 | from tensorflow_addons.layers.multihead_attention import MultiHeadAttention 29 | from tensorflow_addons.layers.normalizations import FilterResponseNormalization 30 | from tensorflow_addons.layers.normalizations import GroupNormalization 31 | from tensorflow_addons.layers.normalizations import InstanceNormalization 32 | from tensorflow_addons.layers.optical_flow import CorrelationCost 33 | from tensorflow_addons.layers.poincare import PoincareNormalize 34 | from tensorflow_addons.layers.polynomial import PolynomialCrossing 35 | from tensorflow_addons.layers.snake import Snake 36 | from tensorflow_addons.layers.sparsemax import Sparsemax 37 | from tensorflow_addons.layers.spectral_normalization import SpectralNormalization 38 | from tensorflow_addons.layers.spatial_pyramid_pooling import SpatialPyramidPooling2D 39 | from tensorflow_addons.layers.tlu import TLU 40 | from tensorflow_addons.layers.wrappers import WeightNormalization 41 | from tensorflow_addons.layers.esn import ESN 42 | from tensorflow_addons.layers.stochastic_depth import StochasticDepth 43 | from tensorflow_addons.layers.noisy_dense import NoisyDense 44 | from tensorflow_addons.layers.crf import CRF 45 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/softshrink.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | from tensorflow_addons.utils.types import Number, TensorLike 18 | 19 | 20 | @tf.keras.utils.register_keras_serializable(package="Addons") 21 | def softshrink(x: TensorLike, lower: Number = -0.5, upper: Number = 0.5) -> tf.Tensor: 22 | r"""Soft shrink function. 23 | 24 | Computes soft shrink function: 25 | 26 | $$ 27 | \mathrm{softshrink}(x) = 28 | \begin{cases} 29 | x - \mathrm{lower} & \text{if } x < \mathrm{lower} \\ 30 | x - \mathrm{upper} & \text{if } x > \mathrm{upper} \\ 31 | 0 & \text{otherwise} 32 | \end{cases}. 33 | $$ 34 | 35 | Usage: 36 | 37 | >>> x = tf.constant([-1.0, 0.0, 1.0]) 38 | >>> tfa.activations.softshrink(x) 39 | 40 | 41 | Args: 42 | x: A `Tensor`. Must be one of the following types: 43 | `bfloat16`, `float16`, `float32`, `float64`. 44 | lower: `float`, lower bound for setting values to zeros. 45 | upper: `float`, upper bound for setting values to zeros. 46 | Returns: 47 | A `Tensor`. Has the same type as `x`. 48 | """ 49 | if lower > upper: 50 | raise ValueError( 51 | "The value of lower is {} and should" 52 | " not be higher than the value " 53 | "variable upper, which is {} .".format(lower, upper) 54 | ) 55 | x = tf.convert_to_tensor(x) 56 | values_below_lower = tf.where(x < lower, x - lower, 0) 57 | values_above_upper = tf.where(upper < x, x - upper, 0) 58 | return values_below_lower + values_above_upper 59 | -------------------------------------------------------------------------------- /tensorflow_addons/tensorflow_addons.bzl: -------------------------------------------------------------------------------- 1 | load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI") 2 | load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda", "if_cuda_is_configured") 3 | 4 | def custom_op_library( 5 | name, 6 | srcs = [], 7 | cuda_srcs = [], 8 | deps = [], 9 | cuda_deps = [], 10 | copts = [], 11 | **kwargs): 12 | deps = deps + [ 13 | "@local_config_tf//:libtensorflow_framework", 14 | "@local_config_tf//:tf_header_lib", 15 | ] 16 | 17 | if cuda_srcs: 18 | copts = copts + if_cuda(["-DGOOGLE_CUDA=1"]) 19 | cuda_copts = copts + if_cuda_is_configured([ 20 | "-x cuda", 21 | "-nvcc_options=relaxed-constexpr", 22 | "-nvcc_options=ftz=true", 23 | ]) 24 | cuda_deps = deps + if_cuda_is_configured(cuda_deps) + if_cuda_is_configured([ 25 | "@local_config_cuda//cuda:cuda_headers", 26 | "@local_config_cuda//cuda:cudart_static", 27 | ]) 28 | basename = name.split(".")[0] 29 | native.cc_library( 30 | name = basename + "_gpu", 31 | srcs = cuda_srcs, 32 | deps = cuda_deps, 33 | copts = cuda_copts, 34 | alwayslink = 1, 35 | **kwargs 36 | ) 37 | deps = deps + if_cuda_is_configured([":" + basename + "_gpu"]) 38 | 39 | copts = copts + select({ 40 | "//tensorflow_addons:windows": [ 41 | "/DEIGEN_STRONG_INLINE=inline", 42 | "-DTENSORFLOW_MONOLITHIC_BUILD", 43 | "/D_USE_MATH_DEFINES", 44 | "/DPLATFORM_WINDOWS", 45 | "/DEIGEN_HAS_C99_MATH", 46 | "/DTENSORFLOW_USE_EIGEN_THREADPOOL", 47 | "/DEIGEN_AVOID_STL_ARRAY", 48 | "/Iexternal/gemmlowp", 49 | "/wd4018", 50 | "/wd4577", 51 | "/DNOGDI", 52 | "/UTF_COMPILE_LIBRARY", 53 | ], 54 | "//conditions:default": ["-pthread", "-std=c++11", D_GLIBCXX_USE_CXX11_ABI], 55 | }) 56 | 57 | native.cc_binary( 58 | name = name, 59 | srcs = srcs, 60 | copts = copts, 61 | linkshared = 1, 62 | features = select({ 63 | "//tensorflow_addons:windows": ["windows_export_all_symbols"], 64 | "//conditions:default": [], 65 | }), 66 | deps = deps, 67 | **kwargs 68 | ) 69 | -------------------------------------------------------------------------------- /tensorflow_addons/testing/tests/serialization_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tensorflow.keras.metrics import MeanAbsoluteError, TrueNegatives, Metric 6 | from tensorflow_addons.testing.serialization import check_metric_serialization 7 | 8 | 9 | def test_check_metric_serialization_mae(): 10 | check_metric_serialization(MeanAbsoluteError(), (2, 2), (2, 2)) 11 | check_metric_serialization(MeanAbsoluteError(name="hello"), (2, 2), (2, 2)) 12 | check_metric_serialization(MeanAbsoluteError(), (2, 2, 2), (2, 2, 2)) 13 | check_metric_serialization(MeanAbsoluteError(), (2, 2, 2), (2, 2, 2), (2, 2, 1)) 14 | 15 | 16 | def get_random_booleans(): 17 | return np.random.uniform(0, 2, size=(2, 2)) 18 | 19 | 20 | def test_check_metric_serialization_true_negative(): 21 | check_metric_serialization( 22 | TrueNegatives(0.8), 23 | np.random.uniform(0, 2, size=(2, 2)).astype(np.bool), 24 | np.random.uniform(0, 1, size=(2, 2)).astype(np.float32), 25 | ) 26 | 27 | 28 | class MyDummyMetric(Metric): 29 | def __init__(self, stuff, name): 30 | super().__init__(name) 31 | self.stuff = stuff 32 | 33 | def update_state(self, y_true, y_pred, sample_weights): 34 | pass 35 | 36 | def get_config(self): 37 | return super().get_config() 38 | 39 | def result(self): 40 | return 3 41 | 42 | 43 | def test_missing_arg(): 44 | with pytest.raises(KeyError) as exception_info: 45 | check_metric_serialization(MyDummyMetric("dodo", "dada"), (2,), (2,)) 46 | 47 | assert "stuff" in str(exception_info.value) 48 | 49 | 50 | class MyOtherDummyMetric(Metric): 51 | def __init__(self, to_add, name=None, dtype=None): 52 | super().__init__(name, dtype) 53 | self.to_add = to_add 54 | self.sum_of_y_pred = self.add_weight(name="my_sum", initializer="zeros") 55 | 56 | def update_state(self, y_true, y_pred, sample_weights=None): 57 | self.sum_of_y_pred.assign_add(tf.math.reduce_sum(y_pred) + self.to_add) 58 | 59 | def get_config(self): 60 | config = {"to_add": self.to_add + 1} 61 | config.update(super().get_config()) 62 | return config 63 | 64 | def result(self): 65 | return self.sum_of_y_pred 66 | 67 | 68 | def test_wrong_serialization(): 69 | with pytest.raises(AssertionError): 70 | check_metric_serialization(MyOtherDummyMetric(5), (2,), (2,)) 71 | -------------------------------------------------------------------------------- /.tours/giou-loss-overview.tour: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://aka.ms/codetour-schema", 3 | "title": "GIoU Loss Overview", 4 | "steps": [ 5 | { 6 | "file": "giou_loss.py", 7 | "description": "GIoU loss, or generalized intersection over union, is a metric and a loss for bounding box regression. It is the most popular evaluation metric for tasks such as segmentation, object detection, and tracking. Object detection consists of two sub-tasks: localization, which is determining the location of an object in an image; and classification, which is assigning a class to that object. The goal of localization in object detection is to draw a 2D bounding box around the objects in the scene.", 8 | "line": 95, 9 | "title": "What is GIoU loss?" 10 | }, 11 | { 12 | "file": "giou_loss.py", 13 | "description": "Define the bounding boxes: one for a prediction, one for ground truth.", 14 | "line": 108, 15 | "title": "Define the bounding boxes." 16 | }, 17 | { 18 | "file": "giou_loss.py", 19 | "description": "Calculate the area of the two bounding boxes.", 20 | "line": 114, 21 | "title": "Calculate the area of the two bounding boxes." 22 | }, 23 | { 24 | "file": "giou_loss.py", 25 | "description": "Calculate the intersection between the two bounding boxes.", 26 | "line": 117, 27 | "title": "Calculate the intersection between the two bounding boxes." 28 | }, 29 | { 30 | "file": "giou_loss.py", 31 | "description": "Find the coordinates of the smallest enclosing bounding box, $B^c$", 32 | "line": 125, 33 | "title": "Find the coordinates of the smallest enclosing bounding box." 34 | }, 35 | { 36 | "file": "giou_loss.py", 37 | "description": "Calculate the IoU.", 38 | "line": 126, 39 | "title": "Calculate the IoU." 40 | }, 41 | { 42 | "file": "giou_loss.py", 43 | "description": "Calculate the area of the smallest enclosing bounding box.", 44 | "line": 130, 45 | "title": "Calculate the area of the smallest enclosing bounding box." 46 | }, 47 | { 48 | "file": "giou_loss.py", 49 | "description": "Calculate the GIoU.", 50 | "line": 137, 51 | "title": "Calculate the GIoU." 52 | }, 53 | { 54 | "file": "giou_loss.py", 55 | "description": "Return the GIoU loss for the input boxes.", 56 | "line": 92, 57 | "title": "Return the GIoU loss for the input boxes." 58 | } 59 | ] 60 | } 61 | -------------------------------------------------------------------------------- /.venv/bin/activate: -------------------------------------------------------------------------------- 1 | # This file must be used with "source bin/activate" *from bash* 2 | # you cannot run it directly 3 | 4 | deactivate () { 5 | # reset old environment variables 6 | if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then 7 | PATH="${_OLD_VIRTUAL_PATH:-}" 8 | export PATH 9 | unset _OLD_VIRTUAL_PATH 10 | fi 11 | if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then 12 | PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}" 13 | export PYTHONHOME 14 | unset _OLD_VIRTUAL_PYTHONHOME 15 | fi 16 | 17 | # This should detect bash and zsh, which have a hash command that must 18 | # be called to get it to forget past commands. Without forgetting 19 | # past commands the $PATH changes we made may not be respected 20 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then 21 | hash -r 22 | fi 23 | 24 | if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then 25 | PS1="${_OLD_VIRTUAL_PS1:-}" 26 | export PS1 27 | unset _OLD_VIRTUAL_PS1 28 | fi 29 | 30 | unset VIRTUAL_ENV 31 | if [ ! "${1:-}" = "nondestructive" ] ; then 32 | # Self destruct! 33 | unset -f deactivate 34 | fi 35 | } 36 | 37 | # unset irrelevant variables 38 | deactivate nondestructive 39 | 40 | VIRTUAL_ENV="/workspaces/addons/.venv" 41 | export VIRTUAL_ENV 42 | 43 | _OLD_VIRTUAL_PATH="$PATH" 44 | PATH="$VIRTUAL_ENV/bin:$PATH" 45 | export PATH 46 | 47 | # unset PYTHONHOME if set 48 | # this will fail if PYTHONHOME is set to the empty string (which is bad anyway) 49 | # could use `if (set -u; : $PYTHONHOME) ;` in bash 50 | if [ -n "${PYTHONHOME:-}" ] ; then 51 | _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}" 52 | unset PYTHONHOME 53 | fi 54 | 55 | if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then 56 | _OLD_VIRTUAL_PS1="${PS1:-}" 57 | if [ "x(.venv) " != x ] ; then 58 | PS1="(.venv) ${PS1:-}" 59 | else 60 | if [ "`basename \"$VIRTUAL_ENV\"`" = "__" ] ; then 61 | # special case for Aspen magic directories 62 | # see https://aspen.io/ 63 | PS1="[`basename \`dirname \"$VIRTUAL_ENV\"\``] $PS1" 64 | else 65 | PS1="(`basename \"$VIRTUAL_ENV\"`)$PS1" 66 | fi 67 | fi 68 | export PS1 69 | fi 70 | 71 | # This should detect bash and zsh, which have a hash command that must 72 | # be called to get it to forget past commands. Without forgetting 73 | # past commands the $PATH changes we made may not be respected 74 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then 75 | hash -r 76 | fi 77 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/text/cc/ops/skip_gram_ops.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow/core/framework/op.h" 17 | #include "tensorflow/core/framework/shape_inference.h" 18 | 19 | namespace tensorflow { 20 | namespace addons { 21 | REGISTER_OP("Addons>SkipGramGenerateCandidates") 22 | .Input("input_tensor: T") 23 | .Input("min_skips: int32") 24 | .Input("max_skips: int32") 25 | .Input("start: int32") 26 | .Input("limit: int32") 27 | .Input("emit_self_as_target: bool") 28 | .Output("tokens: T") 29 | .Output("labels: T") 30 | .Attr("T: type") 31 | // The seed attributes are needed by GuardedPhiloxRandom 32 | .Attr("seed: int = 0") 33 | .Attr("seed2: int = 0") 34 | .SetIsStateful() 35 | .SetShapeFn([](shape_inference::InferenceContext* c) { 36 | shape_inference::ShapeHandle unused; 37 | // input_tensor must be of rank-1. 38 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 39 | // All other args must be scalar. 40 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 41 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 42 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 43 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 44 | 45 | // Due to possible randomness in selecting skips, we only know that the 46 | // outputs will be of rank-1, but not their sizes. 47 | c->set_output(0, c->Vector(c->UnknownDim())); 48 | c->set_output(1, c->Vector(c->UnknownDim())); 49 | return Status::OK(); 50 | }) 51 | .Doc(R"doc( 52 | Generates skip-gram token and label paired Tensors from the input tensor. 53 | See docs for the public-facing skip_gram_sample() Python op for more details. 54 | )doc"); 55 | } // end namespace addons 56 | } // namespace tensorflow -------------------------------------------------------------------------------- /tensorflow_addons/callbacks/time_stopping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Callback that stops training when a specified amount of time has passed.""" 16 | 17 | import datetime 18 | import time 19 | from typeguard import typechecked 20 | 21 | import tensorflow as tf 22 | from tensorflow.keras.callbacks import Callback 23 | 24 | 25 | @tf.keras.utils.register_keras_serializable(package="Addons") 26 | class TimeStopping(Callback): 27 | """Stop training when a specified amount of time has passed. 28 | 29 | Args: 30 | seconds: maximum amount of time before stopping. 31 | Defaults to 86400 (1 day). 32 | verbose: verbosity mode. Defaults to 0. 33 | """ 34 | 35 | @typechecked 36 | def __init__(self, seconds: int = 86400, verbose: int = 0): 37 | super().__init__() 38 | 39 | self.seconds = seconds 40 | self.verbose = verbose 41 | self.stopped_epoch = None 42 | 43 | def on_train_begin(self, logs=None): 44 | self.stopping_time = time.time() + self.seconds 45 | 46 | def on_epoch_end(self, epoch, logs={}): 47 | if time.time() >= self.stopping_time: 48 | self.model.stop_training = True 49 | self.stopped_epoch = epoch 50 | 51 | def on_train_end(self, logs=None): 52 | if self.stopped_epoch is not None and self.verbose > 0: 53 | formatted_time = datetime.timedelta(seconds=self.seconds) 54 | msg = "Timed stopping at epoch {} after training for {}".format( 55 | self.stopped_epoch + 1, formatted_time 56 | ) 57 | print(msg) 58 | 59 | def get_config(self): 60 | config = { 61 | "seconds": self.seconds, 62 | "verbose": self.verbose, 63 | } 64 | 65 | base_config = super().get_config() 66 | return {**base_config, **config} 67 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/tests/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import inspect 16 | 17 | from tensorflow.keras.metrics import Metric 18 | from tensorflow_addons import metrics 19 | 20 | 21 | def test_update_state_signature(): 22 | public_params = ["sample_weight"] 23 | params_comb = [["y_true", "y_pred"], ["values"]] 24 | for name, obj in inspect.getmembers(metrics): 25 | if inspect.isclass(obj) and issubclass(obj, Metric): 26 | check_update_state_signature(obj, public_params, params_comb) 27 | 28 | 29 | def check_update_state_signature(metric_class, public_params, case_list): 30 | error_msg = ( 31 | "Class {} is missing the parameter {} in the `update_state` " 32 | "method. If the method doesn't use this argument, declare " 33 | "it anyway and raise a UserWarning if it is " 34 | "not None." 35 | ) 36 | 37 | update_state_signature = inspect.signature(metric_class.update_state) 38 | 39 | for expected_param in public_params: 40 | if expected_param not in update_state_signature.parameters.keys(): 41 | raise ValueError(error_msg.format(metric_class.__name__, expected_param)) 42 | 43 | missing_params = [] 44 | for case in case_list: 45 | case_miss_params = [] 46 | case_check = True 47 | for expected_param in case: 48 | if expected_param not in update_state_signature.parameters.keys(): 49 | case_miss_params.append(expected_param) 50 | case_check = False 51 | break 52 | if case_check: 53 | return 54 | missing_params.append(case_miss_params) 55 | missing_params = [", ".join(p) for p in missing_params] 56 | missing_params = " or ".join(missing_params) 57 | raise ValueError(error_msg.format(metric_class.__name__, missing_params)) 58 | -------------------------------------------------------------------------------- /tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow/core/framework/common_shape_fns.h" 17 | #include "tensorflow/core/framework/op.h" 18 | #include "tensorflow/core/framework/shape_inference.h" 19 | 20 | namespace tensorflow { 21 | namespace addons { 22 | 23 | using shape_inference::DimensionHandle; 24 | using shape_inference::InferenceContext; 25 | using shape_inference::ShapeHandle; 26 | 27 | // -------------------------------------------------------------------------- 28 | REGISTER_OP("Addons>Resampler") 29 | .Input("data: T") 30 | .Input("warp: T") 31 | .Output("output: T") 32 | .Attr("T: {half, float, double}") 33 | .SetShapeFn([](InferenceContext* c) { 34 | ShapeHandle data; 35 | ShapeHandle warp; 36 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data)); 37 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &warp)); 38 | 39 | ShapeHandle output; // will be warp[:-1] + [data[-1]] 40 | TF_RETURN_IF_ERROR(c->Subshape(warp, 0, -1, &output)); 41 | TF_RETURN_IF_ERROR( 42 | c->Concatenate(output, c->Vector(c->Dim(data, -1)), &output)); 43 | 44 | c->set_output(0, output); 45 | return Status::OK(); 46 | }) 47 | .Doc(R"doc(Resampler op.)doc"); 48 | 49 | // -------------------------------------------------------------------------- 50 | REGISTER_OP("Addons>ResamplerGrad") 51 | .Input("data: T") 52 | .Input("warp: T") 53 | .Input("grad_output: T") 54 | .Output("grad_data: T") 55 | .Output("grad_warp: T") 56 | .Attr("T: {half, float, double}") 57 | .SetShapeFn([](InferenceContext* c) { 58 | c->set_output(0, c->input(0)); 59 | c->set_output(1, c->input(1)); 60 | return Status::OK(); 61 | }) 62 | .Doc(R"doc(Resampler Grad op.)doc"); 63 | 64 | } // namespace addons 65 | } // namespace tensorflow -------------------------------------------------------------------------------- /tensorflow_addons/losses/tests/metric_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for metric learning.""" 16 | 17 | 18 | import pytest 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow_addons.losses.metric_learning import pairwise_distance 22 | 23 | 24 | def test_zero_distance(): 25 | """Test that equal embeddings have a pairwise distance of 0.""" 26 | equal_embeddings = tf.constant([[1.0, 0.5], [1.0, 0.5]]) 27 | 28 | distances = pairwise_distance(equal_embeddings, squared=False) 29 | np.testing.assert_allclose(tf.math.reduce_sum(distances), 0, 1e-6, 1e-6) 30 | 31 | 32 | def test_positive_distances(): 33 | """Test that the pairwise distances are always positive.""" 34 | 35 | # Create embeddings very close to each other in [1.0 - 2e-7, 1.0 + 2e-7] 36 | # This will encourage errors in the computation 37 | embeddings = 1.0 + 2e-7 * tf.random.uniform([64, 6], dtype=tf.float32) 38 | distances = pairwise_distance(embeddings, squared=False) 39 | assert np.all(distances >= 0) 40 | 41 | 42 | def test_correct_distance(): 43 | """Compare against numpy caluclation.""" 44 | tf_embeddings = tf.constant([[0.5, 0.5], [1.0, 1.0]]) 45 | 46 | expected_distance = np.array([[0, np.sqrt(2) / 2], [np.sqrt(2) / 2, 0]]) 47 | 48 | distances = pairwise_distance(tf_embeddings, squared=False) 49 | np.testing.assert_allclose(expected_distance, distances, 1e-6, 1e-6) 50 | 51 | 52 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 53 | def test_correct_distance_squared(): 54 | """Compare against numpy caluclation for squared distances.""" 55 | tf_embeddings = tf.constant([[0.5, 0.5], [1.0, 1.0]]) 56 | 57 | expected_distance = np.array([[0, 0.5], [0.5, 0]]) 58 | 59 | distances = pairwise_distance(tf_embeddings, squared=True) 60 | np.testing.assert_allclose(expected_distance, distances, 1e-6, 1e-6) 61 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional optimizers that conform to Keras API.""" 16 | 17 | from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper 18 | from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient 19 | from tensorflow_addons.optimizers.cyclical_learning_rate import CyclicalLearningRate 20 | from tensorflow_addons.optimizers.cyclical_learning_rate import ( 21 | TriangularCyclicalLearningRate, 22 | ) 23 | from tensorflow_addons.optimizers.cyclical_learning_rate import ( 24 | Triangular2CyclicalLearningRate, 25 | ) 26 | from tensorflow_addons.optimizers.cyclical_learning_rate import ( 27 | ExponentialCyclicalLearningRate, 28 | ) 29 | from tensorflow_addons.optimizers.discriminative_layer_training import ( 30 | MultiOptimizer, 31 | ) 32 | from tensorflow_addons.optimizers.lamb import LAMB 33 | from tensorflow_addons.optimizers.lazy_adam import LazyAdam 34 | from tensorflow_addons.optimizers.lookahead import Lookahead 35 | from tensorflow_addons.optimizers.moving_average import MovingAverage 36 | from tensorflow_addons.optimizers.novograd import NovoGrad 37 | from tensorflow_addons.optimizers.proximal_adagrad import ProximalAdagrad 38 | from tensorflow_addons.optimizers.rectified_adam import RectifiedAdam 39 | from tensorflow_addons.optimizers.stochastic_weight_averaging import SWA 40 | from tensorflow_addons.optimizers.weight_decay_optimizers import AdamW 41 | from tensorflow_addons.optimizers.weight_decay_optimizers import SGDW 42 | from tensorflow_addons.optimizers.weight_decay_optimizers import ( 43 | extend_with_decoupled_weight_decay, 44 | ) 45 | from tensorflow_addons.optimizers.weight_decay_optimizers import ( 46 | DecoupledWeightDecayExtension, 47 | ) 48 | from tensorflow_addons.optimizers.yogi import Yogi 49 | from tensorflow_addons.optimizers.cocob import COCOB 50 | -------------------------------------------------------------------------------- /tensorflow_addons/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional image manipulation ops.""" 16 | 17 | from tensorflow_addons.image.distort_image_ops import adjust_hsv_in_yiq 18 | from tensorflow_addons.image.compose_ops import blend 19 | from tensorflow_addons.image.color_ops import equalize 20 | from tensorflow_addons.image.color_ops import sharpness 21 | from tensorflow_addons.image.connected_components import connected_components 22 | from tensorflow_addons.image.cutout_ops import cutout 23 | from tensorflow_addons.image.dense_image_warp import dense_image_warp 24 | from tensorflow_addons.image.distance_transform import euclidean_dist_transform 25 | from tensorflow_addons.image.dense_image_warp import interpolate_bilinear 26 | from tensorflow_addons.image.interpolate_spline import interpolate_spline 27 | from tensorflow_addons.image.filters import gaussian_filter2d 28 | from tensorflow_addons.image.filters import mean_filter2d 29 | from tensorflow_addons.image.filters import median_filter2d 30 | from tensorflow_addons.image.cutout_ops import random_cutout 31 | from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq 32 | from tensorflow_addons.image.resampler_ops import resampler 33 | from tensorflow_addons.image.transform_ops import rotate 34 | from tensorflow_addons.image.transform_ops import shear_x 35 | from tensorflow_addons.image.transform_ops import shear_y 36 | from tensorflow_addons.image.sparse_image_warp import sparse_image_warp 37 | from tensorflow_addons.image.transform_ops import compose_transforms 38 | from tensorflow_addons.image.transform_ops import angles_to_projective_transforms 39 | from tensorflow_addons.image.transform_ops import transform 40 | from tensorflow_addons.image.translate_ops import translate 41 | from tensorflow_addons.image.translate_ops import translate_xy 42 | from tensorflow_addons.image.translate_ops import translations_to_projective_transforms 43 | -------------------------------------------------------------------------------- /tensorflow_addons/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import warnings 4 | import traceback 5 | 6 | try: 7 | _TF_ADDONS_PY_OPS = bool(int(os.environ["TF_ADDONS_PY_OPS"])) 8 | except KeyError: 9 | if platform.system() == "Linux": 10 | _TF_ADDONS_PY_OPS = False 11 | else: 12 | _TF_ADDONS_PY_OPS = True 13 | 14 | _FALLBACK_WARNING_TEMPLATE = """{} 15 | 16 | The {} C++/CUDA custom op could not be loaded. 17 | For this reason, Addons will fallback to an implementation written 18 | in Python with public TensorFlow ops. There worst you might experience with 19 | this is a moderate slowdown on GPU. There can be multiple 20 | reason for this loading error, one of them may be an ABI incompatibility between 21 | the TensorFlow installed on your system and the TensorFlow used to compile 22 | TensorFlow Addons' custom ops. The stacktrace generated when loading the 23 | shared object file was displayed above. 24 | 25 | If you want this warning to disappear, either make sure the TensorFlow installed 26 | is compatible with this version of Addons, or tell TensorFlow Addons to 27 | prefer using Python implementations and not custom C++/CUDA ones. You can do that 28 | by setting the enviornment variable `TF_ADDONS_PY_OPS=1`: 29 | ```bash 30 | TF_ADDONS_PY_OPS=1 python my_script.py 31 | ``` 32 | or run `tfa.options.disable_custom_kernel()` in your code, after your imports: 33 | ```python 34 | import tensorflow_addons as tfa 35 | import ... 36 | import ... 37 | 38 | tfa.options.disable_custom_kernel() 39 | ``` 40 | """ 41 | 42 | 43 | def warn_fallback(op_name): 44 | warning_msg = _FALLBACK_WARNING_TEMPLATE.format(traceback.format_exc(), op_name) 45 | warnings.warn(warning_msg, RuntimeWarning) 46 | disable_custom_kernel() 47 | 48 | 49 | def enable_custom_kernel(): 50 | """Prefer custom C++/CUDA kernel to pure python operations. 51 | 52 | Enable using custom C++/CUDA kernel instead of pure python operations. 53 | It has the same effect as setting environment variable `TF_ADDONS_PY_OPS=0`. 54 | """ 55 | global _TF_ADDONS_PY_OPS 56 | _TF_ADDONS_PY_OPS = False 57 | 58 | 59 | def disable_custom_kernel(): 60 | """Prefer pure python operations to custom C++/CUDA kernel. 61 | 62 | Disable using custom C++/CUDA kernel instead of pure python operations. 63 | It has the same effect as setting environment variable `TF_ADDONS_PY_OPS=1`. 64 | """ 65 | global _TF_ADDONS_PY_OPS 66 | _TF_ADDONS_PY_OPS = True 67 | 68 | 69 | def is_custom_kernel_disabled(): 70 | """Return whether custom C++/CUDA kernel is disabled.""" 71 | return _TF_ADDONS_PY_OPS 72 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/tests/esn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Echo State recurrent Network (ESN).""" 16 | 17 | import pytest 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.layers.esn import ESN 21 | from tensorflow_addons.utils import test_utils 22 | 23 | 24 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 25 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 26 | def layer_test_esn(dtype): 27 | inp = np.asanyarray( 28 | [[[1.0, 1.0, 1.0, 1.0]], [[2.0, 2.0, 2.0, 2.0]], [[3.0, 3.0, 3.0, 3.0]]] 29 | ).astype(dtype) 30 | out = np.asarray([[2.5, 2.5, 2.5], [4.5, 4.5, 4.5], [6.5, 6.5, 6.5]]).astype(dtype) 31 | 32 | const_initializer = tf.constant_initializer(0.5) 33 | kwargs = { 34 | "units": 3, 35 | "connectivity": 1, 36 | "leaky": 1, 37 | "spectral_radius": 0.9, 38 | "use_norm2": True, 39 | "use_bias": True, 40 | "activation": None, 41 | "kernel_initializer": const_initializer, 42 | "recurrent_initializer": const_initializer, 43 | "bias_initializer": const_initializer, 44 | "dtype": dtype, 45 | } 46 | 47 | test_utils.layer_test(ESN, kwargs=kwargs, input_data=inp, expected_output=out) 48 | 49 | 50 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 51 | def test_serialization(dtype): 52 | esn = ESN( 53 | units=3, 54 | connectivity=1, 55 | leaky=1, 56 | spectral_radius=0.9, 57 | use_norm2=False, 58 | use_bias=True, 59 | activation=None, 60 | kernel_initializer="ones", 61 | recurrent_initializer="ones", 62 | bias_initializer="ones", 63 | ) 64 | serialized_esn = tf.keras.layers.serialize(esn) 65 | new_layer = tf.keras.layers.deserialize(serialized_esn) 66 | assert esn.get_config() == new_layer.get_config() 67 | -------------------------------------------------------------------------------- /tensorflow_addons/optimizers/tests/standard_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from tensorflow_addons import optimizers 21 | from tensorflow_addons.utils.test_utils import discover_classes 22 | 23 | class_exceptions = [ 24 | "MultiOptimizer", # is wrapper 25 | "SGDW", # is wrapper 26 | "AdamW", # is wrapper 27 | "SWA", # is wrapper 28 | "AveragedOptimizerWrapper", # is wrapper 29 | "ConditionalGradient", # is wrapper 30 | "Lookahead", # is wrapper 31 | "MovingAverage", # is wrapper 32 | ] 33 | 34 | 35 | classes_to_test = discover_classes( 36 | optimizers, tf.keras.optimizers.Optimizer, class_exceptions 37 | ) 38 | 39 | 40 | @pytest.mark.parametrize("optimizer", classes_to_test) 41 | @pytest.mark.parametrize("serialize", [True, False]) 42 | def test_optimizer_minimize_serialize(optimizer, serialize, tmpdir): 43 | """ 44 | Purpose of this test is to confirm that the optimizer can minimize the loss in toy conditions. 45 | It also tests for serialization as a parameter. 46 | """ 47 | model = tf.keras.Sequential([tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1)]) 48 | 49 | x = np.array(np.ones([1])) 50 | y = np.array(np.zeros([1])) 51 | 52 | opt = optimizer() 53 | loss = tf.keras.losses.MSE 54 | 55 | model.compile(optimizer=opt, loss=loss) 56 | 57 | # serialize whole model including optimizer, clear the session, then reload the whole model. 58 | # successfully serialized optimizers should not require a compile before training 59 | if serialize: 60 | model.save(str(tmpdir), save_format="tf") 61 | tf.keras.backend.clear_session() 62 | model = tf.keras.models.load_model(str(tmpdir)) 63 | 64 | history = model.fit(x, y, batch_size=1, epochs=10) 65 | 66 | loss_values = history.history["loss"] 67 | 68 | np.testing.assert_array_less(loss_values[-1], loss_values[0]) 69 | -------------------------------------------------------------------------------- /.venv/bin/activate.fish: -------------------------------------------------------------------------------- 1 | # This file must be used with ". bin/activate.fish" *from fish* (http://fishshell.org) 2 | # you cannot run it directly 3 | 4 | function deactivate -d "Exit virtualenv and return to normal shell environment" 5 | # reset old environment variables 6 | if test -n "$_OLD_VIRTUAL_PATH" 7 | set -gx PATH $_OLD_VIRTUAL_PATH 8 | set -e _OLD_VIRTUAL_PATH 9 | end 10 | if test -n "$_OLD_VIRTUAL_PYTHONHOME" 11 | set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME 12 | set -e _OLD_VIRTUAL_PYTHONHOME 13 | end 14 | 15 | if test -n "$_OLD_FISH_PROMPT_OVERRIDE" 16 | functions -e fish_prompt 17 | set -e _OLD_FISH_PROMPT_OVERRIDE 18 | functions -c _old_fish_prompt fish_prompt 19 | functions -e _old_fish_prompt 20 | end 21 | 22 | set -e VIRTUAL_ENV 23 | if test "$argv[1]" != "nondestructive" 24 | # Self destruct! 25 | functions -e deactivate 26 | end 27 | end 28 | 29 | # unset irrelevant variables 30 | deactivate nondestructive 31 | 32 | set -gx VIRTUAL_ENV "/workspaces/addons/.venv" 33 | 34 | set -gx _OLD_VIRTUAL_PATH $PATH 35 | set -gx PATH "$VIRTUAL_ENV/bin" $PATH 36 | 37 | # unset PYTHONHOME if set 38 | if set -q PYTHONHOME 39 | set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME 40 | set -e PYTHONHOME 41 | end 42 | 43 | if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" 44 | # fish uses a function instead of an env var to generate the prompt. 45 | 46 | # save the current fish_prompt function as the function _old_fish_prompt 47 | functions -c fish_prompt _old_fish_prompt 48 | 49 | # with the original prompt function renamed, we can override with our own. 50 | function fish_prompt 51 | # Save the return status of the last command 52 | set -l old_status $status 53 | 54 | # Prompt override? 55 | if test -n "(.venv) " 56 | printf "%s%s" "(.venv) " (set_color normal) 57 | else 58 | # ...Otherwise, prepend env 59 | set -l _checkbase (basename "$VIRTUAL_ENV") 60 | if test $_checkbase = "__" 61 | # special case for Aspen magic directories 62 | # see https://aspen.io/ 63 | printf "%s[%s]%s " (set_color -b blue white) (basename (dirname "$VIRTUAL_ENV")) (set_color normal) 64 | else 65 | printf "%s(%s)%s" (set_color -b blue white) (basename "$VIRTUAL_ENV") (set_color normal) 66 | end 67 | end 68 | 69 | # Restore the return status of the previous command. 70 | echo "exit $old_status" | . 71 | _old_fish_prompt 72 | end 73 | 74 | set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV" 75 | end 76 | -------------------------------------------------------------------------------- /tensorflow_addons/layers/poincare.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementing PoincareNormalize layer.""" 16 | 17 | import tensorflow as tf 18 | from typeguard import typechecked 19 | from typing import Union, List 20 | 21 | 22 | @tf.keras.utils.register_keras_serializable(package="Addons") 23 | class PoincareNormalize(tf.keras.layers.Layer): 24 | """Project into the Poincare ball with `norm <= 1.0 - epsilon`. 25 | 26 | See [Poincaré Embeddings for Learning Hierarchical Representations](https://arxiv.org/pdf/1705.08039.pdf), 27 | and [wiki](https://en.wikipedia.org/wiki/Poincare_ball_model). 28 | 29 | For a 1-D tensor with `axis = 0`, computes 30 | 31 | (x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon 32 | output = 33 | x otherwise 34 | 35 | For `x` with more dimensions, independently normalizes each 1-D slice along 36 | dimension `axis`. 37 | 38 | Args: 39 | axis: Axis along which to normalize. A scalar or a vector of integers. 40 | epsilon: A small deviation from the edge of the unit sphere for 41 | numerical stability. 42 | """ 43 | 44 | @typechecked 45 | def __init__( 46 | self, axis: Union[None, int, List[int]] = 1, epsilon: float = 1e-5, **kwargs 47 | ): 48 | super().__init__(**kwargs) 49 | self.axis = axis 50 | self.epsilon = epsilon 51 | 52 | def call(self, inputs): 53 | x = tf.convert_to_tensor(inputs) 54 | square_sum = tf.math.reduce_sum(tf.math.square(x), self.axis, keepdims=True) 55 | x_inv_norm = tf.math.rsqrt(square_sum) 56 | x_inv_norm = tf.math.minimum((1.0 - self.epsilon) * x_inv_norm, 1.0) 57 | outputs = tf.math.multiply(x, x_inv_norm) 58 | return outputs 59 | 60 | def compute_output_shape(self, input_shape): 61 | return input_shape 62 | 63 | def get_config(self): 64 | config = {"axis": self.axis, "epsilon": self.epsilon} 65 | base_config = super().get_config() 66 | return {**base_config, **config} 67 | -------------------------------------------------------------------------------- /tools/docker/build_wheel.Dockerfile: -------------------------------------------------------------------------------- 1 | #syntax=docker/dockerfile:1.1.5-experimental 2 | ARG TF_VERSION 3 | ARG PY_VERSION 4 | FROM gcr.io/tensorflow-testing/nosla-cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython as base_install 5 | ENV TF_NEED_CUDA="1" 6 | 7 | # Required for setuptools v50.0.0 8 | # https://setuptools.readthedocs.io/en/latest/history.html#v50-0-0 9 | # https://github.com/pypa/setuptools/issues/2352 10 | ENV SETUPTOOLS_USE_DISTUTILS=stdlib 11 | 12 | # Fix presented in 13 | # https://stackoverflow.com/questions/44967202/pip-is-showing-error-lsb-release-a-returned-non-zero-exit-status-1/44967506 14 | RUN echo "#! /usr/bin/python2.7" >> /usr/bin/lsb_release2 15 | RUN cat /usr/bin/lsb_release >> /usr/bin/lsb_release2 16 | RUN mv /usr/bin/lsb_release2 /usr/bin/lsb_release 17 | 18 | ARG PY_VERSION 19 | RUN ln -sf /usr/local/bin/python$PY_VERSION /usr/bin/python 20 | 21 | ARG TF_VERSION 22 | RUN python -m pip install --default-timeout=1000 tensorflow==$TF_VERSION 23 | 24 | COPY tools/install_deps/ /install_deps 25 | RUN python -m pip install -r /install_deps/pytest.txt 26 | 27 | COPY requirements.txt . 28 | RUN python -m pip install -r requirements.txt 29 | 30 | COPY ./ /addons 31 | RUN rm /addons/.bazelversion 32 | WORKDIR /addons 33 | 34 | # ------------------------------------------------------------------- 35 | FROM base_install as tfa_gpu_tests 36 | CMD ["bash", "tools/testing/build_and_run_tests.sh"] 37 | 38 | # ------------------------------------------------------------------- 39 | FROM base_install as make_wheel 40 | ARG NIGHTLY_FLAG 41 | ARG NIGHTLY_TIME 42 | 43 | RUN python configure.py 44 | 45 | RUN bash tools/testing/build_and_run_tests.sh && \ 46 | bazel build \ 47 | --noshow_progress \ 48 | --noshow_loading_progress \ 49 | --verbose_failures \ 50 | --test_output=errors \ 51 | --crosstool_top=//build_deps/toolchains/gcc7_manylinux2010-nvcc-cuda11:toolchain \ 52 | build_pip_pkg && \ 53 | # Package Whl 54 | bazel-bin/build_pip_pkg artifacts $NIGHTLY_FLAG 55 | 56 | RUN bash tools/releases/tf_auditwheel_patch.sh 57 | RUN python -m auditwheel repair --plat manylinux2010_x86_64 artifacts/*.whl 58 | RUN ls -al wheelhouse/ 59 | 60 | # ------------------------------------------------------------------- 61 | 62 | FROM python:$PY_VERSION as test_wheel_in_fresh_environment 63 | 64 | ARG TF_VERSION 65 | RUN python -m pip install --default-timeout=1000 tensorflow==$TF_VERSION 66 | 67 | COPY --from=make_wheel /addons/wheelhouse/ /addons/wheelhouse/ 68 | RUN pip install /addons/wheelhouse/*.whl 69 | 70 | RUN python -c "import tensorflow_addons as tfa; print(tfa.register_all())" 71 | 72 | # ------------------------------------------------------------------- 73 | FROM scratch as output 74 | 75 | COPY --from=test_wheel_in_fresh_environment /addons/wheelhouse/ . 76 | -------------------------------------------------------------------------------- /tensorflow_addons/metrics/tests/harmonic_mean_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests HarmonicMean metrics.""" 16 | 17 | import numpy as np 18 | import pytest 19 | import tensorflow as tf 20 | 21 | from tensorflow_addons.metrics import HarmonicMean 22 | 23 | 24 | def get_test_data(): 25 | return [ 26 | ([np.inf] * 2, 0), 27 | ([0, 0, 0, 0], 0), 28 | ([1, 4, 4], 2.0), 29 | ([0, 0, 0, 0, 0, 0, 0, 1, 2, 6], 0), 30 | ([0.2, 0.5, 0.3, 0.6, 0.1, 0.7], 0.25609756), 31 | ([8, 4, 1, 7, 2, 11, 9, 22, 52], 3.9394846), 32 | ([8.2, 9.7, 9.1, 2.7, 1.1, 2.0], 2.8376906), 33 | ([0.6666666, 0.215213, 0.15167], 0.23548213), 34 | ] 35 | 36 | 37 | def assert_result(expected, result): 38 | np.testing.assert_allclose(expected, result, atol=1e-6) 39 | 40 | 41 | def check_result(obj, expected_result, expected_count): 42 | result = obj.result().numpy() 43 | count = obj.count.numpy() 44 | assert_result(expected_result, result) 45 | np.testing.assert_equal(expected_count, count) 46 | 47 | 48 | @pytest.mark.parametrize("values, expected", get_test_data()) 49 | def test_vector_update_state_hmean(values, expected): 50 | obj = HarmonicMean() 51 | values = tf.constant(values, tf.float32) 52 | obj.update_state(values) 53 | check_result(obj, expected, len(values)) 54 | 55 | 56 | @pytest.mark.parametrize("values, expected", get_test_data()) 57 | def test_call_hmean(values, expected): 58 | obj = HarmonicMean() 59 | result = obj(tf.constant(values, tf.float32)) 60 | count = obj.count.numpy() 61 | assert_result(expected, result) 62 | np.testing.assert_equal(len(values), count) 63 | 64 | 65 | @pytest.mark.parametrize( 66 | "values, sample_weight, expected", 67 | [ 68 | ([1, 2, 3, 4, 5], 1, 2.1897807), 69 | ([2.1, 4.6, 7.1], [1, 2, 3], 4.499409), 70 | ([9.6, 1.8, 8.2], [0.2, 0.5, 0.3], 2.9833248), 71 | ], 72 | ) 73 | def test_sample_weight_hmean(values, sample_weight, expected): 74 | obj = HarmonicMean() 75 | obj.update_state(values, sample_weight=sample_weight) 76 | assert_result(expected, obj.result().numpy()) 77 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Brief Description of the PR: 4 | 5 | Fixes # (issue) 6 | 7 | ## Type of change 8 | 9 | - [ ] Bug fix 10 | - [ ] New Tutorial 11 | - [ ] Updated or additional documentation 12 | - [ ] Additional Testing 13 | - [ ] New Activation and the changes conform to the [activation contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/activations/README.md#contribution-guidelines) 14 | - [ ] New Callback and the changes conform to the [callback contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/callbacks/README.md#contribution-guidelines) 15 | - [ ] New Image addition and the changes conform to the [image op contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/README.md#contribution-guidelines) 16 | - [ ] New Layer and the changes conform to the [layer contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/layers/README.md#contribution-guidelines) 17 | - [ ] New Loss and the changes conform to the [loss contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/losses/README.md#contribution-guidelines) 18 | - [ ] New Metric and the changes conform to the [metric contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/metrics/README.md#contribution-guidelines) 19 | - [ ] New Optimizer and the changes conform to the [optimizer contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/README.md#contribution-guidelines) 20 | - [ ] New RNN Cell and the changes conform to the [rnn contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/rnn/README.md#contribution-guidelines) 21 | - [ ] New Seq2seq addition and the changes conform to the [seq2seq contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/seq2seq/README.md#contribution-guidelines) 22 | - [ ] New Text addition and the changes conform to the [text op contribution guidelines](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/text/README.md#contribution-guidelines) 23 | 24 | # Checklist: 25 | 26 | - [ ] I've properly [formatted my code according to the guidelines](https://github.com/tensorflow/addons/blob/master/CONTRIBUTING.md#coding-style) 27 | - [ ] By running Black + Flake8 28 | - [ ] By running pre-commit hooks 29 | - [ ] This PR addresses an already submitted issue for TensorFlow Addons 30 | - [ ] I have made corresponding changes to the documentation 31 | - [ ] I have added tests that prove my fix is effective or that my feature works 32 | - [ ] This PR contains modifications to C++ custom-ops 33 | 34 | # How Has This Been Tested? 35 | 36 | If you're adding a bugfix or new feature please describe the tests that you ran to verify your changes: 37 | * 38 | -------------------------------------------------------------------------------- /tensorflow_addons/activations/tests/rrelu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import pytest 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow_addons.activations import rrelu 21 | from tensorflow_addons.utils import test_utils 22 | 23 | SEED = 111111 24 | 25 | 26 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 27 | @pytest.mark.parametrize("training", [True, False]) 28 | def test_rrelu_old(dtype, training): 29 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 30 | lower = 0.1 31 | upper = 0.2 32 | 33 | tf.random.set_seed(SEED) 34 | training_results = { 35 | np.float16: [-0.288330078, -0.124206543, 0, 1, 2], 36 | np.float32: [-0.26851666, -0.116421416, 0, 1, 2], 37 | np.float64: [-0.3481333923206531, -0.17150176242558851, 0, 1, 2], 38 | } 39 | result = rrelu(x, lower, upper, training=training, seed=SEED) 40 | if training: 41 | expect_result = training_results.get(dtype) 42 | else: 43 | expect_result = [ 44 | -0.30000001192092896, 45 | -0.15000000596046448, 46 | 0, 47 | 1, 48 | 2, 49 | ] 50 | test_utils.assert_allclose_according_to_type(result, expect_result) 51 | 52 | 53 | @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) 54 | @pytest.mark.parametrize("training", [True, False]) 55 | def test_rrelu(dtype, training): 56 | x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) 57 | lower = 0.1 58 | upper = 0.2 59 | training_results = { 60 | np.float16: [-0.3826, -0.165, 0, 1, 2], 61 | np.float32: [-0.282151192, -0.199812651, 0, 1, 2], 62 | np.float64: [-0.25720977, -0.1221586, 0, 1, 2], 63 | } 64 | result = rrelu( 65 | x, 66 | lower, 67 | upper, 68 | training=training, 69 | seed=None, 70 | rng=tf.random.Generator.from_seed(SEED), 71 | ) 72 | if training: 73 | expect_result = training_results.get(dtype) 74 | else: 75 | expect_result = [-0.30000001192092896, -0.15000000596046448, 0, 1, 2] 76 | test_utils.assert_allclose_according_to_type(result, expect_result) 77 | -------------------------------------------------------------------------------- /tensorflow_addons/image/tests/compose_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests of augmentation ops""" 16 | 17 | import pytest 18 | import tensorflow as tf 19 | import numpy as np 20 | 21 | from tensorflow_addons.image import compose_ops 22 | 23 | _DTYPES = { 24 | tf.dtypes.uint8, 25 | tf.dtypes.int32, 26 | tf.dtypes.int64, 27 | tf.dtypes.float16, 28 | tf.dtypes.float32, 29 | tf.dtypes.float64, 30 | } 31 | 32 | 33 | def blend_np(image1, image2, factor): 34 | image1 = image1.astype("float32") 35 | image2 = image2.astype("float32") 36 | difference = image2 - image1 37 | scaled = factor * difference 38 | temp = image1 + scaled 39 | if factor >= 0.0 and factor <= 1.0: 40 | temp = np.round(temp) 41 | return temp 42 | temp = np.round(np.clip(temp, 0.0, 255.0)) 43 | return temp 44 | 45 | 46 | @pytest.mark.usefixtures("maybe_run_functions_eagerly") 47 | @pytest.mark.parametrize("dtype", _DTYPES) 48 | def test_blend(dtype): 49 | image1 = tf.constant( 50 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], dtype=dtype 51 | ) 52 | image2 = tf.constant( 53 | [ 54 | [255, 255, 255, 255], 55 | [255, 255, 255, 255], 56 | [255, 255, 255, 255], 57 | [255, 255, 255, 255], 58 | ], 59 | dtype=dtype, 60 | ) 61 | blended = compose_ops.blend(image1, image2, 0.5).numpy() 62 | np.testing.assert_equal( 63 | blended, 64 | [ 65 | [128, 128, 128, 128], 66 | [128, 128, 128, 128], 67 | [128, 128, 128, 128], 68 | [128, 128, 128, 128], 69 | ], 70 | ) 71 | 72 | np.random.seed(0) 73 | image1 = np.random.randint(0, 255, (3, 5, 5), np.uint8) 74 | image2 = np.random.randint(0, 255, (3, 5, 5), np.uint8) 75 | tf.random.set_seed(0) 76 | factor = tf.random.uniform(shape=[], maxval=1, dtype=tf.dtypes.float32, seed=0) 77 | blended = compose_ops.blend( 78 | tf.convert_to_tensor(image1), tf.convert_to_tensor(image2), factor 79 | ).numpy() 80 | expected = blend_np(image1, image2, factor.numpy()) 81 | np.testing.assert_equal(blended, expected) 82 | assert blended.dtype == expected.dtype 83 | -------------------------------------------------------------------------------- /tensorflow_addons/image/compose_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Compose Ops""" 16 | 17 | import tensorflow as tf 18 | 19 | from tensorflow_addons.utils.types import TensorLike, Number 20 | 21 | 22 | def blend(image1: TensorLike, image2: TensorLike, factor: Number) -> tf.Tensor: 23 | """Blend `image1` and `image2` using `factor`. 24 | 25 | Factor can be above 0.0. A value of 0.0 means only `image1` is used. 26 | A value of 1.0 means only `image2` is used. A value between 0.0 and 27 | 1.0 means we linearly interpolate the pixel values between the two 28 | images. A value greater than 1.0 "extrapolates" the difference 29 | between the two pixel values, and we clip the results to values 30 | between 0 and 255. 31 | 32 | Args: 33 | image1: An image Tensor of shape 34 | `(num_rows, num_columns, num_channels)` (HWC), or 35 | `(num_rows, num_columns)` (HW), or 36 | `(num_channels, num_rows, num_columns)` (CHW). 37 | image2: An image Tensor of shape 38 | `(num_rows, num_columns, num_channels)` (HWC), or 39 | `(num_rows, num_columns)` (HW), or 40 | `(num_channels, num_rows, num_columns)`. 41 | factor: A floating point value or Tensor of type `tf.float32` above 0.0. 42 | 43 | Returns: 44 | A blended image Tensor of `tf.float32`. 45 | 46 | """ 47 | with tf.name_scope("blend"): 48 | 49 | if factor == 0.0: 50 | return tf.convert_to_tensor(image1) 51 | if factor == 1.0: 52 | return tf.convert_to_tensor(image2) 53 | 54 | image1 = tf.cast(image1, dtype=tf.dtypes.float32) 55 | image2 = tf.cast(image2, dtype=tf.dtypes.float32) 56 | 57 | difference = image2 - image1 58 | scaled = factor * difference 59 | 60 | # Do addition in float. 61 | temp = image1 + scaled 62 | 63 | # Interpolate 64 | if factor > 0.0 and factor < 1.0: 65 | # Interpolation means we always stay within 0 and 255. 66 | temp = tf.round(temp) 67 | return temp 68 | 69 | # Extrapolate: 70 | # 71 | # We need to clip and then cast. 72 | temp = tf.round(tf.clip_by_value(temp, 0.0, 255.0)) 73 | return temp 74 | --------------------------------------------------------------------------------