├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── create-release.yml │ ├── publish-package.yml │ └── run-tests.yml ├── .gitignore ├── .gitpod.yml ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── basic-api │ └── modules-losses-metrics.md ├── contributing.md ├── elegy-module.md ├── getting-started │ ├── high-level-api.ipynb │ └── low-level-api.ipynb ├── guides │ └── contributing.md ├── images │ └── favicon.png ├── index.md ├── low-level-api │ ├── basics.md │ ├── default-implementation.md │ ├── methods │ │ ├── pred_step.md │ │ └── test_step.md │ └── states.md └── na.md ├── elegy ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── callback.py │ ├── callback_list.py │ ├── csv_logger.py │ ├── early_stopping.py │ ├── history.py │ ├── lambda_callback.py │ ├── model_checkpoint.py │ ├── progbar_logger.py │ ├── remote_monitor.py │ ├── sigint.py │ ├── tensorboard.py │ ├── terminate_nan.py │ └── wandb_callback.py ├── data │ ├── __init__.py │ ├── array_adapter.py │ ├── data_adapter.py │ ├── data_handler.py │ ├── dataset.py │ ├── generator_adapter.py │ ├── list_adapter.py │ ├── tf_dataset_adapter.py │ ├── torch_dataloader_adapter.py │ └── utils.py ├── model │ ├── __init__.py │ ├── model.py │ ├── model_base.py │ ├── model_core.py │ └── utils.py ├── nets │ ├── __init__.py │ └── resnet.py ├── types.py └── utils.py ├── examples ├── elegy │ ├── mnist.py │ ├── mnist_autoencoder.py │ ├── mnist_conv.py │ ├── mnist_dataloader.py │ ├── mnist_tf_data.py │ ├── mnist_torch_dataloader.py │ ├── mnist_vae.py │ └── toy_mlp.py ├── flax │ ├── mnist_conv.py │ ├── mnist_vae.py │ └── toy_mlp.py ├── haiku │ ├── mnist_conv.py │ ├── mnist_vae.py │ └── toy_mlp.py ├── jax │ ├── linear_classifier_test_step.py │ └── linear_classifier_train_step.py ├── need-fixing │ ├── CIFAR10_95%accuracy.ipynb │ ├── WGAN-GP │ │ ├── README.md │ │ ├── images │ │ │ ├── epoch-0009.png │ │ │ ├── epoch-0049.png │ │ │ └── epoch-0099.png │ │ ├── main.py │ │ └── model.py │ └── imagenet │ │ ├── README.md │ │ ├── input_pipeline.py │ │ ├── requirements.txt │ │ └── resnet_imagenet.py └── requirements.txt ├── gitpod.Dockerfile ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── scripts ├── deploy-docs.sh ├── get-coverage.sh ├── install-remote.sh ├── run-docs.sh ├── test-all-versions.sh ├── test-examples.sh ├── test-gpu.sh ├── test-version.sh ├── update_docs.py └── update_version.py ├── tests ├── callbacks │ └── early_stopping_test.py ├── data │ ├── array_adapter_test.py │ ├── data_utils_test.py │ ├── dataset_test.py │ ├── list_adapter_test.py │ ├── tf_dataset_adapter_test.py │ └── torch_dataloader_adapter_test.py ├── model │ ├── model_base_test.py │ ├── model_core_test.py │ └── model_test.py ├── nets │ └── resnet_test.py └── utils_test.py └── tmp └── test.py /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | 3 | !elegy 4 | !tests 5 | !pyproject.toml 6 | !poetry.lock -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug]" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Minimal code to reproduce** 14 | Small snippet that contains a minimal amount of code. 15 | ```python 16 | import elegy 17 | ``` 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Library Info** 23 | Please provide os info and elegy version. 24 | ```python 25 | import elegy 26 | print(elegy.__version__) 27 | ``` 28 | **Screenshots** 29 | If applicable, add screenshots to help explain your problem. 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature Request]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen or how would you like the API to be designed ( A small code snippet can work) 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered, any example in any other framework 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/create-release.yml: -------------------------------------------------------------------------------- 1 | name: Create Release 2 | 3 | on: 4 | create 5 | 6 | jobs: 7 | create-release: 8 | if: startsWith(github.ref_name, 'version-') && endsWith(github.ref_name, '-create-release') 9 | name: Create Release 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 🛎️ 13 | uses: actions/checkout@v2 14 | 15 | - name: Set up Python 3.8 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.8 19 | 20 | - name: Setup 21 | id: setup 22 | run: | 23 | # install python dependencies 24 | pip install typer==0.4.0 click==8.0.3 25 | 26 | # variables 27 | RELEASE_VERSION='${{ github.ref_name }}' 28 | RELEASE_VERSION=${RELEASE_VERSION//version-/} 29 | RELEASE_VERSION=${RELEASE_VERSION//-create-release/} 30 | echo "::set-output name=RELEASE_VERSION::${RELEASE_VERSION}" 31 | 32 | 33 | - name: Test Environment 34 | run: | 35 | RELEASE_VERSION='${{ steps.setup.outputs.RELEASE_VERSION }}' 36 | 37 | # setup git 38 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 39 | git config --local user.name "github-actions[bot]" 40 | 41 | # switch to master 42 | git pull origin master 43 | git checkout master 44 | 45 | # update version 46 | python scripts/update_version.py $RELEASE_VERSION 47 | git commit -am "Update version to $RELEASE_VERSION" 48 | 49 | # create tag 50 | git fetch --tags 51 | git tag $RELEASE_VERSION 52 | 53 | # push to master 54 | git push 55 | git push --tags 56 | 57 | # delete branch 58 | git push -d origin ${{ github.ref_name }} 59 | 60 | - name: Build Changelog 61 | id: github_release 62 | uses: mikepenz/release-changelog-builder-action@v2.9.0 63 | env: 64 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 65 | with: 66 | toTag: ${{ steps.setup.outputs.RELEASE_VERSION }} 67 | 68 | - name: Create Release 69 | uses: actions/create-release@v1 70 | with: 71 | tag_name: ${{ steps.setup.outputs.RELEASE_VERSION }} 72 | release_name: ${{ steps.setup.outputs.RELEASE_VERSION }} 73 | body: ${{ steps.github_release.outputs.changelog }} 74 | draft: true 75 | env: 76 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/publish-package.yml: -------------------------------------------------------------------------------- 1 | name: Publish Package 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | publish-docs-and-package: 7 | name: Publish Docs and Package 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout 🛎️ 11 | uses: actions/checkout@v2 12 | 13 | - name: Set up Python 3.8 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.8 17 | 18 | - name: Install Poetry 📖 19 | uses: snok/install-poetry@v1.1.1 20 | with: 21 | version: 1.1.4 22 | 23 | - name: Install Dependencies 24 | run: | 25 | poetry config virtualenvs.create false 26 | pip install -U certifi 27 | poetry install 28 | 29 | - name: Build Docs 🔨 30 | run: | 31 | cp README.md docs/index.md 32 | python scripts/update_docs.py 33 | mkdocs build 34 | 35 | - name: Deploy Page 🚀 36 | uses: JamesIves/github-pages-deploy-action@4.1.6 37 | with: 38 | branch: gh-pages 39 | folder: site 40 | 41 | - name: Publish to PyPI 42 | run: | 43 | poetry build 44 | poetry publish \ 45 | --username ${{ secrets.PYPI_USERNAME }} \ 46 | --password ${{ secrets.PYPI_PASSWORD }} 47 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | # Checks that we can build and validate the Unittest 2 | name: Run Tests 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - uses: actions/setup-python@v2 14 | - uses: pre-commit/action@v2.0.3 15 | test: 16 | name: Run Tests 17 | if: ${{ !contains(github.event.pull_request.title, 'WIP') }} 18 | runs-on: ubuntu-latest 19 | strategy: 20 | matrix: 21 | python-version: [3.7, 3.8, 3.9] 22 | steps: 23 | - name: Check out the code 24 | uses: actions/checkout@v2 25 | with: 26 | fetch-depth: 1 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v2 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install Poetry 33 | uses: snok/install-poetry@v1.1.1 34 | with: 35 | version: 1.1.4 36 | 37 | - name: Install Dependencies 38 | run: | 39 | poetry config virtualenvs.create false 40 | pip install -U certifi 41 | poetry install 42 | 43 | - name: Run Tests 44 | run: pytest --cov=elegy --cov-report=term-missing --cov-report=xml 45 | 46 | - name: Upload coverage 47 | uses: codecov/codecov-action@v1 48 | 49 | - name: Test Examples 50 | run: bash scripts/test-examples.sh 51 | 52 | test-import: 53 | name: Test Import without Dev Dependencies 54 | if: ${{ !contains(github.event.pull_request.title, 'WIP') }} 55 | runs-on: ubuntu-latest 56 | strategy: 57 | matrix: 58 | python-version: [3.7, 3.8, 3.9] 59 | steps: 60 | - name: Check out the code 61 | uses: actions/checkout@v2 62 | with: 63 | fetch-depth: 1 64 | - name: Set up Python ${{ matrix.python-version }} 65 | uses: actions/setup-python@v2 66 | with: 67 | python-version: ${{ matrix.python-version }} 68 | 69 | - name: Install Poetry 70 | uses: snok/install-poetry@v1.1.1 71 | with: 72 | version: 1.1.4 73 | 74 | - name: Install Dependencies 75 | run: | 76 | pip install -U certifi 77 | poetry config virtualenvs.create false 78 | poetry install --no-dev 79 | 80 | - name: Test Import Elegy 81 | run: python -c "import elegy" 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | .idea/ 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | # custom 142 | /.vscode 143 | /.theia 144 | /test.* 145 | /summaries 146 | /runs 147 | /models 148 | /docs/models/ 149 | /docs/getting-started/models/ 150 | /saved-models 151 | /docs/saved-models/ 152 | /docs/getting-started/saved-models/ 153 | /TODO 154 | .git 155 | .python-version 156 | .devcontainer -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | 2 | image: 3 | file: gitpod.Dockerfile 4 | 5 | tasks: 6 | - init: poetry install -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.3.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.9.3 8 | hooks: 9 | - id: isort 10 | args: ["--profile", "black"] -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | This is a short guide on how to start contributing to Elegy along with some best practices for the project. 3 | 4 | ## Setup 5 | We use `poetry >= 1.1.4`, the easiest way to setup a development environment is run: 6 | 7 | ```bash 8 | poetry config virtualenvs.in-project true --local 9 | poetry install 10 | ``` 11 | 12 | In order for Jax to recognize your GPU, you will probably have to install it again using the command below. 13 | 14 | ```bash 15 | PYTHON_VERSION=cp38 16 | CUDA_VERSION=cuda101 # alternatives: cuda100, cuda101, cuda102, cuda110, check your cuda version 17 | PLATFORM=manylinux2010_x86_64 # alternatives: manylinux2010_x86_64 18 | BASE_URL='https://storage.googleapis.com/jax-releases' 19 | pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.55-$PYTHON_VERSION-none-$PLATFORM.whl 20 | pip install --upgrade jax 21 | ``` 22 | 23 | #### Gitpod 24 | An alternative way to contribute is using [gitpod](https://gitpod.io/) which creates a vscode-based cloud development enviroment. 25 | To get started just login at gitpod, grant the appropriate permissions to github, and open the following link: 26 | 27 | https://gitpod.io/#https://github.com/poets-ai/elegy 28 | 29 | We have built a `python 3.8` enviroment and all development dependencies will install when the enviroment starts. 30 | 31 | ## Creating Losses and Metrics 32 | For this you can follow these guidelines: 33 | 34 | * Each loss / metric should be defined in its own file. 35 | * Inherit from either `elegy.losses.loss.Loss` or `elegy.metrics.metric.Metric` or an existing class that inherits from them. 36 | * Try to use an existing metric or loss as a template 37 | * You must provide documentation for the following: 38 | * The class definition. 39 | * The `__init__` method. 40 | * The `call` method. 41 | * Try to port the documentation + signature from its Keras counter part. 42 | * If so you must give credits to the original source file. 43 | * You must include tests. 44 | * If you there exists an equivalent loss/metric in Keras you must test numerical equivalence between both. 45 | 46 | ## Testing 47 | To execute all the tests just run 48 | ```bash 49 | pytest 50 | ``` 51 | 52 | ## Documentation 53 | We use `mkdocs`. If you create a new object that requires documentation please do the following: 54 | 55 | * Add a markdown file inside `/docs/api` in the appropriate location according to the project's structure. This file must: 56 | * Contain the path of function / class as header 57 | * Use `mkdocstring` to render the API information. 58 | * Example: 59 | ```markdown 60 | # elegy.losses.BinaryCrossentropy 61 | 62 | ::: elegy.losses.BinaryCrossentropy 63 | selection: 64 | inherited_members: true 65 | members: 66 | - call 67 | - __init__ 68 | ``` 69 | * Add and entry to `mkdocs.yml` inside `nav` pointing to this file. Checkout `mkdocs.yml`. 70 | 71 | To build and visualize the documentation locally run 72 | ```bash 73 | mkdocs serve 74 | ``` 75 | 76 | ## Creating a PR 77 | Before sending a pull request make sure all test run and code is formatted with `black`: 78 | 79 | ```bash 80 | black . 81 | ``` 82 | 83 | ## Changelog 84 | `CHANGELOG.md` is automatically generated using [github-changelog-generator](https://github.com/github-changelog-generator/github-changelog-generator), to update the changelog just run: 85 | ```bash 86 | docker run -it --rm -v (pwd):/usr/local/src/your-app ferrarimarco/github-changelog-generator -u poets-ai -p elegy -t 87 | ``` 88 | where `` token can be obtained from Github at [Personal access tokens](https://github.com/settings/tokens), you only have to give permission for the `repo` section. 89 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Cristian Garcia and others 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | This is a short guide on how to start contributing to Elegy along with some best practices for the project. 3 | 4 | ## Setup 5 | We use `poetry >= 1.1.4`, the easiest way to setup a development environment is run: 6 | 7 | ```bash 8 | poetry config virtualenvs.in-project true --local 9 | poetry install 10 | ``` 11 | 12 | In order for Jax to recognize your GPU, you will probably have to install it again using the command below. 13 | 14 | ```bash 15 | PYTHON_VERSION=cp38 16 | CUDA_VERSION=cuda101 # alternatives: cuda100, cuda101, cuda102, cuda110, check your cuda version 17 | PLATFORM=manylinux2010_x86_64 # alternatives: manylinux2010_x86_64 18 | BASE_URL='https://storage.googleapis.com/jax-releases' 19 | pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.55-$PYTHON_VERSION-none-$PLATFORM.whl 20 | pip install --upgrade jax 21 | ``` 22 | 23 | #### Gitpod 24 | An alternative way to contribute is using [gitpod](https://gitpod.io/) which creates a vscode-based cloud development enviroment. 25 | To get started just login at gitpod, grant the appropriate permissions to github, and open the following link: 26 | 27 | https://gitpod.io/#https://github.com/poets-ai/elegy 28 | 29 | We have built a `python 3.8` enviroment and all development dependencies will install when the enviroment starts. 30 | 31 | ## Creating Losses and Metrics 32 | For this you can follow these guidelines: 33 | 34 | * Each loss / metric should be defined in its own file. 35 | * Inherit from either `elegy.losses.loss.Loss` or `elegy.metrics.metric.Metric` or an existing class that inherits from them. 36 | * Try to use an existing metric or loss as a template 37 | * You must provide documentation for the following: 38 | * The class definition. 39 | * The `__init__` method. 40 | * The `call` method. 41 | * Try to port the documentation + signature from its Keras counter part. 42 | * If so you must give credits to the original source file. 43 | * You must include tests. 44 | * If you there exists an equivalent loss/metric in Keras you must test numerical equivalence between both. 45 | 46 | ## Testing 47 | To execute all the tests just run 48 | ```bash 49 | pytest 50 | ``` 51 | 52 | ## Documentation 53 | We use `mkdocs`. If you create a new object that requires documentation please do the following: 54 | 55 | * Add a markdown file inside `/docs/api` in the appropriate location according to the project's structure. This file must: 56 | * Contain the path of function / class as header 57 | * Use `mkdocstring` to render the API information. 58 | * Example: 59 | ```markdown 60 | # elegy.losses.BinaryCrossentropy 61 | 62 | ::: elegy.losses.BinaryCrossentropy 63 | selection: 64 | inherited_members: true 65 | members: 66 | - call 67 | - __init__ 68 | ``` 69 | * Add and entry to `mkdocs.yml` inside `nav` pointing to this file. Checkout `mkdocs.yml`. 70 | 71 | To build and visualize the documentation locally run 72 | ```bash 73 | mkdocs serve 74 | ``` 75 | 76 | ## Creating a PR 77 | Before sending a pull request make sure all test run and code is formatted with `black`: 78 | 79 | ```bash 80 | black . 81 | ``` 82 | 83 | ## Changelog 84 | `CHANGELOG.md` is automatically generated using [github-changelog-generator](https://github.com/github-changelog-generator/github-changelog-generator), to update the changelog just run: 85 | ```bash 86 | docker run -it --rm -v (pwd):/usr/local/src/your-app ferrarimarco/github-changelog-generator -u poets-ai -p elegy -t 87 | ``` 88 | where `` token can be obtained from Github at [Personal access tokens](https://github.com/settings/tokens), you only have to give permission for the `repo` section. 89 | -------------------------------------------------------------------------------- /docs/guides/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | This is a short guide on how to start contributing to Elegy along with some best practices for the project. 3 | 4 | ## Setup 5 | We use `poetry >= 1.1.4`, the easiest way to setup a development environment is run: 6 | 7 | ```bash 8 | poetry config virtualenvs.in-project true --local 9 | poetry install 10 | ``` 11 | 12 | In order for Jax to recognize your GPU, you will probably have to install it again using the command below. 13 | 14 | ```bash 15 | PYTHON_VERSION=cp38 16 | CUDA_VERSION=cuda101 # alternatives: cuda100, cuda101, cuda102, cuda110, check your cuda version 17 | PLATFORM=manylinux2010_x86_64 # alternatives: manylinux2010_x86_64 18 | BASE_URL='https://storage.googleapis.com/jax-releases' 19 | pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.55-$PYTHON_VERSION-none-$PLATFORM.whl 20 | pip install --upgrade jax 21 | ``` 22 | 23 | #### Gitpod 24 | An alternative way to contribute is using [gitpod](https://gitpod.io/) which creates a vscode-based cloud development enviroment. 25 | To get started just login at gitpod, grant the appropriate permissions to github, and open the following link: 26 | 27 | https://gitpod.io/#https://github.com/poets-ai/elegy 28 | 29 | We have built a `python 3.8` enviroment and all development dependencies will install when the enviroment starts. 30 | 31 | ## Creating Losses and Metrics 32 | For this you can follow these guidelines: 33 | 34 | * Each loss / metric should be defined in its own file. 35 | * Inherit from either `elegy.losses.loss.Loss` or `elegy.metrics.metric.Metric` or an existing class that inherits from them. 36 | * Try to use an existing metric or loss as a template 37 | * You must provide documentation for the following: 38 | * The class definition. 39 | * The `__init__` method. 40 | * The `call` method. 41 | * Try to port the documentation + signature from its Keras counter part. 42 | * If so you must give credits to the original source file. 43 | * You must include tests. 44 | * If you there exists an equivalent loss/metric in Keras you must test numerical equivalence between both. 45 | 46 | ## Testing 47 | To execute all the tests just run 48 | ```bash 49 | pytest 50 | ``` 51 | 52 | ## Documentation 53 | We use `mkdocs`. If you create a new object that requires documentation please do the following: 54 | 55 | * Add a markdown file inside `/docs/api` in the appropriate location according to the project's structure. This file must: 56 | * Contain the path of function / class as header 57 | * Use `mkdocstring` to render the API information. 58 | * Example: 59 | ```markdown 60 | # elegy.losses.BinaryCrossentropy 61 | 62 | ::: elegy.losses.BinaryCrossentropy 63 | selection: 64 | inherited_members: true 65 | members: 66 | - call 67 | - __init__ 68 | ``` 69 | * Add and entry to `mkdocs.yml` inside `nav` pointing to this file. Checkout `mkdocs.yml`. 70 | 71 | To build and visualize the documentation locally run 72 | ```bash 73 | mkdocs serve 74 | ``` 75 | 76 | ## Creating a PR 77 | Before sending a pull request make sure all test run and code is formatted with `black`: 78 | 79 | ```bash 80 | black . 81 | ``` 82 | 83 | ## Changelog 84 | `CHANGELOG.md` is automatically generated using [github-changelog-generator](https://github.com/github-changelog-generator/github-changelog-generator), to update the changelog just run: 85 | ```bash 86 | docker run -it --rm -v (pwd):/usr/local/src/your-app ferrarimarco/github-changelog-generator -u poets-ai -p elegy -t 87 | ``` 88 | where `` token can be obtained from Github at [Personal access tokens](https://github.com/settings/tokens), you only have to give permission for the `repo` section. 89 | -------------------------------------------------------------------------------- /docs/images/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poets-ai/elegy/4709ce8dc9dde3925ce717e2358ce49112e36398/docs/images/favicon.png -------------------------------------------------------------------------------- /docs/low-level-api/basics.md: -------------------------------------------------------------------------------- 1 | # Low-level API 2 | Elegy's low-level API allows you to override some core methods in `Model` that specify what happens during training, inference, etc. This approach is perfect when you want to do things that are hard or simply not possible with the high-level API as it gives you the flexibility to do anything inside these methods as long as you return the expected types. 3 | 4 | 5 | ### Methods 6 | This is the list of all the overrideable methods: 7 | 8 | | Caller | Method | 9 | | :--------- | :------------- | 10 | | `predict` | `pred_step` | 11 | | `evaluate` | `test_step` | 12 | | | `grad_step` | 13 | | `fit` | `train_step` | 14 | | `init` | `init_step` | 15 | | `summary` | `summary_step` | 16 | | | `states_step` | 17 | | | `jit_step` | 18 | 19 | Each method has a default implementation which is what gives rise to the high-level API. 20 | 21 | ### Example 22 | Most overrideable methods take some input & state, perform some `jax` operations & updates the state, and returns some outputs & the new state. Lets see a simple example of a linear classifier using `test_step`: 23 | 24 | ```python 25 | class LinearClassifier(elegy.Model): 26 | def test_step(self, x, y_true, states, initializing): 27 | x = jnp.reshape(x, (x.shape[0], -1)) / 255 28 | 29 | # initialize or use existing parameters 30 | if initializing: 31 | w = jax.random.uniform( 32 | jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10] 33 | ) 34 | b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1]) 35 | else: 36 | w, b = states.net_params 37 | 38 | # model 39 | logits = jnp.dot(x, w) + b 40 | 41 | # categorical crossentropy loss 42 | labels = jax.nn.one_hot(y_true, 10) 43 | loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)) 44 | accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true) 45 | 46 | # metrics 47 | logs = dict(accuracy=accuracy, loss=loss) 48 | 49 | # update states 50 | states = states.update(net_params=(w, b)) 51 | 52 | return loss, logs, states 53 | 54 | model = LinearClassifier( 55 | optimizer=optax.adam(1e-3) 56 | ) 57 | 58 | model.fit( 59 | x=X_train, 60 | y=y_train, 61 | epochs=100, 62 | batch_size=64, 63 | ) 64 | ``` 65 | 66 | As you see here we perform everything from parameter initialization, modeling, calculating the main loss, and logging some metrics. Some notes about the previous example: 67 | 68 | * The `states` argument of type `elegy.States` is an immutable Mapping which you add / update fields via its `update` method. 69 | * `net_params` is one of the names used by the default implementation, check the [States](./states.md) guide for more information. 70 | * `initializing` tells you whether to initialize the parameters of the model or fetch the current ones from `states`, if you are using a Module framework this usually tells you whether to call `init` or `apply`. 71 | * `test_step` should returns 3 specific outputs (`loss`, `logs`, `states`), you should check the docs for each method to know what to return. 72 | -------------------------------------------------------------------------------- /docs/low-level-api/default-implementation.md: -------------------------------------------------------------------------------- 1 | # Default Implementation 2 | 3 | ### Methods 4 | The default implementation favors composition by implementing a method in term of another, especifically if follows this call graph: 5 | 6 | ``` 7 | summary predict evalutate fit init 8 | ⬇️ ⬇️ ⬇️ ⬇️ ⬇️ 9 | call_summary_step call_pred_step call_test_step call_train_step call_init_step 10 | ⬇️ ⬇️ ⬇️ ⬇️ ⬇️ 11 | summary_step ➡️ pred_step ⬅ test_step ⬅ grad_step ⬅ train_step ⬅ init_step 12 | ``` 13 | This structure allows you to for example override `test_step` and still be able to use use `fit` since `train_step` (called by `fit`) will call your `test_step` via `grad_step`. It also means that if you implement `test_step` but not `pred_step` there is a high chance both `predict` and `summary` will not work. 14 | 15 | #### call_* methods 16 | The `call_` method family are _entrypoints_ that usually just redirect to their inputs to ``, you choose to override these if you need to perform some some computation only when method in question is the entry point i.e. when its not called by other methods in the bottom path. 17 | For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`. -------------------------------------------------------------------------------- /docs/low-level-api/methods/pred_step.md: -------------------------------------------------------------------------------- 1 | # pred_step 2 | The `pred_step` method computes the predictions of the main model, by overriding this method you can directly influence what happens during `predict`. 3 | 4 | ### Inputs 5 | Any of following input arguments are available for `pred_step`: 6 | 7 | | name | type | | 8 | | :------------- | :------- | :--------------------------------------- | 9 | | `x` | `Any` | Input data | 10 | | `states` | `States` | Current state of the model | 11 | | `initializing` | `bool` | Whether the model is initializing or not | 12 | | `training` | `bool` | Whether the model is training or not | 13 | 14 | You must request the arguments you want by **name**. 15 | 16 | ### Outputs 17 | `pred_step` must output a tuple with the following values: 18 | 19 | | name | type | | 20 | | :------- | :------- | :--------------------------- | 21 | | `y_pred` | `Any` | The predictions of the model | 22 | | `states` | `States` | The new state of the model | 23 | 24 | 25 | ### Callers 26 | | method | when | 27 | | :------------- | :--------------------- | 28 | | `predict` | always | 29 | | `test_step` | default implementation | 30 | | `summary_step` | default implementation | 31 | 32 | ### Examples 33 | If for some reason you wish to create a pure jax / Module-less model, you can define your own Model that implements `pred_step` like this: 34 | 35 | ```python 36 | class LinearClassifier(elegy.Model): 37 | def pred_step(self, x, y_true, states, initializing): 38 | x = jnp.reshape(x, (x.shape[0], -1)) / 255 39 | 40 | # initialize or use existing parameters 41 | if initializing: 42 | w = jax.random.uniform( 43 | jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10] 44 | ) 45 | b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1]) 46 | else: 47 | w, b = states.net_params 48 | 49 | # model 50 | y_pred = jnp.dot(x, w) + b 51 | 52 | return y_pred, states.update(net_params=(w, b)) 53 | 54 | model = LinearClassifier( 55 | optimizer=optax.adam(1e-3), 56 | loss=elegy.losses.Crossentropy(), 57 | metrics=elegy.metrics.SparseCategoricalAccuracy(), 58 | ) 59 | 60 | model.fit( 61 | x=X_train, 62 | y=y_train, 63 | epochs=100, 64 | batch_size=64, 65 | ) 66 | ``` 67 | Here we implement the same `LinearClassifier` from the [basics](./basics) section but we extracted the definition of the model to `pred_step` and we let the basic implementation of `test_step` take care of the `loss` and `metrics` which we provide to the `LinearClassifier`'s constructor. 68 | 69 | ### Default Implementation 70 | The default implementation of `pred_step` does the following: 71 | 72 | * Calls `api_module.init` or `api_module.apply` depending on state of `initializing`. `api_module` of type `GeneralizedModule` is a wrapper over the `module` object passed by the user to the `Model`s constructor. -------------------------------------------------------------------------------- /docs/low-level-api/methods/test_step.md: -------------------------------------------------------------------------------- 1 | # test_step 2 | The `test_step` computes the main `loss` of the model along with some `logs` for reporting, by overriding this method you can directly influence what happens during `evaluate`. 3 | 4 | ### Inputs 5 | Any of following input arguments are available for `test_step`: 6 | 7 | | name | type | | 8 | | :-------------- | :------------------ | :------------------------------------------ | 9 | | `x` | `Any` | Input data | 10 | | `y_true` | `Any` | The target labels | 11 | | `sample_weight` | `Optional[ndarray]` | The weight of each sample in the total loss | 12 | | `class_weight` | `Optional[ndarray]` | The weight of each class in the total loss | 13 | | `states` | `States` | Current state of the model | 14 | | `initializing` | `bool` | Whether the model is initializing or not | 15 | | `training` | `bool` | Whether the model is training or not | 16 | 17 | 18 | You must request the arguments you want by **name**. 19 | 20 | ### Outputs 21 | `pred_step` must output a tuple with the following values: 22 | 23 | | name | type | | 24 | | :------- | :------------------- | :------------------------------------------ | 25 | | `loss` | `ndarray` | The loss of the model over the data | 26 | | `logs` | `Dict[str, ndarray]` | A dictionary with a set of values to report | 27 | | `states` | `States` | The new state of the model | 28 | 29 | 30 | ### Callers 31 | | method | when | 32 | | :----------- | :------------------------------------------------ | 33 | | `evaluate` | always | 34 | | `grad_step` | default implementation | 35 | | `train_step` | default implementation during initialization only | 36 | 37 | ### Examples 38 | Lets review the example of `test_step` found in [basics](./basics): 39 | 40 | ```python 41 | class LinearClassifier(elegy.Model): 42 | def test_step(self, x, y_true, states, initializing): 43 | x = jnp.reshape(x, (x.shape[0], -1)) / 255 44 | 45 | # initialize or use existing parameters 46 | if initializing: 47 | w = jax.random.uniform( 48 | jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10] 49 | ) 50 | b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1]) 51 | else: 52 | w, b = states.net_params 53 | 54 | # model 55 | logits = jnp.dot(x, w) + b 56 | 57 | # categorical crossentropy loss 58 | labels = jax.nn.one_hot(y_true, 10) 59 | loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)) 60 | accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true) 61 | 62 | # metrics 63 | logs = dict(accuracy=accuracy, loss=loss) 64 | 65 | # update states 66 | states = states.update(net_params=(w, b)) 67 | 68 | return loss, logs, states 69 | 70 | model = LinearClassifier( 71 | optimizer=optax.adam(1e-3) 72 | ) 73 | 74 | model.fit( 75 | x=X_train, 76 | y=y_train, 77 | epochs=100, 78 | batch_size=64, 79 | ) 80 | ``` 81 | In this case `test_step` is defining both the "forward" pass of the model and calculating the losses and metrics in a single place. However, since we are not defining `pred_step` we loose the power to call `predict` which might not be desirable. The optimimal way to fix this is to extract the calculation of the logits into `pred_step` and call this from `test_step`: 82 | 83 | ```python 84 | class LinearClassifier(elegy.Model): 85 | def test_step(self, x, states, initializing): 86 | x = jnp.reshape(x, (x.shape[0], -1)) / 255 87 | 88 | # initialize or use existing parameters 89 | if initializing: 90 | w = jax.random.uniform( 91 | jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10] 92 | ) 93 | b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1]) 94 | else: 95 | w, b = states.net_params 96 | 97 | # model 98 | logits = jnp.dot(x, w) + b 99 | 100 | return logits, states.update(net_params=(w, b)) 101 | 102 | def test_step(self, x, y_true, states, initializing): 103 | # call pred_step 104 | logits, states = self.pred_step((x, states, initializing) 105 | 106 | # categorical crossentropy loss 107 | labels = jax.nn.one_hot(y_true, 10) 108 | loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)) 109 | accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true) 110 | 111 | # metrics 112 | logs = dict(accuracy=accuracy, loss=loss) 113 | 114 | # update states 115 | states = states.update(net_params=(w, b)) 116 | 117 | return loss, logs, states 118 | 119 | model = LinearClassifier( 120 | optimizer=optax.adam(1e-3), 121 | ) 122 | 123 | model.fit( 124 | x=X_train, 125 | y=y_train, 126 | epochs=100, 127 | batch_size=64, 128 | ) 129 | ``` 130 | This not only creates a separation of concerns, it also favors code reuse, and we can now use `predict`, `evaluate`, and `fit` as intended. 131 | 132 | There are cases however where you might want to implement a forward pass inside `test_step` that is different from what you would define in `pred_step`, for example you can create a `VAE` or `GAN` Models that use multiple modules to calculate the loss inside `test_step` (e.g. encoder, decoder, and discriminator) but only use the decoder inside `pred_step` to generate samples. 133 | 134 | ### Default Implementation 135 | The default implementation of `pred_step` does the following: 136 | 137 | * Call `pred_step` to get `y_pred`. 138 | * Calls `api_loss.init` or `api_loss.apply` depending on state of `initializing`. `api_loss` of type `Losses` computes the aggregated batch loss from the loss functions passed by the user through the `loss` argument in the `Model`s constructor, and also computes a running mean of each loss individually which is passed for reporting to `logs`. 139 | * Calls `api_metrics.init` or `api_metrics.apply` depending on state of `initializing`. `api_metrics` of type `Metrics` calculates the metrics passed by the user through the `metrics` argument in the `Model`s constructor and passes their values to `logs` for reporting. -------------------------------------------------------------------------------- /docs/low-level-api/states.md: -------------------------------------------------------------------------------- 1 | 2 | # States 3 | `elegy.States` is an immutable `Mapping` that contains all the states needed in `Model`, the low-level API provides a simple state management system by passing the `states` parameter (of type `elegy.States`) to all methods. 4 | 5 | ### Basic usage 6 | The most common way to use `States` is via its `update` method you can use to set or update field: 7 | ```python 8 | states = states.update(some_field=some_value) 9 | ``` 10 | 11 | You can access a field via index or field access notation: 12 | ```python 13 | some_value = states["some_field"] 14 | some_value = states.some_field 15 | ``` 16 | 17 | ### Default Implementation 18 | The default implementation uses the following fields: 19 | 20 | | name | description | 21 | | :----------------- | :----------------------------------------------------------------------- | 22 | | `rng` | contains an `elegy.RNGSeq` instance you can you to request random state. | 23 | | `net_params` | the trainable parameters of the model. | 24 | | `net_states` | the non-trainable parameters of the model. | 25 | | `metrics_states` | the states used to calculate cumulative metrics. | 26 | | `optimizer_states` | the states for the optimizer. | -------------------------------------------------------------------------------- /docs/na.md: -------------------------------------------------------------------------------- 1 | # Not Available 2 | 3 | 🚧 This page is not available yet, we are working on it 🚧 -------------------------------------------------------------------------------- /elegy/__init__.py: -------------------------------------------------------------------------------- 1 | # isort:skip_file 2 | 3 | __version__ = "0.8.6" 4 | 5 | from treex import * 6 | 7 | import elegy.types as types 8 | import elegy.utils as utils 9 | 10 | 11 | from . import ( 12 | callbacks, 13 | data, 14 | model, 15 | # nets, 16 | ) 17 | 18 | from .model.model import Model 19 | from .model.model_base import ModelBase, load 20 | from .model.model_core import ( 21 | GradStepOutput, 22 | PredStepOutput, 23 | TestStepOutput, 24 | TrainStepOutput, 25 | LossStepOutput, 26 | ModelCore, 27 | ) 28 | from .types import KeySeq 29 | from .utils import inject_dependencies 30 | -------------------------------------------------------------------------------- /elegy/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .callback import Callback 2 | from .callback_list import CallbackList 3 | from .csv_logger import CSVLogger 4 | from .early_stopping import EarlyStopping 5 | from .history import History 6 | from .lambda_callback import LambdaCallback 7 | from .model_checkpoint import ModelCheckpoint 8 | from .remote_monitor import RemoteMonitor 9 | from .sigint import SigInt 10 | from .tensorboard import TensorBoard 11 | from .terminate_nan import TerminateOnNaN 12 | from .wandb_callback import WandbCallback 13 | 14 | __all__ = [ 15 | "CallbackList", 16 | "Callback", 17 | "History", 18 | "ModelCheckpoint", 19 | "EarlyStopping", 20 | "LambdaCallback", 21 | "TerminateOnNaN", 22 | "RemoteMonitor", 23 | "CSVLogger", 24 | "TensorBoard", 25 | "WandbCallback", 26 | ] 27 | -------------------------------------------------------------------------------- /elegy/callbacks/csv_logger.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.callbacks.py 2 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py 3 | 4 | import collections 5 | import csv 6 | import io 7 | import os 8 | import typing as tp 9 | 10 | import numpy as np 11 | import six 12 | 13 | from .callback import Callback 14 | 15 | 16 | class CSVLogger(Callback): 17 | """Callback that streams epoch results to a csv file. 18 | 19 | Supports all values that can be represented as a string, 20 | including 1D iterables such as `np.ndarray`. 21 | 22 | Example: 23 | 24 | ```python 25 | csv_logger = CSVLogger('training.log') 26 | model.fit(X_train, Y_train, callbacks=[csv_logger]) 27 | ``` 28 | """ 29 | 30 | def __init__(self, filename: str, separator: str = ",", append: bool = False): 31 | """ 32 | Arguments: 33 | filename: filename of the csv file, e.g. 'run/log.csv'. 34 | separator: string used to separate elements in the csv file. 35 | append: True: append if file exists (useful for continuing 36 | training). False: overwrite existing file, 37 | """ 38 | self.sep = separator 39 | self.filename = filename 40 | self.append = append 41 | self.writer = None 42 | self.keys = None 43 | self.append_header = True 44 | self.file_flags = "" 45 | self._open_args = {"newline": "\n"} 46 | super(CSVLogger, self).__init__() 47 | 48 | def on_train_begin(self, logs=None): 49 | if self.append: 50 | if os.path.exists(self.filename): 51 | with open(self.filename, "r" + self.file_flags) as f: 52 | self.append_header = not bool(len(f.readline())) 53 | mode = "a" 54 | else: 55 | mode = "w" 56 | self.csv_file = io.open( 57 | self.filename, mode + self.file_flags, **self._open_args 58 | ) 59 | 60 | def on_epoch_end(self, epoch, logs=None): 61 | logs = logs or {} 62 | 63 | def handle_value(k): 64 | is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 65 | if isinstance(k, six.string_types): 66 | return k 67 | elif isinstance(k, tp.Iterable) and not is_zero_dim_ndarray: 68 | return '"[%s]"' % (", ".join(map(str, k))) 69 | else: 70 | return k 71 | 72 | if self.keys is None: 73 | self.keys = sorted(logs.keys()) 74 | 75 | if self.model.stop_training: 76 | # We set NA so that csv parsers do not fail for this last epoch. 77 | logs = dict([(k, logs[k]) if k in logs else (k, "NA") for k in self.keys]) 78 | 79 | if not self.writer: 80 | 81 | class CustomDialect(csv.excel): 82 | delimiter = self.sep 83 | 84 | fieldnames = ["epoch"] + self.keys 85 | 86 | self.writer = csv.DictWriter( 87 | self.csv_file, fieldnames=fieldnames, dialect=CustomDialect 88 | ) 89 | if self.append_header: 90 | self.writer.writeheader() 91 | 92 | row_dict = collections.OrderedDict({"epoch": epoch}) 93 | row_dict.update((key, handle_value(logs[key])) for key in self.keys) 94 | self.writer.writerow(row_dict) 95 | self.csv_file.flush() 96 | 97 | def on_train_end(self, logs=None): 98 | self.csv_file.close() 99 | self.writer = None 100 | -------------------------------------------------------------------------------- /elegy/callbacks/early_stopping.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.callbacks.py 2 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py 3 | import logging 4 | import typing as tp 5 | 6 | import numpy as np 7 | 8 | from .callback import Callback 9 | 10 | 11 | class EarlyStopping(Callback): 12 | """ 13 | Stop training when a monitored metric has stopped improving. 14 | 15 | Assuming the goal of a training is to minimize the loss. With this, the 16 | metric to be monitored would be 'loss', and mode would be 'min'. A 17 | `model.fit()` training loop will check at end of every epoch whether 18 | the loss is no longer decreasing, considering the `min_delta` and 19 | `patience` if applicable. Once it's found no longer decreasing, 20 | `model.stop_training` is marked True and the training terminates. 21 | 22 | The quantity to be monitored needs to be available in `logs` dict. 23 | To make it so, pass the loss or metrics at `model.__init__()`. 24 | 25 | Example: 26 | ```python 27 | np.random.seed(42) 28 | class MLP(elegy.Module): 29 | def call(self, input): 30 | mlp = elegy.Sequential([elegy.nn.Linear(10),]) 31 | return mlp(input) 32 | 33 | callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3) 34 | # This callback will stop the training when there is no improvement in 35 | # the for three consecutive epochs. 36 | model = elegy.Model( 37 | module=MLP(), 38 | loss=elegy.losses.MeanSquaredError(), 39 | optimizer=optax.rmsprop(0.01), 40 | ) 41 | history = model.fit( 42 | np.arange(100).reshape(5, 20).astype(np.float32), 43 | np.zeros(5), 44 | epochs=10, 45 | batch_size=1, 46 | callbacks=[callback], 47 | verbose=0, 48 | ) 49 | assert len(history.history["loss"]) == 7 # Only 7 epochs are run. 50 | ``` 51 | """ 52 | 53 | def __init__( 54 | self, 55 | monitor: str = "val_loss", 56 | min_delta: int = 0, 57 | patience: int = 0, 58 | verbose: int = 0, 59 | mode: str = "auto", 60 | baseline: tp.Optional[float] = None, 61 | restore_best_weights: bool = False, 62 | ): 63 | """Initialize an EarlyStopping callback. 64 | 65 | Arguments: 66 | monitor: Quantity to be monitored. 67 | min_delta: Minimum change in the monitored quantity 68 | to qualify as an improvement, i.e. an absolute 69 | change of less than min_delta, will count as no 70 | improvement. 71 | patience: Number of epochs with no improvement 72 | after which training will be stopped. 73 | verbose: verbosity mode. 74 | mode: One of `{"auto", "min", "max"}`. In `min` mode, 75 | training will stop when the quantity 76 | monitored has stopped decreasing; in `max` 77 | mode it will stop when the quantity 78 | monitored has stopped increasing; in `auto` 79 | mode, the direction is automatically inferred 80 | from the name of the monitored quantity. 81 | baseline: Baseline value for the monitored quantity. 82 | Training will stop if the model doesn't show improvement over the 83 | baseline. 84 | restore_best_weights: Whether to restore model weights from 85 | the epoch with the best value of the monitored quantity. 86 | If False, the model weights obtained at the last step of 87 | training are used. 88 | """ 89 | super(EarlyStopping, self).__init__() 90 | 91 | self.monitor = monitor 92 | self.patience = patience 93 | self.verbose = verbose 94 | self.baseline = baseline 95 | self.min_delta = abs(min_delta) 96 | self.wait = 0 97 | self.stopped_epoch = 0 98 | self.restore_best_weights = restore_best_weights 99 | self.best_weights = None 100 | 101 | if mode not in ["auto", "min", "max"]: 102 | logging.warning( 103 | "EarlyStopping mode %s is unknown, " "fallback to auto mode.", mode 104 | ) 105 | mode = "auto" 106 | 107 | if mode == "min": 108 | self.monitor_op = np.less 109 | elif mode == "max": 110 | self.monitor_op = np.greater 111 | else: 112 | if "acc" in self.monitor: 113 | self.monitor_op = np.greater 114 | else: 115 | self.monitor_op = np.less 116 | 117 | if self.monitor_op == np.greater: 118 | self.min_delta *= 1 119 | else: 120 | self.min_delta *= -1 121 | 122 | def on_train_begin(self, logs=None): 123 | # Allow instances to be re-used 124 | self.wait = 0 125 | self.stopped_epoch = 0 126 | if self.baseline is not None: 127 | self.best = self.baseline 128 | else: 129 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf 130 | 131 | def on_epoch_end(self, epoch, logs=None): 132 | current = self.get_monitor_value(logs) 133 | if current is None: 134 | return 135 | if self.monitor_op(current - self.min_delta, self.best): 136 | self.best = current 137 | self.wait = 0 138 | if self.restore_best_weights: 139 | # This will also save optimizer state 140 | self.best_state = self.model.full_state 141 | else: 142 | self.wait += 1 143 | if self.wait >= self.patience: 144 | self.stopped_epoch = epoch 145 | self.model.stop_training = True 146 | if self.restore_best_weights: 147 | if self.verbose > 0: 148 | print("Restoring model weights from the end of the best epoch.") 149 | self.model.full_state = self.best_state 150 | 151 | def on_train_end(self, logs=None): 152 | if self.stopped_epoch > 0 and self.verbose > 0: 153 | print("Epoch %05d: early stopping" % (self.stopped_epoch + 1)) 154 | 155 | def get_monitor_value(self, logs): 156 | logs = logs or {} 157 | monitor_value = logs.get(self.monitor) 158 | if monitor_value is None: 159 | logging.warning( 160 | "Early stopping conditioned on metric `%s` " 161 | "which is not available. Available metrics are: %s", 162 | self.monitor, 163 | ",".join(list(logs.keys())), 164 | ) 165 | return monitor_value 166 | -------------------------------------------------------------------------------- /elegy/callbacks/history.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.callbacks.py 2 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py 3 | from .callback import Callback 4 | 5 | 6 | class History(Callback): 7 | """Callback that records events into a `History` object. 8 | 9 | This callback is automatically applied to 10 | every Keras model. The `History` object 11 | gets returned by the `fit` method of models. 12 | """ 13 | 14 | def __init__(self): 15 | super(History, self).__init__() 16 | self.history = {} 17 | 18 | def on_train_begin(self, logs=None): 19 | self.epoch = [] 20 | 21 | def on_epoch_end(self, epoch, logs=None): 22 | logs = logs or {} 23 | self.epoch.append(epoch) 24 | for k, v in logs.items(): 25 | self.history.setdefault(k, []).append(v) 26 | 27 | # Set the history attribute on the model after the epoch ends. This will 28 | # make sure that the state which is set is the latest one. 29 | self.model.history = self 30 | -------------------------------------------------------------------------------- /elegy/callbacks/lambda_callback.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.callbacks.py 2 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py 3 | import typing as tp 4 | 5 | import numpy as np 6 | 7 | from .callback import Callback 8 | 9 | 10 | class LambdaCallback(Callback): 11 | r"""Callback for creating simple, custom callbacks on-the-fly. 12 | 13 | This callback is constructed with anonymous functions that will be called 14 | at the appropriate time. Note that the callbacks expects positional 15 | arguments, as: 16 | 17 | - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: 18 | `epoch`, `logs` 19 | - `on_train_batch_begin` and `on_train_batch_end` expect two positional arguments: 20 | `batch`, `logs` 21 | - `on_train_begin` and `on_train_end` expect one positional argument: 22 | `logs` 23 | 24 | Example: 25 | 26 | ```python 27 | # Print the batch number at the beginning of every batch. 28 | batch_print_callback = LambdaCallback( 29 | on_train_batch_begin=lambda batch,logs: print(batch)) 30 | 31 | # Stream the epoch loss to a file in JSON format. The file content 32 | # is not well-formed JSON but rather has a JSON object per line. 33 | import json 34 | json_log = open('loss_log.json', mode='wt', buffering=1) 35 | json_logging_callback = LambdaCallback( 36 | on_epoch_end=lambda epoch, logs: json_log.write( 37 | json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), 38 | on_train_end=lambda logs: json_log.close() 39 | ) 40 | 41 | # Terminate some processes after having finished model training. 42 | processes = ... 43 | cleanup_callback = LambdaCallback( 44 | on_train_end=lambda logs: [ 45 | p.terminate() for p in processes if p.is_alive()]) 46 | 47 | model.fit(..., 48 | callbacks=[batch_print_callback, 49 | json_logging_callback, 50 | cleanup_callback]) 51 | ``` 52 | """ 53 | 54 | def __init__( 55 | self, 56 | on_epoch_begin: tp.Optional[ 57 | tp.Callable[[int, tp.Dict[str, np.ndarray]], None] 58 | ] = None, 59 | on_epoch_end: tp.Optional[ 60 | tp.Callable[[int, tp.Dict[str, np.ndarray]], None] 61 | ] = None, 62 | on_train_batch_begin: tp.Optional[ 63 | tp.Callable[[int, tp.Dict[str, np.ndarray]], None] 64 | ] = None, 65 | on_train_batch_end: tp.Optional[ 66 | tp.Callable[[int, tp.Dict[str, np.ndarray]], None] 67 | ] = None, 68 | on_train_begin: tp.Optional[ 69 | tp.Callable[[tp.Dict[str, np.ndarray]], None] 70 | ] = None, 71 | on_train_end: tp.Optional[tp.Callable[[tp.Dict[str, np.ndarray]], None]] = None, 72 | **kwargs 73 | ): 74 | """ 75 | Arguments: 76 | on_epoch_begin: called at the beginning of every epoch. 77 | on_epoch_end: called at the end of every epoch. 78 | on_train_batch_begin: called at the beginning of every batch. 79 | on_train_batch_end: called at the end of every batch. 80 | on_train_begin: called at the beginning of model training. 81 | on_train_end: called at the end of model training. 82 | """ 83 | super(LambdaCallback, self).__init__() 84 | self.__dict__.update(kwargs) 85 | if on_epoch_begin is not None: 86 | self.on_epoch_begin = on_epoch_begin 87 | else: 88 | self.on_epoch_begin = lambda epoch, logs: None 89 | if on_epoch_end is not None: 90 | self.on_epoch_end = on_epoch_end 91 | else: 92 | self.on_epoch_end = lambda epoch, logs: None 93 | if on_train_batch_begin is not None: 94 | self.on_train_batch_begin = on_train_batch_begin 95 | else: 96 | self.on_train_batch_begin = lambda batch, logs: None 97 | if on_train_batch_end is not None: 98 | self.on_train_batch_end = on_train_batch_end 99 | else: 100 | self.on_train_batch_end = lambda batch, logs: None 101 | if on_train_begin is not None: 102 | self.on_train_begin = on_train_begin 103 | else: 104 | self.on_train_begin = lambda logs: None 105 | if on_train_end is not None: 106 | self.on_train_end = on_train_end 107 | else: 108 | self.on_train_end = lambda logs: None 109 | -------------------------------------------------------------------------------- /elegy/callbacks/remote_monitor.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.callbacks.py 2 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py 3 | try: 4 | import requests 5 | except ImportError: 6 | requests = None 7 | import json 8 | import logging 9 | import typing as tp 10 | 11 | import numpy as np 12 | 13 | from .callback import Callback 14 | 15 | 16 | class RemoteMonitor(Callback): 17 | """Callback used to stream events to a server. 18 | 19 | Requires the `requests` library. 20 | Events are sent to `root + '/publish/epoch/end/'` by default. Calls are 21 | HTTP POST, with a `data` argument which is a 22 | JSON-encoded dictionary of event data. 23 | If send_as_json is set to True, the content type of the request will be 24 | application/json. Otherwise the serialized JSON will be sent within a form. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | root: str = "http://localhost:9000", 30 | path: str = "/publish/epoch/end/", 31 | field: str = "data", 32 | headers: tp.Optional[tp.Dict[str, str]] = None, 33 | send_as_json: bool = False, 34 | ): 35 | """ 36 | Arguments: 37 | root: String; root url of the target server. 38 | path: String; path relative to `root` to which the events will be sent. 39 | field: String; JSON field under which the data will be stored. 40 | The field is used only if the payload is sent within a form 41 | (i.e. send_as_json is set to False). 42 | headers: Dictionary; optional custom HTTP headers. 43 | send_as_json: Boolean; whether the request should be 44 | sent as application/json. 45 | """ 46 | super(RemoteMonitor, self).__init__() 47 | 48 | self.root = root 49 | self.path = path 50 | self.field = field 51 | self.headers = headers 52 | self.send_as_json = send_as_json 53 | 54 | def on_epoch_end(self, epoch, logs=None): 55 | if requests is None: 56 | raise ImportError("RemoteMonitor requires the `requests` library.") 57 | logs = logs or {} 58 | send = {} 59 | send["epoch"] = epoch 60 | for k, v in logs.items(): 61 | # np.ndarray and np.generic are not scalar types 62 | # therefore we must unwrap their scalar values and 63 | # pass to the json-serializable dict 'send' 64 | if isinstance(v, (np.ndarray, np.generic)): 65 | send[k] = v.item() 66 | else: 67 | send[k] = v 68 | try: 69 | if self.send_as_json: 70 | requests.post(self.root + self.path, json=send, headers=self.headers) 71 | else: 72 | requests.post( 73 | self.root + self.path, 74 | {self.field: json.dumps(send)}, 75 | headers=self.headers, 76 | ) 77 | except requests.exceptions.RequestException: 78 | logging.warning( 79 | "Warning: could not reach RemoteMonitor " 80 | "root server at " + str(self.root) 81 | ) 82 | -------------------------------------------------------------------------------- /elegy/callbacks/sigint.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.callbacks.py 2 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py 3 | import enum 4 | import logging 5 | import signal 6 | import typing as tp 7 | 8 | import numpy as np 9 | 10 | from .callback import Callback 11 | 12 | ORIGINAL_HANDLER = signal.getsignal(signal.SIGINT) 13 | 14 | 15 | class SigIntMode(enum.Enum): 16 | TRAIN = enum.auto() 17 | TEST = enum.auto() 18 | 19 | 20 | class SigInt(Callback): 21 | def __init__(self, mode: tp.Union[SigIntMode, str]): 22 | super().__init__() 23 | self.mode = mode if isinstance(mode, SigIntMode) else SigIntMode(mode.upper()) 24 | 25 | def signal_handler(self, signal, frame): 26 | print("\n\nStopping...\n\n") 27 | self.model.stop_training = True 28 | # signal.signal(signal.SIGINT, ORIGINAL_HANDLER) 29 | 30 | def on_train_begin(self, logs=None): 31 | signal.signal(signal.SIGINT, self.signal_handler) 32 | 33 | def on_train_end(self, logs=None): 34 | if self.mode == SigIntMode.TRAIN: 35 | signal.signal(signal.SIGINT, ORIGINAL_HANDLER) 36 | 37 | def on_test_begin(self, logs=None): 38 | signal.signal(signal.SIGINT, self.signal_handler) 39 | 40 | def on_test_end(self, logs=None): 41 | if self.mode == SigIntMode.TEST: 42 | signal.signal(signal.SIGINT, ORIGINAL_HANDLER) 43 | -------------------------------------------------------------------------------- /elegy/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.callbacks.py 2 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py 3 | 4 | 5 | import os 6 | import typing as tp 7 | from typing import Any, Dict, Optional, Union 8 | 9 | from tensorboardX.writer import SummaryWriter 10 | 11 | from .callback import Callback 12 | 13 | 14 | class TensorBoard(Callback): 15 | """ 16 | Callback that streams epoch results to tensorboard events folder. 17 | 18 | Supports all values that can be represented as a string, 19 | including 1D iterables such as `np.ndarray`. 20 | 21 | 22 | ```python 23 | tensorboard_logger = TensorBoard('runs') 24 | model.fit(X_train, Y_train, callbacks=[tensorboard_logger]) 25 | ``` 26 | """ 27 | 28 | def __init__( 29 | self, 30 | logdir: Optional[str] = None, 31 | *, 32 | update_freq: Union[str, int] = "epoch", 33 | purge_step: Optional[int] = None, 34 | comment: str = "", 35 | ) -> None: 36 | """ 37 | Arguments: 38 | logdir: Save directory location. Default is 39 | runs/**CURRENT_DATETIME_HOSTNAME**/{train, val}, which changes after each run. 40 | Use hierarchical folder structure to compare 41 | between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc. 42 | for each new experiment to compare across them. 43 | update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, 44 | writes the losses and metrics to TensorBoard after each batch. The same 45 | applies for `'epoch'`. If using an integer, let's say `1000`, the 46 | callback will write the metrics and losses to TensorBoard every 1000 47 | batches. Note that writing too frequently to TensorBoard can slow down 48 | your training. 49 | purge_step (int): 50 | When logging crashes at step :math:`T+X` and restarts at step :math:`T`, 51 | any events whose global_step larger or equal to :math:`T` will be 52 | purged and hidden from TensorBoard. 53 | Note that crashed and resumed experiments should have the same ``logdir``. 54 | comment (string): Comment logdir suffix appended to the default 55 | ``logdir``. If ``logdir`` is assigned, this argument has no effect. 56 | """ 57 | if not logdir: 58 | import socket 59 | from datetime import datetime 60 | 61 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 62 | self.logdir = os.path.join( 63 | "runs", current_time + "_" + socket.gethostname() + comment 64 | ) 65 | else: 66 | self.logdir = logdir 67 | self.train_writer = None 68 | self.val_writer = None 69 | self.keys = None 70 | self.write_per_batch = True 71 | try: 72 | self.update_freq = int(update_freq) 73 | except ValueError as e: 74 | self.update_freq = 1 75 | if update_freq == "batch": 76 | self.write_per_batch = True 77 | elif update_freq == "epoch": 78 | self.write_per_batch = False 79 | else: 80 | raise e 81 | self.purge_step = purge_step 82 | 83 | super(TensorBoard, self).__init__() 84 | 85 | def on_train_begin(self, logs=None): 86 | self.train_writer = SummaryWriter( 87 | os.path.join(self.logdir, "train"), purge_step=self.purge_step 88 | ) 89 | self.val_writer = SummaryWriter( 90 | os.path.join(self.logdir, "val"), purge_step=self.purge_step 91 | ) 92 | self.steps = self.params["steps"] 93 | self.global_step = 0 94 | 95 | def on_train_batch_end(self, batch: int, logs=None): 96 | if not self.write_per_batch: 97 | return 98 | logs = logs or {} 99 | self.global_step = batch + self.current_epoch * (self.steps) 100 | if self.global_step % self.update_freq == 0: 101 | if self.keys is None: 102 | self.keys = logs.keys() 103 | for key in self.keys: 104 | self.train_writer.add_scalar(key, logs[key], self.global_step) 105 | 106 | def on_epoch_begin(self, epoch: int, logs=None): 107 | self.current_epoch = epoch 108 | 109 | def on_epoch_end(self, epoch, logs=None): 110 | # if self.model.stop_training: 111 | # return 112 | 113 | logs = logs or {} 114 | 115 | if self.keys is None: 116 | self.keys = logs.keys() 117 | 118 | # logs on on_{train, test}_batch_end do not have val metrics 119 | if self.write_per_batch: 120 | for key in logs: 121 | if "val" in key: 122 | self.val_writer.add_scalar( 123 | key.replace("val_", ""), logs[key], self.global_step 124 | ) 125 | return 126 | 127 | elif epoch % self.update_freq == 0: 128 | 129 | for key in logs: 130 | if "val" in key and key: 131 | self.val_writer.add_scalar( 132 | key.replace("val_", ""), logs[key], epoch 133 | ) 134 | else: 135 | self.train_writer.add_scalar(key, logs[key], epoch) 136 | 137 | def on_train_end(self, logs=None): 138 | self.train_writer.close() 139 | self.val_writer.close() 140 | -------------------------------------------------------------------------------- /elegy/callbacks/terminate_nan.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.callbacks.py 2 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py 3 | import numpy as np 4 | 5 | from .callback import Callback 6 | 7 | 8 | class TerminateOnNaN(Callback): 9 | """Callback that terminates training when a NaN loss is encountered.""" 10 | 11 | def on_batch_end(self, batch, logs=None): 12 | logs = logs or {} 13 | loss = logs.get("loss") 14 | if loss is not None: 15 | if np.isnan(loss) or np.isinf(loss): 16 | print("Batch %d: Invalid loss, terminating training" % (batch)) 17 | self.model.stop_training = True 18 | -------------------------------------------------------------------------------- /elegy/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_handler import DataHandler 2 | from .dataset import DataLoader, Dataset 3 | from .utils import ( 4 | map_append, 5 | map_structure, 6 | train_validation_split, 7 | unpack_x_y_sample_weight, 8 | ) 9 | 10 | __all__ = ["Dataset", "DataLoader"] 11 | -------------------------------------------------------------------------------- /elegy/data/array_adapter.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.engine.data_adapter.py 2 | # https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/engine/data_adapter.py 3 | 4 | 5 | import math 6 | import typing as tp 7 | from operator import itemgetter 8 | 9 | import jax.numpy as jnp 10 | import numpy as np 11 | 12 | from elegy import types 13 | 14 | from .data_adapter import DataAdapter 15 | from .utils import flatten, map_structure, pack_x_y_sample_weight 16 | 17 | DEFAULT_BATCH_SIZE = 32 18 | 19 | 20 | class ArrayDataAdapter(DataAdapter): 21 | """Adapter that handles NumPy and Jax numpy arrays.""" 22 | 23 | @staticmethod 24 | def can_handle(x, y=None): 25 | flat_inputs = list(flatten(x)) 26 | if y is not None: 27 | flat_inputs += list(flatten(y)) 28 | 29 | supported_types = (jnp.ndarray, np.ndarray) 30 | # if pd: 31 | # supported_types = (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame) 32 | 33 | def _is_array(v): 34 | if isinstance(v, supported_types): 35 | return True 36 | return False 37 | 38 | return all(_is_array(v) for v in flat_inputs) 39 | 40 | def __init__( 41 | self, 42 | x: types.ArrayHolder, 43 | y: tp.Union[types.ArrayHolder, None] = None, 44 | sample_weights: tp.Union[jnp.ndarray, np.ndarray, None] = None, 45 | batch_size: tp.Optional[int] = None, 46 | epochs: int = 1, 47 | steps: tp.Optional[int] = None, 48 | shuffle: bool = False, 49 | drop_remainder: bool = False, 50 | **kwargs, 51 | ): 52 | super(ArrayDataAdapter, self).__init__(x, y, **kwargs) 53 | 54 | inputs = pack_x_y_sample_weight(x, y, sample_weights) 55 | 56 | num_samples = set(int(i.shape[0]) for i in flatten(inputs)) 57 | 58 | if len(num_samples) > 1: 59 | msg = "Data cardinality is ambiguous:\n" 60 | for label, data in zip(["x", "y", "sample_weight"], inputs): 61 | msg += " {} sizes: {}\n".format( 62 | label, ", ".join(str(i.shape[0]) for i in data) 63 | ) 64 | msg += "Please provide data which shares the same first dimension." 65 | raise ValueError(msg) 66 | 67 | num_samples = ( 68 | num_samples.pop() 69 | if num_samples 70 | else batch_size 71 | if batch_size is not None 72 | else DEFAULT_BATCH_SIZE 73 | ) 74 | 75 | # If batch_size is not passed but steps is, calculate from the input data. 76 | if batch_size is None: 77 | # if batch_size is None and steps is None: 78 | # raise ValueError("Please provide either batch_size or steps") 79 | batch_size = ( 80 | int(math.ceil(num_samples / steps)) if steps else DEFAULT_BATCH_SIZE 81 | ) 82 | 83 | self._size = int(math.ceil(num_samples / batch_size)) 84 | self._batch_size = batch_size 85 | 86 | num_full_batches = int(num_samples // batch_size) 87 | self._partial_batch_size = num_samples % batch_size 88 | 89 | self._shuffle = shuffle 90 | 91 | dataset_indices = np.arange(num_samples) 92 | 93 | def dataset_generator(): 94 | while True: 95 | if shuffle: 96 | np.random.shuffle(dataset_indices) 97 | 98 | for batch in range( 99 | num_full_batches + int(self._partial_batch_size != 0) 100 | ): 101 | indices = dataset_indices[ 102 | batch * batch_size : (batch + 1) * batch_size 103 | ] 104 | 105 | # # Drop last batch 106 | # if drop_remainder and len(indices) < batch_size: 107 | # print("Dropping!") 108 | # continue 109 | inputs_slices = map_structure(itemgetter(indices), inputs) 110 | 111 | yield inputs_slices 112 | 113 | self._dataset = dataset_generator 114 | 115 | def get_dataset(self): 116 | return self._dataset 117 | 118 | def get_size(self): 119 | return self._size 120 | 121 | @property 122 | def batch_size(self): 123 | return self._batch_size 124 | 125 | def has_partial_batch(self): 126 | return self._partial_batch_size > 0 127 | 128 | @property 129 | def partial_batch_size(self): 130 | return self._partial_batch_size or None 131 | 132 | def should_recreate_iterator(self): 133 | # An infinite dataset is always created here. 134 | return False 135 | -------------------------------------------------------------------------------- /elegy/data/data_adapter.py: -------------------------------------------------------------------------------- 1 | # Implementation based on elegy.engine.data_adapter.py 2 | # https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/engine/data_adapter.py 3 | 4 | 5 | import abc 6 | import typing as tp 7 | 8 | import six 9 | 10 | 11 | @six.add_metaclass(abc.ABCMeta) 12 | class DataAdapter(object): 13 | """Base class for input data adapter. 14 | In order to simplify the training code path, all the input data 15 | object will be converted to a `generator` if possible. 16 | The sample usage of this class is like: 17 | 18 | ``` 19 | x = list(range(100)) 20 | adapter_cls = [ArrayDataAdapter, ..., ListsOfScalarsDataAdapter] 21 | applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)] 22 | if len(applicable_adapters) != 1: 23 | raise ValueError("Expect only one adapter class to handle the input") 24 | dataset = applicable_adapters[0](x).get_dataset() 25 | for data in dataset(): 26 | # training 27 | ``` 28 | """ 29 | 30 | @staticmethod 31 | def can_handle(x, y=None): 32 | """Whether the current DataAdapter could handle the input x and y. 33 | Structure wise, x and y can be single object, or list of objects if there 34 | multiple input/output, or dictionary of objects when the input/output are 35 | named. 36 | Arguments: 37 | x: input features. 38 | y: target labels. Note that y could be None in the case of prediction. 39 | Returns: 40 | boolean 41 | """ 42 | raise NotImplementedError 43 | 44 | @abc.abstractmethod 45 | def __init__(self, x, y=None, **kwargs): 46 | """Create a DataAdapter based on data inputs. 47 | The caller must make sure to call `can_handle()` first before invoking this 48 | method. Provide unsupported data type will result into unexpected behavior. 49 | 50 | Arguments: 51 | x: input features. 52 | y: target labels. Note that y could be None in the case of prediction. 53 | **kwargs: Other keyword arguments for DataAdapter during the construction 54 | of the generator. For example: 55 | 56 | - Numpy data might have `sample_weights` which will be used for 57 | weighting the loss function during training. 58 | - Numpy data might need to have `batch_size` parameter when constructing 59 | the dataset and iterator. 60 | 61 | DataAdapter might choose to ignore any keyword argument if it doesn't 62 | use it, or raise exception if any required argument is not provide. 63 | """ 64 | if not self.can_handle(x, y): 65 | raise ValueError( 66 | "{} Cannot handle input {}, {}".format(self.__class__, x, y) 67 | ) 68 | 69 | @abc.abstractmethod 70 | def get_dataset(self): 71 | """Get a function that returns a generator for the current DataAdapter. 72 | Note that the generator wrapped in the function will repeat for each epoch, 73 | so the steps for traversing it should be known. 74 | Returns: 75 | An function wrapping a generator. 76 | """ 77 | raise NotImplementedError 78 | 79 | @abc.abstractmethod 80 | def get_size(self): 81 | """Return the size (number of batches) for the dataset created. 82 | For certain type of the data input, the number of batches is known, eg for 83 | Numpy data, the size is same as (number_of_element / batch_size). Whereas 84 | for python generator, the size is unknown since it may or may not 85 | have a end state. 86 | Returns: 87 | int, the number of batches for the dataset, or None if it is unknown. The 88 | caller could use this to control the loop of training, show progress bar, 89 | or handle unexpected StopIteration error. 90 | """ 91 | raise NotImplementedError 92 | 93 | @abc.abstractmethod 94 | def batch_size(self): 95 | """Return the batch size of the dataset created. 96 | For certain type of the data input, the batch size is known, and even 97 | required, like numpy array. Where as for generator, the batch is unknown 98 | unless we take a peek. 99 | Returns: 100 | int, the batch size of the dataset, or None if it is unknown. 101 | """ 102 | raise NotImplementedError 103 | 104 | def representative_batch_size(self): 105 | """Return a representative size for batches in the dataset. 106 | This is not guaranteed to be the batch size for all batches in the 107 | dataset. It just needs to be a rough approximation for batch sizes in 108 | the dataset. 109 | Returns: 110 | int, a representative size for batches found in the dataset, 111 | or None if it is unknown. 112 | """ 113 | return self.batch_size() 114 | 115 | @abc.abstractmethod 116 | def has_partial_batch(self): 117 | """Whether the dataset has partial batch at the end.""" 118 | raise NotImplementedError 119 | 120 | @abc.abstractmethod 121 | def partial_batch_size(self): 122 | """The size of the final partial batch for dataset. 123 | Will return None if has_partial_batch is False or batch_size is None. 124 | """ 125 | raise NotImplementedError 126 | 127 | @abc.abstractmethod 128 | def should_recreate_iterator(self): 129 | """Returns whether a new iterator should be created every epoch.""" 130 | raise NotImplementedError 131 | 132 | def get_samples(self): 133 | """Returns number of samples in the data, or `None`.""" 134 | if not self.get_size() or not self.batch_size(): 135 | return None 136 | total_sample = self.get_size() * self.batch_size() 137 | if self.has_partial_batch(): 138 | total_sample -= self.batch_size() - self.partial_batch_size() 139 | return total_sample 140 | 141 | def on_epoch_end(self): 142 | """A hook called after each epoch.""" 143 | pass 144 | -------------------------------------------------------------------------------- /elegy/data/generator_adapter.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import typing as tp 3 | 4 | from .data_adapter import DataAdapter 5 | from .utils import ( 6 | assert_not_namedtuple, 7 | flatten, 8 | is_none_or_empty, 9 | pack_x_y_sample_weight, 10 | unpack_x_y_sample_weight, 11 | ) 12 | 13 | 14 | class GeneratorDataAdapter(DataAdapter): 15 | """Adapter that handles python generators and iterators.""" 16 | 17 | @staticmethod 18 | def can_handle(x, y=None): 19 | return (hasattr(x, "__next__") or hasattr(x, "next")) and hasattr(x, "__iter__") 20 | 21 | def __init__( 22 | self, 23 | x: tp.Union[tp.Iterable], 24 | y=None, 25 | sample_weights=None, 26 | **kwargs, 27 | ): 28 | # Generators should never shuffle as exhausting the generator in order to 29 | # shuffle the batches is inefficient. 30 | kwargs.pop("shuffle", None) 31 | 32 | if not is_none_or_empty(y): 33 | raise ValueError( 34 | "`y` argument is not supported when using " "python generator as input." 35 | ) 36 | if not is_none_or_empty(sample_weights): 37 | raise ValueError( 38 | "`sample_weight` argument is not supported when using " 39 | "python generator as input." 40 | ) 41 | 42 | super(GeneratorDataAdapter, self).__init__(x, y, **kwargs) 43 | 44 | # Since we have to know the dtype of the python generator when we build the 45 | # dataset, we have to look at a batch to infer the structure. 46 | peek, x = self._peek_and_restore(x) 47 | assert_not_namedtuple(peek) 48 | peek = self._standardize_batch(peek) 49 | 50 | self._first_batch_size = int(list(flatten(peek))[0].shape[0]) 51 | 52 | def wrapped_generator(): 53 | for data in x: 54 | yield self._standardize_batch(data) 55 | 56 | dataset = wrapped_generator 57 | 58 | self._dataset = dataset 59 | 60 | def _standardize_batch(self, data): 61 | """Standardizes a batch output by a generator.""" 62 | # Removes `None`s. 63 | x, y, sample_weight = unpack_x_y_sample_weight(data) 64 | data = pack_x_y_sample_weight(x, y, sample_weight) 65 | 66 | return data 67 | 68 | @staticmethod 69 | def _peek_and_restore(x): 70 | peek = next(x) 71 | return peek, itertools.chain([peek], x) 72 | 73 | def get_dataset(self): 74 | return self._dataset 75 | 76 | def get_size(self): 77 | return None 78 | 79 | @property 80 | def batch_size(self): 81 | return self.representative_batch_size 82 | 83 | @property 84 | def representative_batch_size(self): 85 | return self._first_batch_size 86 | 87 | def has_partial_batch(self): 88 | return False 89 | 90 | @property 91 | def partial_batch_size(self): 92 | return 93 | 94 | def should_recreate_iterator(self): 95 | return False 96 | -------------------------------------------------------------------------------- /elegy/data/list_adapter.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.engine.data_adapter.py 2 | # https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/engine/data_adapter.py 3 | 4 | 5 | import typing as tp 6 | 7 | import numpy as np 8 | 9 | from .array_adapter import ArrayDataAdapter 10 | from .data_adapter import DataAdapter 11 | 12 | scalar_types = (float, int, str) 13 | 14 | 15 | class ListsOfScalarsDataAdapter(DataAdapter): 16 | """Adapter that handles lists of scalars and lists of lists of scalars.""" 17 | 18 | @staticmethod 19 | def can_handle(x, y=None): 20 | handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x) 21 | handles_y = True 22 | if y is not None: 23 | handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y) 24 | return handles_x and handles_y 25 | 26 | @staticmethod 27 | def _is_list_of_scalars(inp): 28 | if isinstance(inp, scalar_types): 29 | return True 30 | if isinstance(inp, (list, tuple)): 31 | return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0]) 32 | return False 33 | 34 | def __init__( 35 | self, x, y=None, sample_weights=None, batch_size=None, shuffle=False, **kwargs 36 | ): 37 | super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs) 38 | x = np.asarray(x) 39 | if y is not None: 40 | y = np.asarray(y) 41 | if sample_weights is not None: 42 | sample_weights = np.asarray(sample_weights) 43 | 44 | self._internal_adapter = ArrayDataAdapter( 45 | x, 46 | y=y, 47 | sample_weights=sample_weights, 48 | batch_size=batch_size, 49 | shuffle=shuffle, 50 | **kwargs 51 | ) 52 | 53 | def get_dataset(self): 54 | return self._internal_adapter.get_dataset() 55 | 56 | def get_size(self): 57 | return self._internal_adapter.get_size() 58 | 59 | @property 60 | def batch_size(self): 61 | return self._internal_adapter.batch_size 62 | 63 | def has_partial_batch(self): 64 | return self._internal_adapter.has_partial_batch() 65 | 66 | @property 67 | def partial_batch_size(self): 68 | return self._internal_adapter.partial_batch_size 69 | 70 | def should_recreate_iterator(self): 71 | return True 72 | -------------------------------------------------------------------------------- /elegy/data/tf_dataset_adapter.py: -------------------------------------------------------------------------------- 1 | # Implementation based on tf.keras.engine.data_adapter.py 2 | # https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/engine/data_adapter.py 3 | 4 | 5 | from tensorflow.python.data.experimental.ops import cardinality 6 | from tensorflow.python.data.ops import dataset_ops 7 | 8 | from .data_adapter import DataAdapter 9 | from .utils import flatten, is_none_or_empty, map_structure 10 | 11 | 12 | class TFDatasetAdapter(DataAdapter): 13 | """Adapter that handles `tf.data.Dataset`.""" 14 | 15 | @staticmethod 16 | def can_handle(x, y=None): 17 | return isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) 18 | 19 | def __init__(self, x, y=None, sample_weights=None, steps=None, **kwargs): 20 | super().__init__(x, y, **kwargs) 21 | # Note that the dataset instance is immutable, its fine to reuse the user 22 | # provided dataset. 23 | self._dataset = x 24 | 25 | # The user-provided steps. 26 | self._user_steps = steps 27 | 28 | self._validate_args(y, sample_weights, steps) 29 | 30 | # Since we have to know the dtype of the dataset when we build the 31 | # dataset, we have to look at a batch to infer the structure. 32 | peek = next(iter(x)) 33 | 34 | self._first_batch_size = int(list(flatten(peek))[0].shape[0]) 35 | 36 | def get_dataset(self): 37 | def parse_tf_data_gen(): 38 | for batch in iter(self._dataset): 39 | batch = map_structure(lambda x: x.numpy(), batch) 40 | yield batch 41 | 42 | return parse_tf_data_gen 43 | 44 | def get_size(self): 45 | size = cardinality.cardinality(self._dataset) 46 | if size == cardinality.INFINITE and self._user_steps is None: 47 | raise ValueError( 48 | "When passing an infinitely repeating tf.data.Dataset, you " 49 | "must specify how many steps to draw." 50 | ) 51 | elif size == cardinality.INFINITE: 52 | return self._user_steps 53 | elif size >= 0: 54 | return size.numpy().item() 55 | 56 | @property 57 | def batch_size(self): 58 | return self.representative_batch_size 59 | 60 | @property 61 | def representative_batch_size(self): 62 | return self._first_batch_size 63 | 64 | @property 65 | def partial_batch_size(self): 66 | return 67 | 68 | def has_partial_batch(self): 69 | return False 70 | 71 | def should_recreate_iterator(self): 72 | # If user doesn't supply `steps`, or if they supply `steps` that 73 | # exactly equals the size of the `Dataset`, create a new iterator 74 | # each epoch. 75 | return ( 76 | self._user_steps is None 77 | or cardinality.cardinality(self._dataset).numpy() == self._user_steps 78 | ) 79 | 80 | def _validate_args(self, y, sample_weights, steps): 81 | """Validates `__init__` arguments.""" 82 | # Arguments that shouldn't be passed. 83 | if not is_none_or_empty(y): 84 | raise ValueError( 85 | "`y` argument is not supported when using " "tf.Data.dataset as input." 86 | ) 87 | if not is_none_or_empty(sample_weights): 88 | raise ValueError( 89 | "`sample_weight` argument is not supported when using " 90 | "tf.Data.dataset as input." 91 | ) 92 | 93 | size = cardinality.cardinality(self._dataset).numpy() 94 | if size == cardinality.INFINITE and steps is None: 95 | raise ValueError( 96 | "When providing an infinitely repeating tf.data.Dataset, you must specify " 97 | "the number of steps to run." 98 | ) 99 | -------------------------------------------------------------------------------- /elegy/data/torch_dataloader_adapter.py: -------------------------------------------------------------------------------- 1 | from jax._src.lax.lax import remaining 2 | from torch.utils.data import DataLoader 3 | 4 | from .data_adapter import DataAdapter 5 | from .utils import is_none_or_empty, list_to_tuple, map_structure 6 | 7 | 8 | class TorchDataLoaderAdapter(DataAdapter): 9 | """Adapter that handles torch Dataloaders.""" 10 | 11 | @staticmethod 12 | def can_handle(x, y=None): 13 | return isinstance(x, DataLoader) 14 | 15 | def __init__( 16 | self, 17 | x: DataLoader, 18 | y=None, 19 | steps=None, 20 | sample_weights=None, 21 | training=True, 22 | **kwargs, 23 | ): 24 | 25 | if not is_none_or_empty(y): 26 | raise ValueError( 27 | "`y` argument is not supported when using " "torch Dataloader as input." 28 | ) 29 | if not is_none_or_empty(sample_weights): 30 | raise ValueError( 31 | "`sample_weight` argument is not supported when using " 32 | "torch Dataloader as input." 33 | ) 34 | 35 | super().__init__(x, y, **kwargs) 36 | 37 | self.training = training 38 | self.steps = steps 39 | self._batch_size = x.batch_size 40 | self._dataset = x 41 | 42 | self.current_step = 0 43 | 44 | def get_dataset(self): 45 | def parse_dataloader_gen(): 46 | self.current_step = 0 47 | for batch in iter(self._dataset): 48 | self.current_step += 1 49 | batch = map_structure(lambda x: x.cpu().numpy(), list_to_tuple(batch)) 50 | yield batch 51 | 52 | return parse_dataloader_gen 53 | 54 | def get_size(self): 55 | try: 56 | return len(self._dataset) 57 | except Exception: 58 | return None 59 | 60 | @property 61 | def batch_size(self): 62 | return self.representative_batch_size 63 | 64 | @property 65 | def representative_batch_size(self): 66 | return self._batch_size 67 | 68 | def has_partial_batch(self): 69 | return False 70 | 71 | @property 72 | def partial_batch_size(self): 73 | return 74 | 75 | def should_recreate_iterator(self): 76 | # if in eval mode should not recreate iterator 77 | # but if in train mode and steps not provided, should recreate at end of each epoch 78 | if not self.training or self.steps is None: 79 | return self.training 80 | 81 | steps_dataset = self.get_size() 82 | if steps_dataset is None: 83 | return False 84 | 85 | remaining_steps = steps_dataset - self.current_step 86 | # if remaining steps less than needed steps, should recreate dataloader 87 | # TODO: This will drop the last steps of data, how to avoid this? 88 | if remaining_steps < self.steps: 89 | return True 90 | else: 91 | return False 92 | -------------------------------------------------------------------------------- /elegy/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poets-ai/elegy/4709ce8dc9dde3925ce717e2358ce49112e36398/elegy/model/__init__.py -------------------------------------------------------------------------------- /elegy/nets/__init__.py: -------------------------------------------------------------------------------- 1 | # from .resnet import ( 2 | # ResNet, 3 | # ResNet18, 4 | # ResNet34, 5 | # ResNet50, 6 | # ResNet101, 7 | # ResNet152, 8 | # ResNet200, 9 | # ) 10 | 11 | # __all__ = [ 12 | # "ResNet", 13 | # "ResNet18", 14 | # "ResNet34", 15 | # "ResNet50", 16 | # "ResNet101", 17 | # "ResNet152", 18 | # "ResNet200", 19 | # ] 20 | -------------------------------------------------------------------------------- /elegy/types.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import typing as tp 3 | from copy import copy 4 | from dataclasses import dataclass, field 5 | from enum import Enum 6 | from functools import total_ordering 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import jax.tree_util 11 | import numpy as np 12 | import treex as tx 13 | 14 | EPSILON = 1e-7 15 | F = tp.TypeVar("F", bound=tp.Callable) 16 | 17 | 18 | KeySeq = tx.KeySeq 19 | 20 | T = tp.TypeVar("T") 21 | Container = tp.Union[ 22 | T, 23 | tp.Tuple["Container", ...], 24 | tp.Dict[str, "Container"], 25 | ] 26 | ArrayHolder = tp.Union[Container[np.ndarray], np.ndarray] 27 | 28 | IndexLike = tp.Union[str, int, tp.Iterable[tp.Union[str, int]]] 29 | 30 | Shape = tp.Sequence[int] 31 | ShapeLike = tp.Union[int, Shape] 32 | FloatLike = tp.Union[float, np.ndarray] 33 | DType = tp.Any 34 | ParamName = str 35 | 36 | Params = tp.Mapping[str, tp.Mapping[ParamName, np.ndarray]] 37 | State = tp.Mapping[str, tp.Mapping[str, np.ndarray]] 38 | PadFn = tp.Callable[[int], tp.Tuple[int, int]] 39 | PadFnOrFns = tp.Union[PadFn, tp.Sequence[PadFn]] 40 | PRNGKey = np.ndarray 41 | Parameters = tp.Dict[str, tp.Any] 42 | Labels = tp.Mapping[str, tp.Any] 43 | ParameterCollection = tp.Dict[str, Parameters] 44 | Logs = tp.Dict[str, jnp.ndarray] 45 | Index = tp.Union[int, str] 46 | Path = tp.Tuple[Index, ...] 47 | Grads = tp.Any 48 | RNG = tp.Union[KeySeq, np.ndarray] 49 | Scalar = tp.Union[np.ndarray, float, int] 50 | SummaryModule = tp.Any 51 | SummaryValue = tp.Any 52 | 53 | NetParams = tp.Any 54 | NetStates = tp.Any 55 | ModuleParams = tp.Any 56 | ModuleStates = tp.Any 57 | MetricsStates = tp.Any 58 | OptimizerStates = tp.Any 59 | OptimizerStates = tp.Any 60 | Grads = tp.Any 61 | Pytree = tp.Any 62 | 63 | 64 | class MissingModule(Exception): 65 | pass 66 | 67 | 68 | class MissingOptimizer(Exception): 69 | pass 70 | 71 | 72 | class MissingMethod(Exception): 73 | pass 74 | 75 | 76 | class DependencyUnavailable(Exception): 77 | pass 78 | 79 | 80 | class ShapeMismatch(Exception): 81 | pass 82 | 83 | 84 | class MissingParameter(Exception): 85 | pass 86 | 87 | 88 | class NoContext(Exception): 89 | pass 90 | 91 | 92 | class ModuleOrderError(Exception): 93 | pass 94 | 95 | 96 | class SubmoduleNotRegistered(Exception): 97 | pass 98 | 99 | 100 | class ModelNotInitialized(Exception): 101 | pass 102 | -------------------------------------------------------------------------------- /examples/elegy/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from functools import partial 4 | from typing import Any, Generator, Mapping, Tuple 5 | 6 | import einops 7 | import jax 8 | import jax.numpy as jnp 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import optax 12 | import typer 13 | from datasets.load import load_dataset 14 | from tensorboardX.writer import SummaryWriter 15 | 16 | import elegy as eg 17 | 18 | 19 | class MLP(eg.Module): 20 | def __init__(self, n1: int = 300, n2: int = 100): 21 | self.n1 = n1 22 | self.n2 = n2 23 | 24 | @eg.compact 25 | def __call__(self, x: jnp.ndarray): 26 | x = x.astype(jnp.float32) / 255.0 27 | x = einops.rearrange(x, "batch ... -> batch (...)") 28 | x = eg.nn.Linear(self.n1)(x) 29 | x = jax.nn.relu(x) 30 | x = eg.nn.Linear(self.n2)(x) 31 | x = jax.nn.relu(x) 32 | x = eg.nn.Linear(10)(x) 33 | return x 34 | 35 | 36 | def main( 37 | debug: bool = False, 38 | eager: bool = False, 39 | logdir: str = "runs", 40 | steps_per_epoch: int = 200, 41 | batch_size: int = 64, 42 | epochs: int = 100, 43 | ): 44 | 45 | if debug: 46 | import debugpy 47 | 48 | print("Waiting for debugger...") 49 | debugpy.listen(5678) 50 | debugpy.wait_for_client() 51 | 52 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 53 | logdir = os.path.join(logdir, current_time) 54 | 55 | dataset = load_dataset("mnist") 56 | dataset.set_format("np") 57 | X_train = np.stack(dataset["train"]["image"]) 58 | y_train = dataset["train"]["label"] 59 | X_test = np.stack(dataset["test"]["image"]) 60 | y_test = dataset["test"]["label"] 61 | 62 | print("X_train:", X_train.shape, X_train.dtype) 63 | print("y_train:", y_train.shape, y_train.dtype) 64 | print("X_test:", X_test.shape, X_test.dtype) 65 | print("y_test:", y_test.shape, y_test.dtype) 66 | 67 | model = eg.Model( 68 | module=MLP(n1=300, n2=100), 69 | loss=[ 70 | eg.losses.Crossentropy(), 71 | eg.regularizers.L2(l=1e-4), 72 | ], 73 | metrics=eg.metrics.Accuracy(), 74 | optimizer=optax.adamw(1e-3), 75 | eager=eager, 76 | ) 77 | 78 | model.summary(X_train[:64]) 79 | 80 | history = model.fit( 81 | inputs=X_train, 82 | labels=y_train, 83 | epochs=epochs, 84 | steps_per_epoch=steps_per_epoch, 85 | batch_size=batch_size, 86 | validation_data=(X_test, y_test), 87 | shuffle=True, 88 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 89 | ) 90 | 91 | eg.utils.plot_history(history) 92 | 93 | # get random samples 94 | idxs = np.random.randint(0, 10000, size=(9,)) 95 | x_sample = X_test[idxs] 96 | 97 | # get predictions 98 | y_pred = model.predict(x=x_sample) 99 | 100 | # plot and save results 101 | with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: 102 | figure = plt.figure(figsize=(12, 12)) 103 | for i in range(3): 104 | for j in range(3): 105 | k = 3 * i + j 106 | plt.subplot(3, 3, k + 1) 107 | plt.title(f"{np.argmax(y_pred[k])}") 108 | plt.imshow(x_sample[k], cmap="gray") 109 | # tbwriter.add_figure("Predictions", figure, 100) 110 | 111 | plt.show() 112 | 113 | print( 114 | "\n\n\nMetrics and images can be explored using tensorboard using:", 115 | f"\n \t\t\t tensorboard --logdir {logdir}", 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | typer.run(main) 121 | -------------------------------------------------------------------------------- /examples/elegy/mnist_autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from typing import Any, Generator, Mapping, Tuple 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import optax 10 | import typer 11 | from datasets.load import load_dataset 12 | from tensorboardX.writer import SummaryWriter 13 | 14 | import elegy as eg 15 | 16 | 17 | class MeanSquaredError(eg.losses.MeanSquaredError): 18 | # we request `x` instead of `y_true` since we are don't require labels in autoencoders 19 | def call(self, inputs, preds): 20 | return super().call(target=inputs, preds=preds) / 255 21 | 22 | 23 | class MLP(eg.Module): 24 | def __init__(self, n1: int = 300, n2: int = 100, **kwargs): 25 | super().__init__(**kwargs) 26 | self.n1 = n1 27 | self.n2 = n2 28 | 29 | @eg.compact 30 | def __call__(self, image: jnp.ndarray): 31 | x = image.astype(jnp.float32) / 255.0 32 | x = eg.Flatten()(x) 33 | x = eg.Linear(self.n1)(x) 34 | x = jax.nn.relu(x) 35 | x = eg.Linear(self.n2)(x) 36 | x = jax.nn.relu(x) 37 | x = eg.Linear(self.n1)(x) 38 | x = jax.nn.relu(x) 39 | x = eg.Linear(np.prod(image.shape[-2:]))(x) 40 | x = jax.nn.sigmoid(x) * 255 41 | x = x.reshape(image.shape) 42 | 43 | return x 44 | 45 | 46 | def main( 47 | debug: bool = False, 48 | eager: bool = False, 49 | logdir: str = "runs", 50 | steps_per_epoch: int = 200, 51 | epochs: int = 100, 52 | batch_size: int = 64, 53 | ): 54 | 55 | if debug: 56 | import debugpy 57 | 58 | print("Waiting for debugger...") 59 | debugpy.listen(5678) 60 | debugpy.wait_for_client() 61 | 62 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 63 | logdir = os.path.join(logdir, current_time) 64 | 65 | dataset = load_dataset("mnist") 66 | dataset.set_format("np") 67 | X_train = np.stack(dataset["train"]["image"]) 68 | X_test = np.stack(dataset["test"]["image"]) 69 | 70 | print("X_train:", X_train.shape, X_train.dtype) 71 | print("X_test:", X_test.shape, X_test.dtype) 72 | 73 | model = eg.Model( 74 | module=MLP(n1=256, n2=64), 75 | loss=MeanSquaredError(), 76 | optimizer=optax.rmsprop(0.001), 77 | eager=eager, 78 | ) 79 | 80 | model.summary(X_train[:64]) 81 | 82 | # Notice we are not passing `y` 83 | history = model.fit( 84 | inputs=X_train, 85 | epochs=epochs, 86 | steps_per_epoch=steps_per_epoch, 87 | batch_size=batch_size, 88 | validation_data=(X_test,), 89 | shuffle=True, 90 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir, update_freq=300)], 91 | ) 92 | 93 | eg.utils.plot_history(history) 94 | 95 | # get random samples 96 | idxs = np.random.randint(0, 10000, size=(5,)) 97 | x_sample = X_test[idxs] 98 | 99 | # get predictions 100 | y_pred = model.predict(x=x_sample) 101 | 102 | # plot and save results 103 | with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: 104 | 105 | figure = plt.figure(figsize=(12, 12)) 106 | for i in range(5): 107 | plt.subplot(2, 5, i + 1) 108 | plt.imshow(x_sample[i], cmap="gray") 109 | plt.subplot(2, 5, 5 + i + 1) 110 | plt.imshow(y_pred[i], cmap="gray") 111 | 112 | plt.show() 113 | 114 | 115 | if __name__ == "__main__": 116 | typer.run(main) 117 | -------------------------------------------------------------------------------- /examples/elegy/mnist_conv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as tp 3 | from dataclasses import dataclass 4 | from datetime import datetime 5 | from functools import partial 6 | from typing import Any, Generator, Mapping, Tuple 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import optax 13 | import typer 14 | from datasets.load import load_dataset 15 | from tensorboardX.writer import SummaryWriter 16 | 17 | import elegy as eg 18 | 19 | 20 | class CNN(eg.Module): 21 | @eg.compact 22 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 23 | # Normalize the input 24 | x = x.astype(jnp.float32) / 255.0 25 | 26 | # Block 1 27 | x = eg.Conv(32, [3, 3], strides=[2, 2])(x) 28 | x = eg.Dropout(0.05)(x) 29 | x = jax.nn.relu(x) 30 | 31 | # Block 2 32 | x = eg.Conv(64, [3, 3], strides=[2, 2])(x) 33 | x = eg.BatchNorm()(x) 34 | x = eg.Dropout(0.1)(x) 35 | x = jax.nn.relu(x) 36 | 37 | # Block 3 38 | x = eg.Conv(128, [3, 3], strides=[2, 2])(x) 39 | 40 | # Global average pooling 41 | x = x.mean(axis=(1, 2)) 42 | 43 | # Classification layer 44 | x = eg.Linear(10)(x) 45 | 46 | return x 47 | 48 | 49 | def main( 50 | debug: bool = False, 51 | eager: bool = False, 52 | logdir: str = "runs", 53 | steps_per_epoch: tp.Optional[int] = None, 54 | epochs: int = 100, 55 | batch_size: int = 32, 56 | distributed: bool = False, 57 | ): 58 | 59 | if debug: 60 | import debugpy 61 | 62 | print("Waiting for debugger...") 63 | debugpy.listen(5678) 64 | debugpy.wait_for_client() 65 | 66 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 67 | logdir = os.path.join(logdir, current_time) 68 | 69 | dataset = load_dataset("mnist") 70 | dataset.set_format("np") 71 | X_train = np.stack(dataset["train"]["image"])[..., None] 72 | y_train = dataset["train"]["label"] 73 | X_test = np.stack(dataset["test"]["image"])[..., None] 74 | y_test = dataset["test"]["label"] 75 | 76 | print("X_train:", X_train.shape, X_train.dtype) 77 | print("y_train:", y_train.shape, y_train.dtype) 78 | print("X_test:", X_test.shape, X_test.dtype) 79 | print("y_test:", y_test.shape, y_test.dtype) 80 | 81 | model = eg.Model( 82 | module=CNN(), 83 | loss=eg.losses.Crossentropy(), 84 | metrics=eg.metrics.Accuracy(), 85 | optimizer=optax.adam(1e-3), 86 | eager=eager, 87 | ) 88 | 89 | if distributed: 90 | model = model.distributed() 91 | 92 | # show model summary 93 | model.summary(X_train[:64], depth=1) 94 | 95 | history = model.fit( 96 | inputs=X_train, 97 | labels=y_train, 98 | epochs=epochs, 99 | steps_per_epoch=steps_per_epoch, 100 | batch_size=batch_size, 101 | validation_data=(X_test, y_test), 102 | shuffle=True, 103 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 104 | ) 105 | 106 | eg.utils.plot_history(history) 107 | 108 | print(model.evaluate(x=X_test, y=y_test)) 109 | 110 | # get random samples 111 | idxs = np.random.randint(0, 10000, size=(9,)) 112 | x_sample = X_test[idxs] 113 | 114 | # get predictions 115 | model = model.local() 116 | y_pred = model.predict(x=x_sample) 117 | 118 | # plot results 119 | figure = plt.figure(figsize=(12, 12)) 120 | for i in range(3): 121 | for j in range(3): 122 | k = 3 * i + j 123 | plt.subplot(3, 3, k + 1) 124 | 125 | plt.title(f"{np.argmax(y_pred[k])}") 126 | plt.imshow(x_sample[k], cmap="gray") 127 | 128 | plt.show() 129 | 130 | 131 | if __name__ == "__main__": 132 | typer.run(main) 133 | -------------------------------------------------------------------------------- /examples/elegy/mnist_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from datetime import datetime 4 | from typing import Any, Generator, Mapping, Tuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import optax 11 | import typer 12 | from datasets.load import load_dataset 13 | from tensorboardX.writer import SummaryWriter 14 | 15 | import elegy as eg 16 | 17 | 18 | class MNIST(eg.data.Dataset): 19 | def __init__(self, training: bool = True): 20 | 21 | dataset = load_dataset("mnist") 22 | dataset.set_format("np") 23 | X_train = np.stack(dataset["train"]["image"]) 24 | y_train = dataset["train"]["label"] 25 | X_test = np.stack(dataset["test"]["image"]) 26 | y_test = dataset["test"]["label"] 27 | 28 | if training: 29 | self.x = X_train 30 | self.y = y_train 31 | else: 32 | self.x = X_test 33 | self.y = y_test 34 | 35 | def __len__(self): 36 | return len(self.x) 37 | 38 | def __getitem__(self, i): 39 | return (self.x[i], self.y[i]) 40 | 41 | 42 | def main( 43 | debug: bool = False, 44 | eager: bool = False, 45 | logdir: str = "runs", 46 | steps_per_epoch: int = 200, 47 | epochs: int = 100, 48 | batch_size: int = 64, 49 | ): 50 | 51 | if debug: 52 | import debugpy 53 | 54 | print("Waiting for debugger...") 55 | debugpy.listen(5678) 56 | debugpy.wait_for_client() 57 | 58 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 59 | logdir = os.path.join(logdir, current_time) 60 | 61 | train_dataset = MNIST(training=True) 62 | test_dataset = MNIST(training=False) 63 | train_loader = eg.data.DataLoader( 64 | train_dataset, batch_size=batch_size, shuffle=True 65 | ) 66 | test_loader = eg.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 67 | 68 | print("X_train:", train_dataset.x.shape, train_dataset.x.dtype) 69 | print("y_train:", train_dataset.y.shape, train_dataset.y.dtype) 70 | print("X_test:", test_dataset.x.shape, test_dataset.x.dtype) 71 | print("y_test:", test_dataset.y.shape, test_dataset.y.dtype) 72 | 73 | @dataclass(unsafe_hash=True, repr=False) 74 | class MLP(eg.Module): 75 | """Standard LeNet-300-100 MLP network.""" 76 | 77 | n1: int = 300 78 | n2: int = 100 79 | 80 | @eg.compact 81 | def __call__(self, x: jnp.ndarray): 82 | x = x.astype(jnp.float32) / 255.0 83 | 84 | x = eg.Flatten()(x) 85 | x = eg.Linear(self.n1)(x) 86 | x = jax.nn.relu(x) 87 | x = eg.Linear(self.n2)(x) 88 | x = jax.nn.relu(x) 89 | x = eg.Linear(10)(x) 90 | 91 | return x 92 | 93 | model = eg.Model( 94 | module=MLP(n1=300, n2=100), 95 | loss=[ 96 | eg.losses.Crossentropy(), 97 | eg.regularizers.L2(l=1e-4), 98 | ], 99 | metrics=eg.metrics.Accuracy(), 100 | optimizer=optax.adamw(1e-3), 101 | eager=eager, 102 | ) 103 | 104 | x_sample, y_sample = next(iter(train_loader)) 105 | model.summary(x_sample) 106 | 107 | history = model.fit( 108 | inputs=train_loader, 109 | epochs=epochs, 110 | steps_per_epoch=steps_per_epoch, 111 | validation_data=test_loader, 112 | shuffle=True, 113 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 114 | ) 115 | 116 | eg.utils.plot_history(history) 117 | 118 | # get random samples 119 | idxs = np.random.randint(0, 10000, size=(9,)) 120 | x_sample, y_sample = next(iter(test_loader)) 121 | 122 | # get predictions 123 | y_pred = model.predict(x=x_sample) 124 | 125 | # plot and save results 126 | def make_plot(): 127 | plt.figure(figsize=(12, 12)) 128 | for i in range(3): 129 | for j in range(3): 130 | k = 3 * i + j 131 | plt.subplot(3, 3, k + 1) 132 | plt.title(f"{np.argmax(y_pred[k])}") 133 | plt.imshow(x_sample[k], cmap="gray") 134 | 135 | with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: 136 | make_plot() 137 | # tbwriter.add_figure("Predictions", plt.gcf(), 100) 138 | 139 | make_plot() 140 | plt.show() 141 | 142 | print( 143 | "\n\n\nMetrics and images can be explored using tensorboard using:", 144 | f"\n \t\t\t tensorboard --logdir {logdir}", 145 | ) 146 | 147 | 148 | if __name__ == "__main__": 149 | typer.run(main) 150 | -------------------------------------------------------------------------------- /examples/elegy/mnist_tf_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from typing import Tuple 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import optax 10 | import tensorflow as tf 11 | import typer 12 | 13 | import elegy as eg 14 | 15 | 16 | @eg.compact_module 17 | def ConvBlock( 18 | x, 19 | units: int, 20 | kernel: Tuple[int, int], 21 | stride: int = 1, 22 | ): 23 | x = eg.Conv( 24 | units, 25 | kernel, 26 | strides=[stride, stride], 27 | padding="same", 28 | )(x) 29 | x = eg.BatchNorm()(x) 30 | x = eg.Dropout(0.2)(x) 31 | return jax.nn.relu(x) 32 | 33 | 34 | class CNN(eg.Module): 35 | @eg.compact 36 | def __call__(self, x: jnp.ndarray): 37 | # normalize 38 | x = x.astype(jnp.float32) / 255.0 39 | 40 | # base 41 | x = ConvBlock()(x, 32, (3, 3)) 42 | x = ConvBlock()(x, 64, (3, 3), stride=2) 43 | x = ConvBlock()(x, 64, (3, 3), stride=2) 44 | x = ConvBlock()(x, 128, (3, 3), stride=2) 45 | 46 | # GlobalAveragePooling2D 47 | x = jnp.mean(x, axis=(1, 2)) 48 | 49 | # 1x1 Conv 50 | x = eg.Linear(10)(x) 51 | 52 | return x 53 | 54 | 55 | def main( 56 | debug: bool = False, 57 | eager: bool = False, 58 | logdir: str = "runs", 59 | steps_per_epoch: int = 200, 60 | epochs: int = 100, 61 | batch_size: int = 64, 62 | ): 63 | 64 | if debug: 65 | import debugpy 66 | 67 | print("Waiting for debugger...") 68 | debugpy.listen(5678) 69 | debugpy.wait_for_client() 70 | 71 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 72 | logdir = os.path.join(logdir, current_time) 73 | 74 | (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() 75 | 76 | def preprocess_images(images): 77 | images = images.reshape((images.shape[0], 28, 28, 1)) / 255.0 78 | return images.astype("float32") 79 | 80 | X_train = preprocess_images(X_train) 81 | X_test = preprocess_images(X_test) 82 | 83 | print("X_train:", X_train.shape, X_train.dtype) 84 | print("y_train:", y_train.shape, y_train.dtype) 85 | print("X_test:", X_test.shape, X_test.dtype) 86 | print("y_test:", y_test.shape, y_test.dtype) 87 | 88 | model = eg.Model( 89 | module=CNN(), 90 | loss=eg.losses.Crossentropy(), 91 | metrics=eg.metrics.Accuracy(), 92 | optimizer=optax.adam(1e-3), 93 | eager=eager, 94 | ) 95 | 96 | # show summary 97 | model.summary(X_train[:64]) 98 | 99 | batch_size = 64 100 | train_size = 60000 101 | test_size = 10000 102 | # Create tf datasets 103 | train_dataset = ( 104 | tf.data.Dataset.from_tensor_slices((X_train, y_train)) 105 | .shuffle(train_size) 106 | .batch(batch_size) 107 | .repeat() 108 | ) 109 | test_dataset = ( 110 | tf.data.Dataset.from_tensor_slices((X_test, y_test)) 111 | .shuffle(test_size) 112 | .batch(batch_size) 113 | ) 114 | 115 | history = model.fit( 116 | train_dataset, 117 | epochs=epochs, 118 | steps_per_epoch=steps_per_epoch, 119 | validation_data=test_dataset, 120 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 121 | ) 122 | 123 | eg.utils.plot_history(history) 124 | 125 | model.save("models/conv") 126 | 127 | model = eg.load("models/conv") 128 | 129 | print(model.evaluate(x=X_test, y=y_test)) 130 | 131 | # get random samples 132 | idxs = np.random.randint(0, 10000, size=(9,)) 133 | x_sample = X_test[idxs] 134 | 135 | # get predictions 136 | y_pred = model.predict(x=x_sample) 137 | 138 | # plot results 139 | figure = plt.figure(figsize=(12, 12)) 140 | for i in range(3): 141 | for j in range(3): 142 | k = 3 * i + j 143 | plt.subplot(3, 3, k + 1) 144 | 145 | plt.title(f"{np.argmax(y_pred[k])}") 146 | plt.imshow(x_sample[k], cmap="gray") 147 | 148 | plt.show() 149 | 150 | 151 | if __name__ == "__main__": 152 | typer.run(main) 153 | -------------------------------------------------------------------------------- /examples/elegy/mnist_torch_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as tp 3 | from datetime import datetime 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import optax 10 | import torch 11 | import typer 12 | from datasets.load import load_dataset 13 | from tensorboardX.writer import SummaryWriter 14 | from torch.utils.data import DataLoader, TensorDataset 15 | 16 | import elegy as eg 17 | 18 | 19 | @eg.compact_module 20 | def ConvBlock( 21 | x, 22 | units: int, 23 | kernel: tp.Tuple[int, int], 24 | stride: int = 1, 25 | ): 26 | x = eg.Conv( 27 | units, 28 | kernel, 29 | strides=[stride, stride], 30 | padding="same", 31 | )(x) 32 | x = eg.BatchNorm()(x) 33 | x = eg.Dropout(0.2)(x) 34 | return jax.nn.relu(x) 35 | 36 | 37 | class CNN(eg.Module): 38 | @eg.compact 39 | def __call__(self, x: jnp.ndarray): 40 | # normalize 41 | x = x.astype(jnp.float32) / 255.0 42 | 43 | # base 44 | x = ConvBlock()(x, 32, (3, 3)) 45 | x = ConvBlock()(x, 64, (3, 3), stride=2) 46 | x = ConvBlock()(x, 64, (3, 3), stride=2) 47 | x = ConvBlock()(x, 128, (3, 3), stride=2) 48 | 49 | # GlobalAveragePooling2D 50 | x = jnp.mean(x, axis=(1, 2)) 51 | 52 | # 1x1 Conv 53 | x = eg.Linear(10)(x) 54 | 55 | return x 56 | 57 | 58 | def main( 59 | debug: bool = False, 60 | eager: bool = False, 61 | logdir: str = "runs", 62 | steps_per_epoch: int = 200, 63 | epochs: int = 100, 64 | batch_size: int = 64, 65 | ): 66 | 67 | if debug: 68 | import debugpy 69 | 70 | print("Waiting for debugger...") 71 | debugpy.listen(5678) 72 | debugpy.wait_for_client() 73 | 74 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 75 | logdir = os.path.join(logdir, current_time) 76 | 77 | dataset = load_dataset("mnist") 78 | dataset.set_format("np") 79 | X_train = np.stack(dataset["train"]["image"])[..., None] 80 | y_train = dataset["train"]["label"] 81 | X_test = np.stack(dataset["test"]["image"])[..., None] 82 | y_test = dataset["test"]["label"] 83 | 84 | print("X_train:", X_train.shape, X_train.dtype) 85 | print("y_train:", y_train.shape, y_train.dtype) 86 | print("X_test:", X_test.shape, X_test.dtype) 87 | print("y_test:", y_test.shape, y_test.dtype) 88 | 89 | model = eg.Model( 90 | module=CNN(), 91 | loss=eg.losses.Crossentropy(), 92 | metrics=eg.metrics.Accuracy(), 93 | optimizer=optax.adam(1e-3), 94 | eager=eager, 95 | ) 96 | 97 | # show summary 98 | model.summary(X_train[:64]) 99 | 100 | train_dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)) 101 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 102 | test_dataset = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test)) 103 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size) 104 | 105 | history = model.fit( 106 | train_dataloader, 107 | epochs=epochs, 108 | steps_per_epoch=steps_per_epoch, 109 | validation_data=test_dataloader, 110 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 111 | ) 112 | 113 | eg.utils.plot_history(history) 114 | 115 | model.save("models/conv") 116 | 117 | model = eg.load("models/conv") 118 | 119 | print(model.evaluate(x=X_test, y=y_test)) 120 | 121 | # get random samples 122 | idxs = np.random.randint(0, 10000, size=(9,)) 123 | x_sample = X_test[idxs] 124 | 125 | # get predictions 126 | y_pred = model.predict(x=x_sample) 127 | 128 | # plot results 129 | with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: 130 | figure = plt.figure(figsize=(12, 12)) 131 | for i in range(3): 132 | for j in range(3): 133 | k = 3 * i + j 134 | plt.subplot(3, 3, k + 1) 135 | 136 | plt.title(f"{np.argmax(y_pred[k])}") 137 | plt.imshow(x_sample[k], cmap="gray") 138 | # tbwriter.add_figure("Conv classifier", figure, 100) 139 | 140 | plt.show() 141 | 142 | 143 | if __name__ == "__main__": 144 | typer.run(main) 145 | -------------------------------------------------------------------------------- /examples/elegy/toy_mlp.py: -------------------------------------------------------------------------------- 1 | # isort:skip_file 2 | # fmt: off 3 | import jax, optax 4 | import numpy as np 5 | import elegy as eg 6 | import treex as tx 7 | 8 | 9 | # 1. create some data 10 | x = np.random.uniform(-1, 1, size=(100, 1)) 11 | y = 1.3 * x ** 2 - 0.3 + 0.1 * np.random.normal(size=x.shape) 12 | 13 | 14 | 15 | # 2. define the architecture 16 | class MLP(tx.Module): 17 | @eg.compact 18 | def __call__(self, x): 19 | x = tx.Linear(64)(x) 20 | x = jax.nn.relu(x) 21 | x = tx.Linear(1)(x) 22 | return x 23 | 24 | 25 | 26 | # 3. create the Model 27 | model = eg.Model( 28 | module=MLP(), 29 | loss=[ 30 | eg.losses.MeanSquaredError(), 31 | eg.regularizers.L2(0.001), 32 | ], 33 | optimizer=optax.adam(1e-2), 34 | ) 35 | 36 | 37 | 38 | # 4. train the model 39 | model.fit( 40 | inputs=x, 41 | labels=y, 42 | epochs=100, 43 | callbacks=[eg.callbacks.TensorBoard("models/mlp/treex")], 44 | ) 45 | 46 | 47 | 48 | # 5. visualize solution 49 | import matplotlib.pyplot as plt 50 | 51 | X_test = np.linspace(x.min(), x.max(), 100).reshape(-1, 1) 52 | y_pred = model.predict(X_test) 53 | 54 | plt.scatter(x, y) 55 | plt.plot(X_test, y_pred) 56 | plt.show() 57 | 58 | 59 | # 6. save the model 60 | model.save("models/mlp/treex/model") 61 | model.saved_model(x, "models/mlp/treex/saved_model") 62 | -------------------------------------------------------------------------------- /examples/flax/mnist_conv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as tp 3 | from datetime import datetime 4 | 5 | import datasets 6 | import jax 7 | import jax.numpy as jnp 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import optax 11 | import typer 12 | from datasets import load_dataset 13 | from flax import linen 14 | 15 | import elegy as eg 16 | 17 | 18 | class CNN(linen.module.Module): 19 | @linen.compact 20 | def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray: 21 | # Normalize the input 22 | x = x.astype(jnp.float32) / 255.0 23 | 24 | # Block 1 25 | x = linen.Conv(32, [3, 3], strides=[2, 2])(x) 26 | x = linen.Dropout(0.05, deterministic=not training)(x) 27 | x = jax.nn.relu(x) 28 | 29 | # Block 2 30 | x = linen.Conv(64, [3, 3], strides=[2, 2])(x) 31 | x = linen.BatchNorm(use_running_average=not training)(x) 32 | x = linen.Dropout(0.1, deterministic=not training)(x) 33 | x = jax.nn.relu(x) 34 | 35 | # Block 3 36 | x = linen.Conv(128, [3, 3], strides=[2, 2])(x) 37 | 38 | # Global average pooling 39 | x = x.mean(axis=(1, 2)) 40 | 41 | # Classification layer 42 | x = linen.Dense(10)(x) 43 | 44 | return x 45 | 46 | 47 | def main( 48 | debug: bool = False, 49 | eager: bool = False, 50 | logdir: str = "runs", 51 | steps_per_epoch: tp.Optional[int] = None, 52 | epochs: int = 100, 53 | batch_size: int = 32, 54 | distributed: bool = False, 55 | ): 56 | 57 | if debug: 58 | import debugpy 59 | 60 | print("Waiting for debugger...") 61 | debugpy.listen(5678) 62 | debugpy.wait_for_client() 63 | 64 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 65 | logdir = os.path.join(logdir, current_time) 66 | 67 | dataset = load_dataset("mnist") 68 | dataset.set_format("np") 69 | X_train = np.stack(dataset["train"]["image"])[..., None] 70 | y_train = dataset["train"]["label"] 71 | X_test = np.stack(dataset["test"]["image"])[..., None] 72 | y_test = dataset["test"]["label"] 73 | 74 | print("X_train:", X_train.shape, X_train.dtype) 75 | print("y_train:", y_train.shape, y_train.dtype) 76 | print("X_test:", X_test.shape, X_test.dtype) 77 | print("y_test:", y_test.shape, y_test.dtype) 78 | 79 | model = eg.Model( 80 | module=CNN(), 81 | loss=eg.losses.Crossentropy(), 82 | metrics=eg.metrics.Accuracy(), 83 | optimizer=optax.adam(1e-3), 84 | eager=eager, 85 | ) 86 | 87 | if distributed: 88 | model = model.distributed() 89 | 90 | model.summary(X_train[:batch_size]) 91 | 92 | history = model.fit( 93 | inputs=X_train, 94 | labels=y_train, 95 | epochs=epochs, 96 | steps_per_epoch=steps_per_epoch, 97 | batch_size=batch_size, 98 | validation_data=(X_test, y_test), 99 | shuffle=True, 100 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 101 | ) 102 | 103 | eg.utils.plot_history(history) 104 | 105 | print(model.evaluate(x=X_test, y=y_test)) 106 | 107 | # get random samples 108 | idxs = np.random.randint(0, 10000, size=(9,)) 109 | x_sample = X_test[idxs] 110 | 111 | # get predictions 112 | model = model.local() 113 | y_pred = model.predict(x=x_sample) 114 | 115 | # plot results 116 | figure = plt.figure(figsize=(12, 12)) 117 | for i in range(3): 118 | for j in range(3): 119 | k = 3 * i + j 120 | plt.subplot(3, 3, k + 1) 121 | 122 | plt.title(f"{np.argmax(y_pred[k])}") 123 | plt.imshow(x_sample[k], cmap="gray") 124 | 125 | plt.show() 126 | 127 | 128 | if __name__ == "__main__": 129 | typer.run(main) 130 | -------------------------------------------------------------------------------- /examples/flax/mnist_vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as tp 3 | from dataclasses import dataclass 4 | from datetime import datetime 5 | from typing import Any, Generator, Mapping, Tuple 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import optax 12 | import typer 13 | from datasets import load_dataset 14 | from flax import linen as nn 15 | from jax._src.numpy.lax_numpy import ndarray 16 | from tensorboardX.writer import SummaryWriter 17 | 18 | import elegy as eg 19 | 20 | # TODO: fix, not learning on par with the elegy version 21 | 22 | Batch = Mapping[str, jnp.ndarray] 23 | np.random.seed(42) 24 | 25 | LATENT_SIZE = 32 26 | MNIST_IMAGE_SHAPE: tp.Sequence[int] = (28, 28) 27 | 28 | 29 | @dataclass 30 | class Encoder(nn.Module): 31 | """Encoder model.""" 32 | 33 | hidden_size: int = 512 34 | latent_size: int = 128 35 | 36 | @nn.compact 37 | def __call__( 38 | self, x: jnp.ndarray 39 | ) -> tp.Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 40 | x = x.reshape((x.shape[0], -1)) # flatten 41 | x = nn.Dense(self.hidden_size)(x) 42 | x = jax.nn.relu(x) 43 | 44 | mean = nn.Dense(self.latent_size, name="linear_mean")(x) 45 | log_stddev = nn.Dense(self.latent_size, name="linear_std")(x) 46 | stddev = jnp.exp(log_stddev) 47 | 48 | key = self.make_rng("dropout") 49 | z = mean + stddev * jax.random.normal(key, mean.shape) 50 | 51 | return z, mean, stddev 52 | 53 | 54 | @dataclass 55 | class Decoder(nn.Module): 56 | """Decoder model.""" 57 | 58 | hidden_size: int = 512 59 | output_shape: tp.Sequence[int] = MNIST_IMAGE_SHAPE 60 | 61 | @nn.compact 62 | def __call__(self, z: jnp.ndarray) -> jnp.ndarray: 63 | z = nn.Dense(self.hidden_size)(z) 64 | z = jax.nn.relu(z) 65 | 66 | logits = nn.Dense(np.prod(self.output_shape))(z) 67 | logits = jnp.reshape(logits, (-1, *self.output_shape)) 68 | 69 | return logits 70 | 71 | 72 | @dataclass 73 | class VAE(nn.Module): 74 | hidden_size: int = 512 75 | latent_size: int = 512 76 | output_shape: tp.Sequence[int] = MNIST_IMAGE_SHAPE 77 | 78 | @nn.compact 79 | def __call__(self, x): 80 | z, mean, std = Encoder( 81 | hidden_size=self.hidden_size, latent_size=self.latent_size 82 | )(x) 83 | logits = Decoder(hidden_size=self.hidden_size, output_shape=self.output_shape)( 84 | z 85 | ) 86 | return dict(logits=logits, mean=mean, std=std) 87 | 88 | def generate(self, z): 89 | return nn.sigmoid(self.decoder(z)) 90 | 91 | 92 | class KL(eg.Loss): 93 | def call(self, preds) -> jnp.ndarray: 94 | mean = preds["mean"] 95 | std = preds["std"] 96 | 97 | return 0.5 * jnp.mean(-jnp.log(std**2) - 1.0 + std**2 + mean**2, axis=-1) 98 | 99 | 100 | class BinaryCrossEntropy(eg.losses.Crossentropy): 101 | def __init__(self, **kwargs): 102 | super().__init__(binary=True, **kwargs) 103 | 104 | def call(self, inputs: jnp.ndarray, preds: jnp.ndarray) -> jnp.ndarray: 105 | return super().call(target=inputs, preds=preds) 106 | 107 | 108 | def main( 109 | steps_per_epoch: int = 200, 110 | batch_size: int = 64, 111 | epochs: int = 50, 112 | debug: bool = False, 113 | eager: bool = False, 114 | logdir: str = "runs", 115 | ): 116 | 117 | if debug: 118 | import debugpy 119 | 120 | print("Waiting for debugger...") 121 | debugpy.listen(5678) 122 | debugpy.wait_for_client() 123 | 124 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 125 | logdir = os.path.join(logdir, current_time) 126 | 127 | dataset = load_dataset("mnist") 128 | X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8) 129 | X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8) 130 | # Now binarize data 131 | X_train = (X_train > 0).astype(jnp.float32) 132 | X_test = (X_test > 0).astype(jnp.float32) 133 | 134 | print("X_train:", X_train.shape, X_train.dtype) 135 | print("X_test:", X_test.shape, X_test.dtype) 136 | 137 | model = eg.Model( 138 | module=VAE(latent_size=LATENT_SIZE), 139 | loss=[ 140 | BinaryCrossEntropy(on="logits"), 141 | KL(weight=0.1), 142 | ], 143 | optimizer=optax.adam(1e-3), 144 | eager=eager, 145 | ) 146 | 147 | model.summary(X_train[:batch_size]) 148 | 149 | # Fit with datasets in memory 150 | history = model.fit( 151 | inputs=X_train, 152 | epochs=epochs, 153 | batch_size=batch_size, 154 | steps_per_epoch=steps_per_epoch, 155 | validation_data=(X_test,), 156 | shuffle=True, 157 | callbacks=[eg.callbacks.TensorBoard(logdir)], 158 | ) 159 | 160 | print( 161 | "\n\n\nMetrics and images can be explored using tensorboard using:", 162 | f"\n \t\t\t tensorboard --logdir {logdir}", 163 | ) 164 | 165 | eg.utils.plot_history(history) 166 | 167 | # get random samples 168 | idxs = np.random.randint(0, len(X_test), size=(5,)) 169 | x_sample = X_test[idxs] 170 | 171 | # get predictions 172 | preds = model.predict(x=x_sample) 173 | y_pred = jax.nn.sigmoid(preds["logits"]) 174 | 175 | # plot and save results 176 | with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: 177 | figure = plt.figure(figsize=(12, 12)) 178 | for i in range(5): 179 | plt.subplot(2, 5, i + 1) 180 | plt.imshow(x_sample[i], cmap="gray") 181 | plt.subplot(2, 5, 5 + i + 1) 182 | plt.imshow(y_pred[i], cmap="gray") 183 | # # tbwriter.add_figure("VAE Example", figure, epochs) 184 | 185 | plt.show() 186 | 187 | # TODO: implement parameter transfer to sample 188 | # sample 189 | # model_decoder = eg.Model(Decoder(latent_size=LATENT_SIZE)) 190 | 191 | # z_samples = np.random.normal(size=(12, LATENT_SIZE)) 192 | # samples = model_decoder.predict(z_samples) 193 | # samples = jax.nn.sigmoid(samples) 194 | 195 | # # plot and save results 196 | # # with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: 197 | # figure = plt.figure(figsize=(5, 12)) 198 | # plt.title("Generative Samples") 199 | # for i in range(5): 200 | # plt.subplot(2, 5, 2 * i + 1) 201 | # plt.imshow(samples[i], cmap="gray") 202 | # plt.subplot(2, 5, 2 * i + 2) 203 | # plt.imshow(samples[i + 1], cmap="gray") 204 | # # # tbwriter.add_figure("VAE Generative Example", figure, epochs) 205 | 206 | # plt.show() 207 | 208 | 209 | if __name__ == "__main__": 210 | typer.run(main) 211 | -------------------------------------------------------------------------------- /examples/flax/toy_mlp.py: -------------------------------------------------------------------------------- 1 | # isort:skip_file 2 | # fmt: off 3 | import jax, optax 4 | import numpy as np 5 | import elegy as eg 6 | import flax.linen as nn 7 | 8 | 9 | 10 | # 1. create some data 11 | x = np.random.uniform(-1, 1, size=(100, 1)) 12 | y = 1.3 * x ** 2 - 0.3 + 0.1 * np.random.normal(size=x.shape) 13 | 14 | 15 | 16 | # 2. define the architecture 17 | class MLP(nn.Module): 18 | @nn.compact 19 | def __call__(self, x): 20 | x = nn.Dense(64)(x) 21 | x = jax.nn.relu(x) 22 | x = nn.Dense(1)(x) 23 | return x 24 | 25 | 26 | 27 | # 3. create the Model 28 | model = eg.Model( 29 | module=MLP(), 30 | loss=[ 31 | eg.losses.MeanSquaredError(), 32 | eg.regularizers.L2(0.001), 33 | ], 34 | optimizer=optax.adam(1e-2), 35 | ) 36 | 37 | 38 | 39 | # 4. train the model 40 | model.fit( 41 | inputs=x, 42 | labels=y, 43 | epochs=100, 44 | callbacks=[eg.callbacks.TensorBoard("models/mlp/flax")], 45 | ) 46 | 47 | 48 | 49 | # 5. visualize solution 50 | import matplotlib.pyplot as plt 51 | 52 | X_test = np.linspace(x.min(), x.max(), 100).reshape(-1, 1) 53 | y_pred = model.predict(X_test) 54 | 55 | plt.scatter(x, y) 56 | plt.plot(X_test, y_pred) 57 | plt.show() 58 | 59 | 60 | 61 | # 6. save the model 62 | model.save("models/mlp/flax/model") 63 | model.saved_model(x, "models/mlp/flax/saved_model") 64 | -------------------------------------------------------------------------------- /examples/haiku/mnist_conv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as tp 3 | from datetime import datetime 4 | 5 | import datasets 6 | import haiku as hk 7 | import jax 8 | import jax.numpy as jnp 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import optax 12 | import typer 13 | from datasets import load_dataset 14 | 15 | import elegy as eg 16 | 17 | 18 | def forward(x: jnp.ndarray, training: bool): 19 | # Normalize input 20 | x = x.astype(jnp.float32) / 255.0 21 | 22 | # Block 1 23 | x = hk.Conv2D(32, [3, 3], stride=[2, 2])(x) 24 | x = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.99)( 25 | x, is_training=training 26 | ) 27 | x = hk.dropout(hk.next_rng_key(), 0.05, x) 28 | x = jax.nn.relu(x) 29 | 30 | # Block 2 31 | x = hk.Conv2D(64, [3, 3], stride=[2, 2])(x) 32 | x = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.99)( 33 | x, is_training=training 34 | ) 35 | x = hk.dropout(hk.next_rng_key(), 0.1, x) 36 | x = jax.nn.relu(x) 37 | 38 | # Block 3 39 | x = hk.Conv2D(128, [3, 3], stride=[2, 2])(x) 40 | 41 | # GlobalAveragePooling2D 42 | x = x.mean(axis=(1, 2)) 43 | 44 | # Classification layer 45 | x = hk.Linear(10)(x) 46 | 47 | return x 48 | 49 | 50 | def main( 51 | debug: bool = False, 52 | eager: bool = False, 53 | logdir: str = "runs", 54 | steps_per_epoch: tp.Optional[int] = None, 55 | epochs: int = 100, 56 | batch_size: int = 32, 57 | distributed: bool = False, 58 | ): 59 | 60 | if debug: 61 | import debugpy 62 | 63 | print("Waiting for debugger...") 64 | debugpy.listen(5678) 65 | debugpy.wait_for_client() 66 | 67 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 68 | logdir = os.path.join(logdir, current_time) 69 | 70 | dataset = load_dataset("mnist") 71 | dataset.set_format("np") 72 | X_train = np.stack(dataset["train"]["image"])[..., None] 73 | y_train = dataset["train"]["label"] 74 | X_test = np.stack(dataset["test"]["image"])[..., None] 75 | y_test = dataset["test"]["label"] 76 | 77 | print("X_train:", X_train.shape, X_train.dtype) 78 | print("y_train:", y_train.shape, y_train.dtype) 79 | print("X_test:", X_test.shape, X_test.dtype) 80 | print("y_test:", y_test.shape, y_test.dtype) 81 | 82 | model = eg.Model( 83 | module=hk.transform_with_state(forward), 84 | loss=eg.losses.Crossentropy(), 85 | metrics=eg.metrics.Accuracy(), 86 | optimizer=optax.adam(1e-3), 87 | eager=eager, 88 | ) 89 | 90 | if distributed: 91 | model = model.distributed() 92 | 93 | model.summary(X_train[:batch_size]) 94 | 95 | history = model.fit( 96 | inputs=X_train, 97 | labels=y_train, 98 | epochs=epochs, 99 | steps_per_epoch=steps_per_epoch, 100 | batch_size=batch_size, 101 | validation_data=(X_test, y_test), 102 | shuffle=True, 103 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 104 | ) 105 | 106 | eg.utils.plot_history(history) 107 | 108 | print(model.evaluate(x=X_test, y=y_test)) 109 | 110 | # get random samples 111 | idxs = np.random.randint(0, 10000, size=(9,)) 112 | x_sample = X_test[idxs] 113 | 114 | # get predictions 115 | model = model.local() 116 | y_pred = model.predict(x=x_sample) 117 | 118 | # plot results 119 | figure = plt.figure(figsize=(12, 12)) 120 | for i in range(3): 121 | for j in range(3): 122 | k = 3 * i + j 123 | plt.subplot(3, 3, k + 1) 124 | 125 | plt.title(f"{np.argmax(y_pred[k])}") 126 | plt.imshow(x_sample[k], cmap="gray") 127 | 128 | plt.show() 129 | 130 | 131 | if __name__ == "__main__": 132 | typer.run(main) 133 | -------------------------------------------------------------------------------- /examples/haiku/toy_mlp.py: -------------------------------------------------------------------------------- 1 | # isort:skip_file 2 | # fmt: off 3 | import jax, optax 4 | import numpy as np 5 | import elegy as eg 6 | import haiku as hk 7 | 8 | 9 | 10 | # 1. create some data 11 | x = np.random.uniform(-1, 1, size=(100, 1)) 12 | y = 1.3 * x ** 2 - 0.3 + 0.1 * np.random.normal(size=x.shape) 13 | 14 | 15 | 16 | # 2. define the architecture 17 | def forward(x): 18 | x = hk.Linear(64)(x) 19 | x = jax.nn.relu(x) 20 | x = hk.Linear(1)(x) 21 | return x 22 | 23 | 24 | 25 | # 3. create the Model 26 | model = eg.Model( 27 | module=hk.transform_with_state(forward), 28 | loss=[ 29 | eg.losses.MeanSquaredError(), 30 | eg.regularizers.L2(0.001), 31 | ], 32 | optimizer=optax.adam(1e-2), 33 | ) 34 | 35 | 36 | 37 | # 4. train the model 38 | model.fit( 39 | inputs=x, 40 | labels=y, 41 | epochs=100, 42 | callbacks=[eg.callbacks.TensorBoard("models/mlp/haiku")], 43 | ) 44 | 45 | 46 | 47 | # 5. visualize solution 48 | import matplotlib.pyplot as plt 49 | 50 | X_test = np.linspace(x.min(), x.max(), 100).reshape(-1, 1) 51 | y_pred = model.predict(X_test) 52 | 53 | plt.scatter(x, y) 54 | plt.plot(X_test, y_pred) 55 | plt.show() 56 | 57 | 58 | 59 | # 6. save the model 60 | model.save("models/mlp/haiku/model") 61 | model.saved_model(x, "models/mlp/haiku/saved_model") 62 | -------------------------------------------------------------------------------- /examples/jax/linear_classifier_test_step.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as tp 3 | from datetime import datetime 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import optax 9 | import typer 10 | from datasets.load import load_dataset 11 | 12 | import elegy as eg 13 | 14 | M = tp.TypeVar("M", bound="Model") 15 | 16 | 17 | class Model(eg.Model): 18 | w: jnp.ndarray = eg.Parameter.node() 19 | b: jnp.ndarray = eg.Parameter.node() 20 | 21 | def __init__( 22 | self, 23 | features_out: int, 24 | loss: tp.Any = None, 25 | metrics: tp.Any = None, 26 | optimizer=None, 27 | seed: int = 42, 28 | eager: bool = False, 29 | ): 30 | self.features_out = features_out 31 | super().__init__( 32 | loss=loss, 33 | metrics=metrics, 34 | optimizer=optimizer, 35 | seed=seed, 36 | eager=eager, 37 | ) 38 | 39 | def init_step(self: M, key: jnp.ndarray, inputs: jnp.ndarray) -> M: 40 | features_in = np.prod(inputs.shape[1:]) 41 | 42 | self.w = jax.random.uniform( 43 | key, 44 | shape=[ 45 | features_in, 46 | self.features_out, 47 | ], 48 | ) 49 | self.b = jnp.zeros([self.features_out]) 50 | 51 | assert self.optimizer is not None 52 | self.optimizer = self.optimizer.init(self) 53 | 54 | return self 55 | 56 | def pred_step(self: M, inputs: tp.Any) -> eg.PredStepOutput[M]: 57 | logits = jnp.dot(inputs, self.w) + self.b 58 | return logits, self 59 | 60 | def test_step( 61 | self, 62 | inputs, 63 | labels, 64 | ): 65 | model = self 66 | # flatten + scale 67 | inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255 68 | 69 | # forward 70 | logits, model = model.pred_step(inputs) 71 | 72 | # crossentropy loss 73 | target = jax.nn.one_hot(labels["target"], self.features_out) 74 | loss = optax.softmax_cross_entropy(logits, target).mean() 75 | 76 | # metrics 77 | logs = dict( 78 | acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]), 79 | loss=loss, 80 | ) 81 | 82 | return loss, logs, model 83 | 84 | 85 | def main( 86 | debug: bool = False, 87 | eager: bool = False, 88 | logdir: str = "runs", 89 | steps_per_epoch: int = 200, 90 | batch_size: int = 64, 91 | epochs: int = 100, 92 | ): 93 | 94 | if debug: 95 | import debugpy 96 | 97 | print("Waiting for debugger...") 98 | debugpy.listen(5678) 99 | debugpy.wait_for_client() 100 | 101 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 102 | logdir = os.path.join(logdir, current_time) 103 | 104 | dataset = load_dataset("mnist") 105 | dataset.set_format("np") 106 | X_train = np.stack(dataset["train"]["image"]) 107 | y_train = dataset["train"]["label"] 108 | X_test = np.stack(dataset["test"]["image"]) 109 | y_test = dataset["test"]["label"] 110 | 111 | print("X_train:", X_train.shape, X_train.dtype) 112 | print("y_train:", y_train.shape, y_train.dtype) 113 | print("X_test:", X_test.shape, X_test.dtype) 114 | print("y_test:", y_test.shape, y_test.dtype) 115 | 116 | model = Model( 117 | features_out=10, 118 | optimizer=optax.adam(1e-3), 119 | eager=eager, 120 | ) 121 | 122 | history = model.fit( 123 | inputs=X_train, 124 | labels=y_train, 125 | epochs=epochs, 126 | steps_per_epoch=steps_per_epoch, 127 | batch_size=batch_size, 128 | validation_data=(X_test, y_test), 129 | shuffle=True, 130 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 131 | ) 132 | 133 | eg.utils.plot_history(history) 134 | 135 | 136 | if __name__ == "__main__": 137 | typer.run(main) 138 | -------------------------------------------------------------------------------- /examples/jax/linear_classifier_train_step.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as tp 3 | from datetime import datetime 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import optax 9 | import typer 10 | from datasets.load import load_dataset 11 | 12 | import elegy as eg 13 | 14 | M = tp.TypeVar("M", bound="Model") 15 | 16 | 17 | class Model(eg.Model): 18 | w: jnp.ndarray = eg.Parameter.node() 19 | b: jnp.ndarray = eg.Parameter.node() 20 | 21 | def __init__( 22 | self, 23 | features_out: int, 24 | loss: tp.Any = None, 25 | metrics: tp.Any = None, 26 | optimizer=None, 27 | seed: int = 42, 28 | eager: bool = False, 29 | ): 30 | self.features_out = features_out 31 | super().__init__( 32 | module=None, 33 | loss=loss, 34 | metrics=metrics, 35 | optimizer=optimizer, 36 | seed=seed, 37 | eager=eager, 38 | ) 39 | 40 | def init_step(self: M, key: jnp.ndarray, inputs: jnp.ndarray) -> M: 41 | features_in = np.prod(inputs.shape[1:]) 42 | 43 | self.w = jax.random.uniform( 44 | key, 45 | shape=[ 46 | features_in, 47 | self.features_out, 48 | ], 49 | ) 50 | self.b = jnp.zeros([self.features_out]) 51 | 52 | assert self.optimizer is not None 53 | self.optimizer = self.optimizer.init(self) 54 | 55 | return self 56 | 57 | def pred_step(self: M, inputs: tp.Any) -> eg.PredStepOutput[M]: 58 | logits = jnp.dot(inputs, self.w) + self.b 59 | return logits, self 60 | 61 | def test_step( 62 | self: M, 63 | inputs, 64 | labels, 65 | ) -> eg.TestStepOutput[M]: 66 | model: M = self 67 | # flatten + scale 68 | inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255 69 | 70 | # forward 71 | logits, model = model.pred_step(inputs) 72 | 73 | # crossentropy loss 74 | target = jax.nn.one_hot(labels["target"], self.features_out) 75 | loss = optax.softmax_cross_entropy(logits, target).mean() 76 | 77 | # metrics 78 | logs = dict( 79 | acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]), 80 | loss=loss, 81 | ) 82 | 83 | return loss, logs, model 84 | 85 | @staticmethod 86 | def loss_fn(params: M, model: M, inputs, labels) -> eg.LossStepOutput[M]: 87 | model = model.merge(params) 88 | loss, logs, model = model.test_step(inputs, labels) 89 | return loss, (logs, model) 90 | 91 | def train_step(self: M, inputs, labels) -> eg.TrainStepOutput[M]: 92 | model: M = self 93 | 94 | params = model.parameters() 95 | # train 96 | grads, (logs, model) = jax.grad(Model.loss_fn, has_aux=True)( 97 | params, 98 | model, 99 | inputs, 100 | labels, 101 | ) 102 | 103 | assert model.optimizer is not None 104 | 105 | params = model.optimizer.update(grads, params) 106 | model = model.merge(params) 107 | 108 | return logs, model 109 | 110 | 111 | def main( 112 | debug: bool = False, 113 | eager: bool = False, 114 | logdir: str = "runs", 115 | steps_per_epoch: int = 200, 116 | epochs: int = 100, 117 | batch_size: int = 64, 118 | ): 119 | 120 | if debug: 121 | import debugpy 122 | 123 | print("Waiting for debugger...") 124 | debugpy.listen(5678) 125 | debugpy.wait_for_client() 126 | 127 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 128 | logdir = os.path.join(logdir, current_time) 129 | 130 | dataset = load_dataset("mnist") 131 | dataset.set_format("np") 132 | X_train = np.stack(dataset["train"]["image"]) 133 | y_train = dataset["train"]["label"] 134 | X_test = np.stack(dataset["test"]["image"]) 135 | y_test = dataset["test"]["label"] 136 | 137 | print("X_train:", X_train.shape, X_train.dtype) 138 | print("y_train:", y_train.shape, y_train.dtype) 139 | print("X_test:", X_test.shape, X_test.dtype) 140 | print("y_test:", y_test.shape, y_test.dtype) 141 | 142 | model = Model( 143 | features_out=10, 144 | optimizer=optax.adam(1e-3), 145 | eager=eager, 146 | ) 147 | 148 | history = model.fit( 149 | inputs=X_train, 150 | labels=y_train, 151 | epochs=epochs, 152 | steps_per_epoch=steps_per_epoch, 153 | batch_size=batch_size, 154 | validation_data=(X_test, y_test), 155 | shuffle=True, 156 | callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], 157 | ) 158 | 159 | eg.utils.plot_history(history) 160 | 161 | 162 | if __name__ == "__main__": 163 | typer.run(main) 164 | -------------------------------------------------------------------------------- /examples/need-fixing/WGAN-GP/README.md: -------------------------------------------------------------------------------- 1 | ## Using Elegy low-level API to train WGAN-GP on the CelebA dataset 2 | 3 | 4 | *** 5 | ### Usage 6 | ``` 7 | main.py --dataset=path/to/celeb_a/*.png --output_dir=<./output/path> [flags] 8 | 9 | 10 | flags: 11 | --dataset: Search path to the dataset images e.g: path/to/*.png 12 | --output_dir: Directory to save model checkpoints and tensorboard log data 13 | 14 | --batch_size: Input batch size (default: '64') 15 | --epochs: Number of epochs to train (default: '100') 16 | ``` 17 | 18 | *** 19 | ### Examples of generated images: 20 | 21 | After 10 epochs: ![Example of generated images after 10 epochs](images/epoch-0009.png) 22 | 23 | After 50 epochs: ![Example of generated images after 10 epochs](images/epoch-0049.png) 24 | 25 | After 100 epochs: ![Example of generated images after 10 epochs](images/epoch-0099.png) 26 | 27 | 28 | *** 29 | [1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017. 30 | 31 | [2] Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017). 32 | 33 | [3] Liu, Ziwei, et al. "Large-scale celebfaces attributes (celeba) dataset." Retrieved August 15.2018 (2018): 11. 34 | -------------------------------------------------------------------------------- /examples/need-fixing/WGAN-GP/images/epoch-0009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poets-ai/elegy/4709ce8dc9dde3925ce717e2358ce49112e36398/examples/need-fixing/WGAN-GP/images/epoch-0009.png -------------------------------------------------------------------------------- /examples/need-fixing/WGAN-GP/images/epoch-0049.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poets-ai/elegy/4709ce8dc9dde3925ce717e2358ce49112e36398/examples/need-fixing/WGAN-GP/images/epoch-0049.png -------------------------------------------------------------------------------- /examples/need-fixing/WGAN-GP/images/epoch-0099.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poets-ai/elegy/4709ce8dc9dde3925ce717e2358ce49112e36398/examples/need-fixing/WGAN-GP/images/epoch-0099.png -------------------------------------------------------------------------------- /examples/need-fixing/WGAN-GP/main.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import PIL.Image 7 | from absl import app, flags 8 | from model import WGAN_GP 9 | 10 | import elegy 11 | 12 | FLAGS = flags.FLAGS 13 | 14 | flags.DEFINE_string( 15 | "output_dir", 16 | default=None, 17 | help="Directory to save model checkpoints and example generated images", 18 | ) 19 | flags.DEFINE_integer("epochs", default=100, help="Number of epochs to train") 20 | flags.DEFINE_integer("batch_size", default=64, help="Input batch size") 21 | 22 | flags.DEFINE_string( 23 | "dataset", default=None, help="Search path to the dataset images e.g: path/to/*.png" 24 | ) 25 | 26 | flags.mark_flag_as_required("dataset") 27 | flags.mark_flag_as_required("output_dir") 28 | 29 | 30 | class Dataset(elegy.data.Dataset): 31 | def __init__(self, path): 32 | self.files = glob.glob(os.path.expanduser(path)) 33 | if len(self.files) == 0: 34 | raise RuntimeError(f'Could not find any files in path "{path}"') 35 | print(f"Found {len(self.files)} files") 36 | 37 | def __len__(self): 38 | return len(self.files) 39 | 40 | def __getitem__(self, i): 41 | f = self.files[i] 42 | img = np.array(PIL.Image.open(f).resize((64, 64))) / np.float32(255) 43 | img = np.fliplr(img) if np.random.random() < 0.5 else img 44 | return img 45 | 46 | 47 | class SaveImagesCallback(elegy.callbacks.Callback): 48 | def __init__(self, model, path): 49 | self.model = model 50 | self.path = path 51 | 52 | def on_epoch_end(self, epoch, *args, **kwargs): 53 | x = self.model.predict(np.random.normal(size=[8, 128])) 54 | x = np.concatenate(list(x * 255), axis=1).astype(np.uint8) 55 | img = PIL.Image.fromarray(x) 56 | img.save(os.path.join(self.path, f"epoch-{epoch:04d}.png")) 57 | 58 | 59 | def main(argv): 60 | assert ( 61 | len(argv) == 1 62 | ), "Please specify arguments via flags. Use --help for instructions" 63 | 64 | assert not os.path.exists( 65 | FLAGS.output_dir 66 | ), "Output directory already exists. Delete manually or specify a new one." 67 | os.makedirs(FLAGS.output_dir) 68 | 69 | ds = Dataset(FLAGS.dataset) 70 | loader = elegy.data.DataLoader( 71 | ds, batch_size=FLAGS.batch_size, n_workers=os.cpu_count(), worker_type="process" 72 | ) 73 | 74 | wgan = WGAN_GP() 75 | wgan.init(np.zeros([8, 128])) 76 | 77 | wgan.fit( 78 | loader, 79 | epochs=FLAGS.epochs, 80 | verbose=4, 81 | callbacks=[ 82 | SaveImagesCallback(wgan, FLAGS.output_dir), 83 | elegy.callbacks.ModelCheckpoint(FLAGS.output_dir), 84 | ], 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /examples/need-fixing/WGAN-GP/model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import optax 4 | 5 | import elegy 6 | 7 | 8 | # the generator architecture adapted from DCGAN 9 | class Generator(elegy.Module): 10 | def call(self, z): 11 | assert len(z.shape) == 2 12 | x = elegy.nn.Reshape([1, 1, z.shape[-1]])(z) 13 | for i, c in enumerate([1024, 512, 256, 128]): 14 | padding = "VALID" if i == 0 else "SAME" 15 | x = elegy.nn.conv.ConvTranspose(c, (4, 4), stride=(2, 2), padding=padding)( 16 | x 17 | ) 18 | x = elegy.nn.BatchNorm(decay_rate=0.9)(x) 19 | x = jax.nn.leaky_relu(x, negative_slope=0.2) 20 | x = elegy.nn.conv.ConvTranspose(3, (4, 4), stride=(2, 2))(x) 21 | x = jax.nn.sigmoid(x) 22 | return x 23 | 24 | 25 | # the discriminator architecture adapted from DCGAN 26 | # also called 'critic' in the WGAN paper 27 | class Discriminator(elegy.Module): 28 | def __call__(self, x): 29 | for c in [128, 256, 512, 1024]: 30 | x = elegy.nn.conv.Conv(c, (4, 4), stride=(2, 2))(x) 31 | x = jax.nn.leaky_relu(x, negative_slope=0.2) 32 | x = elegy.nn.Flatten()(x) 33 | x = elegy.nn.Linear(1)(x) 34 | return x 35 | 36 | 37 | # multiplier for gradient normalization 38 | LAMBDA_GP = 10 39 | 40 | # gradient regularization term 41 | def gradient_penalty(x_real, x_fake, applied_discriminator_fn, rngkey): 42 | assert len(x_real) == len(x_fake) 43 | alpha = jax.random.uniform(rngkey, shape=[len(x_real), 1, 1, 1]) 44 | x_hat = x_real * alpha + x_fake * (1 - alpha) 45 | grads = jax.grad(lambda x: applied_discriminator_fn(x)[0].mean())(x_hat) 46 | norm = jnp.sqrt((grads**2).sum(axis=[1, 2, 3])) 47 | penalty = (norm - 1) ** 2 48 | return penalty.mean() * LAMBDA_GP 49 | 50 | 51 | class WGAN_GP(elegy.Model): 52 | def __init__(self): 53 | super().__init__() 54 | self.generator = Generator() 55 | self.discriminator = Discriminator() 56 | self.g_optimizer = optax.adam(2e-4, b1=0.5) 57 | self.d_optimizer = optax.adam(2e-4, b1=0.5) 58 | 59 | def init_step(self, x): 60 | rng = elegy.KeySeq(0) 61 | gx, g_params, g_states = self.generator.init(rng=rng)(x) 62 | dx, d_params, d_states = self.discriminator.init(rng=rng)(gx) 63 | 64 | g_optimizer_states = self.g_optimizer.init(g_params) 65 | d_optimizer_states = self.d_optimizer.init(d_params) 66 | 67 | return elegy.States( 68 | g_states=g_states, 69 | d_states=d_states, 70 | g_params=g_params, 71 | d_params=d_params, 72 | g_opt_states=g_optimizer_states, 73 | d_opt_states=d_optimizer_states, 74 | rng=rng, 75 | step=0, 76 | ) 77 | 78 | def pred_step(self, x, states): 79 | z = x 80 | x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0] 81 | return (x_fake, states) 82 | 83 | def train_step(self, x, states): 84 | # training the discriminator on every iteration 85 | d_loss, gp, states = self.discriminator_step(x, states) 86 | 87 | # training the generator only every 5 iterations as recommended in the original WGAN paper 88 | step = states.step + 1 89 | no_update = lambda states: (0.0, states) 90 | do_update = lambda states: self.generator_step(len(x), states) 91 | g_loss, states = jax.lax.cond(step % 5 == 0, do_update, no_update, states) 92 | 93 | return {"d_loss": d_loss, "g_loss": g_loss, "gp": gp}, states.update(step=step) 94 | 95 | def discriminator_step(self, x_real: jnp.ndarray, states: elegy.States): 96 | z = jax.random.normal(states.rng.next(), (len(x_real), 128)) 97 | x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0] 98 | 99 | def d_loss_fn(d_params, states, x_real, x_fake): 100 | y_real, d_params, d_states = self.discriminator.apply( 101 | d_params, states.d_states 102 | )(x_real) 103 | y_fake, d_params, d_states = self.discriminator.apply(d_params, d_states)( 104 | x_fake 105 | ) 106 | loss = -y_real.mean() + y_fake.mean() 107 | gp = gradient_penalty( 108 | x_real, 109 | x_fake, 110 | self.discriminator.apply(d_params, d_states), 111 | states.rng.next(), 112 | ) 113 | loss = loss + gp 114 | return loss, (gp, states.update_known(**locals())) 115 | 116 | (d_loss, (gp, states)), d_grads = jax.value_and_grad(d_loss_fn, has_aux=True)( 117 | states.d_params, states, x_real, x_fake 118 | ) 119 | d_grads, d_opt_states = self.d_optimizer.update( 120 | d_grads, states.d_opt_states, states.d_params 121 | ) 122 | d_params = optax.apply_updates(states.d_params, d_grads) 123 | 124 | return d_loss, gp, states.update_known(**locals()) 125 | 126 | def generator_step(self, batch_size: int, states: elegy.States): 127 | z = jax.random.normal(states.rng.next(), (batch_size, 128)) 128 | 129 | def g_loss_fn(g_params, states, z): 130 | x_fake, g_params, g_states = self.generator.apply( 131 | g_params, states.g_states 132 | )(z) 133 | y_fake_scores = self.discriminator.apply(states.d_params, states.d_states)( 134 | x_fake 135 | )[0] 136 | y_fake_true = jnp.ones(len(z)) 137 | loss = -y_fake_scores.mean() 138 | return loss, states.update_known(**locals()) 139 | 140 | (g_loss, states), g_grads = jax.value_and_grad(g_loss_fn, has_aux=True)( 141 | states.g_params, states, z 142 | ) 143 | g_grads, g_opt_states = self.g_optimizer.update( 144 | g_grads, states.g_opt_states, states.g_params 145 | ) 146 | g_params = optax.apply_updates(states.g_params, g_grads) 147 | 148 | return g_loss, states.update_known(**locals()) 149 | -------------------------------------------------------------------------------- /examples/need-fixing/imagenet/README.md: -------------------------------------------------------------------------------- 1 | ## Training ResNet on ImageNet 2 | 3 | Adapted from the [Flax](https://github.com/google/flax) library. 4 | 5 | This example currently runs only on one device. 6 | 7 | See `requirements.txt` for required packages, additional to Elegy. 8 | 9 | *** 10 | ### Usage 11 | ``` 12 | resnet_imagenet.py --model= --output_dir=<./output/path> [flags] 13 | 14 | 15 | flags: 16 | --model: : Type of ResNet to train 17 | --output_dir: Directory to save model checkpoints and tensorboard log data 18 | 19 | --base_lr: SGD optimizer base learning rate (default: '0.1') 20 | --batch_size: Input batch size (default: '64') 21 | --[no]cache: Whether to cache the data in RAM (default: 'false') 22 | --dataset: TFDS dataset name and version (default: 'imagenet2012:*.*.*') 23 | --dtype: : Mixed precision or normal mode (default: 'float32') 24 | --epochs: Number of epochs to train (default: '90') 25 | --image_size: Image size in pixels (default: '224') 26 | --L2_reg: L2 weight regularization (default: '0.0001') 27 | --loss_scale: Loss scale for numerical stability when dtype=float16 (default: '1.0') 28 | --momentum: SGD optimizer momentum (default: '0.9') 29 | --[no]nesterov: SGD optimizer Nesterov mode (default: 'true') 30 | ``` 31 | 32 | *** 33 | ### Pretrained Models 34 | 35 | | Model | Top-1 accuracy | Weight Files | 36 | | --- | --- | --- | 37 | | ResNet18 | 68.7% | [model.pkl](https://www.dropbox.com/s/ofwh7947y6t84zp/ResNet18_ImageNet.pkl?dl=1) | 38 | | ResNet50 | 76.5% | [model.pkl](https://www.dropbox.com/s/fmr7tm6rmah682s/ResNet50_ImageNet.pkl?dl=1) | 39 | 40 | Pretrained weights can be loaded with: `elegy.nets.ResNet18(weights='path/to/ResNet18_ImageNet.pkl')` 41 | 42 | or with automatic download: `elegy.nets.ResNet18(weights='imagenet')` 43 | 44 | 45 | *** 46 | [1] He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 47 | 48 | [2] Russakovsky, Olga, et al. "Imagenet large scale visual recognition challenge." International journal of computer vision 115.3 (2015): 211-252. 49 | -------------------------------------------------------------------------------- /examples/need-fixing/imagenet/requirements.txt: -------------------------------------------------------------------------------- 1 | #additional requirements 2 | tensorflow-datasets==4.0.1 3 | tensorflow-gpu==2.2.0 #tensorflow-cpu also ok, but with gpu faster -------------------------------------------------------------------------------- /examples/need-fixing/imagenet/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if "miniconda3/envs" in os.__file__: 4 | # specify the cuda location for XLA when working with conda environments 5 | os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=" + os.sep.join( 6 | os.__file__.split(os.sep)[:-3] 7 | ) 8 | 9 | 10 | import jax 11 | import jax.numpy as jnp 12 | from absl import app, flags 13 | 14 | # importing tensorflow_datasets before performing any jax convolutions gives me a 'DNN Library not found' error later 15 | # workaround: do a dummy convolution before importing tfds 16 | _x0 = jnp.zeros((1, 1, 1, 1)) 17 | _x1 = jnp.zeros((1, 1, 1, 1)) 18 | jax.lax.conv(_x0, _x1, (1, 1), "SAME").block_until_ready() 19 | 20 | 21 | import input_pipeline 22 | import optax 23 | import tensorflow_datasets as tfds 24 | 25 | import elegy 26 | 27 | print("JAX version:", jax.__version__) 28 | print("Elegy version:", elegy.__version__) 29 | 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | flags.DEFINE_enum( 34 | "model", 35 | default=None, 36 | enum_values=[ 37 | "ResNet18", 38 | "ResNet34", 39 | "ResNet50", 40 | "ResNet101", 41 | "ResNet152", 42 | "ResNet200", 43 | ], 44 | help="Type of ResNet to train", 45 | ) 46 | 47 | flags.DEFINE_string( 48 | "output_dir", 49 | default=None, 50 | help="Directory to save model checkpoints and tensorboard log data", 51 | ) 52 | flags.DEFINE_integer("epochs", default=90, help="Number of epochs to train") 53 | flags.DEFINE_integer("batch_size", default=64, help="Input batch size") 54 | flags.DEFINE_integer("image_size", default=224, help="Image size in pixels") 55 | flags.DEFINE_string( 56 | "dataset", default="imagenet2012:*.*.*", help="TFDS dataset name and version" 57 | ) 58 | flags.DEFINE_enum( 59 | "dtype", 60 | default="float32", 61 | enum_values=["float16", "float32"], 62 | help="Mixed precision or normal mode", 63 | ) 64 | flags.DEFINE_float("base_lr", default=0.1, help="SGD optimizer base learning rate") 65 | flags.DEFINE_float("momentum", default=0.9, help="SGD optimizer momentum") 66 | flags.DEFINE_bool("nesterov", default=True, help="SGD optimizer Nesterov mode") 67 | flags.DEFINE_float("L2_reg", default=1e-4, help="L2 weight regularization") 68 | flags.DEFINE_bool("cache", default=False, help="Whether to cache the data in RAM") 69 | flags.DEFINE_float( 70 | "loss_scale", 71 | default=1.0, 72 | help="Loss scale for numerical stability when dtype=float16", 73 | ) 74 | 75 | flags.mark_flag_as_required("model") 76 | flags.mark_flag_as_required("output_dir") 77 | 78 | 79 | def main(argv): 80 | assert ( 81 | len(argv) == 1 82 | ), "Please specify arguments via flags. Use --help for instructions" 83 | 84 | assert ( 85 | getattr(elegy.nets.resnet, FLAGS.model, None) is not None 86 | ), f"{FLAGS.model} is not defined in elegy.nets.resnet" 87 | 88 | assert not os.path.exists( 89 | FLAGS.output_dir 90 | ), "Output directory already exists. Delete manually or specify a new one." 91 | os.makedirs(FLAGS.output_dir) 92 | 93 | # dataset 94 | dataset_builder = tfds.builder(FLAGS.dataset) 95 | ds_train = input_pipeline.create_split( 96 | dataset_builder, 97 | batch_size=FLAGS.batch_size, 98 | image_size=FLAGS.image_size, 99 | dtype=FLAGS.dtype, 100 | train=True, 101 | cache=FLAGS.cache, 102 | ) 103 | ds_valid = input_pipeline.create_split( 104 | dataset_builder, 105 | batch_size=FLAGS.batch_size, 106 | image_size=FLAGS.image_size, 107 | dtype=FLAGS.dtype, 108 | train=False, 109 | cache=FLAGS.cache, 110 | ) 111 | N_BATCHES_TRAIN = ( 112 | dataset_builder.info.splits["train"].num_examples // FLAGS.batch_size 113 | ) 114 | N_BATCHES_VALID = ( 115 | dataset_builder.info.splits["validation"].num_examples // FLAGS.batch_size 116 | ) 117 | 118 | # generator that converts tfds dataset batches to jax arrays 119 | def tfds2jax_generator(tf_ds): 120 | for batch in tf_ds: 121 | yield jnp.asarray(batch["image"], dtype=FLAGS.dtype), jax.device_put( 122 | jnp.asarray(batch["label"]) 123 | ) 124 | 125 | # model and optimizer definition 126 | def build_optimizer( 127 | lr, momentum, steps_per_epoch, n_epochs, nesterov, warmup_epochs=5 128 | ): 129 | cosine_schedule = optax.cosine_decay_schedule( 130 | 1, decay_steps=n_epochs * steps_per_epoch, alpha=1e-10 131 | ) 132 | warmup_schedule = optax.polynomial_schedule( 133 | init_value=0.0, 134 | end_value=1.0, 135 | power=1, 136 | transition_steps=warmup_epochs * steps_per_epoch, 137 | ) 138 | schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x)) 139 | optimizer = optax.sgd(lr, momentum, nesterov=nesterov) 140 | optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule)) 141 | return optimizer 142 | 143 | module = getattr(elegy.nets.resnet, FLAGS.model)(dtype=FLAGS.dtype) 144 | model = elegy.Model( 145 | module, 146 | loss=[ 147 | elegy.losses.Crossentropy(from_logits=True, weight=FLAGS.loss_scale), 148 | elegy.regularizers.L2(FLAGS.L2_reg / 2 * FLAGS.loss_scale), 149 | ], 150 | metrics=elegy.metrics.Accuracy(), 151 | optimizer=build_optimizer( 152 | FLAGS.base_lr / FLAGS.loss_scale, 153 | FLAGS.momentum, 154 | N_BATCHES_TRAIN, 155 | FLAGS.epochs, 156 | FLAGS.nesterov, 157 | ), 158 | ) 159 | 160 | # training 161 | model.fit( 162 | inputs=tfds2jax_generator(ds_train), 163 | validation_data=tfds2jax_generator(ds_valid), 164 | epochs=FLAGS.epochs, 165 | verbose=2, 166 | steps_per_epoch=N_BATCHES_TRAIN, 167 | validation_steps=N_BATCHES_VALID, 168 | callbacks=[ 169 | elegy.callbacks.ModelCheckpoint(FLAGS.output_dir, save_best_only=True), 170 | elegy.callbacks.TerminateOnNaN(), 171 | elegy.callbacks.TensorBoard(logdir=FLAGS.output_dir), 172 | ], 173 | ) 174 | 175 | 176 | if __name__ == "__main__": 177 | app.run(main) 178 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | typer 2 | matplotlib 3 | datasets -------------------------------------------------------------------------------- /gitpod.Dockerfile: -------------------------------------------------------------------------------- 1 | # You can update the PY_VERSION to pick a python version 2 | ARG PY_VERSION=3.8 3 | FROM docker.io/python:${PY_VERSION} 4 | 5 | RUN pip install poetry 6 | RUN poetry config virtualenvs.create false -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Elegy 2 | repo_name: poets-ai/elegy 3 | repo_url: https://github.com/poets-ai/elegy 4 | site_url: https://poets-ai.github.io/elegy 5 | nav: 6 | - Introduction: index.md 7 | - Getting Started: 8 | - High Level API: getting-started/high-level-api.ipynb 9 | - Low Level API: getting-started/low-level-api.ipynb 10 | - Contributing: contributing.md 11 | - API Reference: {} 12 | extra: 13 | search: 14 | language: en 15 | social: 16 | - icon: octicons/mark-github-16 17 | link: https://github.com/cgarciae 18 | - icon: octicons/mark-github-16 19 | link: https://github.com/charlielito 20 | - icon: fontawesome/brands/twitter 21 | link: https://twitter.com/cgarciae88 22 | - icon: fontawesome/brands/linkedin 23 | link: https://www.linkedin.com/in/cgarciae 24 | - icon: fontawesome/brands/linkedin 25 | link: https://www.linkedin.com/in/calvarez92 26 | theme: 27 | name: material 28 | favicon: images/favicon.png 29 | icon: 30 | logo: fontawesome/solid/infinity 31 | repo: octicons/mark-github-16 32 | markdown_extensions: 33 | - pymdownx.arithmatex 34 | - admonition 35 | - codehilite: 36 | guess_lang: false 37 | use_pygments: true 38 | noclasses: true 39 | pygments_style: default 40 | plugins: 41 | - search 42 | - mkdocs-jupyter 43 | - mkdocstrings 44 | extra_javascript: 45 | - https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML 46 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "elegy" 3 | description = "Elegy is a Neural Networks framework based on Jax and Haiku." 4 | authors = ["Cristian Garcia ", 5 | "Carlos Alvarez ", 6 | "David Cardozo ", 7 | "Sebastian Arango"] 8 | version = "0.8.6" 9 | license = "APACHE" 10 | readme = "README.md" 11 | repository = "https://github.com/poets-ai/elegy" 12 | homepage = "https://poets-ai.github.io/elegy" 13 | 14 | [[tool.poetry.source]] 15 | name = "torch" 16 | url = "https://eternalphane.github.io/pytorch-pypi/" 17 | secondary = true 18 | 19 | [tool.poetry.dependencies] 20 | python = ">=3.7,<3.10" 21 | cloudpickle = "^1.5.0" 22 | tensorboardx = "^2.1" 23 | wandb = { version = "^0.12.10", optional = true } 24 | treex = "^0.6.5" 25 | # treex = {path = "../treex", develop = true} 26 | 27 | [tool.poetry.dev-dependencies] 28 | jax = "^0.2.24" 29 | jaxlib = "^0.1.73" 30 | pytest = "^5.4.3" 31 | pytest-cov = "^2.10.0" 32 | dm-haiku = "^0.0.5" 33 | mkdocs = "^1.1.2" 34 | mkdocs-material = "^6.2.7" 35 | mkdocstrings = "^0.14.0" 36 | black = "^22.3.0" 37 | typer = "^0.4.1" 38 | mkdocs-jupyter = { version = "^0.15.1", python = ">=3.7" } 39 | matplotlib = "^3.3.0" 40 | debugpy = "^1.0.0-beta.12" 41 | jupyter = { version = "^1.0.0", python = ">=3.7" } 42 | jupyterlab = { version = "^3.0.6", python = ">=3.7" } 43 | ipython = { version = "^7.20.0", python = ">=3.7" } 44 | flax = "^0.3.6" 45 | torch = "1.9.1+cpu" 46 | einops = "^0.3.0" 47 | sh = "^1.14.1" 48 | pre-commit = "^2.15.0" 49 | datasets = "^1.14.0" 50 | livereload = "^2.6.3" 51 | libclang = "^13.0.0" 52 | wandb = "^0.12.10" 53 | tensorflow-cpu = "^2.9.0" 54 | 55 | [tool.poetry.extras] 56 | wandb = ["all"] 57 | 58 | [build-system] 59 | requires = ["poetry>=0.12"] 60 | build-backend = "poetry.masonry.api" 61 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = 3 | tests 4 | docs 5 | addopts = --doctest-modules --disable-pytest-warnings --doctest-glob="*.md" 6 | doctest_optionflags = NUMBER ELLIPSIS -------------------------------------------------------------------------------- /scripts/deploy-docs.sh: -------------------------------------------------------------------------------- 1 | 2 | cp README.md docs/index.md 3 | cp CONTRIBUTING.md docs/guides/contributing.md 4 | python scripts/update_docs.py 5 | mkdocs build 6 | mkdocs gh-deploy -------------------------------------------------------------------------------- /scripts/get-coverage.sh: -------------------------------------------------------------------------------- 1 | pytest --cov=elegy --cov-report=term-missing --cov-report=html 2 | rm .coverage 3 | rm .coverage.* -------------------------------------------------------------------------------- /scripts/install-remote.sh: -------------------------------------------------------------------------------- 1 | pip install poetry 2 | poetry install 3 | # pip install --upgrade jax jaxlib==0.1.60+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html 4 | poetry shell -------------------------------------------------------------------------------- /scripts/run-docs.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | cp README.md docs/index.md 4 | cp CONTRIBUTING.md docs/guides/contributing.md 5 | python scripts/update_docs.py 6 | 7 | mkdocs serve -------------------------------------------------------------------------------- /scripts/test-all-versions.sh: -------------------------------------------------------------------------------- 1 | for python_version in "3.6" "3.7" "3.8"; do 2 | if [[ $python_version =~ $kPYTHON_VERSIONS ]] || [[ -z "$python_version" ]]; then 3 | bash scripts/test-version.sh "$python_version" 4 | else 5 | echo "Check python version" 6 | fi 7 | done -------------------------------------------------------------------------------- /scripts/test-examples.sh: -------------------------------------------------------------------------------- 1 | 2 | set -e 3 | 4 | 5 | 6 | #---------------------------------------------------------------- 7 | # test docs/getting-started 8 | #---------------------------------------------------------------- 9 | # create tmp_dir 10 | tmp_dir=$(mktemp -d -t XXXXXXXXXX) 11 | 12 | # low-level-api 13 | file="docs/getting-started/low-level-api.ipynb" 14 | echo RUNNING: $file 15 | jupyter nbconvert --log-level "ERROR" --to python --output $tmp_dir/result.py $file > /dev/null 16 | sed -i "s/get_ipython/#get_ipython/" $tmp_dir/result.py 17 | sed -i "s/epochs=100/epochs=2/" $tmp_dir/result.py 18 | sed -i "s/steps_per_epoch=200/steps_per_epoch=2/" $tmp_dir/result.py 19 | sed -i "s/batch_size=64/batch_size=4/" $tmp_dir/result.py 20 | DISPLAY="" python $tmp_dir/result.py > /dev/null 21 | 22 | # high-level-api 23 | file="docs/getting-started/high-level-api.ipynb" 24 | echo RUNNING: $file 25 | jupyter nbconvert --log-level "ERROR" --to python --output $tmp_dir/result.py $file > /dev/null 26 | sed -i "s/get_ipython/#get_ipython/" $tmp_dir/result.py 27 | sed -i "s/epochs=100/epochs=2/" $tmp_dir/result.py 28 | sed -i "s/steps_per_epoch=200/steps_per_epoch=2/" $tmp_dir/result.py 29 | sed -i "s/batch_size=64/batch_size=4/" $tmp_dir/result.py 30 | DISPLAY="" python $tmp_dir/result.py > /dev/null 31 | 32 | # delete tmp_dir 33 | rm -fr $tmp_dir 34 | 35 | #---------------------------------------------------------------- 36 | # test examples 37 | #---------------------------------------------------------------- 38 | for file in $(find examples -name '*.py' -not -path '*/imagenet/*' -not -path '*/WGAN-GP/*') ; do 39 | cmd="python $file --epochs 2 --steps-per-epoch 1 --batch-size 3" 40 | echo RUNNING: $cmd 41 | DISPLAY="" $cmd > /dev/null 42 | done 43 | 44 | #WGAN example 45 | # tmpdir=`mktemp -d`; rm -r $tmpdir 46 | # cmd="python examples/WGAN-GP/main.py --epochs=2 --dataset=examples/WGAN-GP/images/*.png --output_dir=$tmpdir" 47 | # echo RUNNING: $cmd 48 | # DISPLAY="" $cmd > /dev/null -------------------------------------------------------------------------------- /scripts/test-gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # test-gpu - A script to run tests elegy in a gpu enabled container 3 | set -e 4 | 5 | 6 | container_runner () { 7 | 8 | if hash podman 2>/dev/null; then 9 | podman build -f Dockerfile_CUDA -t elegy:latest 10 | podman run --privileged -it --rm --security-opt=label=disable -v "$(pwd)":/usr/src/app:Z -e NVIDIA_VISIBLE_DEVICES=all elegy bash 11 | else 12 | docker build -f Dockerfile_CUDA -t elegy:latest 13 | docker run --privileged -it --rm --security-opt=label=disable -v "$(pwd)":/usr/src/app:Z -e NVIDIA_VISIBLE_DEVICES=all elegy bash 14 | 15 | fi 16 | } 17 | 18 | container_runner -------------------------------------------------------------------------------- /scripts/test-version.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run-test - A script to run tests elegy in a container 3 | # can receive an optional declaring python version 4 | set -e 5 | 6 | kPYTHON_VERSIONS='^[3]\.[0-9]{1,2}$' 7 | kDEFAULT_VERSION=3.8 8 | 9 | 10 | container_runner () { 11 | if [[ -z "$1" ]]; then 12 | py_version="$kDEFAULT_VERSION" 13 | else 14 | py_version=$1 15 | fi 16 | 17 | if hash podman 2>/dev/null; then 18 | podman build --build-arg PY_VERSION="$py_version" -t elegy . 19 | podman run -it --rm -v "$(pwd)":/usr/src/app:Z elegy:latest 20 | else 21 | docker build --build-arg PY_VERSION="$py_version" -t elegy . 22 | docker run -it --rm -v "$(pwd)":/usr/src/app:Z elegy:latest 23 | fi 24 | } 25 | 26 | if [[ $1 =~ $kPYTHON_VERSIONS ]] || [[ -z "$1" ]]; then 27 | container_runner "$1" 28 | else 29 | echo "Check python version" 30 | fi -------------------------------------------------------------------------------- /scripts/update_docs.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import typing as tp 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from types import ModuleType 6 | 7 | import jax 8 | import jinja2 9 | import yaml 10 | 11 | import elegy 12 | 13 | 14 | @dataclass 15 | class Structure: 16 | obj: tp.Any 17 | name_path: str 18 | module_path: str 19 | members: tp.List[str] 20 | 21 | 22 | def get(module, name_path): 23 | 24 | all_members = module.__all__ if hasattr(module, "__all__") else [] 25 | all_members = sorted(all_members) 26 | 27 | outputs = { 28 | name: get(module, f"{name_path}.{name}") 29 | if isinstance(module, ModuleType) 30 | else Structure( 31 | obj=module, 32 | name_path=f"{name_path}.{name}", 33 | module_path=f"{module.__module__}.{name}", 34 | members=module.__all__ if hasattr(module, "__all__") else [], 35 | ) 36 | for module, name in ((getattr(module, name), name) for name in all_members) 37 | } 38 | 39 | return {k: v for k, v in outputs.items() if v} 40 | 41 | 42 | docs_info = get(elegy, "elegy") 43 | 44 | # populate mkdocs 45 | with open("mkdocs.yml", "r") as f: 46 | docs = yaml.safe_load(f) 47 | 48 | 49 | [api_reference_index] = [ 50 | index for index, section in enumerate(docs["nav"]) if "API Reference" in section 51 | ] 52 | 53 | 54 | api_reference = jax.tree_map( 55 | lambda s: s.name_path.replace("elegy", "api").replace(".", "/") + ".md", docs_info 56 | ) 57 | 58 | docs["nav"][api_reference_index] = {"API Reference": api_reference} 59 | 60 | with open("mkdocs.yml", "w") as f: 61 | yaml.safe_dump(docs, f, default_flow_style=False, sort_keys=False) 62 | 63 | 64 | template = """ 65 | # {{name_path}} 66 | 67 | ::: {{module_path}} 68 | selection: 69 | inherited_members: true 70 | {%- if members %} 71 | members: 72 | {%- for member in members %} 73 | - {{member}} 74 | {%- endfor %} 75 | {% endif %} 76 | """ 77 | 78 | api_path = Path("docs/api") 79 | shutil.rmtree(api_path, ignore_errors=True) 80 | 81 | for structure in jax.tree_leaves(docs_info): 82 | filepath: Path = api_path / ( 83 | structure.name_path.replace("elegy.", "").replace(".", "/") + ".md" 84 | ) 85 | markdown = jinja2.Template(template).render( 86 | name_path=structure.name_path, 87 | module_path=structure.module_path, 88 | members=structure.members, 89 | ) 90 | 91 | filepath.parent.mkdir(parents=True, exist_ok=True) 92 | filepath.write_text(markdown) 93 | -------------------------------------------------------------------------------- /scripts/update_version.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | import typer 5 | 6 | 7 | # NOTE: this script could be written bash using sed, but I'm not sure if it's worth it 8 | def main(release_name: str): 9 | release_name = release_name.replace("-create-release", "") 10 | 11 | # Update pyproject.toml 12 | pyproject_path = Path("pyproject.toml") 13 | pyproject_text = pyproject_path.read_text() 14 | pyproject_text = re.sub( 15 | r'version = ".*"', 16 | f'version = "{release_name}"', 17 | pyproject_text, 18 | count=1, 19 | ) 20 | pyproject_path.write_text(pyproject_text) 21 | 22 | # Update __init__.py 23 | init_path = Path("elegy", "__init__.py") 24 | init_text = init_path.read_text() 25 | init_text = re.sub( 26 | r'__version__ = "(.*?)"', 27 | f'__version__ = "{release_name}"', 28 | init_text, 29 | count=1, 30 | ) 31 | init_path.write_text(init_text) 32 | 33 | 34 | if __name__ == "__main__": 35 | typer.run(main) 36 | -------------------------------------------------------------------------------- /tests/callbacks/early_stopping_test.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import jax 4 | import numpy as np 5 | import optax 6 | import pytest 7 | 8 | import elegy as eg 9 | 10 | np.random.seed(42) 11 | 12 | 13 | class EarlyStoppingTest(TestCase): 14 | def test_example(self): 15 | class MLP(eg.Module): 16 | @eg.compact 17 | def __call__(self, x): 18 | x = eg.Linear(10)(x) 19 | x = jax.lax.stop_gradient(x) 20 | return x 21 | 22 | # This callback will stop the training when there is no improvement in 23 | # the for three consecutive epochs. 24 | model = eg.Model( 25 | module=MLP(), 26 | loss=eg.losses.MeanSquaredError(), 27 | optimizer=optax.rmsprop(0.01), 28 | ) 29 | history = model.fit( 30 | inputs=np.ones((5, 20)), 31 | labels=np.zeros((5, 10)), 32 | epochs=10, 33 | batch_size=1, 34 | callbacks=[ 35 | eg.callbacks.EarlyStopping( 36 | monitor="loss", 37 | patience=3, 38 | ) 39 | ], 40 | verbose=0, 41 | ) 42 | assert len(history.history["loss"]) == 4 # Only 4 epochs are run. 43 | 44 | @pytest.mark.skip("fix later") 45 | def test_example_restore(self): 46 | class MLP(eg.Module): 47 | @eg.compact 48 | def __call__(self, x): 49 | x = eg.Linear(10)(x) 50 | x = jax.lax.stop_gradient(x) 51 | return x 52 | 53 | # This callback will stop the training when there is no improvement in 54 | # the for three consecutive epochs. 55 | model = eg.Model( 56 | module=MLP(), 57 | loss=eg.losses.MeanSquaredError(), 58 | optimizer=optax.rmsprop(0.01), 59 | ) 60 | history = model.fit( 61 | inputs=np.ones((5, 20)), 62 | labels=np.zeros((5, 10)), 63 | epochs=10, 64 | batch_size=1, 65 | callbacks=[ 66 | eg.callbacks.EarlyStopping( 67 | monitor="loss", patience=3, restore_best_weights=True 68 | ) 69 | ], 70 | verbose=0, 71 | ) 72 | assert len(history.history["loss"]) == 4 # Only 4 epochs are run. 73 | -------------------------------------------------------------------------------- /tests/data/array_adapter_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | from unittest import TestCase 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | 8 | from elegy.data.array_adapter import ArrayDataAdapter 9 | 10 | 11 | class ArrayDataAdapterTest(TestCase): 12 | def test_basic(self): 13 | x = np.random.uniform(size=(100, 32, 32, 3)) 14 | y = np.random.uniform(size=(100, 1)) 15 | batch_size = 10 16 | epochs = 1 17 | data_adapter = ArrayDataAdapter( 18 | x, 19 | y=y, 20 | sample_weights=None, 21 | batch_size=batch_size, 22 | epochs=epochs, 23 | steps=None, 24 | shuffle=False, 25 | ) 26 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 27 | iterator_fn = data_adapter.get_dataset() 28 | for i, batch in zip(range(num_steps), iterator_fn()): 29 | batch_x, batch_y = batch 30 | assert batch_x.shape == (batch_size, *x.shape[1:]) 31 | assert batch_y.shape == (batch_size, *y.shape[1:]) 32 | np.testing.assert_array_equal( 33 | batch_x, x[i * batch_size : (i + 1) * batch_size] 34 | ) 35 | 36 | data_adapter.get_size() == x.shape[0] 37 | data_adapter.partial_batch_size == 0 38 | 39 | def test_jax(self): 40 | x = jnp.array(np.random.uniform(size=(100, 32, 32, 3))) 41 | y = jnp.array(np.random.uniform(size=(100, 1))) 42 | batch_size = 10 43 | epochs = 1 44 | data_adapter = ArrayDataAdapter( 45 | x, 46 | y=y, 47 | sample_weights=None, 48 | batch_size=batch_size, 49 | epochs=epochs, 50 | steps=None, 51 | shuffle=False, 52 | ) 53 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 54 | iterator_fn = data_adapter.get_dataset() 55 | for i, batch in zip(range(num_steps), iterator_fn()): 56 | batch_x, batch_y = batch 57 | assert batch_x.shape == (batch_size, *x.shape[1:]) 58 | assert batch_y.shape == (batch_size, *y.shape[1:]) 59 | np.testing.assert_array_equal( 60 | batch_x, x[i * batch_size : (i + 1) * batch_size] 61 | ) 62 | 63 | data_adapter.get_size() == x.shape[0] 64 | data_adapter.partial_batch_size == 0 65 | 66 | def test_shuffle(self): 67 | x = np.random.uniform(size=(100, 32, 32, 3)) 68 | y = np.random.uniform(size=(100, 1)) 69 | batch_size = 10 70 | epochs = 1 71 | data_adapter = ArrayDataAdapter( 72 | x, 73 | y=y, 74 | sample_weights=None, 75 | batch_size=batch_size, 76 | epochs=epochs, 77 | steps=None, 78 | shuffle=True, 79 | ) 80 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 81 | iterator_fn = data_adapter.get_dataset() 82 | for i, batch in zip(range(num_steps), iterator_fn()): 83 | batch_x, batch_y = batch 84 | assert batch_x.shape == (batch_size, *x.shape[1:]) 85 | assert batch_y.shape == (batch_size, *y.shape[1:]) 86 | assert not np.array_equal(batch_x, x[i * batch_size : (i + 1) * batch_size]) 87 | 88 | data_adapter.get_size() == x.shape[0] 89 | data_adapter.partial_batch_size == 0 90 | 91 | def test_partial_batch(self): 92 | x = np.random.uniform(size=(100, 32, 32, 3)) 93 | y = np.random.uniform(size=(100, 1)) 94 | batch_size = 32 95 | epochs = 1 96 | data_adapter = ArrayDataAdapter( 97 | x, 98 | y=y, 99 | sample_weights=None, 100 | batch_size=batch_size, 101 | epochs=epochs, 102 | steps=None, 103 | shuffle=True, 104 | ) 105 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 106 | 107 | iterator_fn = data_adapter.get_dataset() 108 | for i, batch in zip(range(num_steps), iterator_fn()): 109 | batch_x, batch_y = batch 110 | if i < num_steps - 1: 111 | assert batch_x.shape == (batch_size, *x.shape[1:]) 112 | assert batch_y.shape == (batch_size, *y.shape[1:]) 113 | else: 114 | assert batch_x.shape == (x.shape[0] % batch_size, *x.shape[1:]) 115 | assert batch_y.shape == (x.shape[0] % batch_size, *y.shape[1:]) 116 | 117 | data_adapter.get_size() == x.shape[0] 118 | data_adapter.partial_batch_size == x.shape[0] % batch_size 119 | -------------------------------------------------------------------------------- /tests/data/data_utils_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | from unittest import TestCase 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | 8 | from elegy.data import utils 9 | 10 | 11 | class TrainValidationSplitTest(TestCase): 12 | def test_basic(self): 13 | x_all = np.random.uniform(size=(100, 32, 32, 3)) 14 | y_all = np.random.uniform(size=(100, 1)) 15 | sample_weight_all = None 16 | split = 0.2 17 | 18 | (x, y, sample_weight), validation_data = utils.train_validation_split( 19 | (x_all, y_all, sample_weight_all), validation_split=0.2, shuffle=False 20 | ) 21 | 22 | assert x.shape[0] == int(x_all.shape[0] * (1 - split)) 23 | assert y.shape[0] == int(y_all.shape[0] * (1 - split)) 24 | assert sample_weight is None 25 | 26 | (x, y, sample_weight) = validation_data 27 | assert x.shape[0] == int(x_all.shape[0] * split) 28 | assert y.shape[0] == int(y_all.shape[0] * split) 29 | assert sample_weight is None 30 | -------------------------------------------------------------------------------- /tests/data/list_adapter_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | from unittest import TestCase 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | 8 | from elegy.data.list_adapter import ListsOfScalarsDataAdapter 9 | 10 | 11 | class ListsOfScalarsDataAdapterTest(TestCase): 12 | def test_basic(self): 13 | x = np.random.uniform(size=(100, 32, 32, 3)) 14 | y = np.random.uniform(size=(100, 1)) 15 | batch_size = 10 16 | epochs = 1 17 | data_adapter = ListsOfScalarsDataAdapter( 18 | x.tolist(), 19 | y=y.tolist(), 20 | sample_weights=None, 21 | batch_size=batch_size, 22 | epochs=epochs, 23 | steps=None, 24 | shuffle=False, 25 | ) 26 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 27 | iterator_fn = data_adapter.get_dataset() 28 | for i, batch in zip(range(num_steps), iterator_fn()): 29 | batch_x, batch_y = batch 30 | assert batch_x.shape == (batch_size, *x.shape[1:]) 31 | assert batch_y.shape == (batch_size, *y.shape[1:]) 32 | np.testing.assert_array_equal( 33 | batch_x, x[i * batch_size : (i + 1) * batch_size] 34 | ) 35 | 36 | data_adapter.get_size() == x.shape[0] 37 | data_adapter.partial_batch_size == 0 38 | 39 | def test_jax(self): 40 | x = jnp.array(np.random.uniform(size=(100, 32, 32, 3))) 41 | y = jnp.array(np.random.uniform(size=(100, 1))) 42 | batch_size = 10 43 | epochs = 1 44 | data_adapter = ListsOfScalarsDataAdapter( 45 | x.tolist(), 46 | y=y.tolist(), 47 | sample_weights=None, 48 | batch_size=batch_size, 49 | epochs=epochs, 50 | steps=None, 51 | shuffle=False, 52 | ) 53 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 54 | iterator_fn = data_adapter.get_dataset() 55 | for i, batch in zip(range(num_steps), iterator_fn()): 56 | batch_x, batch_y = batch 57 | assert batch_x.shape == (batch_size, *x.shape[1:]) 58 | assert batch_y.shape == (batch_size, *y.shape[1:]) 59 | np.testing.assert_array_equal( 60 | batch_x, x[i * batch_size : (i + 1) * batch_size] 61 | ) 62 | 63 | data_adapter.get_size() == x.shape[0] 64 | data_adapter.partial_batch_size == 0 65 | 66 | def test_shuffle(self): 67 | x = np.random.uniform(size=(100, 32, 32, 3)) 68 | y = np.random.uniform(size=(100, 1)) 69 | batch_size = 10 70 | epochs = 1 71 | data_adapter = ListsOfScalarsDataAdapter( 72 | x.tolist(), 73 | y=y.tolist(), 74 | sample_weights=None, 75 | batch_size=batch_size, 76 | epochs=epochs, 77 | steps=None, 78 | shuffle=True, 79 | ) 80 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 81 | iterator_fn = data_adapter.get_dataset() 82 | for i, batch in zip(range(num_steps), iterator_fn()): 83 | batch_x, batch_y = batch 84 | assert batch_x.shape == (batch_size, *x.shape[1:]) 85 | assert batch_y.shape == (batch_size, *y.shape[1:]) 86 | assert not np.array_equal(batch_x, x[i * batch_size : (i + 1) * batch_size]) 87 | 88 | data_adapter.get_size() == x.shape[0] 89 | data_adapter.partial_batch_size == 0 90 | 91 | def test_partial_batch(self): 92 | x = np.random.uniform(size=(100, 32, 32, 3)) 93 | y = np.random.uniform(size=(100, 1)) 94 | batch_size = 32 95 | epochs = 1 96 | data_adapter = ListsOfScalarsDataAdapter( 97 | x.tolist(), 98 | y=y.tolist(), 99 | sample_weights=None, 100 | batch_size=batch_size, 101 | epochs=epochs, 102 | steps=None, 103 | shuffle=True, 104 | ) 105 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 106 | 107 | iterator_fn = data_adapter.get_dataset() 108 | for i, batch in zip(range(num_steps), iterator_fn()): 109 | batch_x, batch_y = batch 110 | if i < num_steps - 1: 111 | assert batch_x.shape == (batch_size, *x.shape[1:]) 112 | assert batch_y.shape == (batch_size, *y.shape[1:]) 113 | else: 114 | assert batch_x.shape == (x.shape[0] % batch_size, *x.shape[1:]) 115 | assert batch_y.shape == (x.shape[0] % batch_size, *y.shape[1:]) 116 | 117 | data_adapter.get_size() == x.shape[0] 118 | data_adapter.partial_batch_size == x.shape[0] % batch_size 119 | -------------------------------------------------------------------------------- /tests/data/tf_dataset_adapter_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | from unittest import TestCase 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from elegy.data.tf_dataset_adapter import TFDatasetAdapter 9 | 10 | 11 | class ArrayDataAdapterTest(TestCase): 12 | def test_basic(self): 13 | batch_size = 10 14 | epochs = 1 15 | x = np.array(np.random.uniform(size=(100, 32, 32, 3))) 16 | y = np.array(np.random.uniform(size=(100, 1))) 17 | dataset = tf.data.Dataset.from_tensor_slices((x, y)) 18 | dataset = dataset.batch(batch_size) 19 | 20 | data_adapter = TFDatasetAdapter(dataset, steps=None) 21 | 22 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 23 | iterator_fn = data_adapter.get_dataset() 24 | for i, batch in zip(range(num_steps), iterator_fn()): 25 | batch_x, batch_y = batch 26 | assert batch_x.shape == (batch_size, *x.shape[1:]) 27 | assert batch_y.shape == (batch_size, *y.shape[1:]) 28 | np.testing.assert_array_equal( 29 | batch_x, x[i * batch_size : (i + 1) * batch_size] 30 | ) 31 | 32 | assert data_adapter.get_size() * batch_size == x.shape[0] 33 | assert data_adapter.batch_size == batch_size 34 | 35 | def test_only_x_repeat(self): 36 | batch_size = 10 37 | epochs = 2 38 | 39 | x = np.array(np.random.uniform(size=(100, 32, 32, 3))) 40 | dataset = tf.data.Dataset.from_tensor_slices(x) 41 | dataset = dataset.batch(batch_size) 42 | dataset = dataset.repeat() 43 | 44 | dataset_length = x.shape[0] 45 | num_steps = math.ceil(dataset_length / batch_size) * epochs 46 | 47 | data_adapter = TFDatasetAdapter( 48 | dataset, steps=math.ceil(dataset_length / batch_size) 49 | ) 50 | 51 | iterator_fn = data_adapter.get_dataset() 52 | for i, batch in zip(range(num_steps), iterator_fn()): 53 | batch_x = batch 54 | assert batch_x.shape == (batch_size, *x.shape[1:]) 55 | np.testing.assert_array_equal( 56 | batch_x, 57 | x[ 58 | (i * batch_size) 59 | % dataset_length : (i * batch_size) 60 | % dataset_length 61 | + batch_size 62 | ], 63 | ) 64 | 65 | assert data_adapter.get_size() * batch_size == x.shape[0] 66 | assert data_adapter.batch_size == batch_size 67 | assert i == num_steps - 1 68 | 69 | def test_error(self): 70 | batch_size = 10 71 | epochs = 2 72 | x = np.array(np.random.uniform(size=(100, 32, 32, 3))) 73 | dataset = tf.data.Dataset.from_tensor_slices(x) 74 | dataset = dataset.batch(batch_size) 75 | 76 | data_adapter = TFDatasetAdapter(dataset, steps=None) 77 | 78 | num_steps = math.ceil(x.shape[0] / batch_size) * epochs 79 | iterator_fn = data_adapter.get_dataset() 80 | iterator = iterator_fn() 81 | 82 | with self.assertRaises(StopIteration): 83 | for i in range(num_steps): 84 | batch = next(iterator) 85 | batch_x = batch 86 | assert batch_x.shape == (batch_size, *x.shape[1:]) 87 | np.testing.assert_array_equal( 88 | batch_x, x[i * batch_size : (i + 1) * batch_size] 89 | ) 90 | -------------------------------------------------------------------------------- /tests/data/torch_dataloader_adapter_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | from unittest import TestCase 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset 7 | 8 | from elegy.data.torch_dataloader_adapter import TorchDataLoaderAdapter 9 | 10 | 11 | class ArrayDataAdapterTest(TestCase): 12 | def test_basic(self): 13 | batch_size = 10 14 | epochs = 1 15 | x = np.array(np.random.uniform(size=(100, 32, 32, 3))) 16 | y = np.array(np.random.uniform(size=(100, 1))) 17 | 18 | dataset = TensorDataset(torch.from_numpy(x), torch.from_numpy(y)) 19 | dataloader = DataLoader(dataset, batch_size=batch_size) 20 | 21 | data_adapter = TorchDataLoaderAdapter(dataloader) 22 | 23 | dataset_length = x.shape[0] 24 | num_steps = math.ceil(dataset_length / batch_size) * epochs 25 | iterator_fn = data_adapter.get_dataset() 26 | for i, batch in zip(range(num_steps), iterator_fn()): 27 | batch_x, batch_y = batch 28 | assert batch_x.shape == (batch_size, *x.shape[1:]) 29 | assert batch_y.shape == (batch_size, *y.shape[1:]) 30 | np.testing.assert_array_equal( 31 | batch_x, 32 | x[ 33 | (i * batch_size) 34 | % dataset_length : (i * batch_size) 35 | % dataset_length 36 | + batch_size 37 | ], 38 | ) 39 | 40 | assert data_adapter.get_size() * batch_size == x.shape[0] 41 | assert data_adapter.batch_size == batch_size 42 | assert i == num_steps - 1 43 | -------------------------------------------------------------------------------- /tests/model/model_base_test.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import unittest 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader, TensorDataset 8 | 9 | import elegy as eg 10 | 11 | 12 | class TestModelBase(unittest.TestCase): 13 | def test_predict(self): 14 | 15 | N = 0 16 | 17 | class Model(eg.ModelBase): 18 | a: jnp.ndarray = eg.node() 19 | 20 | def init_step( 21 | self, 22 | key: jnp.ndarray, 23 | inputs: tp.Any, 24 | ) -> "Model": 25 | self.a = jnp.array(0, dtype=jnp.int32) 26 | return self 27 | 28 | def pred_step(self, inputs): 29 | nonlocal N 30 | N += 1 31 | 32 | preds = inputs + 1.0 33 | self.a += 1 34 | 35 | return preds, self 36 | 37 | def reset_metrics(self): 38 | pass 39 | 40 | model = Model() 41 | 42 | x = np.random.uniform(size=(100, 1)) 43 | y = model.predict(x, batch_size=50) 44 | 45 | assert np.allclose(y, x + 1.0) 46 | assert model.a == 2 47 | assert N == 1 48 | 49 | y = model.predict(x, batch_size=50) 50 | assert np.allclose(y, x + 1.0) 51 | assert model.a == 4 52 | assert N == 1 53 | 54 | model = model.eager() 55 | 56 | y = model.predict(x, batch_size=50) 57 | assert np.allclose(y, x + 1.0) 58 | assert model.a == 6 59 | assert N == 3 60 | 61 | def test_evaluate(self): 62 | N = 0 63 | 64 | class Model(eg.ModelBase): 65 | a: jnp.ndarray = eg.node() 66 | 67 | def init_step( 68 | self, 69 | key: jnp.ndarray, 70 | inputs: tp.Any, 71 | ) -> "Model": 72 | self.a = jnp.array(0, dtype=jnp.int32) 73 | return self 74 | 75 | def test_step(self, inputs, labels): 76 | nonlocal N 77 | N += 1 78 | 79 | preds = inputs + 1.0 80 | self.a += 1 81 | 82 | loss = 0.1 83 | logs = dict(loss=jnp.sum(inputs)) 84 | 85 | return loss, logs, self 86 | 87 | def reset_metrics(self): 88 | pass 89 | 90 | model = Model() 91 | 92 | x = np.random.uniform(size=(100, 1)) 93 | 94 | logs = model.evaluate(x, batch_size=100) 95 | assert np.allclose(logs["loss"], np.sum(x)) 96 | assert N == 1 97 | assert model.a == 1 98 | 99 | logs = model.evaluate(x, batch_size=50) 100 | assert np.allclose(logs["loss"], np.sum(x[50:])) 101 | assert N == 2 102 | assert model.a == 3 103 | 104 | logs = model.evaluate(x, batch_size=50) 105 | assert np.allclose(logs["loss"], np.sum(x[50:])) 106 | assert N == 2 107 | assert model.a == 5 108 | 109 | model = model.eager() 110 | 111 | logs = model.evaluate(x, batch_size=50) 112 | assert np.allclose(logs["loss"], np.sum(x[50:])) 113 | assert N == 4 114 | assert model.a == 7 115 | 116 | def test_fit(self): 117 | N = 0 118 | 119 | class Model(eg.ModelBase): 120 | a: jnp.ndarray = eg.node() 121 | 122 | def init_step( 123 | self, 124 | key: jnp.ndarray, 125 | inputs: tp.Any, 126 | ) -> "Model": 127 | self.a = jnp.array(0, dtype=jnp.int32) 128 | return self 129 | 130 | def train_step(self, inputs, labels): 131 | nonlocal N 132 | N += 1 133 | 134 | self.a += 1 135 | 136 | logs = dict(loss=jnp.sum(inputs)) 137 | 138 | return logs, self 139 | 140 | def reset_metrics(self): 141 | pass 142 | 143 | model = Model() 144 | 145 | x = np.random.uniform(size=(100, 1)) 146 | 147 | history = model.fit(x, batch_size=100) 148 | assert np.allclose(history.history["loss"], np.sum(x)) 149 | assert N == 1 150 | assert model.a == 1 151 | 152 | history = model.fit(x, batch_size=50, shuffle=False) 153 | assert np.allclose(history.history["loss"][0], np.sum(x[50:])) 154 | assert N == 2 155 | assert model.a == 3 156 | 157 | history = model.fit(x, batch_size=50, shuffle=False) 158 | assert np.allclose(history.history["loss"], np.sum(x[50:])) 159 | assert N == 2 160 | assert model.a == 5 161 | 162 | model = model.eager() 163 | 164 | history = model.fit(x, batch_size=50, shuffle=False) 165 | assert np.allclose(history.history["loss"], np.sum(x[50:])) 166 | assert N == 4 167 | assert model.a == 7 168 | 169 | def test_dataloader(self): 170 | N = 0 171 | 172 | class Model(eg.ModelBase): 173 | a: jnp.ndarray = eg.node() 174 | 175 | def init_step( 176 | self, 177 | key: jnp.ndarray, 178 | inputs: tp.Any, 179 | ) -> "Model": 180 | self.a = jnp.array(0, dtype=jnp.int32) 181 | return self 182 | 183 | def pred_step(self, inputs): 184 | nonlocal N 185 | N += 1 186 | 187 | preds = inputs + 1.0 188 | self.a += 1 189 | 190 | return preds, self 191 | 192 | def reset_metrics(self): 193 | pass 194 | 195 | model = Model() 196 | 197 | x = np.random.uniform(size=(10, 1)) 198 | y = np.random.uniform(size=(10, 3)) 199 | 200 | dataset = TensorDataset(torch.from_numpy(x), torch.from_numpy(y)) 201 | dataloader = DataLoader(dataset, batch_size=2) 202 | 203 | y_pred = model.predict(x=dataloader) 204 | assert jnp.allclose(y_pred, x + 1) 205 | y_pred = model.predict(x=dataloader) 206 | assert jnp.allclose(y_pred, x + 1) 207 | y_pred 208 | 209 | 210 | if __name__ == "__main__": 211 | TestModelBase().test_fit() 212 | -------------------------------------------------------------------------------- /tests/model/model_core_test.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import unittest 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import treex as tx 8 | 9 | import elegy 10 | from elegy.model.model_core import ModelCore 11 | 12 | 13 | class ModelCoreTest(unittest.TestCase): 14 | def test_init(self): 15 | N = 0 16 | 17 | class Model(ModelCore): 18 | a: jnp.ndarray = tx.node() 19 | 20 | def __init__(self): 21 | super().__init__() 22 | self.a = jnp.array(-1, dtype=jnp.int32) 23 | 24 | def init_step( 25 | self, 26 | key: jnp.ndarray, 27 | inputs: tp.Any, 28 | ) -> "Model": 29 | nonlocal N 30 | 31 | N += 1 32 | self.a = jnp.array(0, dtype=jnp.int32) 33 | print("JITTING") 34 | return self 35 | 36 | model = Model() 37 | inputs = np.array(1.0) 38 | 39 | assert N == 0 40 | assert model.a == -1 41 | 42 | model.init_on_batch(inputs) 43 | assert N == 1 44 | 45 | # jits again because _initialized changed 46 | model.init_on_batch(inputs) 47 | assert N == 2 48 | 49 | # no jit change this time 50 | model.init_on_batch(inputs) 51 | assert N == 2 52 | 53 | def test_pred_step(self): 54 | N = 0 55 | 56 | class Model(ModelCore): 57 | a: jnp.ndarray = tx.node() 58 | 59 | def init_step( 60 | self, 61 | key: jnp.ndarray, 62 | inputs: tp.Any, 63 | ) -> "Model": 64 | self.a = jnp.array(0, dtype=jnp.int32) 65 | return self 66 | 67 | def pred_step(self, inputs): 68 | nonlocal N 69 | N += 1 70 | 71 | self.a += 1 72 | 73 | return 1, self 74 | 75 | model = Model() 76 | 77 | preds = model.predict_on_batch(inputs=np.array(1.0)) 78 | assert N == 1 79 | assert preds == 1 80 | assert model.a == 1 81 | 82 | preds = model.predict_on_batch(inputs=np.array(1.0)) 83 | assert N == 1 84 | assert preds == 1 85 | assert model.a == 2 86 | 87 | model.eager = True 88 | 89 | preds = model.predict_on_batch(inputs=(np.array(1.0))) 90 | assert N == 2 91 | assert preds == 1 92 | assert model.a == 3 93 | 94 | def test_test_step(self): 95 | N = 0 96 | 97 | class Model(ModelCore): 98 | a: jnp.ndarray = tx.node() 99 | 100 | def init_step( 101 | self, 102 | key: jnp.ndarray, 103 | inputs: tp.Any, 104 | ) -> "Model": 105 | self.a = jnp.array(0, dtype=jnp.int32) 106 | return self 107 | 108 | def test_step(self, inputs, labels) -> elegy.TestStepOutput["Model"]: 109 | nonlocal N 110 | N += 1 111 | self.a += 1 112 | loss = 1.0 113 | 114 | return loss, dict(loss=loss), self 115 | 116 | model = Model() 117 | 118 | logs = model.test_on_batch(inputs=(np.array(1.0)), labels=(1.0,)) 119 | assert N == 1 120 | assert logs["loss"] == 1.0 121 | assert model.a == 1 122 | 123 | logs = model.test_on_batch(inputs=(np.array(1.0)), labels=(1.0,)) 124 | assert N == 1 125 | assert logs["loss"] == 1.0 126 | assert model.a == 2 127 | 128 | model.eager = True 129 | 130 | logs = model.test_on_batch(inputs=(np.array(1.0)), labels=(1.0,)) 131 | assert N == 2 132 | assert logs["loss"] == 1 133 | assert model.a == 3 134 | 135 | def test_train_step(self): 136 | N = 0 137 | 138 | class Model(ModelCore): 139 | a: jnp.ndarray = tx.node() 140 | 141 | def init_step( 142 | self, 143 | key: jnp.ndarray, 144 | inputs: tp.Any, 145 | ) -> "Model": 146 | self.a = jnp.array(0, dtype=jnp.int32) 147 | return self 148 | 149 | def train_step(self, inputs, labels) -> elegy.TrainStepOutput["Model"]: 150 | nonlocal N 151 | N += 1 152 | self.a += 1 153 | loss = 2.0 154 | 155 | return dict(loss=loss), self 156 | 157 | model = Model() 158 | 159 | logs = model.train_on_batch(inputs=(np.array(1.0)), labels=(1.0,)) 160 | assert N == 1 161 | assert logs["loss"] == 2.0 162 | assert model.a == 1 163 | 164 | logs = model.train_on_batch(inputs=(np.array(1.0)), labels=(1.0,)) 165 | assert N == 1 166 | assert logs["loss"] == 2.0 167 | assert model.a == 2 168 | 169 | model.eager = True 170 | 171 | logs = model.train_on_batch(inputs=(np.array(1.0)), labels=(1.0,)) 172 | assert N == 2 173 | assert logs["loss"] == 2.0 174 | assert model.a == 3 175 | -------------------------------------------------------------------------------- /tests/model/model_test.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import unittest 3 | from dataclasses import dataclass 4 | from hashlib import new 5 | from pathlib import Path 6 | from tempfile import TemporaryDirectory 7 | 8 | import cloudpickle 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | import optax 13 | import pytest 14 | import sh 15 | import tensorflow as tf 16 | from treex.nn.linear import Linear 17 | 18 | import elegy as eg 19 | 20 | 21 | @dataclass 22 | class MLP(eg.Module): 23 | dmid: int 24 | dout: int 25 | 26 | @eg.compact 27 | def __call__(self, x: jnp.ndarray): 28 | x = eg.Linear(self.dmid)(x) 29 | x = eg.BatchNorm()(x) 30 | x = jax.nn.relu(x) 31 | 32 | x = eg.Linear(self.dout)(x) 33 | return x 34 | 35 | 36 | class ModelBasicTest(unittest.TestCase): 37 | def test_predict(self): 38 | 39 | model = eg.Model(module=eg.Linear(1)) 40 | 41 | X = np.random.uniform(size=(5, 2)) 42 | y = np.random.randint(10, size=(5, 1)) 43 | 44 | y_pred = model.predict(X) 45 | 46 | assert y_pred.shape == (5, 1) 47 | 48 | def test_evaluate(self): 49 | class mse(eg.Loss): 50 | def call(self, target, preds): 51 | return jnp.mean((target - preds) ** 2) 52 | 53 | class mae(eg.Metric): 54 | value: eg.MetricState = eg.MetricState.node( 55 | default=jnp.array(0.0, jnp.float32) 56 | ) 57 | 58 | def update(self, target, preds): 59 | return jnp.mean(jnp.abs(target - preds)) 60 | 61 | def compute(self) -> tp.Any: 62 | return self.value 63 | 64 | model = eg.Model( 65 | module=eg.Linear(1), 66 | loss=dict(a=mse()), 67 | metrics=dict(b=mae()), 68 | optimizer=optax.adamw(1e-3), 69 | eager=True, 70 | ) 71 | 72 | X = np.random.uniform(size=(5, 2)) 73 | y = np.random.uniform(size=(5, 1)) 74 | 75 | logs = model.evaluate(x=X, y=y) 76 | 77 | assert "a/mse_loss" in logs 78 | assert "b/mae" in logs 79 | assert "loss" in logs 80 | 81 | 82 | class ModelTest(unittest.TestCase): 83 | def test_evaluate(self): 84 | 85 | model = eg.Model( 86 | module=MLP(dmid=3, dout=4), 87 | loss=[ 88 | eg.losses.Crossentropy(), 89 | eg.regularizers.L2(l=1e-4), 90 | ], 91 | metrics=eg.metrics.Accuracy(), 92 | optimizer=optax.adamw(1e-3), 93 | eager=True, 94 | ) 95 | 96 | X = np.random.uniform(size=(5, 2)) 97 | y = np.random.randint(4, size=(5,)) 98 | 99 | history = model.fit( 100 | inputs=X, 101 | labels=y, 102 | epochs=1, 103 | steps_per_epoch=1, 104 | batch_size=5, 105 | validation_data=(X, y), 106 | shuffle=True, 107 | verbose=1, 108 | ) 109 | 110 | logs = model.evaluate(X, y) 111 | 112 | eval_acc = logs["accuracy"] 113 | predict_acc = (model.predict(X).argmax(-1) == y).mean() 114 | 115 | assert eval_acc == predict_acc 116 | 117 | def test_saved_model(self): 118 | 119 | with TemporaryDirectory() as model_dir: 120 | 121 | model = eg.Model(module=eg.Linear(4)) 122 | 123 | x = np.random.uniform(size=(5, 6)) 124 | 125 | model.merge 126 | 127 | model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8]) 128 | 129 | output = str(sh.ls(model_dir)) 130 | 131 | assert "saved_model.pb" in output 132 | assert "variables" in output 133 | 134 | saved_model = tf.saved_model.load(model_dir) 135 | 136 | saved_model 137 | 138 | def test_saved_model_poly(self): 139 | 140 | with TemporaryDirectory() as model_dir: 141 | 142 | model = eg.Model(module=eg.Linear(4)) 143 | 144 | x = np.random.uniform(size=(5, 6)).astype(np.float32) 145 | 146 | model.saved_model(x, model_dir, batch_size=None) 147 | 148 | output = str(sh.ls(model_dir)) 149 | 150 | assert "saved_model.pb" in output 151 | assert "variables" in output 152 | 153 | saved_model = tf.saved_model.load(model_dir) 154 | 155 | # change batch 156 | x = np.random.uniform(size=(3, 6)).astype(np.float32) 157 | y = saved_model(x) 158 | 159 | assert y.shape == (3, 4) 160 | 161 | @pytest.mark.skip("only failing within pytest for some reason") 162 | def test_cloudpickle(self): 163 | model = eg.Model( 164 | module=eg.Linear(10), 165 | loss=[ 166 | eg.losses.Crossentropy(), 167 | eg.regularizers.L2(1e-4), 168 | ], 169 | metrics=eg.metrics.Accuracy(), 170 | optimizer=optax.adamw(1e-3), 171 | eager=True, 172 | ) 173 | 174 | X = np.random.uniform(size=(5, 2)) 175 | y = np.random.randint(10, size=(5,)) 176 | 177 | y0 = model.predict(X) 178 | 179 | with TemporaryDirectory() as model_dir: 180 | model.save(model_dir) 181 | newmodel = eg.load(model_dir) 182 | 183 | y1 = newmodel.predict(X) 184 | assert np.all(y0 == y1) 185 | 186 | def test_distributed_init(self): 187 | n_devices = jax.device_count() 188 | batch_size = 5 * n_devices 189 | 190 | x = np.random.uniform(size=(batch_size, 1)) 191 | y = 1.4 * x + 0.1 * np.random.uniform(size=(batch_size, 2)) 192 | 193 | model = eg.Model( 194 | eg.Linear(2), 195 | loss=[eg.losses.MeanSquaredError()], 196 | ) 197 | 198 | model = model.distributed() 199 | 200 | model.init_on_batch(x) 201 | 202 | assert model.module.kernel.shape == (n_devices, 1, 2) 203 | assert model.module.bias.shape == (n_devices, 2) 204 | 205 | 206 | # DELETE THIS 207 | if __name__ == "__main__": 208 | ModelTest().test_distributed_init() 209 | 210 | # DONT REMOVE THIS, CI WILL RUN THIS 211 | if __name__ == "__main__": 212 | ModelTest().test_cloudpickle() 213 | -------------------------------------------------------------------------------- /tests/nets/resnet_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import tempfile 4 | import urllib 5 | from unittest import TestCase 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import PIL 11 | import pytest 12 | 13 | import elegy 14 | from elegy import utils 15 | 16 | 17 | class ResNetTest(TestCase): 18 | @pytest.mark.skip("pending fix") 19 | def test_basic_predict(self): 20 | # FIXME: test succeeds if run alone or if run on the cpu-only version of jax 21 | # test fails with "DNN library is not found" if run on gpu with all other tests together 22 | 23 | model = elegy.Model(elegy.nets.resnet.ResNet18(), eager=True) 24 | assert isinstance(model.module, elegy.Module) 25 | 26 | x = np.random.random((2, 224, 224, 3)).astype(np.float32) 27 | 28 | model.init(x) 29 | y = model.predict(x) 30 | 31 | # update_modules results in a call to `set_default_parameters` for elegy Modules 32 | # it might be better to have the user call this explicitly to avoid potential OOM 33 | model.update_modules() 34 | 35 | assert jnp.all(y.shape == (2, 1000)) 36 | 37 | # test loading weights from file 38 | with tempfile.TemporaryDirectory() as tempdir: 39 | pklpath = os.path.join(tempdir, "delete_me.pkl") 40 | open(pklpath, "wb").write( 41 | pickle.dumps(model.module.get_default_parameters()) 42 | ) 43 | 44 | new_r18 = elegy.nets.resnet.ResNet18(weights=pklpath) 45 | y2 = elegy.Model(new_r18, eager=True).predict(x, initialize=True) 46 | 47 | assert np.allclose(y, y2, rtol=0.001) 48 | 49 | @pytest.mark.skip("pending fix") 50 | def test_autodownload_pretrained_r18(self): 51 | fname, _ = urllib.request.urlretrieve( 52 | "https://upload.wikimedia.org/wikipedia/commons/e/e4/A_French_Bulldog.jpg" 53 | ) 54 | im = np.array(PIL.Image.open(fname).resize([224, 224])) / np.float32(255) 55 | 56 | r18 = elegy.nets.resnet.ResNet18(weights="imagenet") 57 | with jax.disable_jit(): 58 | assert ( 59 | elegy.Model(r18).predict(im[np.newaxis], initialize=True).argmax() 60 | == 245 61 | ) 62 | 63 | @pytest.mark.skip("pending fix") 64 | def test_autodownload_pretrained_r50(self): 65 | fname, _ = urllib.request.urlretrieve( 66 | "https://upload.wikimedia.org/wikipedia/commons/e/e4/A_French_Bulldog.jpg" 67 | ) 68 | im = np.array(PIL.Image.open(fname).resize([224, 224])) / np.float32(255) 69 | 70 | r50 = elegy.nets.resnet.ResNet50(weights="imagenet") 71 | with jax.disable_jit(): 72 | assert ( 73 | elegy.Model(r50).predict(im[np.newaxis], initialize=True).argmax() 74 | == 245 75 | ) 76 | -------------------------------------------------------------------------------- /tests/utils_test.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import pytest 4 | 5 | from elegy import utils 6 | 7 | 8 | class DIFunctionTests(TestCase): 9 | def test_positional(self): 10 | def f(a, b, c): 11 | return a + b + c 12 | 13 | g = utils.inject_dependencies(f) 14 | 15 | y = g("a", "b", "c") 16 | 17 | assert y == "abc" 18 | 19 | def test_positional_error_missing(self): 20 | def f(a, b, c): 21 | return a + b + c 22 | 23 | g = utils.inject_dependencies(f) 24 | 25 | with pytest.raises(TypeError): 26 | g("a", "b") 27 | 28 | def test_positional_error_remaining(self): 29 | def f(a, b, c): 30 | return a + b + c 31 | 32 | g = utils.inject_dependencies(f) 33 | 34 | with pytest.raises(TypeError): 35 | g("a", "b", "c", "d") 36 | 37 | def test_positional_extras_ok(self): 38 | def f(a, b, c): 39 | return a + b + c 40 | 41 | g = utils.inject_dependencies(f) 42 | 43 | y = g("a", "b", "c", d="d") 44 | 45 | assert y == "abc" 46 | 47 | def test_keyword(self): 48 | def f(a, b, c): 49 | return a + b + c 50 | 51 | g = utils.inject_dependencies(f) 52 | 53 | y = g(b="b", c="c", a="a") 54 | 55 | assert y == "abc" 56 | 57 | def test_keyword_extras_ok(self): 58 | def f(a, b, c): 59 | return a + b + c 60 | 61 | g = utils.inject_dependencies(f) 62 | 63 | y = g(b="b", c="c", a="a", d="d") 64 | 65 | assert y == "abc" 66 | 67 | def test_keyword_error_missing(self): 68 | def f(a, b, c): 69 | return a + b + c 70 | 71 | g = utils.inject_dependencies(f) 72 | 73 | with pytest.raises(TypeError): 74 | g(b="b", c="c") 75 | 76 | def test_mixed(self): 77 | def f(a, b, c): 78 | return a + b + c 79 | 80 | g = utils.inject_dependencies(f) 81 | 82 | y = g("a", c="c", b="b") 83 | 84 | assert y == "abc" 85 | 86 | def test_mixed_ignore_duplicated_kwarg_in_arg(self): 87 | def f(a, b, c): 88 | return a + b + c 89 | 90 | g = utils.inject_dependencies(f) 91 | 92 | y = g("a", c="c", b="b", a="f") 93 | 94 | assert y == "abc" 95 | 96 | def test_override_defaults(self): 97 | def f(a, b, c="x"): 98 | return a + b + c 99 | 100 | g = utils.inject_dependencies(f) 101 | 102 | y = g("a", c="c", b="b") 103 | 104 | assert y == "abc" 105 | 106 | 107 | class TestMergeStructs: 108 | def test_basic(self): 109 | a = dict(x=1) 110 | b = dict(y=2) 111 | 112 | c = utils.merge_params(a, b) 113 | 114 | assert c == {"x": 1, "y": 2} 115 | 116 | def test_hierarchy(self): 117 | a = dict(a=dict(x=1)) 118 | b = dict(a=dict(y=2)) 119 | 120 | c = utils.merge_params(a, b) 121 | 122 | assert c == {"a": {"x": 1, "y": 2}} 123 | 124 | def test_repeated_leafs(self): 125 | a = dict(a=dict(x=1)) 126 | b = dict(a=dict(x=2)) 127 | 128 | with pytest.raises(ValueError): 129 | c = utils.merge_params(a, b) 130 | 131 | def test_list(self): 132 | a = [dict(x=1)] 133 | b = [dict(y=2)] 134 | 135 | c = utils.merge_params(a, b) 136 | 137 | assert c == [{"x": 1, "y": 2}] 138 | 139 | def test_different_lengths(self): 140 | a = [dict(x=1)] 141 | b = [] 142 | 143 | with pytest.raises(ValueError): 144 | c = utils.merge_params(a, b) 145 | -------------------------------------------------------------------------------- /tmp/test.py: -------------------------------------------------------------------------------- 1 | import elegy as eg 2 | --------------------------------------------------------------------------------