├── .editorconfig ├── .github └── workflows │ ├── run_linter.yml │ └── run_tests.yml ├── .gitignore ├── .readthedocs.yml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── README.md ├── requirements.txt └── source │ ├── _static │ ├── objax.js │ └── theme_overrides.css │ ├── advanced │ ├── gradients.rst │ ├── io.rst │ ├── jit.rst │ └── variables_and_modules.rst │ ├── conf.py │ ├── dev │ ├── adding_module.rst │ └── setup.rst │ ├── examples.rst │ ├── faq.rst │ ├── index.rst │ ├── installation_setup.rst │ ├── notebooks │ ├── Custom_Networks.ipynb │ ├── Logistic_Regression.ipynb │ └── Objax_Basics.ipynb │ ├── objax │ ├── functional.rst │ ├── index.rst │ ├── io.rst │ ├── jaxboard.rst │ ├── nn.rst │ ├── objax.rst │ ├── optimizer.rst │ ├── privacy.rst │ ├── random.rst │ ├── util.rst │ └── zoo.rst │ └── tutorials.rst ├── examples ├── README.md ├── fixmatch │ ├── README.md │ ├── fixmatch.py │ ├── libml │ │ ├── __init__.py │ │ ├── augment │ │ │ ├── __init__.py │ │ │ ├── augment.py │ │ │ ├── core.py │ │ │ ├── ctaugment.py │ │ │ ├── randaugment │ │ │ │ ├── __init__.py │ │ │ │ ├── augment_ops.py │ │ │ │ └── randaugment.py │ │ │ └── tf_ctaugment.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── fsl.py │ │ │ └── ssl.py │ │ ├── models.py │ │ ├── train.py │ │ ├── util.py │ │ └── zoo │ │ │ ├── convnet.py │ │ │ └── resnet.py │ └── scripts │ │ ├── create_datasets.py │ │ ├── create_split.py │ │ ├── create_unlabeled.py │ │ └── extract_accuracy.py ├── gpt-2 │ ├── README.md │ └── gpt2.py ├── image_classification │ ├── README.md │ ├── __init__.py │ ├── cifar10_advanced.py │ ├── cifar10_simple.py │ ├── horses_or_humans_logistic.py │ ├── imagenet_pretrained_vgg.md │ ├── imagenet_pretrained_vgg.py │ ├── imagenet_resnet50.md │ ├── imagenet_resnet50_data.py │ ├── imagenet_resnet50_train.py │ ├── mnist_cnn.py │ ├── mnist_dnn.py │ ├── mnist_dp.py │ └── tfdata │ │ ├── __init__.py │ │ └── data.py ├── jaxboard │ ├── README.md │ └── summary.py ├── maml │ ├── README.md │ └── maml.py ├── requirements.txt ├── text_generation │ ├── README.md │ └── shakespeare_rnn.py └── tutorials │ ├── cifar10.ipynb │ ├── metric-learning.ipynb │ ├── mnist-tutorial.ipynb │ └── objax_to_tf.ipynb ├── objax ├── __init__.py ├── _patch_jax.py ├── _version.py ├── constants.py ├── functional │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── ops.py │ │ └── pooling.py │ ├── divergence.py │ ├── loss.py │ └── parallel.py ├── gradient.py ├── io │ ├── __init__.py │ ├── checkpoint.py │ └── ops.py ├── jaxboard.py ├── module.py ├── nn │ ├── __init__.py │ ├── init.py │ └── layers.py ├── optimizer │ ├── __init__.py │ ├── adam.py │ ├── ema.py │ ├── lars.py │ ├── momentum.py │ ├── scheduler.py │ └── sgd.py ├── privacy │ ├── __init__.py │ └── dpsgd │ │ ├── __init__.py │ │ ├── gradient.py │ │ └── privacyaccountant.py ├── random │ ├── __init__.py │ └── random.py ├── typing.py ├── util │ ├── __init__.py │ ├── check.py │ ├── image.py │ ├── objax2tf.py │ ├── tracing.py │ └── util.py ├── variable.py └── zoo │ ├── __init__.py │ ├── convnet.py │ ├── dnnet.py │ ├── resnet_v2.py │ ├── rnn.py │ ├── vgg.py │ └── wide_resnet.py ├── requirements.txt ├── setup.py └── tests ├── conv.py ├── conv_transpose.py ├── dropout.py ├── functional_interpolate.py ├── functional_pooling.py ├── gradient.py ├── jit.py ├── linear.py ├── loss.py ├── module.py ├── nn_init.py ├── nn_moving_average.py ├── normalization.py ├── objax2tf.py ├── optimizer.py ├── parallel.py ├── repr.py ├── requirements.txt ├── resnet_v2.py ├── run_linter.sh ├── run_tests.sh ├── scan.py ├── scheduler.py ├── sequential.py ├── testio.py ├── testrandom.py ├── tracing.py ├── util.py ├── util_image.py ├── var_collection.py ├── variable.py ├── vectorize.py └── wide_resnet.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig file to specify formatting, for more details: https://EditorConfig.org 2 | 3 | [*] 4 | charset = utf-8 5 | end_of_line = lf 6 | indent_size = 4 7 | indent_style = space 8 | insert_final_newline = false 9 | max_line_length = 120 10 | tab_width = 4 11 | ij_continuation_indent_size = 8 12 | ij_formatter_off_tag = @formatter:off 13 | ij_formatter_on_tag = @formatter:on 14 | ij_formatter_tags_enabled = false 15 | ij_smart_tabs = false 16 | ij_wrap_on_typing = false 17 | 18 | [{*.py,*.pyw}] 19 | ij_python_align_collections_and_comprehensions = true 20 | ij_python_align_multiline_imports = true 21 | ij_python_align_multiline_parameters = true 22 | ij_python_align_multiline_parameters_in_calls = true 23 | ij_python_blank_line_at_file_end = true 24 | ij_python_blank_lines_after_imports = 1 25 | ij_python_blank_lines_after_local_imports = 0 26 | ij_python_blank_lines_around_class = 1 27 | ij_python_blank_lines_around_method = 1 28 | ij_python_blank_lines_around_top_level_classes_functions = 2 29 | ij_python_blank_lines_before_first_method = 0 30 | ij_python_dict_alignment = 0 31 | ij_python_dict_new_line_after_left_brace = false 32 | ij_python_dict_new_line_before_right_brace = false 33 | ij_python_dict_wrapping = 1 34 | ij_python_from_import_new_line_after_left_parenthesis = false 35 | ij_python_from_import_new_line_before_right_parenthesis = false 36 | ij_python_from_import_parentheses_force_if_multiline = false 37 | ij_python_from_import_trailing_comma_if_multiline = false 38 | ij_python_from_import_wrapping = 1 39 | ij_python_hang_closing_brackets = false 40 | ij_python_keep_blank_lines_in_code = 1 41 | ij_python_keep_blank_lines_in_declarations = 1 42 | ij_python_keep_indents_on_empty_lines = false 43 | ij_python_keep_line_breaks = true 44 | ij_python_new_line_after_colon = false 45 | ij_python_new_line_after_colon_multi_clause = true 46 | ij_python_optimize_imports_always_split_from_imports = false 47 | ij_python_optimize_imports_case_insensitive_order = false 48 | ij_python_optimize_imports_join_from_imports_with_same_source = false 49 | ij_python_optimize_imports_sort_by_type_first = true 50 | ij_python_optimize_imports_sort_imports = true 51 | ij_python_optimize_imports_sort_names_in_from_imports = false 52 | ij_python_space_after_comma = true 53 | ij_python_space_after_number_sign = true 54 | ij_python_space_after_py_colon = true 55 | ij_python_space_before_backslash = true 56 | ij_python_space_before_comma = false 57 | ij_python_space_before_for_semicolon = false 58 | ij_python_space_before_lbracket = false 59 | ij_python_space_before_method_call_parentheses = false 60 | ij_python_space_before_method_parentheses = false 61 | ij_python_space_before_number_sign = true 62 | ij_python_space_before_py_colon = false 63 | ij_python_space_within_empty_method_call_parentheses = false 64 | ij_python_space_within_empty_method_parentheses = false 65 | ij_python_spaces_around_additive_operators = true 66 | ij_python_spaces_around_assignment_operators = true 67 | ij_python_spaces_around_bitwise_operators = true 68 | ij_python_spaces_around_eq_in_keyword_argument = false 69 | ij_python_spaces_around_eq_in_named_parameter = false 70 | ij_python_spaces_around_equality_operators = true 71 | ij_python_spaces_around_multiplicative_operators = true 72 | ij_python_spaces_around_power_operator = true 73 | ij_python_spaces_around_relational_operators = true 74 | ij_python_spaces_around_shift_operators = true 75 | ij_python_spaces_within_braces = false 76 | ij_python_spaces_within_brackets = false 77 | ij_python_spaces_within_method_call_parentheses = false 78 | ij_python_spaces_within_method_parentheses = false 79 | ij_python_use_continuation_indent_for_arguments = false 80 | ij_python_use_continuation_indent_for_collection_and_comprehensions = false 81 | ij_python_wrap_long_lines = false 82 | -------------------------------------------------------------------------------- /.github/workflows/run_linter.yml: -------------------------------------------------------------------------------- 1 | name: Style check 2 | 3 | # Trigger workflow on pull requests or push to master branch. 4 | on: 5 | push: 6 | branches: [ master ] 7 | pull_request: 8 | branches: [ master ] 9 | 10 | jobs: 11 | linting: 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | python-version: [3.9] 17 | 18 | steps: 19 | # Checks-out repository under $GITHUB_WORKSPACE 20 | - uses: actions/checkout@v2 21 | 22 | - name: Setup Python # Set Python version 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install flake8 31 | pip install --upgrade -r requirements.txt 32 | - name: Lint with flake8 33 | run: | 34 | export PYTHONPATH=$PYTHONPATH:. 35 | ./tests/run_linter.sh 36 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | 3 | # Trigger workflow on pull requests or push to master branch. 4 | on: 5 | push: 6 | branches: [ master ] 7 | pull_request: 8 | branches: [ master ] 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | python-version: [3.9] 17 | 18 | steps: 19 | # Checks-out repository under $GITHUB_WORKSPACE 20 | - uses: actions/checkout@v2 21 | 22 | - name: Setup Python # Set Python version 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install --upgrade -r requirements.txt 31 | pip install --upgrade -r tests/requirements.txt 32 | 33 | - name: Test with pytest 34 | run: | 35 | export PYTHONPATH=$PYTHONPATH:. 36 | CUDA_VISIBLE_DEVICES= pytest tests/*.py 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.iml 3 | *.log 4 | *.pyc 5 | .idea 6 | .vscode 7 | __pycache__ 8 | build/ 9 | dist/ 10 | docs/build/ 11 | docs/source/_autosummary 12 | experiments/ 13 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: "ubuntu-20.04" 10 | tools: 11 | python: "3.9" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/source/conf.py 16 | 17 | # Optionally build your docs in additional formats such as PDF and ePub 18 | formats: 19 | - htmlzip 20 | 21 | # Optionally set the version of Python and requirements required to build your docs 22 | python: 23 | install: 24 | - requirements: docs/requirements.txt 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | Below are some basic requirements, for a more detailed discussion see the 7 | [setup guide](https://objax.readthedocs.io/en/latest/dev/setup.html). 8 | 9 | In addition to them take a look at a 10 | [guide on adding new modules](https://objax.readthedocs.io/en/latest/dev/adding_module.html). 11 | 12 | ## Contributor License Agreement 13 | 14 | Contributions to this project must be accompanied by a Contributor License 15 | Agreement. You (or your employer) retain the copyright to your contribution; 16 | this simply gives us permission to use and redistribute your contributions as 17 | part of the project. Head over to to see 18 | your current agreements on file or to sign a new one. 19 | 20 | You generally only need to submit a CLA once, so if you've already submitted one 21 | (even if it was for a different project), you probably don't need to do it 22 | again. 23 | 24 | ## Code reviews 25 | 26 | All submissions, including submissions by project members, require review. We 27 | use GitHub pull requests for this purpose. Consult 28 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 29 | information on using pull requests. 30 | 31 | ## Running Tests 32 | 33 | Before submitting a PR, it can help to run the unit tests and linter. Install the linter 34 | ```bash 35 | pip install flake8 36 | ``` 37 | 38 | and then run 39 | ```bash 40 | ./tests/run_linter.sh 41 | ./tests/run_tests.sh 42 | ``` 43 | to confirm that the tests all pass. 44 | 45 | A single test can be run with 46 | ```bash 47 | CUDA_VISIBLE_DEVICES= python3 -m unittest tests/jit.py 48 | ``` 49 | 50 | ## Community Guidelines 51 | 52 | This project follows [Google's Open Source Community 53 | Guidelines](https://opensource.google.com/conduct/). 54 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Objax 2 | 3 | [**Tutorials**](https://objax.readthedocs.io/en/latest/notebooks/Objax_Basics.html) 4 | | [**Install**](https://objax.readthedocs.io/en/latest/installation_setup.html) 5 | | [**Documentation**](https://objax.readthedocs.io/en/latest/) 6 | | [**Philosophy**](https://objax.readthedocs.io/en/latest/index.html#objax-philosophy) 7 | 8 | This is not an officially supported Google product. 9 | 10 | Objax is an open source machine learning framework that accelerates research and learning thanks to a 11 | minimalist object-oriented design and a readable code base. 12 | Its name comes from the contraction of Object and [JAX](https://github.com/google/jax) -- a popular high-performance 13 | framework. 14 | Objax is designed **by researchers for researchers** with a focus on simplicity and understandability. 15 | Its users should be able to easily read, understand, extend, and modify it to fit their needs. 16 | 17 | This is the developer repository of Objax, there is very little user documentation 18 | here, for the full documentation go to [objax.readthedocs.io](https://objax.readthedocs.io/). 19 | 20 | You can find READMEs in the subdirectory of this project, for example: 21 | 22 | * [Sample Code](examples/README.md) 23 | * [Writing documentation](docs/README.md) 24 | 25 | 26 | ## User installation guide 27 | 28 | You install Objax using `pip` as follows: 29 | 30 | ```bash 31 | pip install --upgrade objax 32 | ``` 33 | 34 | Objax supports GPUs but assumes that you already have some version of CUDA 35 | installed. Here are the extra steps required to install CUDA-enabled jaxlib 36 | (jaxlib releases require CUDA 11.2 or newer): 37 | 38 | ```bash 39 | RELEASE_URL="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" 40 | JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'` 41 | pip uninstall -y jaxlib 42 | pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION 43 | ``` 44 | 45 | For more installation options, see https://github.com/google/jax#pip-installation-gpu-cuda 46 | 47 | ### Useful environment configurations 48 | 49 | Here are a few useful options: 50 | 51 | ```bash 52 | # Prevent JAX from taking the whole GPU memory 53 | # (useful if you want to run several programs on a single GPU) 54 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 55 | ``` 56 | 57 | ### Testing your installation 58 | 59 | You can test your installation by running the code below: 60 | 61 | ```python 62 | import jax 63 | import objax 64 | 65 | print(f'Number of GPUs {jax.device_count()}') 66 | 67 | x = objax.random.normal(shape=(100, 4)) 68 | m = objax.nn.Linear(nin=4, nout=5) 69 | print('Matrix product shape', m(x).shape) # (100, 5) 70 | 71 | x = objax.random.normal(shape=(100, 3, 32, 32)) 72 | m = objax.nn.Conv2D(nin=3, nout=4, k=3) 73 | print('Conv2D return shape', m(x).shape) # (100, 4, 32, 32) 74 | ``` 75 | 76 | Typically if you get errors running this using CUDA, it probably means your 77 | installation of CUDA or CuDNN has issues. 78 | 79 | ### Runing code examples 80 | 81 | Clone the code repository: 82 | 83 | ```bash 84 | git clone https://github.com/google/objax.git 85 | cd objax/examples 86 | ``` 87 | 88 | ### Citing Objax 89 | 90 | To cite this repository: 91 | 92 | ``` 93 | @software{objax2020github, 94 | author = {{Objax Developers}}, 95 | title = {{Objax}}, 96 | url = {https://github.com/google/objax}, 97 | version = {1.2.0}, 98 | year = {2020}, 99 | } 100 | ``` 101 | 102 | ## Developer documentation 103 | 104 | Here is information about 105 | [development setup](https://objax.readthedocs.io/en/latest/dev/setup.html) 106 | and a [guide on adding new code](https://objax.readthedocs.io/en/latest/dev/adding_module.html). 107 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | clean: 23 | rm -rf build source/_autosummary 24 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation folder 2 | 3 | The document uses `.rst` format which stands for reStructuredText 4 | (reST)](https://docutils.sourceforge.io/docs/user/rst/quickstart.html). 5 | 6 | [Cheat sheet](http://openalea.gforge.inria.fr/doc/openalea/doc/_build/html/source/sphinx/rest_syntax.html) 7 | for reST. 8 | 9 | ## Initial setup 10 | 11 | ```bash 12 | # Install python libraries 13 | pip install --upgrade -r docs/requirements.txt 14 | 15 | # Install pandoc, see also https://pandoc.org/installing.html 16 | sudo apt install pandoc 17 | ``` 18 | 19 | ## Building 20 | 21 | ```bash 22 | cd docs 23 | make clean 24 | PYTHONPATH=$PYTHONPATH:.. make html 25 | ``` 26 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx >= 3.0 2 | sphinx_rtd_theme 3 | recommonmark 4 | nbsphinx 5 | pandoc 6 | jupyter_client 7 | ipykernel 8 | objax 9 | # needed to avoid https://github.com/sphinx-doc/sphinx/issues/8198 10 | pygments>=2.4.1 11 | tensorflow 12 | . 13 | -------------------------------------------------------------------------------- /docs/source/_static/objax.js: -------------------------------------------------------------------------------- 1 | document.addEventListener("DOMContentLoaded", function () { 2 | document.body.innerHTML = document.body.innerHTML 3 | .replace(/<\/em>[\n^<]*em>/g, '') 4 | .replace(/Union\[jax\.numpy\.lax_numpy\.ndarray, jax.interpreters\.xla\.DeviceArray(, jax\.interpreters\.pxla\.ShardedDeviceArray)?]/g, 'objax.JaxArray'); 5 | }); 6 | -------------------------------------------------------------------------------- /docs/source/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | /* override table width restrictions */ 2 | @media screen and (min-width: 767px) { 3 | 4 | .wy-table-responsive table td { 5 | /* !important prevents the common CSS stylesheets from overriding 6 | this as on RTD they are loaded after this stylesheet */ 7 | white-space: normal !important; 8 | } 9 | 10 | .wy-table-responsive { 11 | overflow: visible !important; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Configuration file for the Sphinx documentation builder. 16 | # 17 | # This file only contains a selection of the most common options. For a full 18 | # list see the documentation: 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 20 | 21 | # -- Path setup -------------------------------------------------------------- 22 | 23 | # If extensions (or modules to document with autodoc) are in another directory, 24 | # add these directories to sys.path here. If the directory is relative to the 25 | # documentation root, use os.path.abspath to make it absolute, like shown here. 26 | # 27 | import os 28 | 29 | import sys 30 | 31 | import objax 32 | 33 | sys.path.insert(0, os.path.abspath('../..')) 34 | 35 | # -- Project information ----------------------------------------------------- 36 | 37 | project = 'Objax' 38 | copyright = '2020, Google LLC' 39 | author = 'Objax team' 40 | 41 | # The full version, including alpha/beta/rc tags 42 | release = objax.__version__ 43 | 44 | # -- General configuration --------------------------------------------------- 45 | 46 | # Add any Sphinx extension module names here, as strings. They can be 47 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 48 | # ones. 49 | extensions = [ 50 | 'nbsphinx', 51 | 'recommonmark', 52 | 'sphinx.ext.autodoc', 53 | 'sphinx.ext.mathjax', 54 | 'sphinx.ext.napoleon', 55 | 'sphinx.ext.autosummary', 56 | 'sphinx.ext.napoleon', 57 | 'sphinx.ext.viewcode', 58 | 'sphinx_rtd_theme', 59 | 'sphinx.ext.autosectionlabel', 60 | ] 61 | 62 | # Add any paths that contain templates here, relative to this directory. 63 | templates_path = ['_templates'] 64 | 65 | # The master toctree document. 66 | master_doc = 'index' 67 | 68 | # List of patterns, relative to source directory, that match files and 69 | # directories to ignore when looking for source files. 70 | # This pattern also affects html_static_path and html_extra_path. 71 | exclude_patterns = [] 72 | 73 | autosummary_generate = True 74 | napolean_use_rtype = False 75 | autodoc_default_options = { 76 | 'member-order': 'bysource', 77 | 'special-members': '__init__, __call__', 78 | 'undoc-members': True, 79 | 'exclude-members': '__weakref__,_abc_impl' 80 | } 81 | autodoc_typehints = 'description' 82 | 83 | # -- Options for HTML output ------------------------------------------------- 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | # 88 | # html_theme = 'alabaster' 89 | html_theme = 'sphinx_rtd_theme' 90 | 91 | # Add any paths that contain custom static files (such as style sheets) here, 92 | # relative to this directory. They are copied after the builtin static files, 93 | # so a file named "default.css" will overwrite the builtin "default.css". 94 | html_static_path = ['_static'] 95 | html_css_files = ['theme_overrides.css'] 96 | html_js_files = ['objax.js'] 97 | 98 | # -- Options for nbsphinx ----------------------------------------------------- 99 | 100 | # Execute notebooks before conversion: 'always', 'never', 'auto' (default) 101 | # We never execute notebooks to avoid problems if nbsphinx won't find all dependencies. 102 | nbsphinx_execute = 'never' 103 | 104 | # If True, the build process is continued even if an exception occurs: 105 | nbsphinx_allow_errors = True 106 | 107 | # Controls when a cell will time out (defaults to 30; use -1 for no timeout): 108 | nbsphinx_timeout = 180 109 | 110 | # Default Pygments lexer for syntax highlighting in code cells: 111 | nbsphinx_codecell_lexer = 'ipython3' 112 | 113 | nbsphinx_prolog = """ 114 | {% set docname = 'docs/source/' + env.doc2path(env.docname, base=None) %} 115 | 116 | .. only:: html 117 | 118 | .. role:: raw-html(raw) 119 | :format: html 120 | 121 | .. nbinfo:: 122 | Interactive online version: 123 | :raw-html:`Open In Colab` 124 | """ 125 | -------------------------------------------------------------------------------- /docs/source/dev/setup.rst: -------------------------------------------------------------------------------- 1 | Development setup 2 | ================= 3 | 4 | This section describes some basic setup to start developing and extending Objax. 5 | 6 | Environment setup 7 | ----------------- 8 | 9 | First of all you need to install all necessary dependencies. 10 | We recommend to setup a separate :code:`virtualenv` to work on Objax, 11 | it could be done with following commands on Ubuntu or similar Linux distribution: 12 | 13 | .. code-block:: bash 14 | 15 | # Install virtualenv if you haven't done so already 16 | sudo apt install python3-dev python3-virtualenv python3-tk imagemagick virtualenv pandoc 17 | # Create a virtual environment (for example ~/.venv/objax, you can use your name here) 18 | virtualenv -p python3 --system-site-packages ~/.venv/objax 19 | # Start the virtual environment 20 | . ~/.venv/objax/bin/activate 21 | 22 | # Clone objax git repository, if you haven't. 23 | git clone https://github.com/google/objax.git 24 | cd objax 25 | 26 | # Install python dependencies. 27 | pip install --upgrade -r requirements.txt 28 | pip install --upgrade -r tests/requirements.txt 29 | pip install --upgrade -r docs/requirements.txt 30 | pip install --upgrade -r examples/requirements.txt 31 | pip install flake8 32 | 33 | # jaxlib releases require CUDA 11.2 or newer 34 | RELEASE_URL="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" 35 | JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'` 36 | pip uninstall -y jaxlib 37 | pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION 38 | 39 | Running tests and linter 40 | ------------------------ 41 | 42 | Run linter: 43 | 44 | .. code-block:: bash 45 | 46 | ./tests/run_linter.sh 47 | 48 | Run tests: 49 | 50 | .. code-block:: bash 51 | 52 | ./tests/run_tests.sh 53 | 54 | Running a single test: 55 | 56 | .. code-block:: bash 57 | 58 | CUDA_VISIBLE_DEVICES= python3 -m unittest tests/jit.py 59 | -------------------------------------------------------------------------------- /docs/source/faq.rst: -------------------------------------------------------------------------------- 1 | Frequently Asked Questions 2 | ========================== 3 | 4 | What is the difference between Objax and other JAX frameworks? 5 | -------------------------------------------------------------- 6 | 7 | JAX itself as well as most of JAX frameworks (other than Objax) 8 | follows a functional style programming paradigm. 9 | This means that all computations are expected to be performed by 10 | stateless `pure functions `_. 11 | And state (i.e. model weights) has to be manually passed to these functions. 12 | 13 | On the other hand, Objax follows an object-oriented programming paradigm 14 | (similar to PyTorch and Tensorflow). 15 | Objax provides objects (called Objax modules) which store and manage 16 | the state of a neural network. 17 | 18 | To better illustrate this distinction, 19 | below are two examples of a similar code written in pure JAX and Objax. 20 | 21 | Every time when a user calls neural network components in JAX (and many JAX frameworks), 22 | they have to pass both neural network parameters :code:`params` 23 | as well as training examples :code:`batch['x'], batch['y']`:: 24 | 25 | params = (jn.zeros(ndim), jn.zeros(1)) 26 | 27 | def loss(params, x, y): 28 | w, b = params 29 | pred = jn.dot(x, w) + b 30 | return 0.5 * ((y - pred) ** 2).mean() 31 | 32 | g_fn = jax.grad(loss) # g_fn is a function 33 | 34 | # Need to pass both parameters and batch to g_fn 35 | g_value = g_fn(params, batch['x'], batch['y']) 36 | 37 | On the other, modules in Objax store parameters and state internally. 38 | Thus a user only has to pass around training examples :code:`batch['x'], batch['y']`:: 39 | 40 | w = objax.TrainVar(jn.zeros(ndim)) 41 | b = objax.TrainVar(jn.zeros(1)) 42 | 43 | def loss(x, y): 44 | pred = jn.dot(x, w) + b 45 | return 0.5 * ((y - pred) ** 2).mean() 46 | 47 | g_fn = objax.Grad(loss, # g_fn is Objax module 48 | objax.VarCollection({'w': w, 'b': b})) 49 | 50 | # Need to pass only batch to g_fn 51 | g_value = g_fn(batch['x'], batch['y']) 52 | 53 | What is the difference between Objax and PyTorch/Tensorflow? 54 | ------------------------------------------------------------ 55 | 56 | Execution runtime 57 | ^^^^^^^^^^^^^^^^^ 58 | 59 | Objax is implemented on top of JAX, 60 | while PyTorch and Tensorflow have their own underlying runtime environments. 61 | In practice it mainly means that to interoperate between these frameworks 62 | some conversion needs to be done. 63 | For example convert PyTorch/Tensorflow tensor to NumPy array 64 | and then feed this NumPy array to code in Objax. 65 | 66 | Design of API 67 | ^^^^^^^^^^^^^ 68 | 69 | Objax was inspired by the best of other machine learning frameworks 70 | (including PyTorch and Tensorflow). 71 | Thus readers may observe similarities between Objax API and API of PyTorch 72 | (or some other frameworks). 73 | 74 | Nevertheless, **Objax is not intended to be a re-implementation of the API 75 | of any other framework and each Objax design decision is weighted on its own merit**. 76 | So there will always be differences between Objax API and APIs of other frameworks. 77 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Objax's documentation! 2 | ================================= 3 | 4 | Objax is an open source machine learning framework that accelerates research and learning thanks to a 5 | minimalist object-oriented design and a readable code base. 6 | Its name comes from the contraction of Object and `JAX `_ -- a popular high-performance 7 | framework. 8 | Objax is designed **by researchers for researchers** with a focus on simplicity and understandability. 9 | Its users should be able to easily read, understand, extend, and modify it to fit their needs. 10 | 11 | :doc:`Try the 5 minutes tutorial. ` 12 | 13 | Machine learning's :code:`'Hello world'`: optimizing the weights of classifier ``net`` through gradient descent:: 14 | 15 | opt = objax.optimizer.Adam(net.vars()) 16 | 17 | @objax.Function.with_vars(net.vars()) 18 | def loss(x, y): 19 | logits = net(x) # Output of classifier on x 20 | xe = cross_entropy_logits(logits, y) 21 | return xe.mean() 22 | 23 | # Perform gradient descent wrt to net weights 24 | gv = objax.GradValues(loss, net.vars()) 25 | 26 | @objax.Function.with_vars(net.vars() + opt.vars()) 27 | def train_op(x, y): 28 | g, v = gv(x, y) # returns gradients g and loss v 29 | opt(lr, g) # update weights 30 | return v 31 | 32 | train_op = objax.Jit(train_op) 33 | 34 | Objax philosophy 35 | ---------------- 36 | 37 | .. epigraph:: 38 | 39 | Objax pursues the quest for the **simplest design and code** that's as **easy** as possible **to extend** 40 | without sacrificing **performance**. 41 | 42 | -- Objax Devs 43 | 44 | Motivation 45 | ^^^^^^^^^^ 46 | 47 | Researchers and students look at machine learning frameworks in their own way. 48 | Often they read the code of some technique, say an Adam optimizer, to understand how it works 49 | so they can extend it or design a new optimizer. 50 | This is how machine learning frameworks differ from standard libraries: a large class of 51 | users not only look at the APIs but also at the code behind these APIs. 52 | 53 | Coded for simplicity 54 | ^^^^^^^^^^^^^^^^^^^^ 55 | 56 | Source code should be understandable by everyone, including users without background in computer science. 57 | So how simple is it really? Judge for yourself with this tutorial: :doc:`notebooks/Logistic_Regression`. 58 | 59 | Object-oriented 60 | ^^^^^^^^^^^^^^^ 61 | It is common in machine learning to separate the inputs (:math:`X`) 62 | from the parameters (:math:`\theta`) of a function :math:`f(X; 63 | \theta)`. 64 | Math notation captures this difference by using a semi-colon to semantically separate the first group of arguments from the other. 65 | 66 | Objax represents this semantic distinction through :py:class:`objax.Module`: 67 | 68 | * the module's parameters :math:`\theta` are attributes of the form :code:`self.w, ...` 69 | * inputs :math:`X` are method arguments such as :code:`def __call__(self, x, y, ...):` 70 | 71 | Designed for flexibility 72 | ^^^^^^^^^^^^^^^^^^^^^^^^ 73 | 74 | Objax minimizes the number of abstractions users need to understand. There are two main ones: *Modules* and *Variables*. 75 | Everything is built out of these two basic classes. You can read more about this in :doc:`advanced/variables_and_modules`. 76 | 77 | Engineered for performance 78 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 79 | 80 | In machine learning, performance is essential. 81 | Every second counts. 82 | Objax makes it count by using the JAX/XLA engine that also powers TensorFlow. 83 | Read more about this in :doc:`advanced/jit`. 84 | 85 | 86 | .. toctree:: 87 | :maxdepth: 1 88 | :caption: Getting Started 89 | 90 | installation_setup 91 | notebooks/Objax_Basics 92 | notebooks/Logistic_Regression 93 | notebooks/Custom_Networks 94 | examples 95 | tutorials 96 | faq 97 | 98 | .. toctree:: 99 | :maxdepth: 2 100 | :caption: API documentation 101 | 102 | objax/index 103 | 104 | .. toctree:: 105 | :maxdepth: 2 106 | :caption: In-depth topics 107 | 108 | advanced/variables_and_modules 109 | advanced/gradients 110 | advanced/jit 111 | advanced/io 112 | 113 | .. toctree:: 114 | :maxdepth: 1 115 | :caption: Developer documentation 116 | 117 | dev/setup 118 | dev/adding_module 119 | 120 | 121 | 122 | Indices and tables 123 | ================== 124 | 125 | * :ref:`genindex` 126 | * :ref:`modindex` 127 | * :ref:`search` 128 | -------------------------------------------------------------------------------- /docs/source/installation_setup.rst: -------------------------------------------------------------------------------- 1 | Installation and Setup 2 | ====================== 3 | 4 | For developing or contributing to Objax, see :ref:`Development setup`. 5 | 6 | User installation 7 | ----------------- 8 | 9 | Install using :code:`pip` with the following command: 10 | 11 | .. code-block:: bash 12 | 13 | pip install --upgrade objax 14 | 15 | For GPU support, we assume you have already some version of CUDA installed (jaxlib releases require 16 | CUDA 11.2 or newer). Here are the extra steps: 17 | 18 | .. code-block:: bash 19 | 20 | RELEASE_URL="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" 21 | JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'` 22 | pip uninstall -y jaxlib 23 | pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION 24 | 25 | Useful shell configurations 26 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 27 | 28 | Here are a few useful options: 29 | 30 | .. code-block:: bash 31 | 32 | # Prevent JAX from taking the whole GPU memory 33 | # (useful if you want to run several programs on a single GPU) 34 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 35 | 36 | Testing your installation 37 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 38 | 39 | You can run the code below to test your installation:: 40 | 41 | import jax 42 | import objax 43 | 44 | print(f'Number of GPUs {jax.device_count()}') 45 | 46 | x = objax.random.normal((100, 4)) 47 | m = objax.nn.Linear(4, 5) 48 | print('Matrix product shape', m(x).shape) # (100, 5) 49 | 50 | x = objax.random.normal((100, 3, 32, 32)) 51 | m = objax.nn.Conv2D(3, 4, k=3) 52 | print('Conv2D return shape', m(x).shape) # (100, 4, 32, 32) 53 | 54 | If you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues. 55 | 56 | Installing examples 57 | ^^^^^^^^^^^^^^^^^^^ 58 | 59 | Clone the code repository: 60 | 61 | .. code-block:: bash 62 | 63 | git clone https://github.com/google/objax.git 64 | cd objax/examples 65 | -------------------------------------------------------------------------------- /docs/source/objax/functional.rst: -------------------------------------------------------------------------------- 1 | objax.functional package 2 | ======================== 3 | 4 | .. currentmodule:: objax.functional 5 | 6 | .. contents:: 7 | :local: 8 | :depth: 1 9 | 10 | objax.functional 11 | ---------------- 12 | 13 | .. currentmodule:: objax.functional 14 | 15 | Due to the large number of APIs in this section, we organized it into the following sub-sections: 16 | 17 | .. contents:: 18 | :local: 19 | :depth: 1 20 | 21 | Activation 22 | ^^^^^^^^^^ 23 | 24 | .. autosummary:: 25 | 26 | celu 27 | elu 28 | leaky_relu 29 | log_sigmoid 30 | log_softmax 31 | logsumexp 32 | relu 33 | selu 34 | sigmoid 35 | softmax 36 | softplus 37 | tanh 38 | 39 | .. autofunction:: celu 40 | .. autofunction:: elu 41 | .. autofunction:: leaky_relu 42 | .. autofunction:: log_sigmoid 43 | .. autofunction:: log_softmax 44 | .. autofunction:: logsumexp 45 | .. autofunction:: relu 46 | .. autofunction:: selu 47 | .. autofunction:: sigmoid 48 | .. autofunction:: softmax 49 | .. autofunction:: softplus 50 | .. autofunction:: tanh 51 | 52 | Pooling 53 | ^^^^^^^ 54 | 55 | .. autosummary:: 56 | 57 | average_pool_2d 58 | batch_to_space2d 59 | channel_to_space2d 60 | max_pool_2d 61 | space_to_batch2d 62 | space_to_channel2d 63 | 64 | .. autofunction:: average_pool_2d 65 | 66 | For a definition of pooling, including examples see 67 | `Pooling Layer `_. 68 | 69 | .. autofunction:: batch_to_space2d 70 | .. autofunction:: channel_to_space2d 71 | .. autofunction:: max_pool_2d 72 | 73 | For a definition of pooling, including examples see 74 | `Pooling Layer `_. 75 | 76 | .. autofunction:: space_to_batch2d 77 | .. autofunction:: space_to_channel2d 78 | 79 | Misc 80 | ^^^^ 81 | 82 | .. autosummary:: 83 | 84 | dynamic_slice 85 | flatten 86 | interpolate 87 | one_hot 88 | pad 89 | scan 90 | stop_gradient 91 | top_k 92 | rsqrt 93 | upsample_2d 94 | upscale_nn 95 | 96 | .. autofunction:: dynamic_slice 97 | .. autofunction:: flatten 98 | .. autofunction:: interpolate 99 | .. autofunction:: one_hot 100 | .. autofunction:: pad 101 | .. autofunction:: scan 102 | .. autofunction:: stop_gradient 103 | .. autofunction:: top_k 104 | .. autofunction:: rsqrt 105 | .. autofunction:: upsample_2d 106 | .. autofunction:: upscale_nn 107 | 108 | objax.functional.divergence 109 | --------------------------- 110 | 111 | .. currentmodule:: objax.functional.divergence 112 | 113 | .. autosummary:: 114 | 115 | kl 116 | 117 | .. autofunction:: kl 118 | 119 | .. math:: 120 | kl(p,q) = p \cdot \log{\frac{p + \epsilon}{q + \epsilon}} 121 | 122 | The :math:`\epsilon` term is added to ensure that neither :code:`p` nor :code:`q` are zero. 123 | 124 | objax.functional.loss 125 | --------------------- 126 | 127 | .. currentmodule:: objax.functional.loss 128 | 129 | .. autosummary:: 130 | 131 | cross_entropy_logits 132 | cross_entropy_logits_sparse 133 | l2 134 | mean_absolute_error 135 | mean_squared_error 136 | mean_squared_log_error 137 | sigmoid_cross_entropy_logits 138 | 139 | .. autofunction:: cross_entropy_logits 140 | 141 | Calculates the cross entropy loss, defined as follows: 142 | 143 | .. math:: 144 | 145 | \begin{eqnarray} 146 | l(y,\hat{y}) & = & - \sum_{j=1}^{q} y_j \log \frac{e^{o_j}}{\sum_{k=1}^{q} e^{o_k}} \nonumber \\ 147 | & = & \log \sum_{k=1}^{q} e^{o_k} - \sum_{j=1}^{q} y_j o_j \nonumber 148 | \end{eqnarray} 149 | 150 | where :math:`o_k` are the logits and :math:`y_k` are the labels. 151 | 152 | .. autofunction:: cross_entropy_logits_sparse 153 | 154 | .. autofunction:: l2 155 | 156 | Calculates the l2 loss, as: 157 | 158 | .. math:: 159 | l_2 = \frac{\sum_{i} x_{i}^2}{2} 160 | 161 | .. autofunction:: mean_absolute_error 162 | 163 | .. autofunction:: mean_squared_error 164 | 165 | .. autofunction:: sigmoid_cross_entropy_logits 166 | 167 | 168 | objax.functional.parallel 169 | ------------------------- 170 | 171 | .. currentmodule:: objax.functional.parallel 172 | 173 | .. autosummary:: 174 | 175 | pmax 176 | pmean 177 | pmin 178 | psum 179 | 180 | .. autofunction:: pmax 181 | .. autofunction:: pmean 182 | .. autofunction:: pmin 183 | .. autofunction:: psum 184 | 185 | 186 | -------------------------------------------------------------------------------- /docs/source/objax/index.rst: -------------------------------------------------------------------------------- 1 | Objax API 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | objax 8 | functional 9 | io 10 | jaxboard 11 | nn 12 | optimizer 13 | privacy 14 | random 15 | util 16 | zoo 17 | -------------------------------------------------------------------------------- /docs/source/objax/io.rst: -------------------------------------------------------------------------------- 1 | objax.io package 2 | ================ 3 | 4 | .. currentmodule:: objax.io 5 | 6 | .. autosummary:: 7 | 8 | Checkpoint 9 | load_var_collection 10 | save_var_collection 11 | 12 | .. autoclass:: Checkpoint 13 | :members: 14 | 15 | .. autofunction:: load_var_collection 16 | 17 | .. autofunction:: save_var_collection 18 | -------------------------------------------------------------------------------- /docs/source/objax/jaxboard.rst: -------------------------------------------------------------------------------- 1 | objax.jaxboard package 2 | ====================== 3 | 4 | .. currentmodule:: objax.jaxboard 5 | 6 | .. autosummary:: 7 | 8 | Reducer 9 | Summary 10 | SummaryWriter 11 | 12 | .. autoclass:: Summary 13 | :members: 14 | 15 | .. autoclass:: SummaryWriter 16 | :members: 17 | -------------------------------------------------------------------------------- /docs/source/objax/privacy.rst: -------------------------------------------------------------------------------- 1 | objax.privacy.dpsgd package 2 | =========================== 3 | 4 | .. currentmodule:: objax.privacy.dpsgd 5 | 6 | .. autosummary:: 7 | 8 | PrivateGradValues 9 | analyze_dp 10 | analyze_renyi 11 | convert_renyidp_to_dp 12 | 13 | .. automodule:: objax.privacy.dpsgd 14 | :members: 15 | :imported-members: 16 | -------------------------------------------------------------------------------- /docs/source/objax/random.rst: -------------------------------------------------------------------------------- 1 | objax.random package 2 | ==================== 3 | 4 | .. currentmodule:: objax.random 5 | 6 | .. autosummary:: 7 | 8 | Generator 9 | normal 10 | randint 11 | truncated_normal 12 | uniform 13 | 14 | .. autoclass:: Generator 15 | :members: __init__, seed, __call__, key 16 | 17 | The default generator can be accessed through :code:`objax.random.DEFAULT_GENERATOR`. 18 | Its seed is **0** by default, and can be set through :code:`objax.random.DEFAULT_GENERATOR.seed(s)` 19 | where integer **s** is the desired seed. 20 | 21 | .. autofunction:: normal 22 | .. autofunction:: randint 23 | .. autofunction:: truncated_normal 24 | .. autofunction:: uniform 25 | -------------------------------------------------------------------------------- /docs/source/objax/util.rst: -------------------------------------------------------------------------------- 1 | objax.util package 2 | ================== 3 | 4 | .. currentmodule:: objax.util 5 | 6 | objax.util 7 | ---------- 8 | 9 | .. autosummary:: 10 | 11 | EasyDict 12 | Objax2Tf 13 | Renamer 14 | args_indexes 15 | dummy_context_mgr 16 | ilog2 17 | positional_args_names 18 | to_tuple 19 | 20 | .. autoclass:: EasyDict 21 | :members: 22 | :inherited-members: 23 | 24 | .. autoclass:: Objax2Tf 25 | :members: 26 | 27 | .. autoclass:: Renamer 28 | :members: 29 | 30 | .. autofunction:: args_indexes 31 | 32 | .. autofunction:: dummy_context_mgr 33 | 34 | .. autofunction:: find_used_variables 35 | 36 | .. autofunction:: ilog2 37 | 38 | .. autofunction:: positional_args_names 39 | 40 | .. autofunction:: to_tuple 41 | 42 | objax.util.image 43 | ---------------- 44 | 45 | .. currentmodule:: objax.util.image 46 | 47 | .. autosummary:: 48 | 49 | nchw 50 | nhwc 51 | normalize_to_uint8 52 | normalize_to_unit_float 53 | to_png 54 | 55 | .. automodule:: objax.util.image 56 | :members: 57 | -------------------------------------------------------------------------------- /docs/source/objax/zoo.rst: -------------------------------------------------------------------------------- 1 | objax.zoo package 2 | ================= 3 | 4 | .. currentmodule:: objax.zoo 5 | 6 | objax.zoo.convnet 7 | ----------------- 8 | 9 | .. currentmodule:: objax.zoo.convnet 10 | 11 | .. autoclass:: ConvNet 12 | :members: 13 | 14 | objax.zoo.dnnet 15 | --------------- 16 | 17 | .. currentmodule:: objax.zoo.dnnet 18 | 19 | .. autoclass:: DNNet 20 | :members: 21 | 22 | objax.zoo.resnet_v2 23 | ------------------- 24 | 25 | .. currentmodule:: objax.zoo.resnet_v2 26 | 27 | .. autoclass:: ResNetV2 28 | :members: 29 | 30 | .. autoclass:: ResNet18 31 | :members: 32 | 33 | .. autoclass:: ResNet34 34 | :members: 35 | 36 | .. autoclass:: ResNet50 37 | :members: 38 | 39 | .. autoclass:: ResNet101 40 | :members: 41 | 42 | .. autoclass:: ResNet152 43 | :members: 44 | 45 | .. autoclass:: ResNet200 46 | :members: 47 | 48 | objax.zoo.wide_resnet 49 | --------------------- 50 | 51 | .. currentmodule:: objax.zoo.wide_resnet 52 | 53 | .. autoclass:: WRNBlock 54 | :members: 55 | 56 | .. autoclass:: WideResNetGeneral 57 | :members: 58 | 59 | .. autoclass:: WideResNet 60 | :members: 61 | 62 | objax.zoo.rnn 63 | ------------- 64 | 65 | .. currentmodule:: objax.zoo.rnn 66 | 67 | .. autoclass:: RNN 68 | :members: 69 | 70 | objax.zoo.vgg 71 | ------------- 72 | 73 | .. currentmodule:: objax.zoo.vgg 74 | 75 | .. autoclass:: VGG19 76 | :members: 77 | -------------------------------------------------------------------------------- /docs/source/tutorials.rst: -------------------------------------------------------------------------------- 1 | Additional tutorials 2 | ==================== 3 | 4 | This section includes various tutorials for Objax. 5 | 6 | * `MNIST Tutorial `_. 7 | Open directly in `colab `_. 8 | * `CIFAR10 Tutorial `_. 9 | Open directly in `colab `_. 10 | * `Metric learning for image similarity search `_. 11 | Open directly in `colab `_. 12 | * `Conversion of Objax models Tensorflow `_. 13 | Open directly in `colab `_. -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | [home](../README.md) > examples 2 | 3 | # Examples 4 | 5 | This directory contains fully functional code examples that you can use to learn more about Objax. 6 | 7 | Examples from classic machine learning: 8 | * [Image Classification](image_classification/README.md) 9 | * [Text Generation](text_generation/README.md) 10 | 11 | Examples from recent research: 12 | * [Model-Agnostic Meta-Learning](maml/README.md) 13 | * [FixMatch](fixmatch/README.md) 14 | * [GPT-2](gpt-2/README.md) 15 | 16 | Other examples: 17 | * [Tutorials](tutorials/README.md) 18 | * [JaxBoard](jaxboard/README.md) -------------------------------------------------------------------------------- /examples/fixmatch/README.md: -------------------------------------------------------------------------------- 1 | [home](../../README.md) > [examples](../README.md) > fixmatch 2 | 3 | # Semi-Supervised Image Classification with [FixMatch](https://arxiv.org/abs/2001.07685) 4 | 5 | ## Setup 6 | 7 | ### Required environment variables 8 | 9 | ```bash 10 | export PYTHONPATH=$PYTHONPATH:. 11 | export ML_DATA="path to where you want the datasets saved" 12 | export PROJECT="ObjaxSSL" 13 | export SSL_PATH=examples/fixmatch 14 | ``` 15 | 16 | ## Data preparation 17 | 18 | ```bash 19 | # Download datasets 20 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_datasets.py 21 | cp $ML_DATA/$PROJECT/svhn-test.tfrecord $ML_DATA/$PROJECT/svhnx-test.tfrecord 22 | 23 | # Create unlabeled datasets 24 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/cifar10 $ML_DATA/$PROJECT/cifar10-train.tfrecord & 25 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/cifar100 $ML_DATA/$PROJECT/cifar100-train.tfrecord & 26 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/stl10 $ML_DATA/$PROJECT/stl10-train.tfrecord $ML_DATA/$PROJECT/stl10-unlabeled.tfrecord & 27 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/svhn $ML_DATA/$PROJECT/svhn-train.tfrecord & 28 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/svhnx $ML_DATA/$PROJECT/svhn-train.tfrecord $ML_DATA/$PROJECT/svhn-extra.tfrecord & 29 | wait 30 | 31 | # Create semi-supervised subsets 32 | for seed in 0 1 2 3 4 5; do 33 | for size in 40 100 250 1000 4000; do 34 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/cifar10 $ML_DATA/$PROJECT/cifar10-train.tfrecord & 35 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/svhn $ML_DATA/$PROJECT/svhn-train.tfrecord & 36 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/svhnx $ML_DATA/$PROJECT/svhn-train.tfrecord $ML_DATA/$PROJECT/svhn-extra.tfrecord & 37 | done 38 | for size in 400 1000 2500 10000; do 39 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/cifar100 $ML_DATA/$PROJECT/cifar100-train.tfrecord & 40 | done 41 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=1000 $ML_DATA/$PROJECT/SSL/stl10 $ML_DATA/$PROJECT/stl10-train.tfrecord $ML_DATA/$PROJECT/stl10-unlabeled.tfrecord & 42 | wait 43 | done 44 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=1 --size=5000 $ML_DATA/$PROJECT/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord 45 | ``` 46 | 47 | ## Training 48 | 49 | ```bash 50 | # FixMatch 51 | python $SSL_PATH/fixmatch.py --dataset=cifar10.3@250-0 --unlabeled=cifar10 --uratio=5 --augment='CTA(sm,sm,sm)' 52 | ``` 53 | 54 | ## Tensorboard 55 | 56 | ```bash 57 | tensorboard --port 6006 --logdir_spec=experiments 58 | ``` 59 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/augment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/augment/core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Augmentations for images. 16 | """ 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def cutout(x, w): 22 | offsets = tf.random.uniform([2], 0, 1) 23 | s = tf.shape(x) 24 | y0 = tf.cast(tf.round(offsets[0] * (tf.cast(s[0], tf.float32) - w)), tf.int32) 25 | x0 = tf.cast(tf.round(offsets[1] * (tf.cast(s[1], tf.float32) - w)), tf.int32) 26 | hr, wr = tf.range(s[0])[:, None, None], tf.range(s[1])[None, :, None] 27 | mask = 1-tf.cast((hr >= y0) & (hr < y0 + w) & (wr >= x0) & (wr < x0 + w), tf.float32) 28 | return mask * x 29 | 30 | 31 | def mirror(x): 32 | return tf.image.random_flip_left_right(x) 33 | 34 | 35 | def shift(x, w): 36 | y = tf.pad(x, [[w] * 2, [w] * 2, [0] * 2], mode='REFLECT') 37 | return tf.image.random_crop(y, tf.shape(x)) 38 | 39 | 40 | def noise(x, std): 41 | return x + std * tf.random.normal(tf.shape(x), dtype=x.dtype) 42 | 43 | 44 | def get_tf_augment(augment, size=32): 45 | aug = dict( 46 | x=lambda **kw: kw, 47 | s=lambda image, **kw: dict(image=shift(image, size >> 3), **kw), 48 | sc=lambda image, **kw: dict(image=cutout(shift(image, size >> 3), size >> 1), **kw), 49 | sm=lambda image, **kw: dict(image=mirror(shift(image, size >> 3)), **kw), 50 | smc=lambda image, **kw: dict(image=cutout(mirror(shift(image, size >> 3)), size >> 1), **kw)) 51 | return lambda x: aug[augment](**x) 52 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/augment/randaugment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .randaugment import RandAugment 16 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/data/ssl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from typing import Callable, List 17 | 18 | from absl import flags 19 | 20 | from examples.fixmatch.libml.data import core 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | 25 | class DataSetsUnlabeled: 26 | def __init__(self, name: str, train: core.DataSet): 27 | self.name = name 28 | self.train = train 29 | 30 | @property 31 | def colors(self): 32 | return self.train.image_shape[2] 33 | 34 | @property 35 | def height(self): 36 | return self.train.image_shape[0] 37 | 38 | @property 39 | def width(self): 40 | return self.train.image_shape[1] 41 | 42 | @classmethod 43 | def creator(cls, name: str, train_files: List[str], parse_fn: Callable = core.record_parse, 44 | height: int = 32, width: int = 32, colors: int = 3, cache: bool = False): 45 | train_files = [os.path.join(core.DATA_DIR, x) for x in train_files] 46 | 47 | def create(): 48 | image_shape = height, width, colors 49 | kw = dict(parse_fn=parse_fn) 50 | train = core.DataSet.from_files(train_files, image_shape, **kw) 51 | if cache: 52 | train = train.cache() 53 | return cls(name, train) 54 | 55 | return name, create 56 | 57 | 58 | def create_datasets(): 59 | d = {} 60 | d.update([DataSetsUnlabeled.creator('mnist', ['mnist-train.tfrecord'], cache=True, 61 | parse_fn=core.record_parse_mnist)]) 62 | d.update([DataSetsUnlabeled.creator('cifar10', ['cifar10-train.tfrecord'], cache=True)]) 63 | d.update([DataSetsUnlabeled.creator('cifar100', ['cifar100-train.tfrecord'], cache=True)]) 64 | d.update([DataSetsUnlabeled.creator('svhn', ['SSL/svhn-unlabel.tfrecord'])]) 65 | d.update([DataSetsUnlabeled.creator('svhnx', ['SSL/svhnx-unlabel.tfrecord'])]) 66 | d.update([DataSetsUnlabeled.creator('stl10', ['SSL/stl10-unlabel.tfrecord'], height=96, width=96)]) 67 | return d 68 | 69 | 70 | DATASETS_UNLABELED = create_datasets 71 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from examples.fixmatch.libml.zoo.convnet import ConvNet 16 | from examples.fixmatch.libml.zoo.resnet import ResNet 17 | 18 | ARCHS = 'convnet resnet'.split() 19 | 20 | 21 | def network(arch: str): 22 | if arch == 'convnet': 23 | return ConvNet 24 | elif arch == 'resnet': 25 | return ResNet 26 | raise ValueError('Architecture not recognized', arch) 27 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tensorflow as tf 16 | 17 | 18 | def setup_tf(): 19 | tf.config.experimental.set_visible_devices([], "GPU") 20 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/zoo/convnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | 17 | import jax 18 | 19 | import objax 20 | from objax.typing import JaxArray 21 | 22 | 23 | class ConvNet(objax.nn.Sequential): 24 | @staticmethod 25 | def _mean_reduce(x: JaxArray) -> JaxArray: 26 | return x.mean((2, 3)) 27 | 28 | def __init__(self, nin, nclass, scales, filters, filters_max, **kwargs): 29 | del kwargs 30 | 31 | def nf(scale): 32 | return min(filters_max, filters << scale) 33 | 34 | ops = [objax.nn.Conv2D(nin, nf(0), 3), objax.functional.leaky_relu] 35 | for i in range(scales): 36 | ops.extend([objax.nn.Conv2D(nf(i), nf(i), 3), objax.functional.leaky_relu, 37 | objax.nn.Conv2D(nf(i), nf(i + 1), 3), objax.functional.leaky_relu, 38 | functools.partial(objax.functional.average_pool_2d, size=2, strides=2)]) 39 | ops.extend([objax.nn.Conv2D(nf(scales), nclass, 3), self._mean_reduce]) 40 | super().__init__(ops) 41 | -------------------------------------------------------------------------------- /examples/fixmatch/libml/zoo/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['ResNetBlock', 'ResNet'] 16 | 17 | import functools 18 | from typing import Callable 19 | 20 | import jax 21 | 22 | import objax 23 | from objax.typing import JaxArray 24 | 25 | 26 | def leaky_relu(x): 27 | return objax.functional.leaky_relu(x, 0.1) 28 | 29 | 30 | def conv_args(k, f): 31 | return dict(w_init=functools.partial(objax.random.normal, stddev=objax.functional.rsqrt(0.5 * k * k * f))) 32 | 33 | 34 | class ResNetBlock(objax.Module): 35 | def __init__(self, nin: int, nout: int, stride: int = 1, activate_before_residual: bool = False, 36 | bn: Callable = objax.nn.BatchNorm2D): 37 | self.activate_before_residual = activate_before_residual 38 | self.bn = bn(nin, momentum=0.999) 39 | self.residual = objax.nn.Sequential([objax.nn.Conv2D(nin, nout, 3, strides=stride, **conv_args(3, nout)), 40 | bn(nout, momentum=0.999), leaky_relu, 41 | objax.nn.Conv2D(nout, nout, 3, **conv_args(3, nout))]) 42 | self.passthrough = objax.nn.Conv2D(nin, nout, 1, strides=stride, **conv_args(1, nout)) if nin != nout else None 43 | 44 | def __call__(self, x: JaxArray, training: bool) -> JaxArray: 45 | y = leaky_relu(self.bn(x, training)) 46 | if self.activate_before_residual: 47 | x = y 48 | if self.passthrough: 49 | x = self.passthrough(x) 50 | return x + self.residual(y, training=training) 51 | 52 | 53 | class ResNet(objax.nn.Sequential): 54 | @staticmethod 55 | def mean_reduce(x: JaxArray) -> JaxArray: 56 | return x.mean((2, 3)) 57 | 58 | def __init__(self, nin: int, nclass: int, scales: int, filters: int, repeat: int, dropout: int = 0, 59 | bn: Callable = objax.nn.BatchNorm2D, **kwargs): 60 | del kwargs 61 | n = 16 62 | ops = [objax.nn.Conv2D(nin, n, 3, **conv_args(3, n))] 63 | for scale in range(scales): 64 | last_n, n = n, filters << scale 65 | ops.append(ResNetBlock(last_n, n, stride=2 if scale else 1, activate_before_residual=scale == 0, bn=bn)) 66 | ops.extend([ResNetBlock(n, n, bn=bn) for _ in range(repeat - 1)]) 67 | ops.extend([bn(n, momentum=0.999), leaky_relu, self.mean_reduce, 68 | objax.nn.Dropout(1 - dropout), 69 | objax.nn.Linear(n, nclass, w_init=objax.nn.init.xavier_truncated_normal)]) 70 | super().__init__(ops) 71 | -------------------------------------------------------------------------------- /examples/fixmatch/scripts/create_split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2020 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Script to create SSL splits from a dataset. 18 | """ 19 | 20 | import json 21 | import os 22 | from collections import defaultdict 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | from absl import app, flags 27 | from tqdm import trange, tqdm 28 | 29 | from examples.fixmatch.libml.data import core 30 | 31 | flags.DEFINE_integer('seed', 0, 'Random seed to use, 0 for no shuffling.') 32 | flags.DEFINE_integer('size', 0, 'Size of labelled set.') 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | 37 | def get_class(serialized_example): 38 | return tf.io.parse_single_example(serialized_example, 39 | features={'label': tf.io.FixedLenFeature([], tf.int64)})['label'] 40 | 41 | 42 | def main(argv): 43 | assert FLAGS.size 44 | argv.pop(0) 45 | if any(not tf.io.gfile.exists(f) for f in argv[1:]): 46 | raise FileNotFoundError(argv[1:]) 47 | target = '%s.%d@%d' % (argv[0], FLAGS.seed, FLAGS.size) 48 | if tf.io.gfile.exists(target): 49 | raise FileExistsError('For safety overwriting is not allowed', target) 50 | input_files = argv[1:] 51 | count = 0 52 | id_class = [] 53 | class_id = defaultdict(list) 54 | print('Computing class distribution') 55 | dataset = tf.data.TFRecordDataset(input_files).map(get_class, 4).batch(1 << 10) 56 | for it in dataset: 57 | with tqdm(leave=False) as t: 58 | for i in it: 59 | id_class.append(i.numpy()) 60 | class_id[i.numpy()].append(count) 61 | count += 1 62 | t.update(it.shape[0]) 63 | print('%d records found' % count) 64 | nclass = len(class_id) 65 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1) 66 | train_stats = np.array([len(class_id[i]) for i in range(nclass)], np.float64) 67 | train_stats /= train_stats.max() 68 | if 'stl10' in argv[1]: 69 | # All of the unlabeled data is given label 0, but we know that 70 | # STL has equally distributed data among the 10 classes. 71 | train_stats[:] = 1 72 | 73 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in train_stats])) 74 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1) 75 | class_id = [np.array(class_id[i], dtype=np.int64) for i in range(nclass)] 76 | if FLAGS.seed: 77 | np.random.seed(FLAGS.seed) 78 | for i in range(nclass): 79 | np.random.shuffle(class_id[i]) 80 | 81 | # Distribute labels to match the input distribution. 82 | npos = np.zeros(nclass, np.int64) 83 | label = [] 84 | for i in range(FLAGS.size): 85 | c = np.argmax(train_stats - npos / max(npos.max(), 1)) 86 | label.append(class_id[c][npos[c]]) 87 | npos[c] += 1 88 | 89 | del npos, class_id 90 | label = frozenset([int(x) for x in label]) 91 | if 'stl10' in argv[1] and FLAGS.size == 1000: 92 | data = tf.io.gfile.GFile(os.path.join(core.DATA_DIR, 'stl10_fold_indices.txt'), 'r').read() 93 | label = frozenset(list(map(int, data.split('\n')[FLAGS.seed].split()))) 94 | 95 | print('Creating split in %s' % target) 96 | tf.io.gfile.makedirs(os.path.dirname(target)) 97 | with tf.io.TFRecordWriter(target + '-label.tfrecord') as writer_label: 98 | pos, loop = 0, trange(count, desc='Writing records') 99 | for input_file in input_files: 100 | for record in tf.compat.v1.python_io.tf_record_iterator(input_file): 101 | if pos in label: 102 | writer_label.write(record) 103 | pos += 1 104 | loop.update() 105 | loop.close() 106 | with tf.io.gfile.GFile(target + '-label.json', 'w') as writer: 107 | writer.write(json.dumps(dict(distribution=train_stats.tolist(), label=sorted(label)), indent=2, sort_keys=True)) 108 | 109 | 110 | if __name__ == '__main__': 111 | app.run(main) 112 | -------------------------------------------------------------------------------- /examples/fixmatch/scripts/create_unlabeled.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2020 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Script to create SSL splits from a dataset. 18 | """ 19 | 20 | import json 21 | import os 22 | from collections import defaultdict 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | from absl import app 27 | from tqdm import trange, tqdm 28 | 29 | 30 | def get_class(serialized_example): 31 | return tf.io.parse_single_example(serialized_example, 32 | features={'label': tf.io.FixedLenFeature([], tf.int64)})['label'] 33 | 34 | 35 | def main(argv): 36 | argv.pop(0) 37 | if any(not tf.io.gfile.exists(f) for f in argv[1:]): 38 | raise FileNotFoundError(argv[1:]) 39 | target = argv[0] 40 | input_files = argv[1:] 41 | count = 0 42 | id_class = [] 43 | class_id = defaultdict(list) 44 | print('Computing class distribution') 45 | dataset = tf.data.TFRecordDataset(input_files).map(get_class, 4).batch(1 << 10) 46 | for it in dataset: 47 | with tqdm(leave=False) as t: 48 | for i in it: 49 | id_class.append(i.numpy()) 50 | class_id[i.numpy()].append(count) 51 | count += 1 52 | t.update(it.shape[0]) 53 | print('%d records found' % count) 54 | nclass = len(class_id) 55 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1) 56 | train_stats = np.array([len(class_id[i]) for i in range(nclass)], np.float64) 57 | train_stats /= train_stats.max() 58 | if 'stl10' in argv[1]: 59 | # All of the unlabeled data is given label 0, but we know that 60 | # STL has equally distributed data among the 10 classes. 61 | train_stats[:] = 1 62 | 63 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in train_stats])) 64 | del class_id 65 | 66 | print('Creating unlabeled dataset for in %s' % target) 67 | npos = np.zeros(nclass, np.int64) 68 | class_data = [[] for _ in range(nclass)] 69 | unlabel = [] 70 | tf.io.gfile.makedirs(os.path.dirname(target)) 71 | with tf.io.TFRecordWriter(target + '-unlabel.tfrecord') as writer_unlabel: 72 | pos, loop = 0, trange(count, desc='Writing records') 73 | for input_file in input_files: 74 | for record in tf.compat.v1.python_io.tf_record_iterator(input_file): 75 | class_data[id_class[pos]].append((pos, record)) 76 | while True: 77 | c = np.argmax(train_stats - npos / max(npos.max(), 1)) 78 | if class_data[c]: 79 | p, v = class_data[c].pop(0) 80 | unlabel.append(p) 81 | writer_unlabel.write(v) 82 | npos[c] += 1 83 | else: 84 | break 85 | pos += 1 86 | loop.update() 87 | for remain in class_data: 88 | for p, v in remain: 89 | unlabel.append(p) 90 | writer_unlabel.write(v) 91 | loop.close() 92 | with tf.io.gfile.GFile(target + '-unlabel.json', 'w') as writer: 93 | writer.write(json.dumps(dict(distribution=train_stats.tolist(), indexes=unlabel), indent=2, sort_keys=True)) 94 | 95 | 96 | if __name__ == '__main__': 97 | app.run(main) 98 | -------------------------------------------------------------------------------- /examples/fixmatch/scripts/extract_accuracy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2020 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Extract and save accuracy to 'stats/accuracy.json'. 18 | 19 | The accuracy is extracted from the most recent eventfile. 20 | """ 21 | 22 | import json 23 | import os 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | from absl import app, flags 28 | 29 | FLAGS = flags.FLAGS 30 | TAG = 'accuracy/test' 31 | 32 | 33 | def summary_dict(accuracies): 34 | return {'last%02d' % x: np.median(accuracies[-x:]) for x in [1, 10, 20, 50]} 35 | 36 | 37 | def main(argv): 38 | if len(argv) > 2: 39 | raise app.UsageError('Too many command-line arguments.') 40 | folder = argv[1] 41 | matches = sorted(tf.io.gfile.glob(os.path.join(folder, 'tb/events.out.tfevents.*'))) 42 | assert matches, 'No events files found' 43 | tags = set() 44 | accuracies = [] 45 | for event_file in matches: 46 | try: 47 | for e in tf.compat.v1.train.summary_iterator(event_file): 48 | for v in e.summary.value: 49 | if v.tag == TAG: 50 | accuracies.append(v.simple_value) 51 | break 52 | elif not accuracies: 53 | tags.add(v.tag) 54 | except tf.errors.DataLossError: 55 | continue 56 | 57 | assert accuracies, 'No "%s" tag found. Found tags = %s' % (TAG, tags) 58 | target_dir = os.path.join(folder, 'stats') 59 | target_file = os.path.join(target_dir, 'accuracy.json') 60 | tf.io.gfile.makedirs(target_dir) 61 | 62 | with tf.io.gfile.GFile(target_file, 'w') as f: 63 | json.dump(summary_dict(accuracies), f, sort_keys=True, indent=4) 64 | print('Saved: %s' % target_file) 65 | 66 | 67 | if __name__ == '__main__': 68 | app.run(main) 69 | -------------------------------------------------------------------------------- /examples/gpt-2/README.md: -------------------------------------------------------------------------------- 1 | # GPT-2 Example usage 2 | 3 | ## Setup 4 | 5 | ```bash 6 | cd examples/gpt-2 7 | 8 | # Install gpt-2 dependencies 9 | pip3 install --upgrade regex 10 | 11 | # Clone the OpenAI GPT-2 repository 12 | git clone https://github.com/openai/gpt-2.git 13 | 14 | # Download model weights 15 | cd gpt-2 16 | python3 download_model.py 124M 17 | cd .. 18 | ``` 19 | 20 | ## Running 21 | 22 | ```bash 23 | python3 gpt2.py 24 | ``` 25 | 26 | You should see something like this: 27 | 28 | > The definition of unicorn is a creature that is a unicorn, but not a 29 | > 30 | > All Together Now (all together now!) 31 | > 32 | > The definition of unicorn is a creature that is a unicorn, but not a 33 | -------------------------------------------------------------------------------- /examples/image_classification/README.md: -------------------------------------------------------------------------------- 1 | [home](../../README.md) > [examples](../README.md) > image_classification 2 | 3 | # Image Classification Examples 4 | 5 | This directory contains various classification examples on image datasets: 6 | 7 | * [MNIST](http://yann.lecun.com/exdb/mnist/): 8 | 9 | * `mnist_dnn.py` - simple MNIST classification example. 10 | *Note*: The purpose of the example on MNIST is to demonstrate the use of a deep 11 | neural network for classification. As such, the network does not achieve State 12 | of the Art (SOTA) classification accurary. A Convolutional Neural Network (CNN) 13 | should be used for that purpose. 14 | 15 | * `mnist_cnn.py` - a CNN-based MNIST classification example. 16 | 17 | * `mnist_dp.py` - MNIST example with differential privacy. 18 | 19 | * [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) 20 | 21 | * `cifar10_simple.py` - very simple CIFAR10 classification example which 22 | demonstrated how to write basic training loop with data augmentation 23 | 24 | * `cifar10_advanced.py` - more advanced CIFAR10 example which allows user to configure 25 | neural network architecture and other hyperparameters. It also supports training on multiple 26 | GPUs using `objax.Parallel`. 27 | 28 | * [Imagenet](http://www.image-net.org/challenges/LSVRC/2012/) 29 | 30 | * `imagenet_pretrained_vgg.py` - example which shows how to load pre-trained weights for a VGG model and use it 31 | to classify input images. For more details see [documentation](imagenet_pretrained_vgg.md). 32 | 33 | * `imagenet_resnet50_train.py` - example which shows how to train Resnet50 model on Imagenet. 34 | For more details see example [documentation](imagenet_resnet50.md). 35 | 36 | * [Horses or Humans](https://www.kaggle.com/sanikamal/horses-or-humans-dataset) 37 | 38 | * `horses_or_humans_logistic.py` - simple example using logistic regression. -------------------------------------------------------------------------------- /examples/image_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/image_classification/cifar10_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | import objax 21 | from objax.zoo.wide_resnet import WideResNet 22 | 23 | # Data 24 | (X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data() 25 | X_train = X_train.transpose(0, 3, 1, 2) / 255.0 26 | X_test = X_test.transpose(0, 3, 1, 2) / 255.0 27 | 28 | # Model 29 | model = WideResNet(nin=3, nclass=10, depth=28, width=2) 30 | opt = objax.optimizer.Adam(model.vars()) 31 | 32 | 33 | # Losses 34 | @objax.Function.with_vars(model.vars()) 35 | def loss(x, label): 36 | logit = model(x, training=True) 37 | return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean() 38 | 39 | 40 | gv = objax.GradValues(loss, model.vars()) 41 | 42 | 43 | @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars()) 44 | def train_op(x, y, lr): 45 | g, v = gv(x, y) 46 | opt(lr=lr, grads=g) 47 | return v 48 | 49 | 50 | train_op = objax.Jit(train_op) 51 | predict = objax.Jit(objax.nn.Sequential([ 52 | objax.ForceArgs(model, training=False), objax.functional.softmax 53 | ])) 54 | 55 | 56 | def augment(x): 57 | if random.random() < .5: 58 | x = x[:, :, :, ::-1] # Flip the batch images about the horizontal axis 59 | # Pixel-shift all images in the batch by up to 4 pixels in any direction. 60 | x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect') 61 | rx, ry = np.random.randint(0, 8), np.random.randint(0, 8) 62 | x = x_pad[:, :, rx:rx + 32, ry:ry + 32] 63 | return x 64 | 65 | 66 | # Training 67 | print(model.vars()) 68 | for epoch in range(30): 69 | # Train 70 | loss = [] 71 | sel = np.arange(len(X_train)) 72 | np.random.shuffle(sel) 73 | for it in range(0, X_train.shape[0], 64): 74 | loss.append(train_op(augment(X_train[sel[it:it + 64]]), Y_train[sel[it:it + 64]].flatten(), 75 | 4e-3 if epoch < 20 else 4e-4)) 76 | 77 | # Eval 78 | test_predictions = [predict(x_batch).argmax(1) for x_batch in X_test.reshape((50, -1) + X_test.shape[1:])] 79 | accuracy = np.array(test_predictions).flatten() == Y_test.flatten() 80 | print(f'Epoch {epoch + 1:4d} Loss {np.mean(loss):.2f} Accuracy {100 * np.mean(accuracy):.2f}') 81 | -------------------------------------------------------------------------------- /examples/image_classification/horses_or_humans_logistic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import numpy as np 18 | import tensorflow_datasets as tfds 19 | 20 | import objax 21 | from objax.util import EasyDict 22 | 23 | # Data: train has 1027 images - test has 256 images 24 | # Each image is 300 x 300 x 3 bytes 25 | DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS') 26 | data = tfds.as_numpy(tfds.load(name='horses_or_humans', batch_size=-1, data_dir=DATA_DIR)) 27 | 28 | 29 | def prepare(x, downscale=3): 30 | """Normalize images to [-1, 1] and downscale them to 100x100x3 (for faster training) and flatten them.""" 31 | s = x.shape 32 | x = x.astype('f').reshape((s[0], s[1] // downscale, downscale, s[2] // downscale, downscale, s[3])) 33 | return x.mean((2, 4)).reshape((s[0], -1)) * (1 / 127.5) - 1 34 | 35 | 36 | train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label']) 37 | test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label']) 38 | ndim = train.image.shape[-1] 39 | del data 40 | 41 | # Settings 42 | lr = 0.0001 # learning rate 43 | batch = 256 44 | epochs = 20 45 | 46 | # Model 47 | model = objax.nn.Linear(ndim, 1) 48 | opt = objax.optimizer.SGD(model.vars()) 49 | print(model.vars()) 50 | 51 | 52 | # Cross Entropy Loss 53 | @objax.Function.with_vars(model.vars()) 54 | def loss(x, label): 55 | return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean() 56 | 57 | 58 | gv = objax.GradValues(loss, model.vars()) 59 | 60 | 61 | @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars()) 62 | def train_op(x, label): 63 | g, v = gv(x, label) # returns gradients, loss 64 | opt(lr, g) 65 | return v 66 | 67 | 68 | # This line is optional: it is compiling the code to make it faster. 69 | train_op = objax.Jit(train_op) 70 | 71 | # Training 72 | for epoch in range(epochs): 73 | # Train 74 | avg_loss = 0 75 | for it in range(0, train.image.shape[0], batch): 76 | sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0]) 77 | avg_loss += float(train_op(train.image[sel], train.label[sel])[0]) * batch 78 | avg_loss /= it + batch 79 | 80 | # Eval 81 | accuracy = 0 82 | for it in range(0, test.image.shape[0], batch): 83 | x, y = test.image[it: it + batch], test.label[it: it + batch] 84 | accuracy += (np.round(objax.functional.sigmoid(model(x)))[:, 0] == y).sum() 85 | accuracy /= test.image.shape[0] 86 | print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy)) 87 | -------------------------------------------------------------------------------- /examples/image_classification/imagenet_pretrained_vgg.md: -------------------------------------------------------------------------------- 1 | # Image Classification with Pretrained VGG model 2 | 3 | This [example](pretrained_vgg.py) demonstrates how to run image classification with 4 | [VGG-19](https://www.robots.ox.ac.uk/~vgg/publications/2015/Simonyan15/simonyan15.pdf) model using 5 | weights pretrained on [ImageNet dataset](http://www.image-net.org/). 6 | 7 | ## Getting weights of VGG-19 pretrained model 8 | 9 | Please download the weights of VGG-19 pretrained model from this 10 | [link](https://mega.nz/file/xZ8glS6J#MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs) and copy to 11 | `./objax/zoo/pretrained/vgg19.npy`. 12 | 13 | ## Classifying images 14 | 15 | This [example](pretrained_vgg.py) shows how to classifying an image downloaded from the internet. 16 | You can set an `IMAGE_PATH` to classify your own image. 17 | -------------------------------------------------------------------------------- /examples/image_classification/imagenet_pretrained_vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from urllib import request 17 | 18 | import jax.numpy as jn 19 | import numpy as np 20 | from PIL import Image 21 | 22 | import objax 23 | from objax.zoo import vgg 24 | 25 | IMAGE_URL = 'https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg' 26 | IMAGE_PATH = './examples/classify/img/misc/001.jpg' 27 | SYNSET_PATH = './objax/zoo/pretrained/synset.txt' 28 | 29 | # Load input image. 30 | if not os.path.exists(os.path.dirname(IMAGE_PATH)): 31 | os.makedirs(os.path.dirname(IMAGE_PATH)) 32 | request.urlretrieve(IMAGE_URL, IMAGE_PATH) 33 | img = Image.open(IMAGE_PATH) 34 | img = np.array(img.resize((224, 224))).astype(np.float32) 35 | img = jn.array(img).transpose((2, 0, 1))[None,] 36 | 37 | # Load model with pretrained weights and make a prediction. 38 | model = vgg.VGG19(pretrained=True) 39 | logit = model(img) 40 | prob = objax.functional.softmax(logit)[0] 41 | 42 | # Present prediction output. 43 | synset = [l.strip() for l in open(SYNSET_PATH).readlines()] 44 | pred = jn.argsort(prob)[::-1][:5] 45 | for i in range(5): 46 | print('Top {:d} (prob {:.3f}) {}'.format(i + 1, prob[pred[i]], synset[pred[i]])) 47 | -------------------------------------------------------------------------------- /examples/image_classification/imagenet_resnet50.md: -------------------------------------------------------------------------------- 1 | # Example of training and evaluation of ResNet50 on Imagenet 2 | 3 | This example trains a ResNet50 model on the ImageNet2012 dataset. 4 | 5 | ## Getting data 6 | 7 | You have to obtain the Imagenet dataset to train the model. 8 | 9 | Internally this code uses [TFDS](https://github.com/tensorflow/datasets) which will show download instructions on the first run. 10 | Run `python examples/image_classification/imagenet_resnet50_train.py` and you will see download instructions, similar to the following: 11 | 12 | ``` 13 | AssertionError: Manual directory /home/${USER}/tensorflow_datasets/downloads/manual does not exist or is empty. Create it and download/extract dataset artifacts in there. Additional instructions: manual_dir should contain two files: ILSVRC2012_img_train.tar and 14 | ILSVRC2012_img_val.tar. 15 | ``` 16 | 17 | You have to download data from http://www.image-net.org/download-images and then put it into 18 | the directory mentioned in the message. 19 | On the next run, run `imagenet_resnet50_train.py` which will process the data and rearrange it inside the data directory which might take a while. 20 | Subsequent runs will re-use the already downloaded data. 21 | 22 | You can override TFDS data directory by providing the `--tfds_data_dir` flag. This might be useful if you don't have enough disk space in the default location or already have a copy of Imagenet data somewhere else. 23 | 24 | ## Training the model 25 | 26 | Use the following command to train: 27 | 28 | ``` 29 | python examples/classify/img/imagenet/imagenet_train.py \ 30 | --model_dir="${HOME}/experiments/resnet50" 31 | ``` 32 | 33 | Some additional useful flags include the following: 34 | 35 | * `--train_device_batch_size` controls per-device training batch size. You may need to adjust it if you don't have enough GPU memory. 36 | * `--eval_device_batch_size` controls per-device evaluation batch size. You may need to adjust it if you don't have enough GPU memory. 37 | * `--eval_every_n_steps` controls the number of training steps between evaluation and checkpointing. 38 | * `--tfds_data_dir` overrides the directory where TFDS looks for datasets. 39 | -------------------------------------------------------------------------------- /examples/image_classification/mnist_cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import numpy as np 18 | import tensorflow_datasets as tfds 19 | from tqdm import trange 20 | 21 | import objax 22 | from objax.util import EasyDict 23 | 24 | 25 | def simple_net_block(nin, nout): 26 | return objax.nn.Sequential([ 27 | objax.nn.Conv2D(nin, nout, k=3), objax.functional.leaky_relu, 28 | objax.functional.max_pool_2d, 29 | objax.nn.Conv2D(nout, nout, k=3), objax.functional.leaky_relu, 30 | ]) 31 | 32 | 33 | class SimpleNet(objax.Module): 34 | def __init__(self, nclass, colors, n): 35 | self.pre_conv = objax.nn.Sequential([objax.nn.Conv2D(colors, n, k=3), objax.functional.leaky_relu]) 36 | self.block1 = simple_net_block(1 * n, 2 * n) 37 | self.block2 = simple_net_block(2 * n, 4 * n) 38 | self.post_conv = objax.nn.Conv2D(4 * n, nclass, k=3) 39 | 40 | def __call__(self, x, training=False): # x = (batch, colors, height, width) 41 | y = self.pre_conv(x) 42 | y = self.block1(y) 43 | y = self.block2(y) 44 | logits = self.post_conv(y).mean((2, 3)) # logits = (batch, nclass) 45 | if training: 46 | return logits 47 | return objax.functional.softmax(logits) 48 | 49 | 50 | # Data 51 | DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS') 52 | data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR)) 53 | train = EasyDict(image=data['train']['image'].transpose(0, 3, 1, 2) / 255, label=data['train']['label']) 54 | test = EasyDict(image=data['test']['image'].transpose(0, 3, 1, 2) / 255, label=data['test']['label']) 55 | del data 56 | 57 | 58 | def augment(x, shift=4): # Shift all images in the batch by up to "shift" pixels in any direction. 59 | x_pad = np.pad(x, [[0, 0], [0, 0], [shift, shift], [shift, shift]]) 60 | rx, ry = np.random.randint(0, shift, size=2) 61 | return x_pad[:, :, rx:rx + 28, ry:ry + 28] 62 | 63 | 64 | # Settings 65 | batch = 512 66 | test_batch = 2048 67 | weight_decay = 0.0001 68 | epochs = 40 69 | lr = 0.0004 * (batch / 64) 70 | train_size = train.image.shape[0] 71 | 72 | # Model 73 | model = SimpleNet(nclass=10, colors=1, n=16) # Use higher values of n to get higher accuracy. 74 | model_ema = objax.optimizer.ExponentialMovingAverageModule(model, momentum=0.999, debias=True) 75 | opt = objax.optimizer.Adam(model.vars()) 76 | 77 | 78 | @objax.Function.with_vars(model.vars()) 79 | def loss(x, y): 80 | logits = model(x, training=True) 81 | loss_xe = objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean() 82 | loss_l2 = 0.5 * sum((v.value ** 2).sum() for k, v in model.vars().items() if k.endswith('.w')) 83 | return loss_xe + weight_decay * loss_l2, {'loss/xe': loss_xe, 'loss/l2': loss_l2} 84 | 85 | 86 | gv = objax.GradValues(loss, model.vars()) 87 | 88 | 89 | @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars() + model_ema.vars()) 90 | def train_op(x, y): 91 | g, v = gv(x, y) 92 | opt(lr, g) 93 | model_ema.update_ema() 94 | return v 95 | 96 | 97 | train_op = objax.Jit(train_op) # Compile train_op to make it run faster. 98 | predict = objax.Jit(model_ema) # Compile predict to make it run faster. 99 | 100 | # Training 101 | print(model.vars()) 102 | for epoch in range(epochs): 103 | # Train one epoch 104 | loop = trange(0, train_size, batch, 105 | leave=False, unit='img', unit_scale=batch, 106 | desc='Epoch %d/%d ' % (1 + epoch, epochs)) 107 | for it in loop: 108 | sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0]) 109 | v = train_op(augment(train.image[sel]), train.label[sel]) 110 | 111 | # Eval 112 | accuracy = 0 113 | for it in trange(0, test.image.shape[0], test_batch, leave=False, desc='Evaluating'): 114 | x = test.image[it: it + test_batch] 115 | xl = test.label[it: it + test_batch] 116 | accuracy += (np.argmax(predict(x), axis=1) == xl).sum() 117 | accuracy /= test.image.shape[0] 118 | print(f'Epoch {epoch + 1:04d} Accuracy {100 * accuracy:.2f}') 119 | -------------------------------------------------------------------------------- /examples/image_classification/mnist_dnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # *Note*: The purpose of the example on MNIST is to demonstrate the use of a deep 16 | # neural network for classification. As such, the network does not achieve State 17 | # of the Art (SOTA) classification accurary. A Convolutional Neural Network (CNN) 18 | # should be used for that purpose. 19 | 20 | import os 21 | 22 | import numpy as np 23 | import tensorflow_datasets as tfds 24 | from tqdm import trange 25 | 26 | import objax 27 | from objax.functional import leaky_relu, one_hot 28 | from objax.jaxboard import SummaryWriter, Summary 29 | from objax.util import EasyDict 30 | from objax.zoo.dnnet import DNNet 31 | 32 | # Data 33 | DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS') 34 | data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR)) 35 | train_size = len(data['train']['image']) 36 | test_size = len(data['test']['image']) 37 | train_shape = data['train']['image'].shape 38 | image_size = train_shape[1] * train_shape[2] * train_shape[3] 39 | nclass = len(np.unique(data['train']['label'])) 40 | flat_train_images = np.reshape(data['train']['image'].transpose(0, 3, 1, 2) / 127.5 - 1, 41 | (train_size, image_size)) 42 | flat_test_images = np.reshape(data['test']['image'].transpose(0, 3, 1, 2) / 127.5 - 1, (test_size, image_size)) 43 | test = EasyDict(image=flat_test_images, label=data['test']['label']) 44 | train = EasyDict(image=flat_train_images, label=data['train']['label']) 45 | del data 46 | 47 | # Settings 48 | lr = 0.0002 49 | batch = 64 50 | num_train_epochs = 40 51 | dnn_layer_sizes = image_size, 128, 10 52 | logdir = f'experiments/classify/img/mnist/filters{dnn_layer_sizes}' 53 | 54 | # Model 55 | model = DNNet(dnn_layer_sizes, leaky_relu) 56 | model_ema = objax.optimizer.ExponentialMovingAverageModule(model, momentum=0.999) 57 | opt = objax.optimizer.Adam(model.vars()) 58 | 59 | 60 | @objax.Function.with_vars(model.vars()) 61 | def loss(x, label): 62 | logit = model(x) 63 | return objax.functional.loss.cross_entropy_logits(logit, label).mean() 64 | 65 | 66 | gv = objax.GradValues(loss, model.vars()) 67 | 68 | 69 | @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars() + model_ema.vars()) 70 | def train_op(x, xl): 71 | g, v = gv(x, xl) # returns gradients, loss 72 | opt(lr, g) 73 | model_ema.update_ema() 74 | return v 75 | 76 | 77 | train_op = objax.Jit(train_op) # Compile train_op to make it run faster. 78 | predict = objax.Jit(model_ema) 79 | 80 | # Training 81 | print(model.vars()) 82 | print(f'Visualize results with: tensorboard --logdir "{logdir}"') 83 | print("Disclaimer: This code demonstrates the DNNet class. For SOTA accuracy use a CNN instead.") 84 | with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: 85 | for epoch in range(num_train_epochs): 86 | # Train one epoch 87 | summary = Summary() 88 | loop = trange(0, train_size, batch, 89 | leave=False, unit='img', unit_scale=batch, 90 | desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) 91 | for it in loop: 92 | sel = np.random.randint(size=(batch,), low=0, high=train_size) 93 | x, xl = train.image[sel], train.label[sel] 94 | xl = one_hot(xl, nclass) 95 | v = train_op(x, xl) 96 | summary.scalar('losses/xe', float(v[0])) 97 | 98 | # Eval 99 | accuracy = 0 100 | for it in trange(0, test.image.shape[0], batch, leave=False, desc='Evaluating'): 101 | x = test.image[it: it + batch] 102 | xl = test.label[it: it + batch] 103 | accuracy += (np.argmax(predict(x), axis=1) == xl).sum() 104 | accuracy /= test.image.shape[0] 105 | summary.scalar('eval/accuracy', 100 * accuracy) 106 | print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']())) 107 | 108 | tensorboard.write(summary, step=(epoch + 1) * train_size) 109 | -------------------------------------------------------------------------------- /examples/image_classification/tfdata/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/objax/9dd7dc37e5f9d0ea71896636d3e180440b2b729e/examples/image_classification/tfdata/__init__.py -------------------------------------------------------------------------------- /examples/image_classification/tfdata/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Callable, Optional, Tuple, List 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | 21 | def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]): 22 | features = tf.io.parse_single_example(serialized_example, 23 | features={'image': tf.io.FixedLenFeature([], tf.string), 24 | 'label': tf.io.FixedLenFeature([], tf.int64)}) 25 | image = tf.image.decode_image(features['image']).set_shape(image_shape) 26 | image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0 27 | return dict(image=image, label=features['label']) 28 | 29 | 30 | class DataSet: 31 | """Wrapper for tf.data.Dataset to permit extensions.""" 32 | 33 | def __init__(self, data: tf.data.Dataset, 34 | image_shape: Tuple[int, int, int], 35 | augment_fn: Optional[Callable] = None, 36 | parse_fn: Optional[Callable] = record_parse): 37 | self.data = data 38 | self.parse_fn = parse_fn 39 | self.augment_fn = augment_fn 40 | self.image_shape = image_shape 41 | 42 | @classmethod 43 | def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None): 44 | return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:], 45 | augment_fn=augment_fn, parse_fn=None) 46 | 47 | @classmethod 48 | def from_files(cls, filenames: List[str], 49 | image_shape: Tuple[int, int, int], 50 | augment_fn: Optional[Callable], 51 | parse_fn: Optional[Callable] = record_parse): 52 | filenames_in = filenames 53 | filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], [])) 54 | if not filenames: 55 | raise ValueError('Empty dataset, files not found:', filenames_in) 56 | return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn) 57 | 58 | @classmethod 59 | def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int], 60 | augment_fn: Optional[Callable] = None): 61 | return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])), 62 | image_shape, augment_fn=augment_fn, parse_fn=None) 63 | 64 | def __iter__(self): 65 | return iter(self.data) 66 | 67 | def __getattr__(self, item): 68 | if item in self.__dict__: 69 | return self.__dict__[item] 70 | 71 | def call_and_update(*args, **kwargs): 72 | v = getattr(self.__dict__['data'], item)(*args, **kwargs) 73 | if isinstance(v, tf.data.Dataset): 74 | return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn) 75 | return v 76 | 77 | return call_and_update 78 | 79 | def augment(self, para_augment: int = 4): 80 | if self.augment_fn: 81 | return self.map(self.augment_fn, para_augment) 82 | return self 83 | 84 | def nchw(self): 85 | return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label'])) 86 | 87 | def one_hot(self, nclass: int): 88 | return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass))) 89 | 90 | def parse(self, para_parse: int = 2): 91 | if not self.parse_fn: 92 | return self 93 | if self.image_shape: 94 | return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse) 95 | return self.map(self.parse_fn, para_parse) 96 | -------------------------------------------------------------------------------- /examples/jaxboard/README.md: -------------------------------------------------------------------------------- 1 | [home](../../README.md) > [examples](../README.md) > optimization 2 | 3 | # Saving to tensorboard 4 | 5 | This directory contains examples on how to visualize data in tensorboard with Objax. 6 | 7 | ```bash 8 | python3 examples/jaxboard/summary.py 9 | tensorboard --logdir experiments/summary_test/tb 10 | ``` 11 | -------------------------------------------------------------------------------- /examples/jaxboard/summary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | import objax 18 | 19 | LOGDIR = 'experiments/summary_test/tb' 20 | with objax.jaxboard.SummaryWriter(LOGDIR) as tensorboard: 21 | summary = objax.jaxboard.Summary() 22 | summary.text('text', '
Hello this just text\nand a newline
') 23 | summary.text('html', '' 24 | '' 25 | '
col1col2
row1.1row1.2
row2.1row2.2
') 26 | img = np.zeros((3, 32, 32), 'f') 27 | img[0] += np.linspace(-1, 1, 32) 28 | img[1] += np.linspace(-1, 1, 32)[:, None] 29 | img[2] += np.linspace(-1, 1, 32)[:, None] * np.linspace(-1, 1, 32) 30 | summary.image('image', img) 31 | summary.scalar('avg', 0) 32 | summary.scalar('avg', 1) 33 | summary.scalar('avg', 2) 34 | summary.scalar('avg', 3) 35 | tensorboard.write(summary, step=1) 36 | 37 | summary = objax.jaxboard.Summary() 38 | summary.text('text', '
Hello this just text\nat step 2
') 39 | summary.scalar('avg', 4) 40 | summary.scalar('avg', 7) 41 | tensorboard.write(summary, step=2) 42 | 43 | print(f'Saved to {LOGDIR}') 44 | print(f'Visualize with: tensorboard --logdir "{LOGDIR}"') 45 | -------------------------------------------------------------------------------- /examples/maml/README.md: -------------------------------------------------------------------------------- 1 | [home](../../README.md) > [examples](../README.md) > optimization 2 | 3 | # Optimization 4 | 5 | This directory contains examples on Model-Agnostic Meta-Learning (MAML). 6 | -------------------------------------------------------------------------------- /examples/maml/maml.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | MAML implementation to demonstrate gradient of gradient. 17 | 18 | https://github.com/ericjang/maml-jax/blob/master/maml.ipynb 19 | """ 20 | 21 | import jax.numpy as jn 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | from tqdm import trange 25 | 26 | import objax 27 | 28 | 29 | def sample_tasks(outer_batch_size, inner_batch_size): 30 | # Select amplitude and phase for the task 31 | amplitudes = [] 32 | phases = [] 33 | for _ in range(outer_batch_size): 34 | amplitudes.append(np.random.uniform(low=0.1, high=.5)) 35 | phases.append(np.random.uniform(low=0., high=np.pi)) 36 | 37 | def get_batch(): 38 | xs, ys = [], [] 39 | for amplitude, phase in zip(amplitudes, phases): 40 | x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) 41 | y = amplitude * np.sin(x + phase) 42 | xs.append(x) 43 | ys.append(y) 44 | return np.stack(xs), np.stack(ys) 45 | 46 | x1, y1 = get_batch() 47 | x2, y2 = get_batch() 48 | return x1, y1, x2, y2 49 | 50 | 51 | def make_net(): 52 | return objax.nn.Sequential([ 53 | objax.nn.Linear(1, 40), objax.functional.relu, 54 | objax.nn.Linear(40, 40), objax.functional.relu, 55 | objax.nn.Linear(40, 1) 56 | ]) 57 | 58 | 59 | source = jn.linspace(-5, 5, 100).reshape((100, 1)) # (k, 1) 60 | target = jn.sin(source) 61 | 62 | print('Standard training.') 63 | net = make_net() 64 | opt = objax.optimizer.Adam(net.vars()) 65 | 66 | 67 | @objax.Function.with_vars(net.vars()) 68 | def loss(x, y): 69 | return ((y - net(x)) ** 2).mean() 70 | 71 | 72 | gv = objax.GradValues(loss, net.vars()) 73 | 74 | 75 | @objax.Function.with_vars(net.vars() + opt.vars()) 76 | def train_op(): 77 | g, v = gv(source, target) 78 | opt(0.01, g) 79 | return v 80 | 81 | 82 | train_op = objax.Jit(train_op) 83 | 84 | for i in range(100): 85 | train_op() 86 | 87 | plt.plot(source, net(source), label='prediction') 88 | plt.plot(source, (target - net(source)) ** 2, label='loss') 89 | plt.plot(source, target, label='target') 90 | plt.legend() 91 | plt.show() 92 | 93 | print('MAML training') 94 | net = make_net() 95 | opt = objax.optimizer.Adam(net.vars()) 96 | 97 | 98 | @objax.Function.with_vars(net.vars()) 99 | def loss(x, y): 100 | return ((y - net(x)) ** 2).mean() 101 | 102 | 103 | gv = objax.GradValues(loss, net.vars()) 104 | 105 | 106 | @objax.Function.with_vars(net.vars()) 107 | def maml_loss(x1, y1, x2, y2, alpha=0.1): 108 | net_vars = net.vars() 109 | original_weights = net_vars.tensors() # Save original weights 110 | g_x1y1 = gv(x1, y1)[0] # Compute gradient at (x1, y1) 111 | # Apply gradient update using SGD 112 | net_vars.assign([v - alpha * g for v, g in zip(original_weights, g_x1y1)]) 113 | loss_x2y2 = loss(x2, y2) 114 | net_vars.assign(original_weights) # Restore original weights 115 | return loss_x2y2 116 | 117 | 118 | vec_maml_loss = objax.Vectorize(maml_loss, batch_axis=(0, 0, 0, 0, None)) 119 | 120 | 121 | @objax.Function.with_vars(vec_maml_loss.vars()) 122 | def batch_maml_loss(x1, y1, x2, y2, alpha=0.1): 123 | return vec_maml_loss(x1, y1, x2, y2, alpha).mean() 124 | 125 | 126 | maml_gv = objax.GradValues(batch_maml_loss, vec_maml_loss.vars()) 127 | 128 | 129 | @objax.Function.with_vars(vec_maml_loss.vars() + opt.vars()) 130 | def train_op(x1, y1, x2, y2): 131 | g, v = maml_gv(x1, y1, x2, y2) 132 | opt(0.001, g) 133 | return v 134 | 135 | 136 | train_op = objax.Jit(train_op) 137 | 138 | for i in trange(20000, leave=False): 139 | x1, y1, x2, y2 = sample_tasks(4, 20) 140 | train_op(x1, y1, x2, y2) 141 | 142 | x1 = np.random.uniform(low=-5., high=5., size=(10, 1)) 143 | y1 = 1. * np.sin(x1 + 0.) 144 | 145 | tensors = net.vars().tensors() 146 | for shot in range(1, 3): 147 | for v, g in zip(net.vars(), gv(x1, y1)[0]): 148 | if isinstance(v, objax.TrainVar): 149 | v.assign(v.value - 0.1 * g) 150 | plt.plot(source, net(source), label='%d-shot predictions' % shot) 151 | net.vars().assign(tensors) 152 | 153 | plt.plot(source, net(source), label='pre-update predictions') 154 | plt.plot(source, target, label='target') 155 | plt.legend() 156 | plt.show() 157 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | tensorflow-cpu>=2.3.0 3 | tensorflow_datasets>=3.2.1 4 | tqdm 5 | -------------------------------------------------------------------------------- /examples/text_generation/README.md: -------------------------------------------------------------------------------- 1 | [home](../../README.md) > [examples](../README.md) > text_generation 2 | 3 | # Examples 4 | 5 | This directory contains text generation examples. 6 | 7 | See: 8 | * `shakespeare_rnn.py` - predict characters from Shakespeare's plays using an RNN. 9 | -------------------------------------------------------------------------------- /objax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | 17 | from ._patch_jax import * 18 | 19 | pass # To avoid reordering imports from above 20 | 21 | from . import functional 22 | from . import io 23 | from . import jaxboard 24 | from . import nn 25 | from . import optimizer 26 | from . import privacy 27 | from . import random 28 | from . import typing 29 | from . import util 30 | from ._version import __version__ 31 | from .constants import * 32 | from .gradient import * 33 | from .module import * 34 | from .variable import * 35 | 36 | assert sys.version_info >= (3, 6) 37 | -------------------------------------------------------------------------------- /objax/_patch_jax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | __all__ = [] 17 | 18 | from typing import Union, Sequence, Tuple, Callable, Optional 19 | 20 | import jax.numpy as jn 21 | 22 | from .typing import JaxArray 23 | from .util import re_sign 24 | 25 | 26 | def _pad(array: JaxArray, 27 | pad_width: Union[Sequence[Tuple[int, int]], Tuple[int, int], int], 28 | mode: Optional[Union[str, Callable]] = 'constant', 29 | *, 30 | stat_length: Optional[Union[Sequence[Tuple[int, int]], int]] = None, 31 | constant_values: Optional[Union[Sequence[Tuple[int, int]], int]] = 0, 32 | end_values: Optional[Union[Sequence[Tuple[int, int]], int]] = None, 33 | reflect_type: Optional[str] = None): 34 | # This is just to have a proper signature for jax.numpy.pad since the API, like in numpy, makes use of kwargs 35 | # and doesn't expose its arguments properly. 36 | pass 37 | 38 | 39 | jn.pad = re_sign(_pad)(jn.pad) 40 | -------------------------------------------------------------------------------- /objax/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = '1.8.0' 16 | -------------------------------------------------------------------------------- /objax/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['ConvPadding', 'Interpolate'] 16 | 17 | import enum 18 | 19 | 20 | class ConvPadding(enum.Enum): 21 | """An Enum holding the possible padding values for convolution modules.""" 22 | SAME = 'SAME' 23 | VALID = 'VALID' 24 | 25 | 26 | class Interpolate(enum.Enum): 27 | """An Enum holding the possible interpolation values for upsampling.""" 28 | NEAREST = 'nearest' 29 | LINEAR = 'linear' 30 | BILINEAR = 'bilinear' 31 | TRILINEAR = 'trilinear' 32 | TRIANGLE = 'triangle' 33 | CUBIC = 'cubic' 34 | BICUBIC = 'bicubic' 35 | TRICUBIC = 'tricubic' 36 | LANCZOS3 = 'lanczos3' 37 | LANCZOS5 = 'lanczos5' 38 | -------------------------------------------------------------------------------- /objax/functional/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import divergence 16 | from . import loss 17 | from . import parallel 18 | from .core import * 19 | -------------------------------------------------------------------------------- /objax/functional/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .activation import * 16 | from .ops import * 17 | from .pooling import * 18 | -------------------------------------------------------------------------------- /objax/functional/core/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['celu', 'elu', 'leaky_relu', 'log_sigmoid', 'log_softmax', 'logsumexp', 'relu', 16 | 'selu', 'sigmoid', 'softmax', 'softplus', 'swish', 'tanh'] 17 | 18 | import jax.nn 19 | import jax.scipy.special 20 | from jax import lax 21 | 22 | from objax.typing import JaxArray 23 | 24 | celu = jax.nn.celu 25 | elu = jax.nn.elu 26 | leaky_relu = jax.nn.leaky_relu 27 | log_sigmoid = jax.nn.log_sigmoid 28 | log_softmax = jax.nn.log_softmax 29 | logsumexp = jax.scipy.special.logsumexp 30 | selu = jax.nn.selu 31 | sigmoid = jax.nn.sigmoid 32 | softmax = jax.nn.softmax 33 | softplus = jax.nn.softplus 34 | tanh = lax.tanh 35 | swish = jax.nn.swish 36 | 37 | 38 | # Have to redefine relu since jax.nn.relu isn't pickable. 39 | def relu(x: JaxArray) -> JaxArray: 40 | """Rectified linear unit activation function. 41 | 42 | Args: 43 | x: input tensor. 44 | 45 | Returns: 46 | tensor with the element-wise output relu(x) = max(x, 0). 47 | """ 48 | return jax.nn.relu(x) 49 | -------------------------------------------------------------------------------- /objax/functional/core/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | __all__ = ['dynamic_slice', 'flatten', 'interpolate', 'one_hot', 'pad', 'rsqrt', 'scan', 'stop_gradient', 17 | 'top_k', 'upsample_2d', 'upscale_nn'] 18 | 19 | from typing import Union, Tuple 20 | 21 | import jax.nn 22 | from jax import numpy as jn, lax 23 | 24 | from objax import util 25 | from objax.constants import Interpolate 26 | from objax.typing import JaxArray 27 | 28 | dynamic_slice = lax.dynamic_slice 29 | one_hot = jax.nn.one_hot 30 | pad = jn.pad 31 | scan = lax.scan 32 | stop_gradient = lax.stop_gradient 33 | top_k = lax.top_k # Current code doesn't work with gradient. 34 | rsqrt = lax.rsqrt 35 | 36 | 37 | def flatten(x: JaxArray) -> JaxArray: 38 | """Flattens input tensor to a 2D tensor. 39 | 40 | Args: 41 | x: input tensor with dimensions (n_1, n_2, ..., n_k) 42 | 43 | Returns: 44 | The input tensor reshaped to two dimensions (n_1, n_prod), 45 | where n_prod is equal to the product of n_2 to n_k. 46 | """ 47 | return x.reshape([x.shape[0], -1]) 48 | 49 | 50 | def interpolate(input: JaxArray, 51 | size: Union[int, Tuple[int, ...]] = None, 52 | scale_factor: Union[int, Tuple[int, ...]] = None, 53 | mode: Union[Interpolate, str] = Interpolate.BILINEAR) -> JaxArray: 54 | """ 55 | Function to interpolate JaxArrays by size or scaling factor 56 | Args: 57 | input: input tensor 58 | size: int or tuple for output size 59 | scale_factor: int or tuple scaling factor for each dimention 60 | mode:str or Interpolate interpolation method e.g. ['bilinear', 'nearest'] 61 | 62 | Returns: 63 | output : output JaxArray after interpolation 64 | """ 65 | assert size or scale_factor, f'both size: {size} and scale_factor: {scale_factor} can not be None .' 66 | assert bool(size) ^ bool(scale_factor), f'either size or scale_factor must be none ' \ 67 | f'scale: {size}, scale_factor: {scale_factor} .' 68 | input_shape = input.shape 69 | input_dim = len(input_shape) 70 | if scale_factor: 71 | if isinstance(scale_factor, int): 72 | size = (input_shape[0], *(jn.array(input_shape[1:]) * scale_factor)) 73 | if isinstance(scale_factor, Tuple): 74 | output_dim = len(scale_factor) 75 | size = (*input_shape[:input_dim - output_dim], 76 | *(jn.array(input_shape[input_dim - output_dim:]) * jn.array(scale_factor))) 77 | else: 78 | if isinstance(size, int): 79 | size = (*input_shape[:-1], size) 80 | if isinstance(size, Tuple): 81 | output_dim = len(size) 82 | assert input_dim >= output_dim, f'Number of dimensions of "{size}"' \ 83 | f' must be < = to input shape"{input_shape}" ' 84 | size = (*input_shape[:input_dim - output_dim], *size) 85 | output = jax.image.resize(input, 86 | shape=size, 87 | method=util.to_interpolate(mode)) 88 | return output 89 | 90 | 91 | def upsample_2d(x: JaxArray, 92 | scale: Union[Tuple[int, int], int], 93 | method: Union[Interpolate, str] = Interpolate.BILINEAR) -> JaxArray: 94 | """Function to upscale 2D images. 95 | 96 | Args: 97 | x: input tensor. 98 | scale: int or tuple scaling factor 99 | method: str or UpSample interpolation methods e.g. ['bilinear', 'nearest']. 100 | 101 | Returns: 102 | upscaled 2d image tensor 103 | """ 104 | s = x.shape 105 | assert len(s) == 4, f'{s} must have 4 dimensions to be upsampled, or you can try interpolate function.' 106 | scale = util.to_tuple(scale, 2) 107 | y = jax.image.resize(x.transpose([0, 2, 3, 1]), 108 | shape=(s[0], s[2] * scale[0], s[3] * scale[1], s[1]), 109 | method=util.to_interpolate(method)) 110 | return y.transpose([0, 3, 1, 2]) 111 | 112 | 113 | def upscale_nn(x: JaxArray, scale: int = 2) -> JaxArray: 114 | """Nearest neighbor upscale for image batches of shape (N, C, H, W). 115 | 116 | Args: 117 | x: input tensor of shape (N, C, H, W). 118 | scale: integer scaling factor. 119 | 120 | Returns: 121 | Output tensor of shape (N, C, H * scale, W * scale). 122 | """ 123 | s = x.shape 124 | x = x.reshape(s[:2] + (s[2], 1, s[3], 1)) 125 | x = jn.tile(x, (1, 1, 1, scale, 1, scale)) 126 | return x.reshape(s[:2] + (scale * s[2], scale * s[3])) 127 | -------------------------------------------------------------------------------- /objax/functional/divergence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['kl'] 16 | 17 | import jax.numpy as jn 18 | 19 | from objax.typing import JaxArray 20 | 21 | 22 | def kl(p: JaxArray, q: JaxArray, eps: float = 2 ** -17) -> JaxArray: 23 | """Calculates the Kullback-Leibler divergence between arrays p and q.""" 24 | return p.dot(jn.log(p + eps) - jn.log(q + eps)) 25 | -------------------------------------------------------------------------------- /objax/functional/parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['partial', 'pmax', 'pmean', 'pmin', 'psum'] 16 | 17 | from functools import partial 18 | 19 | import jax 20 | from jax import lax 21 | 22 | 23 | def pmax(x: jax.Array, axis_name: str = 'device'): 24 | """Compute a multi-device reduce max on x over the device axis axis_name.""" 25 | return lax.pmax(x, axis_name) 26 | 27 | 28 | def pmean(x: jax.Array, axis_name: str = 'device'): 29 | """Compute a multi-device reduce mean on x over the device axis axis_name.""" 30 | return lax.pmean(x, axis_name) 31 | 32 | 33 | def pmin(x: jax.Array, axis_name: str = 'device'): 34 | """Compute a multi-device reduce min on x over the device axis axis_name.""" 35 | return lax.pmin(x, axis_name) 36 | 37 | 38 | def psum(x: jax.Array, axis_name: str = 'device'): 39 | """Compute a multi-device reduce sum on x over the device axis axis_name.""" 40 | return lax.psum(x, axis_name) 41 | -------------------------------------------------------------------------------- /objax/io/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .checkpoint import * 16 | from .ops import * 17 | -------------------------------------------------------------------------------- /objax/io/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['load_var_collection', 'save_var_collection'] 16 | 17 | import collections 18 | import os 19 | from typing import IO, BinaryIO, Union, Optional 20 | 21 | import jax.numpy as jn 22 | import numpy as np 23 | 24 | from objax.util import Renamer 25 | from objax.variable import TrainRef, VarCollection 26 | 27 | 28 | def load_var_collection(file: Union[str, IO[BinaryIO]], 29 | vc: VarCollection, 30 | renamer: Optional[Renamer] = None): 31 | """Loads values of all variables in the given variables collection from file. 32 | 33 | Values loaded from file will replace old values in the variables collection. 34 | If variable exists in the file, but does not exist in the variables collection it will be ignored. 35 | If variable exists in the variables collection, but not found in the file then exception will be raised. 36 | 37 | Args: 38 | file: filename or python file handle of the input file. 39 | vc: variables collection which will be loaded from file. 40 | renamer: optional renamer to pre-process variables names from the file being read. 41 | 42 | Raises: 43 | ValueError: if variable from variables collection is not found in the input file. 44 | """ 45 | renamer = renamer or (lambda x: x) 46 | do_close = isinstance(file, str) 47 | if do_close: 48 | file = open(file, 'rb') 49 | data = np.load(file, allow_pickle=False) 50 | name_index = {renamer(k): str(i) for i, k in enumerate(data['names'])} 51 | var_names = collections.defaultdict(list) 52 | var_values = {} 53 | for k, v in vc.items(): 54 | if isinstance(v, TrainRef): 55 | v = v.ref 56 | var_names[id(v)].append(k) 57 | var_values[id(v)] = v 58 | misses = [] 59 | used_vars = set() 60 | for var_id, names in var_names.items(): 61 | v = var_values[var_id] 62 | for name in names: 63 | index = name_index.get(name) 64 | if index is not None: 65 | used_vars.add(name) 66 | try: 67 | v.assign(jn.array(data[index])) 68 | except AssertionError as e: 69 | raise AssertionError(f'Error when restoring variable {name}: {str(e)}') from None 70 | break 71 | else: 72 | misses += names 73 | if misses: 74 | not_used = set(name_index.keys()) - used_vars 75 | raise ValueError(f'Missing value for variables currently in the model: {misses}. ' 76 | f'The following variables on disk were not used, ' 77 | f'maybe the missing variable was renamed from one of these: {not_used}.') 78 | if do_close: 79 | file.close() 80 | 81 | 82 | def save_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection): 83 | """Saves variables collection into file. 84 | 85 | Args: 86 | file: filename or python file handle of the file where variables will be saved. 87 | vc: variables collection which will be saved into file. 88 | """ 89 | do_close = isinstance(file, str) 90 | if do_close: 91 | filename, file = file, open(file + '.tmp', 'wb') # Save to a temporary in case the job is killed while saving. 92 | data, names, seen, replicated = {}, [], set(), [] 93 | for k, v in vc.items(): 94 | if isinstance(v, TrainRef): 95 | v = v.ref 96 | if id(v) not in seen: 97 | names.append(k) 98 | data[str(len(data))] = v.value 99 | seen.add(id(v)) 100 | if replicated: 101 | print('Warning: When saving VarCollection, some variables were replicated on multiple devices.') 102 | print(' While it is valid, in most use cases it is more disk efficient to save variables outside of ') 103 | print(' vars().replicate().') 104 | 105 | def _disabled_seek(*_): 106 | raise AttributeError('seek() is disabled on this object.') 107 | _old_seek = getattr(file, 'seek') 108 | setattr(file, 'seek', _disabled_seek) 109 | np.savez(file, names=np.array(names), **data) 110 | setattr(file, 'seek', _old_seek) 111 | if do_close: 112 | file.close() 113 | os.rename(filename + '.tmp', filename) # Atomic rename to avoid broken file (when killed while saving). 114 | -------------------------------------------------------------------------------- /objax/jaxboard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import enum 16 | import os 17 | from time import time 18 | from typing import Union, Callable, Tuple, ByteString 19 | 20 | import numpy as np 21 | from tensorboard.compat.proto import event_pb2 22 | from tensorboard.compat.proto import summary_pb2 23 | from tensorboard.summary.writer.event_file_writer import EventFileWriter 24 | from tensorboard.util.tensor_util import make_tensor_proto 25 | 26 | from objax import util 27 | 28 | 29 | class Reducer(enum.Enum): 30 | """Reduces tensor batch into a single tensor.""" 31 | FIRST = lambda x: x[0] 32 | LAST = lambda x: x[-1] 33 | MEAN = lambda x: np.mean(x) 34 | 35 | 36 | class DelayedScalar: 37 | def __init__(self, reduce: Union[Callable, Reducer]): 38 | self.values = [] 39 | self.reduce = reduce 40 | 41 | def __call__(self): 42 | return self.reduce(self.values) 43 | 44 | 45 | class Image: 46 | def __init__(self, shape: Tuple[int, int, int], png: ByteString): 47 | self.shape = shape 48 | self.png = png 49 | 50 | 51 | class Text: 52 | def __init__(self, text: str): 53 | self.text = text 54 | 55 | 56 | class Summary(dict): 57 | """Writes entries to `Summary` protocol buffer.""" 58 | 59 | def image(self, tag: str, image: np.ndarray): 60 | """Adds image to the summary. Float image in [-1, 1] in CHW format expected.""" 61 | self[tag] = Image(image.shape, util.image.to_png(image)) 62 | 63 | def scalar(self, tag: str, value: float, reduce: Union[Callable, Reducer] = Reducer.MEAN): 64 | """Adds scalar to the summary.""" 65 | if tag not in self: 66 | self[tag] = DelayedScalar(reduce) 67 | self[tag].values.append(value) 68 | 69 | def text(self, tag: str, text: str): 70 | """Adds text to the summary.""" 71 | self[tag] = Text(text) 72 | 73 | def __call__(self): 74 | entries = [] 75 | for tag, value in self.items(): 76 | if isinstance(value, DelayedScalar): 77 | entries.append(summary_pb2.Summary.Value(tag=tag, simple_value=value())) 78 | elif isinstance(value, Image): 79 | image_summary = summary_pb2.Summary.Image(encoded_image_string=value.png, 80 | colorspace=value.shape[0], 81 | height=value.shape[1], 82 | width=value.shape[2]) 83 | entries.append(summary_pb2.Summary.Value(tag=tag, image=image_summary)) 84 | elif isinstance(value, Text): 85 | metadata = summary_pb2.SummaryMetadata( 86 | plugin_data=summary_pb2.SummaryMetadata.PluginData(plugin_name='text')) 87 | entries.append(summary_pb2.Summary.Value(tag=tag, metadata=metadata, 88 | tensor=make_tensor_proto(values=value.text.encode('utf-8'), 89 | shape=(1,)))) 90 | else: 91 | raise NotImplementedError(tag, value) 92 | return summary_pb2.Summary(value=entries) 93 | 94 | 95 | class SummaryWriter: 96 | """Writes entries to event files in the logdir to be consumed by Tensorboard.""" 97 | 98 | def __init__(self, logdir: str, queue_size: int = 5, write_interval: int = 5): 99 | """Creates SummaryWriter instance. 100 | 101 | Args: 102 | logdir: directory where event file will be written. 103 | queue_size: size of the queue for pending events and summaries 104 | before one of the 'add' calls forces a flush to disk. 105 | write_interval: how often, in seconds, to write the pending events and summaries to disk. 106 | """ 107 | if not os.path.isdir(logdir): 108 | os.makedirs(logdir, exist_ok=True) 109 | 110 | self.writer = EventFileWriter(logdir, queue_size, write_interval) 111 | 112 | def write(self, summary: Summary, step: int): 113 | """Adds on event to the event file.""" 114 | self.writer.add_event(event_pb2.Event(step=step, summary=summary(), wall_time=time())) 115 | 116 | def close(self): 117 | """Flushes the event file to disk and close the file.""" 118 | self.writer.close() 119 | 120 | def __enter__(self): 121 | return self 122 | 123 | def __exit__(self, exc_type, exc_val, exc_tb): 124 | self.close() 125 | -------------------------------------------------------------------------------- /objax/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import init 16 | from .layers import * 17 | -------------------------------------------------------------------------------- /objax/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .adam import * 16 | from .ema import * 17 | from .lars import * 18 | from .momentum import * 19 | from .sgd import * 20 | from . import scheduler -------------------------------------------------------------------------------- /objax/optimizer/adam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['Adam'] 16 | 17 | from typing import List, Optional 18 | 19 | from jax import numpy as jn 20 | 21 | from objax import functional 22 | from objax.module import Module, ModuleList 23 | from objax.typing import JaxArray 24 | from objax.util import class_name 25 | from objax.variable import TrainRef, StateVar, TrainVar, VarCollection 26 | 27 | 28 | class Adam(Module): 29 | """Adam optimizer.""" 30 | 31 | def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8): 32 | """Constructor for Adam optimizer class. 33 | 34 | Args: 35 | vc: collection of variables to optimize. 36 | beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9. 37 | beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999. 38 | eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8. 39 | """ 40 | self.beta1 = beta1 41 | self.beta2 = beta2 42 | self.eps = eps 43 | self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) 44 | self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) 45 | self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) 46 | self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) 47 | 48 | def __call__(self, lr: float, grads: List[JaxArray], beta1: Optional[float] = None, beta2: Optional[float] = None): 49 | """Updates variables and other state based on Adam algorithm. 50 | 51 | Args: 52 | lr: the learning rate. 53 | grads: the gradients to apply. 54 | beta1: optional, override the default beta1. 55 | beta2: optional, override the default beta2. 56 | """ 57 | assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables' 58 | if beta1 is None: 59 | beta1 = self.beta1 60 | if beta2 is None: 61 | beta2 = self.beta2 62 | self.step.value += 1 63 | lr *= jn.sqrt(1 - beta2 ** self.step.value) / (1 - beta1 ** self.step.value) 64 | for g, p, m, v in zip(grads, self.train_vars, self.m, self.v): 65 | m.value += (1 - beta1) * (g - m.value) 66 | v.value += (1 - beta2) * (g ** 2 - v.value) 67 | p.value -= lr * m.value * functional.rsqrt(v.value + self.eps) 68 | 69 | def __repr__(self): 70 | return f'{class_name(self)}(beta1={self.beta1}, beta2={self.beta2}, eps={self.eps})' 71 | -------------------------------------------------------------------------------- /objax/optimizer/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['LARS'] 16 | 17 | from typing import List 18 | 19 | import jax.numpy as jn 20 | 21 | from objax.module import Module, ModuleList 22 | from objax.typing import JaxArray 23 | from objax.util import class_name 24 | from objax.variable import TrainRef, StateVar, TrainVar, VarCollection 25 | 26 | 27 | class LARS(Module): 28 | """Layerwise adaptive rate scaling (LARS) optimizer. 29 | 30 | See https://arxiv.org/abs/1708.03888 31 | """ 32 | 33 | def __init__(self, vc: VarCollection, 34 | momentum: float = 0.9, 35 | weight_decay: float = 1e-4, 36 | tc: float = 1e-3, 37 | eps: float = 1e-5): 38 | """Constructor for LARS optimizer. 39 | 40 | Args: 41 | vc: collection of variables to optimize. 42 | momentum: coefficient used for the moving average of the gradient. 43 | weight_decay: weight decay coefficient. 44 | tc: trust coefficient eta ( < 1) for trust ratio computation. 45 | eps: epsilon used for trust ratio computation. 46 | """ 47 | self.momentum = momentum 48 | self.weight_decay = weight_decay 49 | self.tc = tc 50 | self.eps = eps 51 | self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) 52 | self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) 53 | 54 | def __call__(self, lr: float, grads: List[JaxArray]): 55 | """Updates variables based on LARS algorithm. 56 | 57 | Args: 58 | lr: learning rate. The LARS paper suggests using lr = lr_0 * (1 -t/T)**2, 59 | where t is the current epoch number and T the maximum number of epochs. 60 | grads: the gradients to apply. 61 | """ 62 | assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables' 63 | 64 | for g, p, m in zip(grads, self.train_vars, self.m): 65 | p_norm = jn.linalg.norm(p.value) 66 | g_norm = jn.linalg.norm(g) 67 | trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps) 68 | local_lr = lr * jn.maximum(jn.logical_or(p_norm == 0, g_norm == 0), trust_ratio) 69 | m.value = self.momentum * m.value + local_lr * (g + self.weight_decay * p.value) 70 | p.value -= m.value 71 | 72 | def __repr__(self): 73 | return f'{class_name(self)}(momentum={self.momentum}, weight_decay={self.weight_decay}, ' \ 74 | f'tc={self.tc}, eps={self.eps})' 75 | -------------------------------------------------------------------------------- /objax/optimizer/momentum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['Momentum'] 16 | 17 | from typing import List, Optional 18 | 19 | from jax import numpy as jn 20 | 21 | from objax.module import Module, ModuleList 22 | from objax.util import class_name 23 | from objax.variable import TrainRef, StateVar, TrainVar, VarCollection 24 | 25 | 26 | class Momentum(Module): 27 | """Momentum optimizer.""" 28 | 29 | def __init__(self, vc: VarCollection, momentum: float = 0.9, nesterov: bool = False): 30 | """Constructor for momentum optimizer class. 31 | 32 | Args: 33 | vc: collection of variables to optimize. 34 | momentum: the momentum hyperparameter. 35 | nesterov: bool indicating whether to use the Nesterov method. 36 | """ 37 | self.momentum = momentum 38 | self.nesterov = nesterov 39 | self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) 40 | self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) 41 | 42 | def __call__(self, lr: float, grads: List[jn.ndarray], momentum: Optional[float] = None): 43 | """Updates variables and other state based on momentum (or Nesterov) SGD. 44 | 45 | Args: 46 | lr: the learning rate. 47 | grads: the gradients to apply. 48 | momentum: optional, override the default momentum. 49 | """ 50 | assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables' 51 | if momentum is None: 52 | momentum = self.momentum 53 | if self.nesterov: 54 | for g, p, m in zip(grads, self.train_vars, self.m): 55 | m.value = g + momentum * m.value 56 | p.value -= lr * (g + momentum * m.value) 57 | else: 58 | for g, p, m in zip(grads, self.train_vars, self.m): 59 | m.value = g + momentum * m.value 60 | p.value -= lr * m.value 61 | 62 | def __repr__(self): 63 | return f'{class_name(self)}(momentum={self.momentum}, nesterov={self.nesterov})' 64 | -------------------------------------------------------------------------------- /objax/optimizer/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['LinearAnnealing', 'StepDecay'] 16 | 17 | 18 | import abc 19 | from typing import List, Tuple, Union 20 | 21 | import jax.numpy as jn 22 | 23 | 24 | class Scheduler: 25 | def __init__(self, 26 | base_lr: float = 1.0): 27 | """Constructs an instance for learning rate scheduler. 28 | 29 | Args: 30 | base_lr: base learning rate. 31 | """ 32 | self.base_lr = base_lr 33 | 34 | @abc.abstractmethod 35 | def multiplier(self, step: float): 36 | """Returns learning rate multiplier w.r.t. certain schedule.""" 37 | raise NotImplementedError 38 | 39 | def __call__(self, step: float): 40 | """Returns learning rate or multiplier at certain step. 41 | 42 | Args: 43 | step: number of training step. When 0, we use the step counter. 44 | 45 | Returns: 46 | learning rate when base_lr is provided; otherwise, 47 | only multiplier is returned. 48 | """ 49 | return self.base_lr * self.multiplier(step) 50 | 51 | 52 | class LinearAnnealing(Scheduler): 53 | def __init__(self, 54 | max_step: float, 55 | base_lr: float = 1.0, 56 | is_cycle: bool = True, 57 | min_lr: float = 0.0): 58 | """Constructs an instance for linear annealing learning rate scheduler. 59 | 60 | Args: 61 | max_step: maximum number of train step. 62 | base_lr: base learning rate. 63 | is_cycle: trigger cyclical learning rate multiplier when step 64 | exceeds max_step. 65 | min_lr: minimum learning rate at max_step. 66 | """ 67 | super().__init__(base_lr=base_lr) 68 | assert base_lr >= min_lr, ( 69 | 'base_lr should be greater than or equal to min_lr.') 70 | self.max_step = max_step 71 | self.is_cycle = is_cycle 72 | self.min_lr_multiplier = min_lr / self.base_lr 73 | 74 | def multiplier(self, step: float): 75 | """Returns linear annealing learning rate multiplier.""" 76 | 77 | # If is_cycle, we use the remainder of step; otherwise, we stop update. 78 | if self.is_cycle: 79 | step = jn.remainder(step, self.max_step) 80 | else: 81 | step = jn.minimum(step, self.max_step) 82 | 83 | return 1.0 - (step / self.max_step) * ( 84 | 1.0 - self.min_lr_multiplier) 85 | 86 | 87 | class StepDecay(Scheduler): 88 | def __init__(self, 89 | step_size: Union[float, List, Tuple], 90 | base_lr: float = 1.0, 91 | gamma: float = 0.1): 92 | """Constructs an instance for step decay learning rate scheduler. 93 | 94 | Args: 95 | step_size: number of train steps to reduce learning rate. 96 | base_lr: base learning rate. 97 | gamma: learning rate decay rate. 98 | """ 99 | super().__init__(base_lr=base_lr) 100 | self.gamma = gamma 101 | self.step_size = step_size 102 | 103 | def multiplier(self, step: float): 104 | """Returns step decay learning rate multiplier.""" 105 | if isinstance(self.step_size, (tuple, list)): 106 | exponent = jn.sum(jn.greater_equal(step, jn.array(self.step_size))) 107 | else: 108 | exponent = step // self.step_size 109 | return self.gamma ** exponent 110 | -------------------------------------------------------------------------------- /objax/optimizer/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['SGD'] 16 | 17 | from typing import List 18 | 19 | from objax.module import Module, ModuleList 20 | from objax.typing import JaxArray 21 | from objax.util import class_name 22 | from objax.variable import TrainRef, TrainVar, VarCollection 23 | 24 | 25 | class SGD(Module): 26 | """Stochastic Gradient Descent (SGD) optimizer.""" 27 | 28 | def __init__(self, vc: VarCollection): 29 | """Constructor for SGD optimizer. 30 | 31 | Args: 32 | vc: collection of variables to optimize. 33 | """ 34 | self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) 35 | 36 | def __call__(self, lr: float, grads: List[JaxArray]): 37 | """Updates variables based on SGD algorithm. 38 | 39 | Args: 40 | lr: the learning rate. 41 | grads: the gradients to apply. 42 | """ 43 | assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables' 44 | for g, p in zip(grads, self.train_vars): 45 | p.value -= lr * g 46 | 47 | def __repr__(self): 48 | return f'{class_name(self)}()' 49 | -------------------------------------------------------------------------------- /objax/privacy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import dpsgd 16 | -------------------------------------------------------------------------------- /objax/privacy/dpsgd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .gradient import * 16 | from .privacyaccountant import * 17 | -------------------------------------------------------------------------------- /objax/random/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .random import * 16 | -------------------------------------------------------------------------------- /objax/random/random.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['DEFAULT_GENERATOR', 'Generator', 'randint', 'normal', 'truncated_normal', 'uniform'] 16 | 17 | from typing import Optional, Tuple 18 | 19 | import jax.random as jr 20 | 21 | from objax.module import Module 22 | from objax.util import class_name 23 | from objax.variable import RandomState, VarCollection 24 | 25 | 26 | class Generator(Module): 27 | """Random number generator module.""" 28 | 29 | def __init__(self, seed: int = 0): 30 | """Create a random key generator, seed is the random generator initial seed.""" 31 | super().__init__() 32 | self.initial_seed = seed 33 | self._key: Optional[RandomState] = None 34 | 35 | @property 36 | def key(self): 37 | """The random generator state (a tensor of 2 int32).""" 38 | if self._key is None: 39 | self._key = RandomState(self.initial_seed) 40 | return self._key 41 | 42 | def seed(self, seed: int = 0): 43 | """Sets a new random generator seed.""" 44 | self.initial_seed = seed 45 | if self._key is not None: 46 | self._key.seed(seed) 47 | 48 | def __call__(self): 49 | """Generate a new generator state.""" 50 | return self.key.split(1)[0] 51 | 52 | def vars(self, scope: str = '') -> VarCollection: 53 | self.key # Make sure the key is created before collecting the vars. 54 | return super().vars(scope) 55 | 56 | def __repr__(self): 57 | return f'{class_name(self)}(seed={self.initial_seed})' 58 | 59 | 60 | DEFAULT_GENERATOR = Generator(0) 61 | 62 | 63 | def normal(shape: Tuple[int, ...], *, mean: float = 0, stddev: float = 1, generator: Generator = DEFAULT_GENERATOR): 64 | """Returns a ``JaxArray`` of shape ``shape`` with random numbers from a normal distribution 65 | with mean ``mean`` and standard deviation ``stddev``. 66 | 67 | NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function 68 | then generator variables (including DEFAULT_GENERATOR) have to be added to the 69 | variable collection.""" 70 | return jr.normal(generator(), shape=shape) * stddev + mean 71 | 72 | 73 | def randint(shape: Tuple[int, ...], low: int, high: int, generator: Generator = DEFAULT_GENERATOR): 74 | """Returns a ``JaxAarray`` of shape ``shape`` with random integers in {low, ..., high-1}. 75 | 76 | NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function 77 | then generator variables (including DEFAULT_GENERATOR) have to be added to the 78 | variable collection.""" 79 | return jr.randint(generator(), shape=shape, minval=low, maxval=high) 80 | 81 | 82 | def truncated_normal(shape: Tuple[int, ...], *, 83 | stddev: float = 1, 84 | lower: float = -2, 85 | upper: float = 2, 86 | generator: Generator = DEFAULT_GENERATOR): 87 | """Returns a ``JaxArray`` of shape ``shape`` with random numbers from a normal distribution 88 | with mean 0 and standard deviation ``stddev`` truncated by (``lower``, ``upper``). 89 | 90 | NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function 91 | then generator variables (including DEFAULT_GENERATOR) have to be added to the 92 | variable collection.""" 93 | return jr.truncated_normal(generator(), shape=shape, lower=lower, upper=upper) * stddev 94 | 95 | 96 | def uniform(shape: Tuple[int, ...], generator: Generator = DEFAULT_GENERATOR): 97 | """Returns a ``JaxArray`` of shape ``shape`` with random numbers from a uniform distribution [0, 1]. 98 | 99 | NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function 100 | then generator variables (including DEFAULT_GENERATOR) have to be added to the 101 | variable collection.""" 102 | return jr.uniform(generator(), shape=shape) 103 | -------------------------------------------------------------------------------- /objax/typing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """This module contains type declarations for Objax.""" 16 | 17 | __all__ = ['FileOrStr', 'JaxArray', 'JaxDType'] 18 | 19 | from typing import Union, IO, BinaryIO, Sequence, Tuple 20 | 21 | import jax 22 | import jax.numpy as jn 23 | 24 | ConvPaddingInt = Union[Sequence[Tuple[int, int]], Tuple[int, int], int] 25 | FileOrStr = Union[str, IO[BinaryIO]] 26 | JaxArray = jax.Array 27 | JaxDType = Union[jn.complex64, jn.complex128, jn.bfloat16, 28 | jn.float16, jn.float32, jn.float64, 29 | jn.int8, jn.int16, jn.int32, jn.int64, 30 | jn.uint8, jn.uint16, jn.uint32, jn.uint64] 31 | -------------------------------------------------------------------------------- /objax/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import image 16 | from . import check 17 | from .util import * 18 | from .objax2tf import Objax2Tf 19 | from .tracing import find_used_variables -------------------------------------------------------------------------------- /objax/util/check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['assert_assigned_type_and_shape_match'] 16 | 17 | import jax 18 | 19 | 20 | TRACER_TYPES = (jax.interpreters.partial_eval.JaxprTracer, 21 | jax.interpreters.partial_eval.DynamicJaxprTracer) 22 | 23 | 24 | def split_shape_and_device(array): 25 | if isinstance(array, jax.Array) and hasattr(array, 'sharding') and isinstance( 26 | array.sharding, jax.sharding.PmapSharding): 27 | return array.shape[0], array.shape[1:] 28 | else: 29 | return None, array.shape 30 | 31 | 32 | def assert_assigned_type_and_shape_match(existing_tensor, new_tensor): 33 | assert isinstance(new_tensor, jax.Array), \ 34 | f'Assignments to variable must be an instance of JaxArray, but received f{type(new_tensor)}.' 35 | 36 | new_tensor_device, new_tensor_shape = split_shape_and_device(new_tensor) 37 | self_device, self_shape = split_shape_and_device(existing_tensor) 38 | 39 | device_mismatch_error = f'Can not replicate a variable that is currently on ' \ 40 | f'{self_device} devices to {new_tensor_device} devices.' 41 | assert (new_tensor_device is None) or (self_device is None) or (self_device == new_tensor_device), \ 42 | device_mismatch_error 43 | 44 | shorter_length = min(len(new_tensor.shape), len(existing_tensor.shape)) 45 | is_special_ok = (isinstance(new_tensor, TRACER_TYPES) or isinstance(existing_tensor, TRACER_TYPES)) 46 | is_special_ok = is_special_ok and existing_tensor.shape[-shorter_length:] == new_tensor.shape[-shorter_length:] 47 | 48 | shape_mismatch_error = f'Assign can not change shape of variable. The current variable shape is {self_shape},' \ 49 | f' but the requested new shape is {new_tensor_shape}.' 50 | assert is_special_ok or new_tensor_shape == self_shape or new_tensor.shape == existing_tensor.shape, \ 51 | shape_mismatch_error 52 | -------------------------------------------------------------------------------- /objax/util/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['from_file', 'image_grid', 'nchw', 'nhwc', 'normalize_to_uint8', 'normalize_to_unit_float', 'to_png'] 16 | 17 | import io 18 | from typing import Union, BinaryIO, IO 19 | 20 | import jax.numpy as jn 21 | import numpy as np 22 | from PIL import Image 23 | 24 | from objax.typing import JaxArray 25 | 26 | 27 | def from_file(file: Union[str, IO[BinaryIO]]) -> np.ndarray: 28 | """Read an image from a file, convert it RGB and return it as an array. 29 | 30 | Args: 31 | file: filename or python file handle of the input file. 32 | 33 | Return: 34 | 3D numpy array (C, H, W) normalized with normalize_to_unit_float. 35 | """ 36 | image = np.asarray(Image.open(file).convert('RGB')) 37 | return normalize_to_unit_float(image.transpose((2, 0, 1))) 38 | 39 | 40 | def image_grid(image: np.ndarray) -> np.ndarray: 41 | """Rearrange array of images (nh, hw, c, h, w) into image grid in a single image (c, nh * h, nh * w).""" 42 | s = image.shape 43 | return image.transpose([2, 0, 3, 1, 4]).reshape([s[2], s[3] * s[0], s[4] * s[1]]) 44 | 45 | 46 | def nchw(x: Union[np.ndarray, JaxArray]) -> Union[np.ndarray, JaxArray]: 47 | """Converts an array in (N,H,W,C) format to (N,C,H,W) format.""" 48 | dims = list(range(x.ndim)) 49 | dims.insert(-2, dims.pop()) 50 | return x.transpose(dims) 51 | 52 | 53 | def nhwc(x: Union[np.ndarray, JaxArray]) -> Union[np.ndarray, JaxArray]: 54 | """Converts an array in (N,C,H,W) format to (N,H,W,C) format.""" 55 | dims = list(range(x.ndim)) 56 | dims.append(dims.pop(-3)) 57 | return x.transpose(dims) 58 | 59 | 60 | def normalize_to_uint8(x: Union[np.ndarray, JaxArray]) -> Union[np.ndarray, JaxArray]: 61 | """Map a float image in [1/256-1, 1-1/256] to uint8 {0, 1, ..., 255}.""" 62 | return (128 * (x + (1 - 1 / 256))).clip(0, 255).round().astype('uint8') 63 | 64 | 65 | def normalize_to_unit_float(x: Union[np.ndarray, JaxArray]) -> Union[np.ndarray, JaxArray]: 66 | """Map an uint8 image in {0, 1, ..., 255} to float interval [1/256-1, 1-1/256].""" 67 | return x * (1 / 128) + (1 / 256 - 1) 68 | 69 | 70 | def to_png(x: Union[np.ndarray, JaxArray]) -> bytes: 71 | """Converts numpy array in (C,H,W) format into PNG format.""" 72 | if isinstance(x, jn.ndarray): 73 | x = np.array(x) 74 | if x.dtype in (np.float64, np.float32, np.float16): 75 | x = np.transpose(normalize_to_uint8(x), (1, 2, 0)) 76 | elif x.dtype != np.uint8: 77 | raise ValueError('Unsupported array type, expecting float or uint8', x.dtype) 78 | if x.shape[2] == 1: 79 | x = np.broadcast_to(x, x.shape[:2] + (3,)) 80 | with io.BytesIO() as f: 81 | Image.fromarray(x).save(f, 'png') 82 | return f.getvalue() 83 | -------------------------------------------------------------------------------- /objax/util/objax2tf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | 17 | from objax.module import Module 18 | from objax.typing import JaxArray 19 | 20 | try: 21 | # Only import tensorflow if available. 22 | import tensorflow as tf 23 | 24 | tf.config.experimental.set_visible_devices([], 'GPU') 25 | except ImportError: 26 | # Make fake tf, so code in this file will be successfully imported even when Tensorflow is not installed. 27 | tf = type('tf', (), {}) 28 | setattr(tf, 'Module', object) 29 | 30 | def _fake_tf_function(func=None, **kwargs): 31 | del kwargs 32 | if func is not None: 33 | return func 34 | else: 35 | return lambda x: x 36 | 37 | setattr(tf, 'function', _fake_tf_function) 38 | 39 | 40 | class Objax2Tf(tf.Module): 41 | """Objax to Tensorflow converter, which converts Objax module to tf.Module.""" 42 | 43 | def __init__(self, module: Module): 44 | """Create a Tensorflow module from Objax module. 45 | 46 | Args: 47 | module: Objax module to be converted to Tensorflow tf.Module. 48 | """ 49 | from jax.experimental import jax2tf 50 | assert hasattr(tf, '__version__'), 'Tensorflow must be installed for Objax2Tf to work.' 51 | assert tf.__version__ >= '2.0', 'Objax2Tf works only with Tensorflow 2.' 52 | assert isinstance(module, Module), 'Input argument to Objax2Tf must be an Objax module.' 53 | 54 | super().__init__() 55 | 56 | module_vars = module.vars() 57 | 58 | def wrapped_op(tensor_list: List[JaxArray], kwargs, *args): 59 | original_values = module_vars.tensors() 60 | try: 61 | module_vars.assign(tensor_list) 62 | return module(*args, **kwargs) 63 | finally: 64 | module_vars.assign(original_values) 65 | 66 | tf_function = jax2tf.convert(wrapped_op) 67 | self._tf_vars = [tf.Variable(v) for v in module_vars.tensors()] 68 | self._tf_call = tf_function 69 | 70 | @tf.function(autograph=False) 71 | def __call__(self, *args, **kwargs): 72 | """Calls Tensorflow function which was generated from Objax module.""" 73 | return self._tf_call(self._tf_vars, kwargs, *args) 74 | -------------------------------------------------------------------------------- /objax/zoo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /objax/zoo/convnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | 17 | import objax 18 | from objax.typing import JaxArray 19 | 20 | 21 | class ConvNet(objax.nn.Sequential): 22 | """ConvNet implementation.""" 23 | 24 | @staticmethod 25 | def _mean_reduce(x: JaxArray) -> JaxArray: 26 | return x.mean((2, 3)) 27 | 28 | def __init__(self, nin, nclass, scales, filters, filters_max, 29 | pooling=objax.functional.max_pool_2d, **kwargs): 30 | """Creates ConvNet instance. 31 | 32 | Args: 33 | nin: number of channels in the input image. 34 | nclass: number of output classes. 35 | scales: number of pooling layers, each of which reduces spatial dimension by 2. 36 | filters: base number of convolution filters. 37 | Number of convolution filters is increased by 2 every scale until it reaches filters_max. 38 | filters_max: maximum number of filters. 39 | pooling: type of pooling layer. 40 | """ 41 | del kwargs 42 | 43 | def nf(scale): 44 | return min(filters_max, filters << scale) 45 | 46 | ops = [objax.nn.Conv2D(nin, nf(0), 3), objax.functional.leaky_relu] 47 | for i in range(scales): 48 | ops.extend([objax.nn.Conv2D(nf(i), nf(i), 3), objax.functional.leaky_relu, 49 | objax.nn.Conv2D(nf(i), nf(i + 1), 3), objax.functional.leaky_relu, 50 | functools.partial(pooling, size=2, strides=2)]) 51 | ops.extend([objax.nn.Conv2D(nf(scales), nclass, 3), self._mean_reduce]) 52 | super().__init__(ops) 53 | -------------------------------------------------------------------------------- /objax/zoo/dnnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Callable, Iterable 16 | 17 | from objax.nn import Linear, Sequential 18 | 19 | 20 | class DNNet(Sequential): 21 | """Deep neural network (MLP) implementation.""" 22 | 23 | def __init__(self, layer_sizes: Iterable[int], activation: Callable): 24 | """Creates DNNet instance. 25 | 26 | Args: 27 | layer_sizes: number of neurons for each layer. 28 | activation: layer activation. 29 | """ 30 | layer_sizes = list(layer_sizes) 31 | assert len(layer_sizes) >= 2 32 | ops = [] 33 | for i in range(1, len(layer_sizes)): 34 | ops.extend([Linear(layer_sizes[i - 1], layer_sizes[i]), activation]) 35 | super().__init__(ops) 36 | -------------------------------------------------------------------------------- /objax/zoo/rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Callable 16 | 17 | import jax.numpy as jn 18 | 19 | from objax import Module 20 | from objax.nn import Linear 21 | from objax.nn.init import kaiming_normal 22 | from objax.typing import JaxArray 23 | from objax.variable import TrainVar, StateVar 24 | 25 | 26 | class RNN(Module): 27 | """ Recurrent Neural Network (RNN) block.""" 28 | 29 | def __init__(self, 30 | nstate: int, 31 | nin: int, 32 | nout: int, 33 | activation: Callable = jn.tanh, 34 | w_init: Callable = kaiming_normal): 35 | """Creates an RNN instance. 36 | 37 | Args: 38 | nstate: number of hidden units. 39 | nin: number of input units. 40 | nout: number of output units. 41 | activation: actication function for hidden layer. 42 | w_init: weight initializer for RNN model weights. 43 | """ 44 | self.num_inputs = nin 45 | self.num_outputs = nout 46 | self.nstate = nstate 47 | self.activation = activation 48 | 49 | # Hidden layer parameters 50 | self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) 51 | self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) 52 | self.b_h = TrainVar(jn.zeros(self.nstate)) 53 | 54 | self.output_layer = Linear(self.nstate, self.num_outputs) 55 | 56 | def init_state(self, batch_size): 57 | """Initialize hidden state for input batch of size ``batch_size``.""" 58 | self.state = StateVar(jn.zeros((batch_size, self.nstate))) 59 | 60 | def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: 61 | """Forward pass through RNN. 62 | 63 | Args: 64 | inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. 65 | only_return_final: return only the last output if ``True``, or all output otherwise.` 66 | 67 | Returns: 68 | Output tensor with dimensions ``num_steps * batch_size, vocabulary_size``. 69 | """ 70 | # Dimensions: num_steps, batch_size, vocab_size 71 | outputs = [] 72 | for x in inputs: 73 | self.state.value = self.activation( 74 | jn.dot(x, self.w_xh.value) 75 | + jn.dot(self.state.value, self.w_hh.value) 76 | + self.b_h.value) 77 | y = self.output_layer(self.state.value) 78 | outputs.append(y) 79 | if only_return_final: 80 | return outputs[-1] 81 | return jn.concatenate(outputs, axis=0) 82 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | numpy>=1.18.0 3 | pillow 4 | jaxlib>=0.4.19 5 | jax>=0.3.25 6 | tensorboard>=2.3.0 7 | parameterized 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | from pkg_resources import parse_requirements 18 | from setuptools import find_packages, setup 19 | 20 | README_FILE = 'README.md' 21 | REQUIREMENTS_FILE = 'requirements.txt' 22 | VERSION_FILE = 'objax/_version.py' 23 | VERSION_REGEXP = r'^__version__ = \'(\d+\.\d+\.\d+)\'' 24 | 25 | r = re.search(VERSION_REGEXP, open(VERSION_FILE).read(), re.M) 26 | if r is None: 27 | raise RuntimeError(f'Unable to find version string in {VERSION_FILE}.') 28 | 29 | version = r.group(1) 30 | long_description = open(README_FILE, encoding='utf-8').read() 31 | install_requires = [str(r) for r in parse_requirements(open(REQUIREMENTS_FILE, 'rt'))] 32 | 33 | setup( 34 | name='objax', 35 | version=version, 36 | description='Objax is a machine learning framework that provides an Object Oriented layer for JAX.', 37 | long_description=long_description, 38 | long_description_content_type='text/markdown', 39 | author='Objax team', 40 | author_email='objax-dev@google.com', 41 | url='https://github.com/google/objax', 42 | packages=find_packages(), 43 | classifiers=[ 44 | 'Development Status :: 5 - Production/Stable', 45 | 'Intended Audience :: Developers', 46 | 'Intended Audience :: Science/Research', 47 | 'License :: OSI Approved :: Apache Software License', 48 | 'Programming Language :: Python :: 3.9', 49 | 'Programming Language :: Python :: 3.10', 50 | 'Programming Language :: Python :: 3.11', 51 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 52 | ], 53 | install_requires=install_requires, 54 | ) 55 | -------------------------------------------------------------------------------- /tests/dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit Tests for Dropout layer.""" 16 | 17 | import unittest 18 | 19 | import jax.numpy as jn 20 | 21 | import objax 22 | from objax import random 23 | 24 | 25 | class TestDropout(unittest.TestCase): 26 | def test_on_dropout_0_5(self): 27 | """ 28 | Pass an input through a Dropout layer 29 | that keeps half the input and test 30 | that half of the output values are zero. 31 | """ 32 | 33 | drop_input = jn.array([[1., 2., 3., 4., 5., 6.]]) 34 | keep = 0.5 35 | test_generator = random.DEFAULT_GENERATOR 36 | test_generator.seed(3) 37 | dropout_layer = objax.nn.Dropout(keep, test_generator) 38 | training = True 39 | drop_output = dropout_layer(drop_input, training) 40 | self.assertEqual(jn.count_nonzero(drop_output), 3) 41 | for index in range(drop_output.shape[1]): 42 | if drop_output[0][index] != 0: 43 | self.assertEqual(drop_output[0][index], drop_input[0][index] / keep) 44 | 45 | def test_on_dropout_two_dimension(self): 46 | """ 47 | Pass a two dimensional input through a Dropout layer 48 | that keeps half the input and test 49 | that half of the output values are zero. 50 | """ 51 | 52 | drop_input = jn.array([[1., 2., 3.], [5., 6., 7.]]) 53 | keep = 0.5 54 | test_generator = random.DEFAULT_GENERATOR 55 | test_generator.seed(3) 56 | dropout_layer = objax.nn.Dropout(keep, test_generator) 57 | training = True 58 | drop_output = dropout_layer(drop_input, training) 59 | self.assertEqual(jn.count_nonzero(drop_output), 3) 60 | 61 | def test_on_dropout_1_0(self): 62 | """ 63 | Pass an input through a Dropout layer 64 | that keeps all of the input and test 65 | that all of the output values are non-zero. 66 | """ 67 | 68 | drop_input = jn.array([[1., 2., 3., 4., 5., 6.]]) 69 | keep = 1.0 70 | dropout_layer = objax.nn.Dropout(keep) 71 | training = True 72 | drop_output = dropout_layer(drop_input, training) 73 | self.assertEqual(jn.count_nonzero(drop_output), 6) 74 | self.assertTrue(jn.array_equal(drop_input, drop_output)) 75 | 76 | def test_on_dropout_0_0(self): 77 | """ 78 | Pass an input through a Dropout layer 79 | that keeps none of the input and test 80 | that all of the output values are zero. 81 | """ 82 | 83 | drop_input = jn.array([[1., 2., 3., 4., 5., 6.]]) 84 | keep = 0.0 85 | test_generator = random.DEFAULT_GENERATOR 86 | test_generator.seed(1) 87 | dropout_layer = objax.nn.Dropout(keep, test_generator) 88 | training = True 89 | drop_output = dropout_layer(drop_input, training) 90 | self.assertEqual(jn.count_nonzero(drop_output), 0) 91 | 92 | def test_on_dropout_inference(self): 93 | """ 94 | Pass an input to the Dropout layer when 95 | training is false and test that the output 96 | is equal to the input. 97 | """ 98 | 99 | drop_input = jn.array([[1., 2., 3., 4., 5.]]) 100 | dropout_layer = objax.nn.Dropout(0.5) 101 | training = False 102 | drop_output = dropout_layer(drop_input, training) 103 | self.assertTrue(jn.array_equal(drop_input, drop_output)) 104 | 105 | 106 | if __name__ == '__main__': 107 | unittest.main() 108 | -------------------------------------------------------------------------------- /tests/functional_interpolate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for functional upsample operations.""" 16 | 17 | import unittest 18 | import jax 19 | import jax.numpy as jn 20 | import numpy as np 21 | 22 | import objax 23 | 24 | 25 | def shaparange(s): 26 | return jn.arange(np.prod(s), dtype=float).reshape(s) 27 | 28 | 29 | class TestUpsample(unittest.TestCase): 30 | methods = ['nearest', 'linear', 'bilinear', 'trilinear', 'triangle', 'cubic', 31 | 'bicubic', 'tricubic', 'lanczos3', 'lanczos5'] 32 | 33 | def test_upsample2d(self): 34 | x = shaparange((2, 3, 10, 30)) 35 | shape = x.shape 36 | for method in self.methods: 37 | y = objax.functional.core.ops.upsample_2d(x, (2, 3), method) 38 | output = jax.image.resize(x.transpose([0, 2, 3, 1]), 39 | shape=(shape[0], shape[2] * 2, shape[3] * 3, shape[1]), 40 | method=method).transpose([0, 3, 1, 2]) 41 | self.assertEqual(y.tolist(), output.tolist()) 42 | 43 | def test_interpolate(self): 44 | x = shaparange((1, 3, 2, 3, 10, 30)) 45 | shape = x.shape 46 | for method in self.methods: 47 | output = 2 48 | y = objax.functional.core.ops.interpolate(x, scale_factor=output, mode=method) 49 | self.assertEqual(y.shape, (shape[0], *(jn.array(shape[1:])) * output)) 50 | output = (2, 2, 2) 51 | y = objax.functional.core.ops.interpolate(x, scale_factor=output, mode=method) 52 | self.assertEqual(y.shape, (*shape[:len(shape) - len(output)], 53 | *(jn.array(shape[len(shape) - len(output):]) * jn.array(output)))) 54 | output = (2, 2, 2, 2) 55 | y = objax.functional.core.ops.interpolate(x, scale_factor=output, mode=method) 56 | self.assertEqual(y.shape, (*shape[:len(shape) - len(output)], 57 | *(jn.array(shape[len(shape) - len(output):]) * jn.array(output)))) 58 | output = 2 59 | y = objax.functional.core.ops.interpolate(x, size=output, mode=method) 60 | self.assertEqual(y.shape, (*shape[:-1], output)) 61 | output = (2, 2, 2) 62 | y = objax.functional.core.ops.interpolate(x, size=output, mode=method) 63 | self.assertEqual(y.shape, (*shape[:len(shape) - len(output)], *output)) 64 | output = (2, 2, 2, 2) 65 | y = objax.functional.core.ops.interpolate(x, size=output, mode=method) 66 | self.assertEqual(y.shape, (*shape[:len(shape) - len(output)], *output)) 67 | 68 | 69 | if __name__ == '__main__': 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /tests/functional_pooling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for functional pooling operations.""" 16 | 17 | import unittest 18 | 19 | import jax.numpy as jn 20 | import numpy as np 21 | 22 | import objax 23 | 24 | 25 | def shaparange(s): 26 | return jn.arange(np.prod(s), dtype=float).reshape(s) 27 | 28 | 29 | def pad(x, pad_width): 30 | return np.pad(x, pad_width, mode='constant') 31 | 32 | 33 | class TestPooling(unittest.TestCase): 34 | def test_average_pooling2d(self): 35 | x = shaparange((2, 3, 10, 30)) 36 | y = objax.functional.average_pool_2d(x, size=5) 37 | z = x.reshape((2, 3, 2, 5, 6, 5)).mean((-3, -1)) 38 | self.assertEqual(y.tolist(), z.tolist()) 39 | y = objax.functional.average_pool_2d(x, size=5, strides=1) 40 | z = np.zeros((2, 3, 6, 26), dtype=float) 41 | for i in range(6): 42 | for j in range(26): 43 | z[:, :, i, j] = x[:, :, i:i + 5, j:j + 5].mean((-2, -1)) 44 | self.assertEqual(y.tolist(), z.tolist()) 45 | y = objax.functional.average_pool_2d(x, size=(2, 3)) 46 | z = x.reshape((2, 3, 5, 2, 10, 3)).mean((-3, -1)) 47 | self.assertEqual(y.tolist(), z.tolist()) 48 | 49 | def test_max_pooling2d(self): 50 | x = shaparange((2, 3, 10, 30)) 51 | y = objax.functional.max_pool_2d(x, size=5) 52 | z = x.reshape((2, 3, 2, 5, 6, 5)).max((-3, -1)) 53 | self.assertEqual(y.tolist(), z.tolist()) 54 | y = objax.functional.max_pool_2d(x, size=5, strides=1) 55 | z = np.zeros((2, 3, 6, 26), dtype=float) 56 | for i in range(6): 57 | for j in range(26): 58 | z[:, :, i, j] = x[:, :, i:i + 5, j:j + 5].max((-2, -1)) 59 | self.assertEqual(y.tolist(), z.tolist()) 60 | y = objax.functional.max_pool_2d(x, size=(2, 3)) 61 | z = x.reshape((2, 3, 5, 2, 10, 3)).max((-3, -1)) 62 | self.assertEqual(y.tolist(), z.tolist()) 63 | 64 | def test_pooling2d_padding(self): 65 | x = shaparange((2, 3, 10, 30)) 66 | y = objax.functional.average_pool_2d(x, size=5, padding=(2, 3)) 67 | z = pad(x, ((0, 0), (0, 0), (2, 3), (2, 3))).reshape((2, 3, 3, 5, 7, 5)).mean((-3, -1)) 68 | self.assertEqual(y.tolist(), z.tolist()) 69 | y = objax.functional.max_pool_2d(x, size=5, padding=(2, 3)) 70 | z = pad(x, ((0, 0), (0, 0), (2, 3), (2, 3))).reshape((2, 3, 3, 5, 7, 5)).max((-3, -1)) 71 | self.assertEqual(y.tolist(), z.tolist()) 72 | y = objax.functional.average_pool_2d(x, size=5, padding=((2, 3), (3, 2))) 73 | z = pad(x, ((0, 0), (0, 0), (2, 3), (3, 2))).reshape((2, 3, 3, 5, 7, 5)).mean((-3, -1)) 74 | self.assertEqual(y.tolist(), z.tolist()) 75 | y = objax.functional.max_pool_2d(x, size=5, padding=((2, 3), (3, 2))) 76 | z = pad(x, ((0, 0), (0, 0), (2, 3), (3, 2))).reshape((2, 3, 3, 5, 7, 5)).max((-3, -1)) 77 | self.assertEqual(y.tolist(), z.tolist()) 78 | y = objax.functional.average_pool_2d(x, size=2, padding=1) 79 | z = pad(x, ((0, 0), (0, 0), (1, 1), (1, 1))).reshape((2, 3, 6, 2, 16, 2)).mean((-3, -1)) 80 | self.assertEqual(y.tolist(), z.tolist()) 81 | y = objax.functional.max_pool_2d(x, size=2, padding=1) 82 | z = pad(x, ((0, 0), (0, 0), (1, 1), (1, 1))).reshape((2, 3, 6, 2, 16, 2)).max((-3, -1)) 83 | self.assertEqual(y.tolist(), z.tolist()) 84 | 85 | def test_space_batch(self): 86 | """Test batch_to_space2d and space_to_batch2d.""" 87 | x = shaparange((2, 3, 10, 30)) 88 | y = objax.functional.space_to_batch2d(x, size=5) 89 | z = objax.functional.batch_to_space2d(y, size=5) 90 | self.assertEqual(x.tolist(), z.tolist()) 91 | self.assertEqual(y.shape, (50, 3, 2, 6)) 92 | y = objax.functional.space_to_batch2d(x, size=(2, 3)) 93 | z = objax.functional.batch_to_space2d(y, size=(2, 3)) 94 | self.assertEqual(x.tolist(), z.tolist()) 95 | self.assertEqual(y.shape, (12, 3, 5, 10)) 96 | 97 | def test_space_channel(self): 98 | """Test channel_to_space2d and space_to_channel2d.""" 99 | x = shaparange((2, 3, 10, 30)) 100 | y = objax.functional.space_to_channel2d(x, size=5) 101 | z = objax.functional.channel_to_space2d(y, size=5) 102 | self.assertEqual(x.tolist(), z.tolist()) 103 | self.assertEqual(y.shape, (2, 75, 2, 6)) 104 | y = objax.functional.space_to_channel2d(x, size=(2, 3)) 105 | z = objax.functional.channel_to_space2d(y, size=(2, 3)) 106 | self.assertEqual(x.tolist(), z.tolist()) 107 | self.assertEqual(y.shape, (2, 18, 5, 10)) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /tests/jit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for ObJAX JIT.""" 16 | 17 | import unittest 18 | 19 | import jax.numpy as jn 20 | from jax.core import ConcretizationTypeError 21 | 22 | import objax 23 | from objax.typing import JaxArray 24 | 25 | 26 | class LinearArgs(objax.nn.Linear): 27 | def __call__(self, x: JaxArray, some_args: float) -> JaxArray: 28 | """Returns the results of applying the linear transformation to input x.""" 29 | y = jn.dot(x, self.w.value) * some_args 30 | if self.b is not None: 31 | y += self.b.value 32 | return y 33 | 34 | 35 | class LinearTrain(objax.nn.Linear): 36 | def __call__(self, x: JaxArray, training: bool) -> JaxArray: 37 | """Returns the results of applying the linear transformation to input x.""" 38 | y = jn.dot(x, self.w.value) 39 | if training: 40 | y = -y 41 | if self.b is not None: 42 | y += self.b.value 43 | return y 44 | 45 | 46 | class TestJit(unittest.TestCase): 47 | def test_on_linear(self): 48 | k = objax.nn.Linear(3, 3) 49 | kj = objax.Jit(k) 50 | x = objax.random.normal((64, 3)) 51 | y1 = kj(x) 52 | k.w.assign(k.w.value + 1) 53 | y2 = kj(x) 54 | k.w.assign(k.w.value - 1) 55 | y3 = kj(x) 56 | self.assertAlmostEqual(((y1 - y3) ** 2).sum(), 0) 57 | self.assertNotEqual(((y1 - y2) ** 2).sum(), 0) 58 | 59 | def test_double_jit(self): 60 | k = objax.nn.Linear(3, 3) 61 | kj = objax.Jit(objax.Jit(k)) 62 | x = objax.random.normal((64, 3)) 63 | y1 = kj(x) 64 | k.w.assign(k.w.value + 1) 65 | y2 = kj(x) 66 | k.w.assign(k.w.value - 1) 67 | y3 = kj(x) 68 | self.assertAlmostEqual(((y1 - y3) ** 2).sum(), 0) 69 | self.assertNotEqual(((y1 - y2) ** 2).sum(), 0) 70 | 71 | def test_jit_kwargs(self): 72 | x = objax.random.normal((64, 3)) 73 | kj = objax.Jit(LinearArgs(3, 3)) 74 | y1 = kj(x, 1) 75 | y2 = kj(x, some_args=1) 76 | y3 = kj(x, some_args=2) 77 | self.assertEqual(y1.tolist(), y2.tolist()) 78 | self.assertNotEqual(y1.tolist(), y3.tolist()) 79 | kj = objax.Jit(LinearTrain(3, 3)) 80 | with self.assertRaises(ConcretizationTypeError): 81 | kj(x, training=True) 82 | 83 | def test_trainvar_assign(self): 84 | m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) 85 | 86 | def increase(): 87 | m[0].assign(m[0].value + 1) 88 | return m[0].value 89 | 90 | jit_increase = objax.Jit(increase, m.vars()) 91 | jit_increase() 92 | self.assertEqual(m[0].value.tolist(), [1., 1.]) 93 | 94 | def test_trainvar_and_ref_assign(self): 95 | m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) 96 | m.append(objax.TrainRef(m[0])) 97 | 98 | def increase(): 99 | m[0].assign(m[0].value + 1) 100 | m[1].assign(m[1].value + 1) 101 | return m[0].value 102 | 103 | jit_increase = objax.Jit(increase, m.vars()) 104 | v = jit_increase() 105 | self.assertEqual(v.tolist(), [2., 2.]) 106 | self.assertEqual(m[0].value.tolist(), [2., 2.]) 107 | 108 | def test_constant_optimization(self): 109 | m = objax.nn.Linear(3, 4) 110 | jit_constant = objax.Jit(m, objax.VarCollection()) 111 | 112 | x = objax.random.normal((10, 3)) 113 | self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 0) 114 | 115 | # Modify m (which was supposed to be constant!) 116 | m.b.assign(m.b.value + 1) 117 | self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 40) 118 | 119 | 120 | if __name__ == '__main__': 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /tests/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for Convolution Layer.""" 16 | 17 | import unittest 18 | 19 | import jax.numpy as jn 20 | 21 | import objax 22 | 23 | 24 | class TestLinear(unittest.TestCase): 25 | def test_on_linear_three_unit(self): 26 | """ 27 | Pass an input through a linear filter with 3 units and 28 | test the shape and contents of the output. 29 | """ 30 | 31 | # Define linear filter with 1 input channel and 3 output channels 32 | linear_filter = objax.nn.Linear(1, 3, use_bias=False) 33 | weights = objax.TrainVar(jn.array([[1., 2., 1.]])) 34 | linear_filter.w = weights 35 | 36 | # Define data and compute output response of linear filter 37 | data = jn.array([[1.], [2.]]) 38 | features = linear_filter(data) 39 | expected_features = jn.array([[1., 2., 1.], [2., 4., 2.]]) 40 | self.assertEqual(features.shape, (2, 3)) 41 | self.assertTrue(jn.array_equal(features, expected_features)) 42 | 43 | def test_on_linear_three_unit_with_bias(self): 44 | """ 45 | Pass an input through a linear filter with 3 units and bias 46 | test the shape and contents of the output. 47 | """ 48 | 49 | # Define linear filter with 1 input channel and 3 output channels 50 | linear_filter = objax.nn.Linear(1, 3, use_bias=True) 51 | weights = objax.TrainVar(jn.array([[1., 2., 1.]])) 52 | bias = objax.TrainVar(jn.array([2., 1., 2.])) 53 | linear_filter.w = weights 54 | linear_filter.b = bias 55 | 56 | # Define data and compute output response of linear filter 57 | data = jn.array([[1.], [2.]]) 58 | features = linear_filter(data) 59 | expected_features = jn.array([[3., 3., 3.], [4., 5., 4.]]) 60 | self.assertEqual(features.shape, (2, 3)) 61 | self.assertTrue(jn.array_equal(features, expected_features)) 62 | 63 | 64 | if __name__ == '__main__': 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /tests/nn_moving_average.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for MovingAverage and ExponentialMovingAverage Layer.""" 16 | 17 | import unittest 18 | 19 | import jax.numpy as jn 20 | import numpy as np 21 | 22 | import objax 23 | 24 | 25 | class TestMovingAverage(unittest.TestCase): 26 | 27 | def test_MovingAverage(self): 28 | """Test MovingAverage.""" 29 | x1 = jn.array([[0, 1, 2]]) 30 | x2 = jn.array([[0, 0, 0]]) 31 | x3 = jn.array([[-3, -4, 5]]) 32 | init_value = 100 33 | shape = x1.shape 34 | ma = objax.nn.MovingAverage(shape=shape, buffer_size=2, init_value=init_value) 35 | 36 | x_ma1 = ma(x1) 37 | x_ma2 = ma(x2) 38 | x_ma3 = ma(x3) 39 | 40 | np.testing.assert_allclose(x_ma1, np.array([[50, 50.5, 51]])) 41 | np.testing.assert_allclose(x_ma2, np.array([[0, 0.5, 1]])) 42 | np.testing.assert_allclose(x_ma3, np.array([[-1.5, -2, 2.5]])) 43 | 44 | def test_ExponentialMovingAverage(self): 45 | """Test ExponentialMovingAverage.""" 46 | x1 = jn.array([[0, 1, 2]]) * 100 47 | x2 = jn.array([[-3, -4, 5]]) * 100 48 | init_value = 100 49 | shape = x1.shape 50 | ema = objax.nn.ExponentialMovingAverage(shape=shape, init_value=init_value, momentum=0.8) 51 | 52 | x_ema1 = ema(x1) 53 | x_ema2 = ema(x2) 54 | 55 | np.testing.assert_allclose(x_ema1, np.array([[80, 100, 120]])) 56 | np.testing.assert_allclose(x_ema2, np.array([[4, 0, 196]])) 57 | 58 | 59 | if __name__ == '__main__': 60 | unittest.main() 61 | -------------------------------------------------------------------------------- /tests/objax2tf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for Objax2Tf converter.""" 16 | 17 | import shutil 18 | import tempfile 19 | import unittest 20 | 21 | import numpy as np 22 | import objax 23 | from objax.zoo.wide_resnet import WideResNet 24 | import tensorflow as tf 25 | 26 | 27 | BATCH_SIZE = 4 28 | NCHANNELS = 3 29 | NCLASSES = 10 30 | IMAGE_SIZE = 32 31 | 32 | 33 | class TestObjax2Tf(unittest.TestCase): 34 | 35 | def verify_converted_predict_op(self, objax_op, tf_op, shape): 36 | x1 = np.random.normal(size=shape) 37 | x2 = np.random.normal(size=shape) 38 | # due to differences in op implementations, there might be small numerical 39 | # differences between TF and Objax, thus comparing up to 1e-4 relative tolerance 40 | np.testing.assert_allclose(objax_op(x1), tf_op(tf.convert_to_tensor(x1, dtype=tf.float32)), rtol=1e-4) 41 | np.testing.assert_allclose(objax_op(x2), tf_op(tf.convert_to_tensor(x2, dtype=tf.float32)), rtol=1e-4) 42 | 43 | # NOTE: Objax2Tf tests are temporary disabled until the release of TF 2.8 44 | 45 | def disabled_test_convert_wrn(self): 46 | # Make a model 47 | model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1) 48 | # Prediction op without JIT 49 | predict_op = objax.nn.Sequential([objax.ForceArgs(model, training=False), objax.functional.softmax]) 50 | predict_tf = objax.util.Objax2Tf(predict_op) 51 | # Compare results 52 | self.verify_converted_predict_op(predict_op, predict_tf, 53 | shape=(BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE)) 54 | # Predict op with JIT 55 | predict_op_jit = objax.Jit(predict_op) 56 | predict_tf_jit = objax.util.Objax2Tf(predict_op_jit) 57 | # Compare results 58 | self.verify_converted_predict_op(predict_op_jit, predict_tf_jit, 59 | shape=(BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE)) 60 | 61 | def disabled_test_savedmodel_wrn(self): 62 | model_dir = tempfile.mkdtemp() 63 | # Make a model and convert it to TF 64 | model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1) 65 | predict_op = objax.Jit(objax.nn.Sequential([objax.ForceArgs(model, training=False), objax.functional.softmax])) 66 | predict_tf = objax.util.Objax2Tf(predict_op) 67 | # Save model 68 | input_shape = (BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE) 69 | tf.saved_model.save( 70 | predict_tf, 71 | model_dir, 72 | signatures=predict_tf.__call__.get_concrete_function(tf.TensorSpec(input_shape, tf.float32))) 73 | # Load model 74 | loaded_tf_model = tf.saved_model.load(model_dir) 75 | loaded_predict_tf_op = loaded_tf_model.signatures['serving_default'] 76 | self.verify_converted_predict_op(predict_op, 77 | lambda x: loaded_predict_tf_op(x)['output_0'], 78 | shape=input_shape) 79 | self.verify_converted_predict_op(predict_op, 80 | lambda x: loaded_tf_model(x), 81 | shape=input_shape) 82 | # Cleanup 83 | shutil.rmtree(model_dir) 84 | 85 | 86 | if __name__ == '__main__': 87 | unittest.main() 88 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | numpy 3 | tensorflow 4 | -------------------------------------------------------------------------------- /tests/run_linter.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2020 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Change directory to repository root 17 | cd "$( dirname "${BASH_SOURCE[0]}" )/.." 18 | 19 | # Run linter with following changes to default rules: 20 | # - We allow assignment of lambda, thus ignore E731 error: https://www.flake8rules.com/rules/E731.html 21 | # - Line break should occur before binary operator, thus between W503 and W504 ignore W503 and follow W504, 22 | # https://www.flake8rules.com/rules/W503.html 23 | # - Set max line length to 120 characters 24 | # - Separately lint __init__.py and other files, otherwise flake8 complains about unused imports in __init__.py 25 | flake8 --exclude=__init__.py --max-line-length=120 --ignore=E731,W503 objax/ || exit 1 26 | flake8 --filename=__init__.py --max-line-length=120 --ignore=E731,W503 objax/ || exit 1 27 | flake8 --max-line-length=120 --ignore=E731,W503 tests/ || exit 1 28 | -------------------------------------------------------------------------------- /tests/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2020 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Change directory to repository root 17 | cd "$( dirname "${BASH_SOURCE[0]}" )/.." 18 | 19 | if python3 -c "import pytest" &> /dev/null ; then 20 | # If pytest is installed then use it to run tests 21 | # Pytest has nicer output compared to unittest package and also it's used 22 | # to run automatic unit tests on GitHub. 23 | CUDA_VISIBLE_DEVICES= pytest tests/*.py 24 | else 25 | # If pytest is not installed then use default unittest to run tests. 26 | for i in tests/*.py; do 27 | CUDA_VISIBLE_DEVICES= python3 -m unittest $i >&$i.log & 28 | done 29 | wait 30 | fgrep FAILED tests/*.log 31 | fi 32 | -------------------------------------------------------------------------------- /tests/scan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for scan method.""" 16 | 17 | import unittest 18 | 19 | import jax.numpy as jn 20 | 21 | import objax 22 | 23 | 24 | class TestScan(unittest.TestCase): 25 | def test_scan(self): 26 | def cell(carry, x): 27 | return jn.array([2]) * carry * x, jn.array([3]) * carry * x 28 | 29 | carry = jn.array([8., 8.]) 30 | output = jn.array([[3., 3.], [6., 6.], [12., 12.]]) 31 | test_carry, test_output = objax.functional.scan(cell, jn.ones((2,)), jn.ones((3,))) 32 | self.assertTrue(jn.array_equal(carry, test_carry)) 33 | self.assertTrue(jn.array_equal(output, test_output)) 34 | -------------------------------------------------------------------------------- /tests/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for optimizers.""" 16 | 17 | import unittest 18 | 19 | import numpy as np 20 | 21 | import objax 22 | 23 | 24 | class TestScheduler(unittest.TestCase): 25 | def test_linear_annealing(self): 26 | sched = objax.optimizer.scheduler.LinearAnnealing(max_step=10, base_lr=1, is_cycle=True, min_lr=0) 27 | lrs = [] 28 | for i in range(10): 29 | lrs.append(sched(step=i)) 30 | lrs_gt = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] 31 | np.testing.assert_array_almost_equal(lrs, lrs_gt) 32 | 33 | def test_step_decay(self): 34 | sched = objax.optimizer.scheduler.StepDecay(step_size=3, base_lr=1, gamma=0.9) 35 | lrs = [] 36 | for i in range(10): 37 | lrs.append(sched(step=i)) 38 | lrs_gt = [1, 1, 1, 0.9, 0.9, 0.9, 0.81, 0.81, 0.81, 0.729] 39 | np.testing.assert_array_almost_equal(lrs, lrs_gt) 40 | 41 | def test_multi_step_decay(self): 42 | sched = objax.optimizer.scheduler.StepDecay(step_size=[3, 5, 8], base_lr=1, gamma=0.9) 43 | lrs = [] 44 | for i in range(10): 45 | lrs.append(sched(step=i)) 46 | lrs_gt = [1, 1, 1, 0.9, 0.9, 0.81, 0.81, 0.81, 0.729, 0.729] 47 | np.testing.assert_array_almost_equal(lrs, lrs_gt) 48 | 49 | 50 | if __name__ == '__main__': 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /tests/sequential.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for Convolution Layer.""" 16 | 17 | import unittest 18 | 19 | import jax.numpy as jn 20 | 21 | import objax 22 | 23 | 24 | class TestSequential(unittest.TestCase): 25 | def test_on_sequential_linear_relu(self): 26 | """ 27 | Pass an input through a linear filter with 3 units followed by ReLU and 28 | test the shape and contents of the output. 29 | """ 30 | 31 | # Define linear filter with 1 input channel and 3 output channels 32 | linear_filter = objax.nn.Linear(2, 3, use_bias=False) 33 | weights = objax.TrainVar(jn.array([[1., 2., 1.], [2., 1., 2.]])) 34 | linear_filter.w = weights 35 | sequential = objax.nn.Sequential([linear_filter, 36 | objax.functional.relu]) 37 | 38 | # Define data and compute output response of linear filter 39 | data = jn.array([[1., -1.], [2., -2.]]) 40 | features = sequential(data) 41 | expected_features = jn.array([[0., 1., 0.], [0., 2., 0.]]) 42 | self.assertEqual(features.shape, (2, 3)) 43 | self.assertTrue(jn.array_equal(features, expected_features)) 44 | 45 | def test_on_sequential_relu_linear(self): 46 | """ 47 | Pass an input through a linear filter with 3 units followed by ReLU and 48 | test the shape and contents of the output. 49 | """ 50 | 51 | # Define linear filter with 1 input channel and 3 output channels 52 | linear_filter = objax.nn.Linear(2, 3, use_bias=False) 53 | weights = objax.TrainVar(jn.array([[1., 2., 1.], [2., 1., 2.]])) 54 | linear_filter.w = weights 55 | sequential = objax.nn.Sequential([objax.functional.relu, 56 | linear_filter]) 57 | 58 | # Define data and compute output response of linear filter 59 | data = jn.array([[1., -1.], [2., -2.]]) 60 | features = sequential(data) 61 | expected_features = jn.array([[1., 2., 1.], [2., 4., 2.]]) 62 | self.assertEqual(features.shape, (2, 3)) 63 | self.assertTrue(jn.array_equal(features, expected_features)) 64 | 65 | def test_kwargs(self): 66 | """Test sequential on modules that take named inputs in kwargs.""" 67 | 68 | class MyModule: 69 | def __init__(self): 70 | pass 71 | 72 | def __call__(self, x, some_param): 73 | return x + some_param 74 | 75 | seq = objax.nn.Sequential([MyModule(), MyModule()]) 76 | self.assertEqual(seq(1, some_param=2), 5) 77 | with self.assertRaises(TypeError): 78 | seq(1) 79 | 80 | def test_variadic(self): 81 | """Test sequential on modules that take multiple inputs and have multiple outputs.""" 82 | 83 | class MyModule: 84 | def __init__(self): 85 | pass 86 | 87 | def __call__(self, x, y): 88 | return x + y, x - y 89 | 90 | seq = objax.nn.Sequential([MyModule(), MyModule()]) 91 | self.assertEqual(seq(1, 2), (2, 4)) 92 | 93 | def test_slice(self): 94 | """Test sequential slices with variadic module.""" 95 | 96 | class MyModule: 97 | def __init__(self, m): 98 | self.m = m 99 | 100 | def __call__(self, x, y): 101 | return self.m * x + y, self.m * x - y 102 | 103 | seq = objax.nn.Sequential([MyModule(2), MyModule(3)]) 104 | self.assertEqual(seq(5, 7), (54, 48)) 105 | self.assertEqual(seq[:1](5, 7), (17, 3)) 106 | self.assertEqual(seq[1:](5, 7), (22, 8)) 107 | 108 | def test_on_sequential_missing_argument(self): 109 | m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.nn.BatchNorm0D(3), objax.nn.Linear(3, 2)]) 110 | x = jn.array([[1., -1.], [2., -2.]]) 111 | msg = "missing 1 required positional argument: 'training'" 112 | try: 113 | m(x) 114 | assert False 115 | except TypeError as e: 116 | self.assertIn(msg, str(e)) 117 | m.pop() 118 | try: 119 | m(x) 120 | assert False 121 | except TypeError as e: 122 | self.assertIn(msg, str(e)) 123 | 124 | 125 | if __name__ == '__main__': 126 | unittest.main() 127 | -------------------------------------------------------------------------------- /tests/util_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for objax.util.image.""" 16 | 17 | import io 18 | import tempfile 19 | import unittest 20 | from typing import Tuple 21 | 22 | import jax.numpy as jn 23 | import numpy as np 24 | from PIL import Image 25 | 26 | import objax 27 | 28 | 29 | class TestUtilImage(unittest.TestCase): 30 | def ndimarange(self, dims: Tuple[int, ...]): 31 | return np.arange(np.prod(dims), dtype=float).reshape(dims) 32 | 33 | def test_nchw(self): 34 | x = self.ndimarange((2, 3, 4, 5)) 35 | self.assertEqual(objax.util.image.nchw(x).tolist(), x.transpose((0, 3, 1, 2)).tolist()) 36 | self.assertEqual(objax.util.image.nchw(jn.array(x)).tolist(), x.transpose((0, 3, 1, 2)).tolist()) 37 | x = self.ndimarange((2, 3, 4, 5, 6)) 38 | self.assertEqual(objax.util.image.nchw(x).tolist(), x.transpose((0, 1, 4, 2, 3)).tolist()) 39 | self.assertEqual(objax.util.image.nchw(jn.array(x)).tolist(), x.transpose((0, 1, 4, 2, 3)).tolist()) 40 | 41 | def test_nhwc(self): 42 | x = self.ndimarange((2, 3, 4, 5)) 43 | self.assertEqual(objax.util.image.nhwc(x).tolist(), x.transpose((0, 2, 3, 1)).tolist()) 44 | self.assertEqual(objax.util.image.nhwc(jn.array(x)).tolist(), x.transpose((0, 2, 3, 1)).tolist()) 45 | x = self.ndimarange((2, 3, 4, 5, 6)) 46 | self.assertEqual(objax.util.image.nhwc(x).tolist(), x.transpose((0, 1, 3, 4, 2)).tolist()) 47 | self.assertEqual(objax.util.image.nhwc(jn.array(x)).tolist(), x.transpose((0, 1, 3, 4, 2)).tolist()) 48 | 49 | def test_normalize(self): 50 | """Test normalize methods.""" 51 | x = np.arange(256) 52 | y = objax.util.image.normalize_to_unit_float(x) 53 | self.assertEqual((x / 128 - (1 - 1 / 256)).tolist(), y.tolist()) 54 | self.assertEqual(y.tolist(), y.clip(-1, 1).tolist()) 55 | z = objax.util.image.normalize_to_uint8(y) 56 | self.assertEqual(x.tolist(), z.tolist()) 57 | z = objax.util.image.normalize_to_uint8(y + 1 / 128) 58 | self.assertEqual((x + 1).clip(0, 255).tolist(), z.tolist()) 59 | z = objax.util.image.normalize_to_uint8(y - 1 / 128) 60 | self.assertEqual((x - 1).clip(0, 255).tolist(), z.tolist()) 61 | 62 | def test_to_png(self): 63 | x = np.zeros((3, 32, 32), float) + 1 / 255 64 | x[:, :12, :12] = 1 65 | x[:, -12:, -12:] = -1 66 | y = objax.util.image.to_png(x) 67 | self.assertEqual( 68 | np.array(Image.open(io.BytesIO(y))).tolist(), 69 | np.array(Image.open(io.BytesIO( 70 | b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00 \x00\x00\x00 \x08\x02\x00\x00\x00\xfc' 71 | b'\x18\xed\xa3\x00\x00\x00FIDATx\x9cc\xfc\xff\xff?\x03!\xd0\xd8\xd8HP\r.\xc0D\xb6\xceQ' 72 | b'\x0bF-\x18\xb5`\x04Y\xc0BI9C\x0c\x18\xfaA4j\xc1\x08\xb0\x80\x85\x12\xcd\r\r\r\x04\xd5' 73 | b'\x0c\xfd \x1a\xb5`\xd4\x82Q\x0b\xe8`\x01\x00\xe3\xf1\x07\xc7\x82\x83p\xa5\x00\x00\x00\x00' 74 | b'IEND\xaeB`\x82' 75 | ))).tolist()) 76 | z = np.array(Image.open(io.BytesIO(y))) 77 | z = (z.transpose((2, 0, 1)) - 127.5) / 127.5 78 | self.assertEqual(x.tolist(), z.tolist()) 79 | 80 | def test_to_png_from_file(self): 81 | x = objax.random.randint((3, 32, 24), 0, 256) 82 | x = objax.util.image.normalize_to_unit_float(x) 83 | bin = objax.util.image.to_png(x) 84 | y = objax.util.image.from_file(io.BytesIO(bin)) 85 | self.assertEqual(x.tolist(), y.tolist()) 86 | 87 | def test_image_grid(self): 88 | x = objax.random.randint((5, 7, 3, 8, 4), 0, 256) 89 | y = objax.util.image.image_grid(x) 90 | z = x.transpose((2, 0, 3, 1, 4)).reshape((3, 40, 28)) 91 | self.assertEqual(y.tolist(), z.tolist()) 92 | 93 | def test_from_file_with_filename(self): 94 | x = objax.random.randint((3, 32, 24), 0, 256) 95 | x = objax.util.image.normalize_to_unit_float(x) 96 | with tempfile.NamedTemporaryFile('wb', suffix='.png') as f: 97 | f.write(objax.util.image.to_png(x)) 98 | f.flush() 99 | y = objax.util.image.from_file(f.name) 100 | self.assertEqual(x.tolist(), y.tolist()) 101 | 102 | 103 | if __name__ == '__main__': 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /tests/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unittests for Resnet v2.""" 16 | 17 | import unittest 18 | 19 | import objax 20 | from objax.zoo.wide_resnet import WideResNet, WideResNetGeneral 21 | 22 | 23 | class TestWideResNetGeneral(unittest.TestCase): 24 | 25 | def test_wide_resnet_general(self): 26 | x = objax.random.normal((4, 3, 128, 128)) 27 | model = WideResNetGeneral(nin=3, nclass=10, blocks_per_group=[4, 4, 4, 4], width=2) 28 | # run in eval mode 29 | y_eval = model(x, training=False) 30 | self.assertEqual(y_eval.shape, (4, 10)) 31 | # run in train mode 32 | y_eval = model(x, training=True) 33 | self.assertEqual(y_eval.shape, (4, 10)) 34 | 35 | def test_wide_resnet(self): 36 | x = objax.random.normal((4, 3, 32, 32)) 37 | model = WideResNet(nin=3, nclass=10, depth=28, width=4) 38 | # run in eval mode 39 | y_eval = model(x, training=False) 40 | self.assertEqual(y_eval.shape, (4, 10)) 41 | # run in train mode 42 | y_eval = model(x, training=True) 43 | self.assertEqual(y_eval.shape, (4, 10)) 44 | 45 | 46 | if __name__ == '__main__': 47 | unittest.main() 48 | --------------------------------------------------------------------------------