├── .github └── workflows │ └── unit_tests.yaml ├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── docs ├── code-of-conduct.md └── contributing.md ├── pytest.ini ├── recml ├── __init__.py ├── core │ ├── data │ │ ├── __init__.py │ │ ├── iterator.py │ │ ├── preprocessing.py │ │ └── tf_dataset_factory.py │ ├── metrics │ │ ├── __init__.py │ │ ├── base_metrics.py │ │ ├── base_metrics_test.py │ │ ├── confusion_metrics.py │ │ ├── confusion_metrics_test.py │ │ ├── mean_metrics.py │ │ ├── mean_metrics_test.py │ │ ├── reduction_metrics.py │ │ ├── reduction_metrics_test.py │ │ └── tools.py │ ├── ops │ │ └── embedding_ops.py │ ├── training │ │ ├── core.py │ │ ├── jax_trainer.py │ │ ├── jax_trainer_quality_test.py │ │ ├── jax_trainer_test.py │ │ ├── keras_trainer.py │ │ ├── keras_trainer_test.py │ │ ├── optax_factory.py │ │ ├── optax_factory_test.py │ │ ├── partitioning.py │ │ └── partitioning_test.py │ └── utils │ │ ├── __init__.py │ │ ├── config.py │ │ ├── config_test.py │ │ ├── keras_utils.py │ │ ├── keras_utils_test.py │ │ ├── py_utils.py │ │ ├── types.py │ │ └── types_test.py ├── examples │ ├── dlrm_experiment.py │ └── dlrm_experiment_test.py └── layers │ ├── keras │ ├── README.md │ ├── bert4rec.py │ ├── bert4rec_test.py │ ├── hstu.py │ ├── hstu_test.py │ ├── mamba.py │ ├── mamba_test.py │ ├── sasrec.py │ ├── sasrec_test.py │ ├── utils.py │ └── utils_test.py │ └── linen │ ├── sparsecore.py │ └── sparsecore_test.py └── requirements.txt /.github/workflows/unit_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run Unit Tests 2 | 3 | on: 4 | push: 5 | branches: [ "main", "master" ] 6 | pull_request: 7 | branches: [ "main", "master" ] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | run_tests: 12 | name: Run Tests 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v4 17 | - name: Install Dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install -r requirements.txt # If you have a requirements.txt file 21 | - name: Run Pytest 22 | run: | 23 | export KERAS_BACKEND=jax 24 | python -m pytest -v -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RecML: High-Performance Recommender Library 2 | 3 | ## Vision 4 | 5 | RecML is envisioned as a high-performance, large-scale deep learning recommender 6 | system library optimized for Cloud TPUs. It aims to provide researchers and 7 | practitioners state-of-the-art reference implementations, tools, and best 8 | practice guidelines for building and deploying recommender systems. 9 | 10 | The key goals of RecML are: 11 | 12 | * **Performance & Scalability:** Leverage Cloud TPUs (including SparseCore 13 | acceleration) to deliver exceptional performance for training and serving 14 | massive models with large embeddings on datasets with millions or billions 15 | of items/users. RecML can additionally target Cloud GPUs. 16 | * **State-of-the-Art Models:** Provide production-ready, easy-to-understand 17 | reference implementations of popular and cutting-edge models, with a strong 18 | focus on LLM-based recommenders. 19 | * **Ease of Use:** Offer a user-friendly API, intuitive abstractions, and 20 | comprehensive documentation/examples for rapid prototyping and deployment. 21 | * **Flexibility:** Primarily built with Keras and JAX, but designed with 22 | potential future expansion to other frameworks like PyTorch/XLA. 23 | * **Open Source:** Foster community collaboration and provide components to 24 | help users get started with advanced recommender workloads on Google Cloud. 25 | 26 | ## Features 27 | 28 | * **High Performance:** Optimized for Cloud TPU (SparseCore) training and 29 | inference. 30 | * **Scalable Architecture:** Designed for massive datasets and models with 31 | large embedding tables. Includes support for efficient data loading 32 | (tf.data, potentially Grain) and sharding/SPMD. 33 | * **State-of-the-Art Model Implementations:** Reference implementations for 34 | various recommendation tasks (ranking, retrieval, sequential). 35 | * **Reusable Building Blocks:** 36 | * Common recommendation layers (e.g., DCN, BERT4Rec). 37 | * Specialized Embedding APIs (e.g. JAX Embedding API for SparseCore). 38 | * Standardized metrics (e.g., AUC, Accuracy, NDCG@K, MRR, Recall@K). 39 | * Common loss functions. 40 | * **Unified Trainer:** A high-level trainer abstraction capable of targeting 41 | different hardware (TPU/GPU) and frameworks. Includes customizable training 42 | and evaluation loops. 43 | * **End-to-End Support:** Covers aspects from data pipelines to training, 44 | evaluation, checkpointing, metrics logging (e.g., to BigQuery), and model 45 | export/serving considerations. 46 | 47 | ## Models Included 48 | 49 | This library aims to house implementations for a variety of recommender models, 50 | including: 51 | 52 | * **SASRec:** Self-Attention based Sequential Recommendation 53 | * **BERT4Rec:** Bidirectional Encoder Representations from Transformer for 54 | Sequential Recommendation. 55 | * **Mamba4Rec:** Efficient Sequential Recommendation with Selective State 56 | Space Models. 57 | * **HSTU:** Hierarchical Sequential Transduction Units for Generative 58 | Recommendations. 59 | * **DLRM v2:** Deep Learning Recommendation Model 60 | 61 | ## Roadmap / Future Work 62 | 63 | * Expand reference model implementations (Retrieval, Uplift, foundation user 64 | model). 65 | * Add support for optimized configurations and lower precision training 66 | (bfloat16, fp16). 67 | * Improve support for Cloud GPU training and inference 68 | * Enhance sharding and quantization support. 69 | * Improve integration with Keras (and Keras Recommenders) and potentially 70 | PyTorch/XLA. 71 | * Develop comprehensive model serving examples and integrations. 72 | * Refine data loading pipelines (e.g., Grain support). 73 | * Add more common layers, losses, and metrics. 74 | 75 | ## Responsible Use 76 | 77 | As with any machine learning model, potential risks exist. The performance and 78 | behavior depend heavily on the training data, which may contain biases reflected 79 | in the recommendations. Developers should carefully evaluate the model's 80 | fairness and potential limitations in their specific application context. 81 | 82 | ## License 83 | 84 | RecML is released under the Apache 2.0. Please see the `LICENSE` file for full 85 | details. 86 | -------------------------------------------------------------------------------- /docs/code-of-conduct.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of 9 | experience, education, socio-economic status, nationality, personal appearance, 10 | race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or reject 41 | comments, commits, code, wiki edits, issues, and other contributions that are 42 | not aligned to this Code of Conduct, or to ban temporarily or permanently any 43 | contributor for other behaviors that they deem inappropriate, threatening, 44 | offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when the Project 56 | Steward has a reasonable belief that an individual's behavior may have a 57 | negative impact on the project or its community. 58 | 59 | ## Conflict Resolution 60 | 61 | We do not believe that all conflict is bad; healthy debate and disagreement 62 | often yield positive results. However, it is never okay to be disrespectful or 63 | to engage in behavior that violates the project’s code of conduct. 64 | 65 | If you see someone violating the code of conduct, you are encouraged to address 66 | the behavior directly with those involved. Many issues can be resolved quickly 67 | and easily, and this gives people more control over the outcome of their 68 | dispute. If you are unable to resolve the matter for any reason, or if the 69 | behavior is threatening or harassing, report it. We are dedicated to providing 70 | an environment where participants feel welcome and safe. 71 | 72 | Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the 73 | Project Steward(s) for *[PROJECT NAME]*. It is the Project Steward’s duty to 74 | receive and address reported violations of the code of conduct. They will then 75 | work with a committee consisting of representatives from the Open Source 76 | Programs Office and the Google Open Source Strategy team. If for any reason you 77 | are uncomfortable reaching out to the Project Steward, please email 78 | opensource@google.com. 79 | 80 | We will investigate every complaint, but you may not receive a direct response. 81 | We will use our discretion in determining when and how to follow up on reported 82 | incidents, which may range from not taking action to permanent expulsion from 83 | the project and project-sponsored spaces. We will notify the accused of the 84 | report and provide them an opportunity to discuss it before any action is taken. 85 | The identity of the reporter will be omitted from the details of the report 86 | supplied to the accused. In potentially harmful situations, such as ongoing 87 | harassment or threats to anyone's safety, we may take action without notice. 88 | 89 | ## Attribution 90 | 91 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4, 92 | available at 93 | https://www.contributor-covenant.org/version/1/4/code-of-conduct/ -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We would love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows [Google's Open Source Community 24 | Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) 32 | for this purpose. -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | required_plugins = pytest-env 3 | env = 4 | KERAS_BACKEND=jax -------------------------------------------------------------------------------- /recml/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Public API for RecML.""" 15 | 16 | # pylint: disable=g-importing-member 17 | 18 | from recml.core import data 19 | from recml.core import metrics 20 | from recml.core import utils 21 | from recml.core.metrics.base_metrics import Metric 22 | from recml.core.training.core import Experiment 23 | from recml.core.training.core import run_experiment 24 | from recml.core.training.core import Trainer 25 | from recml.core.training.jax_trainer import JaxState 26 | from recml.core.training.jax_trainer import JaxTask 27 | from recml.core.training.jax_trainer import JaxTrainer 28 | from recml.core.training.jax_trainer import KerasState 29 | from recml.core.training.keras_trainer import KerasTask 30 | from recml.core.training.keras_trainer import KerasTrainer 31 | from recml.core.training.optax_factory import AdagradFactory 32 | from recml.core.training.optax_factory import AdamFactory 33 | from recml.core.training.optax_factory import OptimizerFactory 34 | from recml.core.training.partitioning import DataParallelPartitioner 35 | from recml.core.training.partitioning import ModelParallelPartitioner 36 | from recml.core.training.partitioning import NullPartitioner 37 | from recml.core.training.partitioning import Partitioner 38 | from recml.core.utils.types import Factory 39 | from recml.core.utils.types import FactoryProtocol 40 | from recml.core.utils.types import ObjectFactory 41 | -------------------------------------------------------------------------------- /recml/core/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Public API for RecML data.""" 15 | 16 | # pylint: disable=g-importing-member 17 | 18 | from recml.core.data.iterator import Iterator 19 | from recml.core.data.iterator import TFDatasetIterator 20 | from recml.core.data.preprocessing import PreprocessingMode 21 | from recml.core.data.tf_dataset_factory import DatasetShardingInfo 22 | from recml.core.data.tf_dataset_factory import TFDatasetFactory 23 | from recml.core.data.tf_dataset_factory import TFDSMetadata 24 | -------------------------------------------------------------------------------- /recml/core/data/iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data loading and preprocessing for feeding Jax models.""" 15 | 16 | from collections.abc import Callable 17 | import os 18 | from typing import Any 19 | 20 | import clu.data as clu_data 21 | from etils import epath 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | Iterator = clu_data.DatasetIterator 27 | 28 | 29 | class TFDatasetIterator(clu_data.DatasetIterator): 30 | """An iterator for TF Datasets that supports postprocessing.""" 31 | 32 | def __init__( 33 | self, 34 | dataset: tf.data.Dataset, 35 | postprocessor: Callable[..., Any] | None = None, 36 | checkpoint: bool = False, 37 | ): 38 | """Initializes the iterator. 39 | 40 | Args: 41 | dataset: The TF Dataset to iterate over. 42 | postprocessor: An optional postprocessor to apply to each batch. This is 43 | useful for sending embedded ID features to a separate accelerator. 44 | checkpoint: Whether to checkpoint the iterator state. 45 | """ 46 | self._dataset = dataset 47 | self._iterator = iter(dataset) 48 | self._postprocessor = postprocessor 49 | self._prefetched_batch = None 50 | self._element_spec = None 51 | self._checkpoint = None 52 | if checkpoint: 53 | self._checkpoint = tf.train.Checkpoint(ds=self._iterator) 54 | 55 | def __next__(self) -> clu_data.Element: 56 | """Returns the next batch.""" 57 | if self._prefetched_batch is not None: 58 | batch = self._prefetched_batch 59 | self._prefetched_batch = None 60 | else: 61 | batch = next(self._iterator) 62 | if self._postprocessor is not None: 63 | batch = self._postprocessor(batch) 64 | 65 | def _maybe_to_numpy( 66 | x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray, 67 | ) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor: 68 | if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)): 69 | return x 70 | if hasattr(x, "_numpy"): 71 | numpy = x._numpy() # pylint: disable=protected-access 72 | else: 73 | numpy = x.numpy() 74 | if isinstance(numpy, np.ndarray): 75 | # `numpy` shares the same underlying buffer as the `x` Tensor. 76 | # Tensors are expected to be immutable, so we disable writes. 77 | numpy.setflags(write=False) 78 | return numpy 79 | 80 | return tf.nest.map_structure(_maybe_to_numpy, batch) 81 | 82 | @property 83 | def element_spec(self) -> clu_data.ElementSpec: 84 | if self._element_spec is not None: 85 | return self._element_spec 86 | 87 | batch = next(self._iterator) 88 | if self._postprocessor is not None: 89 | batch = self._postprocessor(batch) 90 | 91 | self._prefetched_batch = batch 92 | 93 | def _to_element_spec( 94 | x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray, 95 | ) -> clu_data.ArraySpec: 96 | if isinstance(x, tf.SparseTensor): 97 | return clu_data.ArraySpec( 98 | dtype=x.dtype.as_numpy_dtype, 99 | shape=tuple(x.shape[0], *[None for _ in x.shape[1:]]), 100 | ) 101 | if isinstance(x, tf.RaggedTensor): 102 | return clu_data.ArraySpec( 103 | dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error 104 | shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error 105 | ) 106 | if isinstance(x, tf.Tensor): 107 | return clu_data.ArraySpec( 108 | dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list()) 109 | ) 110 | return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape)) 111 | 112 | element_spec = tf.nest.map_structure(_to_element_spec, batch) 113 | self._element_spec = element_spec 114 | return element_spec 115 | 116 | def reset(self): 117 | self._iterator = iter(self._dataset) 118 | if self._checkpoint is not None: 119 | self._checkpoint = tf.train.Checkpoint(ds=self._iterator) 120 | 121 | def save(self, filename: epath.Path): 122 | if self._checkpoint is not None: 123 | self._checkpoint.write(os.fspath(filename)) 124 | 125 | def restore(self, filename: epath.Path): 126 | if self._checkpoint is not None: 127 | self._checkpoint.read(os.fspath(filename)).assert_consumed() 128 | -------------------------------------------------------------------------------- /recml/core/data/preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Preprocessing utilities.""" 15 | 16 | import enum 17 | 18 | 19 | class PreprocessingMode(enum.StrEnum): 20 | """Mode for data preprocessing.""" 21 | 22 | TRAINING = "training" 23 | EVAL = "eval" 24 | SERVING = "serving" 25 | -------------------------------------------------------------------------------- /recml/core/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Public API for metrics.""" 15 | 16 | # pylint: disable=g-importing-member 17 | 18 | from recml.core.metrics.base_metrics import Metric 19 | from recml.core.metrics.base_metrics import scalar 20 | from recml.core.metrics.confusion_metrics import aucpr 21 | from recml.core.metrics.confusion_metrics import aucroc 22 | from recml.core.metrics.confusion_metrics import estimate_confusion_matrix 23 | from recml.core.metrics.confusion_metrics import f1_score 24 | from recml.core.metrics.confusion_metrics import fbeta_score 25 | from recml.core.metrics.confusion_metrics import precision 26 | from recml.core.metrics.confusion_metrics import precision_at_recall 27 | from recml.core.metrics.confusion_metrics import recall 28 | from recml.core.metrics.mean_metrics import accuracy 29 | from recml.core.metrics.mean_metrics import binary_accuracy 30 | from recml.core.metrics.mean_metrics import mean_squared_error 31 | from recml.core.metrics.mean_metrics import top_k_accuracy 32 | from recml.core.metrics.reduction_metrics import mean 33 | from recml.core.metrics.reduction_metrics import sum # pylint: disable=redefined-builtin 34 | from recml.core.metrics.tools import MetricAccumulator 35 | -------------------------------------------------------------------------------- /recml/core/metrics/base_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Functional metrics inspired by the CLU interface and Keras semantics.""" 15 | 16 | import abc 17 | from collections.abc import Mapping, Sequence 18 | import math 19 | from typing import Self, dataclass_transform 20 | 21 | import clu.metrics as clu_metrics 22 | from flax import struct 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | Scalar = float | Sequence[float] | Mapping[str, float] | jax.Array | np.ndarray 28 | 29 | # TODO(aahil): Look into why pytype doesn't respect the Self type as a generic 30 | # type. We should not be violating LSP. 31 | 32 | # TODO(b/387463777): Consider removing the dependency on CLU metrics longer term 33 | # since it's just an interface. 34 | @dataclass_transform(field_specifiers=(struct.field,)) # pytype: disable=not-supported-yet 35 | class Metric(abc.ABC, clu_metrics.Metric, struct.PyTreeNode): 36 | """PyTree node representing the state of a metric. 37 | 38 | Note: This class follows the same interface as CLU metrics and can be used 39 | interchangeably. 40 | 41 | There are a few suble differences between subclasses and standard CLU metrics: 42 | 1. Inheriting from this automatically makes the metric a PyTree node. 43 | 2. `mask` has been replaced by `weights` to be consistent with Keras. 44 | 3. Subclasses do not implement methods apart from the ones listed below. 45 | 4. The `localize` method is added to specify how to localize the metric from 46 | device to host. 47 | """ 48 | 49 | @classmethod 50 | def from_model_output(cls, *args, **kwargs) -> Self: 51 | """Creates a metric from observations. 52 | 53 | Args: 54 | *args: Positional arguments to pass to the metric. 55 | **kwargs: Keyword arguments to pass to the metric. 56 | 57 | Returns: 58 | A new instance of the metric. 59 | 60 | NOTE: This metric is always called on the device and should therefore use 61 | only jax ops. 62 | """ 63 | raise NotImplementedError() 64 | 65 | @abc.abstractmethod 66 | def merge(self, other: Self) -> Self: # pytype: disable=signature-mismatch 67 | """Merges two metrics. 68 | 69 | Args: 70 | other: Another metric of the same type to merge with. 71 | 72 | Returns: 73 | A new instance of the same class that is the merge of the two. 74 | 75 | NOTE: This method is almost always called on the host which means that it 76 | should *never* call jax ops - this will implicitly move the state of the 77 | metric back to the device for computation. It is safest to rely on dunder 78 | methods via `+` / `-`. 79 | """ 80 | 81 | @abc.abstractmethod 82 | def compute(self) -> Scalar: # pytype: disable=signature-mismatch 83 | """Computes the value of the metric. 84 | 85 | NOTE: This method is almost always called on the host which means that it 86 | should *never* call jax ops - this will implicitly move the state of the 87 | metric back to the device for computation. Use numpy ops instead. 88 | """ 89 | 90 | def localize(self) -> Self: 91 | """Localizes the metric from device to host. 92 | 93 | Returns: 94 | A new instance of the same class that is localized, i.e. jax arrays on the 95 | metric are replaced by numpy arrays. 96 | """ 97 | 98 | def _localize(x): 99 | x = jax.device_get(x) 100 | if isinstance(x, jax.Array) and not isinstance(x, jax.core.Tracer): 101 | return x.addressable_data(0) 102 | return x 103 | 104 | return jax.tree.map(_localize, self) 105 | 106 | 107 | class ScalarMetric(Metric): 108 | """A metric for reporting scalar values without aggregation.""" 109 | 110 | value: jax.Array 111 | 112 | @classmethod 113 | def from_model_output(cls, value: jax.Array | float) -> Self: 114 | if hasattr(value, "shape") and math.prod(value.shape) != 1: 115 | raise ValueError( 116 | f"Scalar metric values must be scalars. Got shape: {value.shape}" 117 | " instead." 118 | ) 119 | return cls(value=jnp.squeeze(jnp.asarray(value, dtype=jnp.float32))) 120 | 121 | def merge(self, other: Self) -> Self: 122 | return other 123 | 124 | def compute(self) -> Scalar: 125 | return self.value 126 | 127 | 128 | def scalar(value: float | jax.Array) -> ScalarMetric: 129 | """Creates a scalar metric from a scalar value at a specific step. 130 | 131 | This is useful for reporting batch metrics during training. When merged with 132 | other instances, effectively the last value observed is reported. 133 | 134 | Note that using this metric during evaluation will result in multiple values 135 | being reported for the same step, which is generally undesirable. 136 | 137 | Example usage: 138 | 139 | ``` 140 | state = ... 141 | metrics = { 142 | "average_loss": mean(loss), 143 | "per_batch_loss": scalar(loss), 144 | "learning_rate": scalar(learning_rate), 145 | } 146 | ``` 147 | 148 | Args: 149 | value: The scalar value to report. 150 | 151 | Returns: 152 | A scalar metric reporting the value. 153 | """ 154 | return ScalarMetric.from_model_output(value) 155 | -------------------------------------------------------------------------------- /recml/core/metrics/base_metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for base metrics.""" 15 | 16 | from absl.testing import absltest 17 | import jax.numpy as jnp 18 | from recml.core.metrics import base_metrics 19 | 20 | 21 | class BaseMetricsTest(absltest.TestCase): 22 | 23 | def test_scalar(self): 24 | m1 = base_metrics.scalar(1.0) 25 | m2 = base_metrics.scalar(2.0) 26 | m3 = m1.merge(m2) 27 | self.assertEqual(m3.compute(), 2.0) 28 | 29 | self.assertRaises(ValueError, base_metrics.scalar, jnp.array([1.0, 2.0])) 30 | 31 | 32 | if __name__ == "__main__": 33 | absltest.main() 34 | -------------------------------------------------------------------------------- /recml/core/metrics/mean_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Mean metrics.""" 15 | 16 | import jax 17 | import keras 18 | from recml.core.metrics import reduction_metrics 19 | 20 | 21 | def accuracy( 22 | y_true: jax.Array, 23 | y_pred: jax.Array, 24 | weights: jax.Array | None = None, 25 | **_, 26 | ) -> reduction_metrics.Mean: 27 | """Computes accuracy from observations. 28 | 29 | Args: 30 | y_true: The true labels of shape [D1, ..., D_N] 31 | y_pred: The predicted logits of shape [D1, ..., D_N, num_classes]. 32 | weights: Optional weights of shape broadcastable to [D1, ... D_N]. 33 | **_: Unused kwargs. 34 | 35 | Returns: 36 | A metric accumulation of the accuracy. 37 | """ 38 | assert keras.backend.backend() == 'jax' 39 | acc = keras.metrics.sparse_categorical_accuracy(y_true, y_pred) 40 | return reduction_metrics.mean(acc, weights) 41 | 42 | 43 | def top_k_accuracy( 44 | y_true: jax.Array, 45 | y_pred: jax.Array, 46 | weights: jax.Array | None = None, 47 | *, 48 | k: int, 49 | **_, 50 | ) -> reduction_metrics.Mean: 51 | """Computes top-k accuracy from observations. 52 | 53 | Args: 54 | y_true: The true labels of shape [D1, ..., D_N] 55 | y_pred: The predicted logits of shape [D1, ..., D_N, num_classes]. 56 | weights: Optional weights of shape broadcastable to [D1, ... D_N]. 57 | k: The number of top classes to consider. Must be less than num_classes. 58 | **_: Unused kwargs. 59 | 60 | Returns: 61 | A metric accumulation of the top-k accuracy. 62 | """ 63 | assert keras.backend.backend() == 'jax' 64 | acc = keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=k) 65 | return reduction_metrics.mean(acc, weights) 66 | 67 | 68 | def binary_accuracy( 69 | y_true: jax.Array, 70 | y_pred: jax.Array, 71 | weights: jax.Array | None = None, 72 | *, 73 | threshold: float = 0.5, 74 | **_, 75 | ) -> reduction_metrics.Mean: 76 | """Computes binary accuracy from observations. 77 | 78 | Args: 79 | y_true: The true labels of shape [D1, ..., D_N] 80 | y_pred: The binary predictions of shape [D1, ..., D_N]. 81 | weights: Optional weights of shape broadcastable to [D1, ... D_N]. 82 | threshold: The threshold to use for binary classification. 83 | **_: Unused kwargs. 84 | 85 | Returns: 86 | A metric accumulation of the binary accuracy. 87 | """ 88 | assert keras.backend.backend() == 'jax' 89 | bin_acc = keras.metrics.binary_accuracy(y_true, y_pred, threshold=threshold) 90 | return reduction_metrics.mean(bin_acc, weights) 91 | 92 | 93 | def mean_squared_error( 94 | y_true: jax.Array, 95 | y_pred: jax.Array, 96 | weights: jax.Array | None = None, 97 | **_, 98 | ) -> reduction_metrics.Mean: 99 | """Computes mean squared error from observations. 100 | 101 | Args: 102 | y_true: The true labels of shape [D1, ..., D_N]. 103 | y_pred: The predictions of shape [D1, ..., D_N]. 104 | weights: Optional weights of shape broadcastable to [D1, ... D_N]. 105 | **_: Unused kwargs. 106 | 107 | Returns: 108 | A metric accumulation of the mean squared error. 109 | """ 110 | assert keras.backend.backend() == 'jax' 111 | mse = keras.metrics.mean_squared_error(y_true, y_pred) 112 | return reduction_metrics.mean(mse, weights) 113 | -------------------------------------------------------------------------------- /recml/core/metrics/mean_metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for mean metrics.""" 15 | 16 | from collections.abc import Sequence 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import numpy as np 21 | from recml.core.metrics import mean_metrics 22 | 23 | 24 | class MeanMetricsTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters( 27 | { 28 | 'testcase_name': 'unweighted', 29 | 'y_true': np.array([1, 0, 1, 0]), 30 | 'y_pred': np.array([[0.2, 0.8], [0.8, 0.2], [0.1, 0.9], [0.6, 0.4]]), 31 | 'weights': None, 32 | 'expected_output': 1.0, 33 | }, 34 | { 35 | 'testcase_name': 'weighted', 36 | 'y_true': np.array([[1, 0, 1, 0]]), 37 | 'y_pred': np.array( 38 | [[[0.2, 0.8], [0.1, 0.9], [0.3, 0.7], [0.4, 0.6]]] 39 | ), 40 | 'weights': np.array([[1, 2, 3, 4]]), 41 | 'expected_output': 0.4, 42 | }, 43 | ) 44 | def test_accuracy( 45 | self, 46 | y_true: np.ndarray, 47 | y_pred: np.ndarray, 48 | weights: np.ndarray | None, 49 | expected_output: np.ndarray, 50 | ): 51 | accuracy = mean_metrics.accuracy(y_true, y_pred, weights) 52 | np.testing.assert_allclose(expected_output, accuracy.compute()) 53 | np.testing.assert_allclose(expected_output, accuracy.localize().compute()) 54 | 55 | @parameterized.named_parameters( 56 | { 57 | 'testcase_name': 'unweighted', 58 | 'y_true': np.array([[1], [4], [2], [3], [3], [1], [0], [5]]), 59 | 'y_pred': np.array([ 60 | [[0.1, 0.7, 0.5, 0.3, 0.2, 0.0]], # [1, 2, 3, 4, 0, 5] 61 | [[0.2, 0.8, 0.0, 0.1, 0.4, 0.3]], # [1, 4, 5, 0, 3, 2] 62 | [[0.1, 0.2, 0.4, 0.8, 0.0, 0.3]], # [3, 2, 5, 1, 0, 4] 63 | [[1.0, 0.9, 0.1, 0.3, 0.2, 0.0]], # [0, 1, 3, 4, 2, 5] 64 | [[0.1, 0.7, 0.5, 0.3, 0.2, 0.0]], # [1, 2, 3, 4, 0, 5] 65 | [[0.2, 0.8, 0.0, 0.1, 0.4, 0.3]], # [1, 4, 5, 0, 3, 2] 66 | [[0.1, 0.2, 0.4, 0.8, 0.0, 0.3]], # [3, 2, 5, 1, 0, 4] 67 | [[1.0, 0.9, 0.1, 0.3, 0.2, 0.0]], # [0, 1, 3, 4, 2, 5] 68 | ]), 69 | 'weights': None, 70 | 'ks': [1, 2, 3, 4, 5, 6], 71 | 'expected_outputs': [0.25, 0.5, 0.75, 0.75, 0.875, 1.0], 72 | }, 73 | { 74 | 'testcase_name': 'weighted', 75 | 'y_true': np.array([0, 1, 1, 0]), 76 | 'y_pred': np.array([[0.1, 0.7], [0.2, 0.4], [0.1, 0.3], [0.2, 0.1]]), 77 | 'weights': np.array([0.2, 0.6, 0.1, 0.1]), 78 | 'ks': [1, 2], 79 | 'expected_outputs': np.array([0.8, 1.0]), 80 | }, 81 | ) 82 | def test_top_k_accuracies( 83 | self, 84 | y_true: np.ndarray, 85 | y_pred: np.ndarray, 86 | weights: np.ndarray | None, 87 | ks: Sequence[int], 88 | expected_outputs: Sequence[float], 89 | ): 90 | for k, expected_output in zip(ks, expected_outputs): 91 | accuracy = mean_metrics.top_k_accuracy(y_true, y_pred, weights, k=k) 92 | np.testing.assert_allclose(expected_output, accuracy.compute()) 93 | np.testing.assert_allclose(expected_output, accuracy.localize().compute()) 94 | 95 | @parameterized.named_parameters( 96 | { 97 | 'testcase_name': 'unweighted', 98 | 'y_true': np.array([1, 0, 1, 0]), 99 | 'y_pred': np.array([0.4, 0.6, 0.8, 0.2]), 100 | 'weights': None, 101 | 'threshold': 0.5, 102 | 'expected_output': 0.5, 103 | }, 104 | { 105 | 'testcase_name': 'weighted', 106 | 'y_true': np.array([[1, 0, 1, 0]]), 107 | 'y_pred': np.array([[0.8, 0.6, 0.7, 0.6]]), 108 | 'weights': np.array([[1, 2, 3, 4]]), 109 | 'threshold': 0.75, 110 | 'expected_output': 0.7, 111 | }, 112 | ) 113 | def test_binary_accuracy( 114 | self, 115 | y_true: np.ndarray, 116 | y_pred: np.ndarray, 117 | weights: np.ndarray | None, 118 | threshold: float, 119 | expected_output: np.ndarray, 120 | ): 121 | accuracy = mean_metrics.binary_accuracy( 122 | y_true, y_pred, weights, threshold=threshold 123 | ) 124 | np.testing.assert_allclose(expected_output, accuracy.compute()) 125 | np.testing.assert_allclose(expected_output, accuracy.localize().compute()) 126 | 127 | @parameterized.named_parameters( 128 | { 129 | 'testcase_name': 'unweighted', 130 | 'y_true': np.array([0.3, 0.5, 0.7, 0.9]), 131 | 'y_pred': np.array([0.4, 0.6, 0.8, 0.2]), 132 | 'weights': None, 133 | 'expected_output': 0.13, 134 | }, 135 | { 136 | 'testcase_name': 'weighted', 137 | 'y_true': np.array([[0.3, 0.6, 0.2, 0.6]]), 138 | 'y_pred': np.array([[0.8, 0.6, 0.7, 0.6]]), 139 | 'weights': np.array([0.5]), 140 | 'expected_output': 0.125, 141 | }, 142 | ) 143 | def test_mean_squared_error( 144 | self, 145 | y_true: np.ndarray, 146 | y_pred: np.ndarray, 147 | weights: np.ndarray | None, 148 | expected_output: np.ndarray, 149 | ): 150 | mse = mean_metrics.mean_squared_error(y_true, y_pred, weights) 151 | np.testing.assert_allclose(expected_output, mse.compute(), rtol=1e-3) 152 | np.testing.assert_allclose( 153 | expected_output, mse.localize().compute(), rtol=1e-3 154 | ) 155 | 156 | 157 | if __name__ == '__main__': 158 | absltest.main() 159 | -------------------------------------------------------------------------------- /recml/core/metrics/reduction_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Reduction metrics.""" 15 | 16 | from __future__ import annotations 17 | 18 | from collections.abc import Callable 19 | import math 20 | from typing import Any, Self 21 | 22 | import jax 23 | import jax.numpy as jnp 24 | from recml.core.metrics import base_metrics 25 | 26 | 27 | class ReductionMetric(base_metrics.Metric): 28 | """A base class for reduction metrics.""" 29 | 30 | @classmethod 31 | def from_model_output( 32 | cls, values: jax.Array, weights: jax.Array | None = None, **_ 33 | ) -> Self: 34 | raise NotImplementedError() 35 | 36 | @classmethod 37 | def from_fun(cls, fun: Callable[..., Any], **kwargs) -> type[Self]: 38 | """Returns a reduction metric class that is computed from a function.""" 39 | base_cls = cls 40 | bound_kwargs = kwargs 41 | 42 | class _FromFun(cls): 43 | """A reduction metric that is computed from a function.""" 44 | 45 | @classmethod 46 | def from_model_output(cls, *args, **kwargs) -> ReductionMetric: 47 | if "weights" in kwargs: 48 | weights = kwargs.pop("weights") 49 | else: 50 | weights = None 51 | 52 | values = fun(*args, **bound_kwargs, **kwargs) 53 | 54 | return base_cls.from_model_output(values, weights) 55 | 56 | return _FromFun 57 | 58 | 59 | class Sum(ReductionMetric): 60 | """Computes the sum of observations over multiple batches.""" 61 | 62 | total: jax.Array 63 | 64 | @classmethod 65 | def from_model_output( 66 | cls, values: jax.Array, weights: jax.Array | None = None, **_ 67 | ) -> Self: 68 | values = jnp.asarray(values, dtype=jnp.float32) 69 | if weights is not None: 70 | weights = jnp.asarray(weights, dtype=jnp.float32) 71 | values, weights = _maybe_reshape_or_broadcast(values, weights) 72 | total = jnp.sum(values * weights) 73 | else: 74 | total = jnp.sum(values) 75 | 76 | return cls(total=total) 77 | 78 | def merge(self, other: Self) -> Self: # pytype: disable=signature-mismatch 79 | return type(self)(total=self.total + other.total) 80 | 81 | def compute(self) -> base_metrics.Scalar: 82 | return self.total 83 | 84 | 85 | class Mean(ReductionMetric): 86 | """Computes the mean of observations over multiple batches. 87 | 88 | This is done by tracking a total and a count over multiple observations and 89 | aggregating their mean over multiple batches when `compute` is called. 90 | """ 91 | 92 | total: jax.Array 93 | count: jax.Array 94 | 95 | @classmethod 96 | def from_model_output( 97 | cls, values: jax.Array, weights: jax.Array | None = None, **_ 98 | ) -> Self: 99 | values = jnp.asarray(values, dtype=jnp.float32) 100 | if weights is not None: 101 | weights = jnp.asarray(weights, dtype=jnp.float32) 102 | values, weights = _maybe_reshape_or_broadcast(values, weights) 103 | total = jnp.sum(values * weights) 104 | count = jnp.sum(weights) 105 | elif values.ndim >= 1: 106 | total = jnp.sum(values) 107 | count = jnp.asarray(math.prod(values.shape), jnp.float32) 108 | else: 109 | total = values 110 | count = jnp.ones((), dtype=jnp.float32) 111 | 112 | return cls(total=total, count=count) 113 | 114 | def merge(self, other: Self) -> Self: # pytype: disable=signature-mismatch 115 | return type(self)( 116 | total=self.total + other.total, 117 | count=self.count + other.count, 118 | ) 119 | 120 | def compute(self) -> base_metrics.Scalar: 121 | return self.total / self.count 122 | 123 | 124 | def mean(values: jax.Array, weights: jax.Array | None = None, **_) -> Mean: 125 | """Computes a mean metric from values and optional weights. 126 | 127 | The resulting metric instance is a reduction metric that will aggregate the 128 | mean of the values over multiple batches. 129 | 130 | The total and counts are computed as follows: 131 | weights = broadcast_to(weights, values.shape) 132 | total = sum(values * weights) 133 | count = sum(weights) 134 | 135 | Where the output of an aggregated metric is the total / count. 136 | 137 | Example usage: 138 | 139 | ``` 140 | metrics = { 141 | # Reports the mean accuracy over multiple batches. 142 | 'accuracy': mean(y_true == y_pred), 143 | # Reports the mean loss over multiple batches. 144 | 'loss': mean(loss), 145 | } 146 | ``` 147 | 148 | Args: 149 | values: The values to compute the mean over of shape [D1, ..., DN]. 150 | weights: Optional weights to apply to the values. If provided, the shape of 151 | the weights must be broadcastable to the shape of the values. If not 152 | provided, all values will effectively have a weight of 1.0. 153 | **_: Unused keyword arguments. 154 | 155 | Returns: 156 | A mean metric accumulation. 157 | """ 158 | return Mean.from_model_output(values, weights) 159 | 160 | 161 | def sum(values: jax.Array, weights: jax.Array | None = None, **_) -> Sum: # pylint: disable=redefined-builtin 162 | """Computes a sum metric from values and optional weights. 163 | 164 | The sum is computed as follows: 165 | weights = broadcast_to(weights, values.shape) 166 | total = sum(values * weights) 167 | 168 | Where total is the output of an aggregated metric. 169 | 170 | Example usage: 171 | 172 | ``` 173 | metrics = { 174 | # Reports the total number of hits over multiple batches. 175 | 'number_of_hits': sum(y_true == y_pred), 176 | } 177 | ``` 178 | 179 | Args: 180 | values: The values to compute the sum over of shape [D1, ..., DN]. 181 | weights: Optional weights to apply to the values. If provided, the shape of 182 | the weights must be broadcastable to the shape of the values. If not 183 | provided, all values will effectively have a weight of 1.0. 184 | **_: Unused keyword arguments. 185 | 186 | Returns: 187 | A sum metric accumulation. 188 | """ 189 | return Sum.from_model_output(values, weights) 190 | 191 | 192 | def _maybe_reshape_or_broadcast( 193 | values: jax.Array, weights: jax.Array 194 | ) -> tuple[jax.Array, jax.Array]: 195 | """Reshapes or broadcasts arrays to have the same shape or throws an error.""" 196 | # Note that we broadcast the weights explicitly so that the sum of the weights 197 | # is not performed on the non-broadcasted array. 198 | if values.shape == weights.shape: 199 | return values, weights 200 | elif values.ndim == weights.ndim and all( 201 | v_d == w_d or w_d == 1 for v_d, w_d in zip(values.shape, weights.shape) 202 | ): 203 | return values, jnp.broadcast_to(weights, values.shape) 204 | elif values.ndim == weights.ndim: 205 | raise ValueError( 206 | f"Got incompatible shapes {values.shape} and {weights.shape}." 207 | ) 208 | elif ( 209 | values.ndim > weights.ndim 210 | and values.shape[: weights.ndim] == weights.shape 211 | ): 212 | weights = jax.lax.expand_dims( 213 | weights, list(range(weights.ndim, values.ndim)) 214 | ) 215 | return values, jnp.broadcast_to(weights, values.shape) 216 | elif ( 217 | weights.ndim > values.ndim 218 | and weights.shape[: values.ndim] == values.shape 219 | ): 220 | values = jax.lax.expand_dims(values, list(range(values.ndim, weights.ndim))) 221 | return values, weights 222 | 223 | raise ValueError( 224 | "The arrays must have the same shape or the shape of one array must be" 225 | f" a broadcastable to the other. Got shapes: {values.shape} and" 226 | f" {weights.shape}." 227 | ) 228 | -------------------------------------------------------------------------------- /recml/core/metrics/reduction_metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for reduction metrics.""" 15 | 16 | from collections.abc import Callable, Mapping, Sequence 17 | from typing import Any 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import numpy as np 22 | from recml.core.metrics import reduction_metrics 23 | 24 | 25 | def mse(y_true, y_pred): 26 | return (y_true - y_pred) ** 2 27 | 28 | 29 | class ReductionMetricsTest(parameterized.TestCase): 30 | 31 | @parameterized.named_parameters( 32 | { 33 | 'testcase_name': 'scalar_weighted_sum', 34 | 'metric': reduction_metrics.sum, 35 | 'args': [0.5, 0.5], 36 | 'kwargs': {}, 37 | 'expected_output': 0.25, 38 | }, 39 | { 40 | 'testcase_name': 'unweighted_sum', 41 | 'metric': reduction_metrics.sum, 42 | 'args': [np.array([1, 3, 5, 7])], 43 | 'kwargs': {}, 44 | 'expected_output': 16.0, 45 | }, 46 | { 47 | 'testcase_name': 'weighted_sum', 48 | 'metric': reduction_metrics.sum, 49 | 'args': [np.array([1, 3, 5, 7]), np.array([1, 1, 0, 0])], 50 | 'kwargs': {}, 51 | 'expected_output': 4.0, 52 | }, 53 | { 54 | 'testcase_name': 'weighted_sum_2d', 55 | 'metric': reduction_metrics.sum, 56 | 'args': [np.array([[1, 3], [5, 7]])], 57 | 'kwargs': {'weights': np.array([[1, 1], [1, 0]])}, 58 | 'expected_output': 9.0, 59 | }, 60 | { 61 | 'testcase_name': 'weighted_sum_2d_broadcast', 62 | 'metric': reduction_metrics.sum, 63 | 'args': [np.array([[1, 3], [5, 7]]), np.array([[1, 0]])], 64 | 'kwargs': {}, 65 | 'expected_output': 6.0, 66 | }, 67 | { 68 | 'testcase_name': 'weighted_sum_3d_broadcast', 69 | 'metric': reduction_metrics.sum, 70 | 'args': [ 71 | np.array([ 72 | [[0.3, 0.7, 0.4, 0.6], [0.5, 0.75, 0.25, 1.5]], 73 | [[0.6, 0.3, 0.1, 1.0], [0.3, 0.7, 0.75, 0.25]], 74 | ]) 75 | ], 76 | 'kwargs': {'weights': np.array([[1, 1], [1, 0]])}, 77 | 'expected_output': 7.0, 78 | }, 79 | { 80 | 'testcase_name': 'unweighted_sum_from_fun', 81 | 'metric': reduction_metrics.Sum.from_fun(mse).from_model_output, 82 | 'args': [ 83 | np.array([ 84 | [0, 1, 0, 1, 0], 85 | [0, 0, 1, 1, 1], 86 | [1, 1, 1, 1, 0], 87 | [0, 0, 0, 0, 1], 88 | ]), 89 | np.array([ 90 | [0, 0, 1, 1, 0], 91 | [1, 1, 1, 1, 1], 92 | [0, 1, 0, 1, 0], 93 | [1, 1, 1, 1, 1], 94 | ]), 95 | ], 96 | 'kwargs': {}, 97 | 'expected_output': 10.0, 98 | }, 99 | { 100 | 'testcase_name': 'scalar_weighted_mean', 101 | 'metric': reduction_metrics.mean, 102 | 'args': [0.5, 0.5], 103 | 'kwargs': {}, 104 | 'expected_output': 0.5, 105 | }, 106 | { 107 | 'testcase_name': 'unweighted_mean', 108 | 'metric': reduction_metrics.mean, 109 | 'args': [np.array([1, 3, 5, 7])], 110 | 'kwargs': {}, 111 | 'expected_output': 4.0, 112 | }, 113 | { 114 | 'testcase_name': 'weighted_mean', 115 | 'metric': reduction_metrics.mean, 116 | 'args': [np.array([1, 3, 5, 7]), np.array([1, 1, 0, 0])], 117 | 'kwargs': {}, 118 | 'expected_output': 2.0, 119 | }, 120 | { 121 | 'testcase_name': 'weighted_mean_neg_weights', 122 | 'metric': reduction_metrics.mean, 123 | 'args': [np.array([1, 3, 5, 7]), np.array([-1, -1, 0, 0])], 124 | 'kwargs': {}, 125 | 'expected_output': 2.0, 126 | }, 127 | { 128 | 'testcase_name': 'weighted_mean_2d', 129 | 'metric': reduction_metrics.mean, 130 | 'args': [np.array([[1, 3], [5, 7]])], 131 | 'kwargs': {'weights': np.array([[1, 1], [1, 0]])}, 132 | 'expected_output': 3.0, 133 | }, 134 | { 135 | 'testcase_name': 'weighted_mean_2d_broadcast', 136 | 'metric': reduction_metrics.mean, 137 | 'args': [np.array([[1, 3], [5, 7]]), np.array([[1, 0]])], 138 | 'kwargs': {}, 139 | 'expected_output': 3.0, 140 | }, 141 | { 142 | 'testcase_name': 'weighted_mean_3d_broadcast', 143 | 'metric': reduction_metrics.mean, 144 | 'args': [ 145 | np.array([ 146 | [[0.3, 0.7, 0.4, 0.6], [0.5, 0.75, 0.25, 1.5]], 147 | [[0.6, 0.3, 0.1, 1.0], [0.3, 0.7, 0.75, 0.25]], 148 | ]) 149 | ], 150 | 'kwargs': {'weights': np.array([[1, 1], [1, 0]])}, 151 | 'expected_output': 7 / 12, 152 | }, 153 | { 154 | 'testcase_name': 'unweighted_mean_from_fun', 155 | 'metric': reduction_metrics.Mean.from_fun(mse).from_model_output, 156 | 'args': [ 157 | np.array([ 158 | [0, 1, 0, 1, 0], 159 | [0, 0, 1, 1, 1], 160 | [1, 1, 1, 1, 0], 161 | [0, 0, 0, 0, 1], 162 | ]), 163 | np.array([ 164 | [0, 0, 1, 1, 0], 165 | [1, 1, 1, 1, 1], 166 | [0, 1, 0, 1, 0], 167 | [1, 1, 1, 1, 1], 168 | ]), 169 | ], 170 | 'kwargs': {}, 171 | 'expected_output': 0.5, 172 | }, 173 | { 174 | 'testcase_name': 'weighted_mean_from_fun', 175 | 'metric': reduction_metrics.Mean.from_fun(mse).from_model_output, 176 | 'args': [ 177 | np.array([ 178 | [0, 1, 0, 1, 0], 179 | [0, 0, 1, 1, 1], 180 | [1, 1, 1, 1, 0], 181 | [0, 0, 0, 0, 1], 182 | ]), 183 | np.array([ 184 | [0, 0, 1, 1, 0], 185 | [1, 1, 1, 1, 1], 186 | [0, 1, 0, 1, 0], 187 | [1, 1, 1, 1, 1], 188 | ]), 189 | ], 190 | 'kwargs': {'weights': np.array([1.0, 1.5, 2.0, 2.5])}, 191 | 'expected_output': 0.542857, 192 | }, 193 | ) 194 | def test_reduction_metric( 195 | self, 196 | metric: Callable[..., reduction_metrics.ReductionMetric], 197 | args: Sequence[Any], 198 | kwargs: Mapping[str, Any], 199 | expected_output: float | np.ndarray, 200 | ): 201 | instance = metric(*args, **kwargs) 202 | np.testing.assert_allclose(expected_output, instance.compute(), 1e-3) 203 | np.testing.assert_allclose( 204 | expected_output, instance.localize().compute(), 1e-3 205 | ) 206 | 207 | 208 | if __name__ == '__main__': 209 | absltest.main() 210 | -------------------------------------------------------------------------------- /recml/core/metrics/tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tools for RecML metrics.""" 15 | 16 | from collections.abc import Mapping 17 | import concurrent.futures 18 | import functools 19 | import os 20 | import re 21 | from typing import Any 22 | 23 | from absl import flags 24 | from clu import metric_writers 25 | import clu.metrics as clu_metrics 26 | import jax 27 | from recml.core.metrics import base_metrics 28 | 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class AsyncMultiWriter(metric_writers.AsyncMultiWriter): 34 | """A multi writer that logs to a summary writer and a logging writer.""" 35 | 36 | def __init__(self, *, log_dir: str, name: str): 37 | summary_writer = metric_writers.SummaryWriter( 38 | os.fspath(os.path.join(log_dir, name)) 39 | ) 40 | writers = [summary_writer] 41 | 42 | super().__init__(writers) 43 | self._summary_writer = summary_writer 44 | 45 | @property 46 | def summary_writer(self) -> metric_writers.SummaryWriter: 47 | return self._summary_writer 48 | 49 | 50 | class MetricAccumulator: 51 | """A utility for asynchronously accumulating metrics.""" 52 | 53 | def __init__(self, writer: AsyncMultiWriter, max_workers: int = 1): 54 | if not isinstance(writer, AsyncMultiWriter): 55 | raise ValueError( 56 | "`summary_writer` must be an instance of AsyncMultiWriter, got" 57 | f" {type(writer)}." 58 | ) 59 | 60 | self._writer = writer 61 | self._metrics: list[Mapping[str, clu_metrics.Metric]] = [] 62 | self._scalar_log_pool = concurrent.futures.ThreadPoolExecutor( 63 | max_workers=max_workers 64 | ) 65 | self._scalar_log_futures: list[concurrent.futures.Future[None]] = [] 66 | 67 | def accumulate( 68 | self, metrics_accum: Mapping[str, clu_metrics.Metric], step: int 69 | ): 70 | """Asynchronously accumulates a set of metrics and logs scalars.""" 71 | self._metrics.append(metrics_accum) 72 | 73 | scalar_metrics_accum = { 74 | k: v 75 | for k, v in metrics_accum.items() 76 | if isinstance(v, base_metrics.ScalarMetric) 77 | } 78 | 79 | self._scalar_log_futures.append( 80 | self._scalar_log_pool.submit( 81 | _localize_and_log_scalars, 82 | # We only want to log per-step scalars via the summary writer. 83 | # Logging per-step scalars via other writers can be expensive. 84 | self._writer.summary_writer, 85 | step, 86 | scalar_metrics_accum, 87 | ) 88 | ) 89 | 90 | def compute_and_log_scalars( 91 | self, step: int 92 | ) -> Mapping[str, base_metrics.Scalar]: 93 | """Computes the scalars from the accumulated metrics and logs them.""" 94 | 95 | if not self._metrics: 96 | return {} 97 | 98 | for future in self._scalar_log_futures: 99 | future.result() 100 | 101 | self._scalar_log_futures.clear() 102 | 103 | metrics = functools.reduce( 104 | merge_metrics, [jax.tree.map(_localize, ms) for ms in self._metrics] 105 | ) 106 | self._metrics.clear() 107 | scalars = compute_metrics(metrics) 108 | 109 | # Log only non-reported scalars but return all for tracking in checkpoints. 110 | non_reported_scalars = { 111 | k: v 112 | for k, v in scalars.items() 113 | if not isinstance(metrics[k], base_metrics.ScalarMetric) 114 | } 115 | self._writer.write_scalars(step, non_reported_scalars) 116 | self._writer.flush() 117 | 118 | return scalars 119 | 120 | 121 | def compute_metrics( 122 | metrics: Mapping[str, clu_metrics.Metric | base_metrics.Metric], 123 | ) -> Mapping[str, base_metrics.Scalar]: 124 | """Collects the merged metrics and returns the computed scalars.""" 125 | return {k: m.compute() for k, m in metrics.items()} 126 | 127 | 128 | def merge_metrics( 129 | a: Mapping[str, clu_metrics.Metric | base_metrics.Metric], 130 | b: Mapping[str, clu_metrics.Metric | base_metrics.Metric], 131 | ) -> Mapping[str, clu_metrics.Metric]: 132 | """Merges two mappings of metrics.""" 133 | merged_metrics = {} 134 | for k in [*a.keys(), *b.keys()]: 135 | if k in a and k in b: 136 | merged_metrics[k] = a[k].merge(b[k]) 137 | elif k in a: 138 | merged_metrics[k] = a[k] 139 | elif k in b: 140 | merged_metrics[k] = b[k] 141 | return merged_metrics 142 | 143 | 144 | def _localize(x: Any) -> Any: 145 | """Returns the localized data for an object.""" 146 | x = jax.device_get(x) 147 | if isinstance(x, jax.Array) and not isinstance(x, jax.core.Tracer): 148 | return x.addressable_data(0) 149 | return x 150 | 151 | 152 | def _localize_and_log_scalars( 153 | summary_writer: metric_writers.SummaryWriter, 154 | step: int, 155 | scalar_metrics: Mapping[str, base_metrics.ScalarMetric], 156 | ) -> None: 157 | """Localizes the metrics from device to host and logs scalars.""" 158 | scalar_metrics = jax.tree.map(_localize, scalar_metrics) 159 | summary_writer.write_scalars(step, compute_metrics(scalar_metrics)) 160 | -------------------------------------------------------------------------------- /recml/core/ops/embedding_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Embedding lookup ops.""" 15 | 16 | from collections.abc import Mapping, Sequence 17 | import dataclasses 18 | import functools 19 | from typing import Any, TypeVar 20 | 21 | from etils import epy 22 | import jax 23 | from jax.experimental import shard_map 24 | 25 | with epy.lazy_imports(): 26 | # pylint: disable=g-import-not-at-top 27 | from jax_tpu_embedding.sparsecore.lib.nn import embedding 28 | # pylint: enable=g-import-not-at-top 29 | 30 | 31 | T = TypeVar("T") 32 | Nested = T | Sequence[T] | Mapping[str, T] 33 | FeatureSpec = Any 34 | 35 | 36 | @dataclasses.dataclass 37 | class SparsecoreParams: 38 | """Embedding parameters.""" 39 | 40 | feature_specs: Nested[FeatureSpec] 41 | abstract_mesh: jax.sharding.AbstractMesh 42 | data_axes: Sequence[str | None] 43 | embedding_axes: Sequence[str | None] 44 | sharding_strategy: str 45 | 46 | 47 | @functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) 48 | def sparsecore_lookup( 49 | sparsecore_params: SparsecoreParams, 50 | tables: Mapping[str, tuple[jax.Array, ...]], 51 | csr_inputs: tuple[jax.Array, ...], 52 | ): 53 | return shard_map.shard_map( 54 | functools.partial( 55 | embedding.tpu_sparse_dense_matmul, 56 | global_device_count=sparsecore_params.abstract_mesh.size, 57 | feature_specs=sparsecore_params.feature_specs, 58 | sharding_strategy=sparsecore_params.sharding_strategy, 59 | ), 60 | mesh=sparsecore_params.abstract_mesh, 61 | in_specs=( 62 | jax.sharding.PartitionSpec(*sparsecore_params.data_axes), 63 | jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes), 64 | ), 65 | out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes), 66 | check_rep=False, 67 | )(csr_inputs, tables) 68 | 69 | 70 | def _emb_lookup_fwd( 71 | sparsecore_params: SparsecoreParams, 72 | tables: Mapping[str, tuple[jax.Array, ...]], 73 | csr_inputs: tuple[jax.Array, ...], 74 | ): 75 | out = sparsecore_lookup(sparsecore_params, tables, csr_inputs) 76 | return out, (tables, csr_inputs) 77 | 78 | 79 | def _emb_lookup_bwd( 80 | sparsecore_params: SparsecoreParams, 81 | res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]], 82 | gradients: Nested[jax.Array], 83 | ) -> tuple[Nested[jax.Array], None]: 84 | """Backward pass for embedding lookup.""" 85 | (tables, csr_inputs) = res 86 | 87 | emb_table_grads = shard_map.shard_map( 88 | functools.partial( 89 | embedding.tpu_sparse_dense_matmul_grad, 90 | feature_specs=sparsecore_params.feature_specs, 91 | sharding_strategy=sparsecore_params.sharding_strategy, 92 | ), 93 | mesh=sparsecore_params.abstract_mesh, 94 | in_specs=( 95 | jax.sharding.PartitionSpec(*sparsecore_params.data_axes), 96 | jax.sharding.PartitionSpec(*sparsecore_params.data_axes), 97 | jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes), 98 | ), 99 | out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes), 100 | check_rep=False, 101 | )(gradients, csr_inputs, tables) 102 | 103 | # `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict). 104 | # It may not be the same type as the embedding table (e.g. FrozenDict). 105 | # Here we use flatten / unflatten to ensure the types are the same. 106 | emb_table_grads = jax.tree.unflatten( 107 | jax.tree.structure(tables), jax.tree.leaves(emb_table_grads) 108 | ) 109 | 110 | return emb_table_grads, None 111 | 112 | 113 | sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd) 114 | -------------------------------------------------------------------------------- /recml/core/training/core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Core training library for Jax.""" 15 | 16 | import abc 17 | from collections.abc import Mapping, Sequence 18 | import dataclasses 19 | import enum 20 | from typing import Any, Generic, TypeVar 21 | 22 | import jax 23 | import jax.numpy as jnp 24 | from recml.core.data import iterator 25 | import tensorflow as tf 26 | 27 | 28 | # pylint: disable=logging-fstring-interpolation 29 | 30 | LOG_DIR = "logs" 31 | BACKUP_DIR = "backup" 32 | CHECKPOINT_DIR = "checkpoints" 33 | TRAINING_COMPLETE_MARKER_FILE = "marker.txt" 34 | TRAIN_LOG_DIRNAME = "train" 35 | EVAL_LOG_DIRNAME = "val" 36 | KERAS_MODEL_SAVEFILE = "model.keras" 37 | ORBAX_CHECKPOINT_DEFAULT_KEY = "default" 38 | 39 | DEFAULT_RNG_SEED = 0 40 | IN_TRAINER_CONTEXT = False # Set to true when run from the main trainer. 41 | STATE_CHECKPOINT_KEY = "state" 42 | 43 | TaskT = TypeVar("TaskT") 44 | DatasetT = TypeVar( 45 | "DatasetT", 46 | tf.data.Dataset, 47 | tuple[tf.data.Dataset, tf.data.Dataset], 48 | tuple[tf.data.Dataset, Mapping[str, tf.data.Dataset]], 49 | iterator.Iterator, 50 | tuple[iterator.Iterator, iterator.Iterator], 51 | tuple[iterator.Iterator, Mapping[str, iterator.Iterator]], 52 | ) 53 | MetaT = TypeVar("MetaT") 54 | Logs = Any # Any metric logs returned by the training or evaluation task. 55 | 56 | 57 | class Trainer(abc.ABC, Generic[TaskT]): 58 | """A base trainer interface for training and evaluation.""" 59 | 60 | @abc.abstractmethod 61 | def __init__(self, model_dir: str, *args, **kwargs): 62 | """Initializes the instance.""" 63 | 64 | @abc.abstractmethod 65 | def train(self, task: TaskT, *args, **kwargs) -> Logs | None: 66 | """Performs training for a fixed number of steps.""" 67 | 68 | @abc.abstractmethod 69 | def evaluate(self, task: TaskT, *args, **kwargs) -> Logs | None: 70 | """Performs evaluation for a fixed number of steps.""" 71 | 72 | @abc.abstractmethod 73 | def train_and_evaluate(self, task: TaskT, *args, **kwargs) -> Logs | None: 74 | """Performs training and evaluation for a fixed number of steps.""" 75 | 76 | @abc.abstractmethod 77 | def evaluate_continuously(self, task: TaskT, *args, **kwargs) -> Logs | None: 78 | """Performs continuous evaluation until a condition is met.""" 79 | 80 | 81 | @dataclasses.dataclass(frozen=True) 82 | class Experiment(Generic[TaskT]): 83 | """Experiment definition. 84 | 85 | Properties: 86 | Mode: The mode to run the experiment in. 87 | 88 | Attributes: 89 | task: A user defined task that defines the training and evaluation logic. 90 | trainer: The trainer to use for the experiment. 91 | """ 92 | 93 | class Mode(enum.StrEnum): 94 | """Mode to run an experiment.""" 95 | 96 | TRAIN = "train" 97 | EVAL = "eval" 98 | TRAIN_AND_EVAL = "train_and_eval" 99 | CONTINUOUS_EVAL = "continuous_eval" 100 | 101 | task: TaskT 102 | trainer: Trainer[TaskT] 103 | 104 | 105 | def run_experiment( 106 | experiment: Experiment, mode: Experiment.Mode 107 | ) -> Logs | None: 108 | """Runs an experiment.""" 109 | if mode == Experiment.Mode.TRAIN_AND_EVAL: 110 | return experiment.trainer.train_and_evaluate(experiment.task) 111 | elif mode == Experiment.Mode.TRAIN: 112 | return experiment.trainer.train(experiment.task) 113 | elif mode == Experiment.Mode.EVAL: 114 | return experiment.trainer.evaluate(experiment.task) 115 | elif mode == Experiment.Mode.CONTINUOUS_EVAL: 116 | return experiment.trainer.evaluate_continuously(experiment.task) 117 | else: 118 | raise ValueError(f"The job mode provided is not supported: {mode}.") 119 | 120 | 121 | def get_iterators( 122 | datasets: DatasetT, 123 | ) -> tuple[iterator.Iterator, Mapping[str, iterator.Iterator]]: 124 | """Creates and unpacks the datasets returned by the task.""" 125 | if isinstance(datasets, (iterator.Iterator, tf.data.Dataset)): 126 | if isinstance(datasets, tf.data.Dataset): 127 | datasets = iterator.TFDatasetIterator(datasets) 128 | return datasets, {} 129 | elif not isinstance(datasets, tuple) and len(datasets) != 2: 130 | raise ValueError( 131 | "Expected `datasets` to be a single dataset or a tuple of training" 132 | f" and evaluation datasets, but got {type(datasets)}." 133 | ) 134 | 135 | train_dataset, eval_datasets = datasets 136 | if isinstance(train_dataset, (iterator.Iterator, tf.data.Dataset)): 137 | if isinstance(train_dataset, tf.data.Dataset): 138 | train_dataset = iterator.TFDatasetIterator(train_dataset) 139 | else: 140 | raise ValueError( 141 | "Expected the training dataset in `datasets` to be a" 142 | " `tf.data.Dataset` or CLU `DatasetIterator` instance, but" 143 | f" {type(train_dataset)}." 144 | ) 145 | 146 | if isinstance(eval_datasets, (iterator.Iterator, tf.data.Dataset)): 147 | if isinstance(eval_datasets, tf.data.Dataset): 148 | eval_datasets = iterator.TFDatasetIterator(eval_datasets) 149 | return train_dataset, {"": eval_datasets} 150 | 151 | if not isinstance(eval_datasets, Mapping): 152 | raise ValueError( 153 | "Expected the evaluation dataset in `datasets` to either be a" 154 | " `tf.data.Dataset` or CLU `DatasetIterator` instance or be a" 155 | " mapping of datasets keyed by name, but got" 156 | f" {type(eval_datasets)}." 157 | ) 158 | 159 | if all(isinstance(v, tf.data.Dataset) for v in eval_datasets.values()): 160 | eval_datasets = { 161 | k: iterator.TFDatasetIterator(v) for k, v in eval_datasets.items() 162 | } 163 | 164 | if not all( 165 | isinstance(v, iterator.Iterator) for v in eval_datasets.values() 166 | ): 167 | raise ValueError( 168 | "Expected all values in the evaluation datasets mapping to be either" 169 | " `tf.data.Dataset` instances or CLU `DatasetIterator` instances," 170 | f" but got {eval_datasets}. You cannot mix both." 171 | ) 172 | 173 | return train_dataset, eval_datasets # pytype: disable=bad-return-type 174 | 175 | 176 | def get_shape( 177 | x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | tf.TensorSpec, 178 | ) -> Sequence[int | None]: 179 | """Gets the shape of a dense / sparse / ragged tensor or tensor spec.""" 180 | if isinstance(x, tf.SparseTensor): 181 | return [x.shape[0]] + [None for _ in x.shape[1:]] 182 | return x.shape.as_list() 183 | 184 | 185 | def in_tracing_context() -> bool: 186 | """Returns whether the current context is a tracing context.""" 187 | return isinstance(jnp.ones(()), jax.core.Tracer) 188 | -------------------------------------------------------------------------------- /recml/core/training/jax_trainer_quality_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for the quality of training loops.""" 15 | 16 | from collections.abc import Mapping 17 | import functools 18 | 19 | from absl import flags 20 | from absl.testing import absltest 21 | import clu.metrics as clu_metrics 22 | import flax.linen as nn 23 | from flax.training import train_state as ts 24 | import jax 25 | import jax.numpy as jnp 26 | import jaxtyping as jt 27 | import optax 28 | from recml.core.training import jax_trainer 29 | from recml.core.training import partitioning 30 | import tensorflow as tf 31 | import tensorflow_datasets as tfds 32 | 33 | 34 | class _MNISTTask(jax_trainer.JaxTask): 35 | """Task for fitting a CNN on MNIST.""" 36 | 37 | def create_datasets(self) -> tuple[tf.data.Dataset, tf.data.Dataset]: 38 | 39 | def _preprocessor(batch: jt.PyTree) -> jt.PyTree: 40 | images = batch['image'] 41 | labels = batch['label'] 42 | images = tf.cast(images, tf.float32) / 255.0 43 | labels = tf.cast(labels, tf.int32) 44 | return images, labels 45 | 46 | def _create_dataset(training: bool) -> tf.data.Dataset: 47 | dataset = tfds.load( 48 | name='mnist', 49 | split='train' if training else 'test', 50 | batch_size=32, 51 | shuffle_files=training, 52 | ) 53 | return dataset.map(_preprocessor).prefetch(buffer_size=tf.data.AUTOTUNE) 54 | 55 | return _create_dataset(training=True), _create_dataset(training=False) 56 | 57 | def create_state(self, batch: jt.PyTree, rng: jax.Array) -> ts.TrainState: 58 | images, _ = batch 59 | model = nn.Sequential([ 60 | nn.Conv(32, kernel_size=(3, 3)), 61 | nn.relu, 62 | functools.partial(nn.max_pool, window_shape=(2, 2), strides=(2, 2)), 63 | nn.Conv(64, kernel_size=(3, 3)), 64 | nn.relu, 65 | functools.partial(nn.max_pool, window_shape=(2, 2), strides=(2, 2)), 66 | lambda x: x.reshape((x.shape[0], -1)), 67 | nn.Dense(256), 68 | nn.relu, 69 | nn.Dense(10), 70 | ]) 71 | variables = model.init(rng, jnp.zeros_like(images)) 72 | optimizer = optax.sgd(0.1) 73 | return ts.TrainState.create( 74 | apply_fn=model.apply, params=variables, tx=optimizer 75 | ) 76 | 77 | def train_step( 78 | self, batch: jt.PyTree, state: ts.TrainState, rng: jax.Array 79 | ) -> tuple[ts.TrainState, Mapping[str, clu_metrics.Metric]]: 80 | images, labels = batch 81 | 82 | def _loss_fn(params): 83 | logits = state.apply_fn(params, images) 84 | loss = jnp.mean( 85 | optax.softmax_cross_entropy_with_integer_labels(logits, labels), 86 | axis=0, 87 | ) 88 | return loss, (logits, labels) 89 | 90 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 91 | (loss, (logits, labels)), grads = grad_fn(state.params) 92 | state = state.apply_gradients(grads=grads) 93 | metrics = { 94 | 'loss': clu_metrics.Average.from_model_output(loss), 95 | 'accuracy': clu_metrics.Accuracy.from_model_output( 96 | logits=logits, labels=labels 97 | ), 98 | } 99 | return state, metrics 100 | 101 | def eval_step( 102 | self, batch: jt.PyTree, state: ts.TrainState 103 | ) -> Mapping[str, clu_metrics.Metric]: 104 | images, labels = batch 105 | logits = state.apply_fn(state.params, images) 106 | loss = jnp.mean( 107 | optax.softmax_cross_entropy_with_integer_labels(logits, labels) 108 | ) 109 | metrics = { 110 | 'loss': clu_metrics.Average.from_model_output(loss), 111 | 'accuracy': clu_metrics.Accuracy.from_model_output( 112 | logits=logits, labels=labels 113 | ), 114 | } 115 | return metrics 116 | 117 | 118 | class JaxQualityTest(absltest.TestCase): 119 | 120 | def setUp(self): 121 | super().setUp() 122 | # Workaround to make `create_tempdir` work with pytest. 123 | if not flags.FLAGS.is_parsed(): 124 | flags.FLAGS.mark_as_parsed() 125 | 126 | def test_mnist_e2e(self): 127 | model_dir = self.create_tempdir().full_path 128 | task = _MNISTTask() 129 | trainer = jax_trainer.JaxTrainer( 130 | partitioner=partitioning.DataParallelPartitioner(), 131 | train_steps=1000, 132 | steps_per_eval=50, 133 | steps_per_loop=100, 134 | continuous_eval_timeout=5, 135 | model_dir=model_dir, 136 | rng_seed=42, 137 | ) 138 | logs = trainer.train_and_evaluate(task) 139 | self.assertGreater(logs['train']['accuracy'], 0.95) 140 | self.assertGreater(logs['val']['accuracy'], 0.95) 141 | 142 | self.assertTrue(tf.io.gfile.exists(model_dir)) 143 | continuous_eval_logs = trainer.evaluate_continuously(task) 144 | self.assertGreater(continuous_eval_logs['val']['accuracy'], 0.95) 145 | 146 | 147 | if __name__ == '__main__': 148 | absltest.main() 149 | -------------------------------------------------------------------------------- /recml/core/training/jax_trainer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Jax task and trainer.""" 15 | 16 | from collections.abc import Mapping, Sequence 17 | import dataclasses 18 | import os 19 | 20 | from absl import flags 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | import clu.metrics as clu_metrics 24 | import flax.linen as nn 25 | from flax.training import train_state as ts 26 | import jax 27 | import jax.numpy as jnp 28 | import jaxtyping as jt 29 | import keras 30 | import optax 31 | import orbax.checkpoint as ocp 32 | from recml.core.training import core 33 | from recml.core.training import jax_trainer 34 | from recml.core.training import partitioning 35 | import tensorflow as tf 36 | 37 | 38 | class _DummyFlaxModel(nn.Module): 39 | 40 | @nn.compact 41 | def __call__(self, inputs: jax.Array) -> jax.Array: 42 | return nn.Dense(1, kernel_init=nn.initializers.constant(-1.0))(inputs) 43 | 44 | 45 | class _JaxTask(jax_trainer.JaxTask): 46 | 47 | def create_datasets( 48 | self, 49 | ) -> tuple[tf.data.Dataset, Mapping[str, tf.data.Dataset]]: 50 | def _map_fn(x: int): 51 | return (tf.cast(x, tf.float32), 0.1 * tf.cast(x, tf.float32) + 3) 52 | 53 | return tf.data.Dataset.range(1000).map(_map_fn).batch(2), { 54 | "eval_on_train": tf.data.Dataset.range(1000).map(_map_fn).batch(2), 55 | "eval_on_test": tf.data.Dataset.range(2000).map(_map_fn).batch(2), 56 | } 57 | 58 | def create_state(self, batch: jt.PyTree, rng: jax.Array) -> ts.TrainState: 59 | x, _ = batch 60 | model = _DummyFlaxModel() 61 | params = model.init(rng, x) 62 | optimizer = optax.adagrad(0.1) 63 | return ts.TrainState.create( 64 | apply_fn=model.apply, 65 | params=params, 66 | tx=optimizer, 67 | ) 68 | 69 | def train_step( 70 | self, batch: jt.PyTree, state: ts.TrainState, rng: jax.Array 71 | ) -> tuple[ts.TrainState, Mapping[str, clu_metrics.Metric]]: 72 | x, y = batch 73 | 74 | def _loss_fn(params): 75 | y_pred = state.apply_fn(params, x) 76 | loss = keras.losses.mean_squared_error(y, y_pred) 77 | return loss 78 | 79 | grad_fn = jax.value_and_grad(_loss_fn) 80 | loss, grads = grad_fn(state.params) 81 | state = state.apply_gradients(grads=grads) 82 | return state, {"loss": clu_metrics.Average.from_model_output(loss)} 83 | 84 | def eval_step( 85 | self, batch: jt.PyTree, state: ts.TrainState 86 | ) -> Mapping[str, clu_metrics.Metric]: 87 | x, y = batch 88 | y_pred = state.apply_fn(state.params, x) 89 | loss = keras.losses.mean_squared_error(y, y_pred) 90 | return {"loss": clu_metrics.Average.from_model_output(loss)} 91 | 92 | 93 | class _KerasJaxTask(jax_trainer.JaxTask): 94 | 95 | def create_datasets(self) -> tf.data.Dataset: 96 | def _map_fn(x: int): 97 | return ( 98 | tf.expand_dims(tf.cast(x, tf.float32), axis=-1), 99 | 0.1 * tf.cast(x, tf.float32) + 3, 100 | ) 101 | 102 | return ( 103 | tf.data.Dataset.range(1000).map(_map_fn).batch(2), 104 | tf.data.Dataset.range(2000).map(_map_fn).batch(2), 105 | ) 106 | 107 | def create_state( 108 | self, batch: jt.PyTree, rng: jax.Array 109 | ) -> jax_trainer.KerasState: 110 | x, _ = batch 111 | 112 | model = keras.Sequential( 113 | [ 114 | keras.layers.Dense( 115 | 1, 116 | kernel_initializer=keras.initializers.constant(-1.0), 117 | name="dense", 118 | ), 119 | ], 120 | name="model", 121 | ) 122 | model.build(x.shape) 123 | 124 | optimizer = optax.adagrad(0.1) 125 | return jax_trainer.KerasState.create(model=model, tx=optimizer) 126 | 127 | def train_step( 128 | self, batch: jt.PyTree, state: jax_trainer.KerasState, rng: jax.Array 129 | ) -> tuple[jax_trainer.KerasState, Mapping[str, clu_metrics.Metric]]: 130 | x, y = batch 131 | 132 | def _loss_fn(tvars): 133 | y_pred, _ = state.model.stateless_call(tvars, state.ntvars, x) 134 | loss = keras.ops.mean(keras.losses.mean_squared_error(y, y_pred)) 135 | return loss 136 | 137 | grad_fn = jax.value_and_grad(_loss_fn) 138 | loss, grads = grad_fn(state.tvars) 139 | state = state.update(grads=grads) 140 | return state, {"loss": clu_metrics.Average.from_model_output(loss)} 141 | 142 | def eval_step( 143 | self, batch: jt.PyTree, state: jax_trainer.KerasState 144 | ) -> Mapping[str, clu_metrics.Metric]: 145 | x, y = batch 146 | y_pred, _ = state.model.stateless_call(state.tvars, state.ntvars, x) 147 | loss = keras.losses.mean_squared_error(y, y_pred) 148 | return {"loss": clu_metrics.Average.from_model_output(loss)} 149 | 150 | 151 | class JaxTest(parameterized.TestCase): 152 | 153 | def setUp(self): 154 | super().setUp() 155 | # Workaround to make `create_tempdir` work with pytest. 156 | if not flags.FLAGS.is_parsed(): 157 | flags.FLAGS.mark_as_parsed() 158 | 159 | @parameterized.named_parameters( 160 | { 161 | "testcase_name": "jax_task_train", 162 | "task_cls": _JaxTask, 163 | "mode": core.Experiment.Mode.TRAIN, 164 | "expected_keys": ["train"], 165 | }, 166 | { 167 | "testcase_name": "keras_jax_task_train", 168 | "task_cls": _KerasJaxTask, 169 | "mode": core.Experiment.Mode.TRAIN, 170 | "expected_keys": ["train"], 171 | }, 172 | { 173 | "testcase_name": "jax_task_eval", 174 | "task_cls": _JaxTask, 175 | "mode": core.Experiment.Mode.EVAL, 176 | "expected_keys": ["val_eval_on_train", "val_eval_on_test"], 177 | }, 178 | { 179 | "testcase_name": "keras_jax_task_eval", 180 | "task_cls": _KerasJaxTask, 181 | "mode": core.Experiment.Mode.EVAL, 182 | "expected_keys": ["val"], 183 | }, 184 | { 185 | "testcase_name": "jax_task_train_and_eval", 186 | "task_cls": _JaxTask, 187 | "mode": core.Experiment.Mode.TRAIN_AND_EVAL, 188 | "expected_keys": ["train", "val_eval_on_train", "val_eval_on_test"], 189 | }, 190 | { 191 | "testcase_name": "keras_jax_task_train_and_eval", 192 | "task_cls": _KerasJaxTask, 193 | "mode": core.Experiment.Mode.TRAIN_AND_EVAL, 194 | "expected_keys": ["train", "val"], 195 | }, 196 | { 197 | "testcase_name": "jax_task_continuous_eval", 198 | "task_cls": _JaxTask, 199 | "mode": core.Experiment.Mode.CONTINUOUS_EVAL, 200 | "expected_keys": ["val_eval_on_train", "val_eval_on_test"], 201 | }, 202 | { 203 | "testcase_name": "keras_jax_task_continuous_eval", 204 | "task_cls": _KerasJaxTask, 205 | "mode": core.Experiment.Mode.CONTINUOUS_EVAL, 206 | "expected_keys": ["val"], 207 | }, 208 | ) 209 | def test_jax_trainer( 210 | self, 211 | task_cls: type[jax_trainer.JaxTask], 212 | mode: str, 213 | expected_keys: Sequence[str], 214 | ): 215 | model_dir = self.create_tempdir().full_path 216 | task = task_cls() 217 | trainer = jax_trainer.JaxTrainer( 218 | partitioner=partitioning.DataParallelPartitioner(data_axis="batch"), 219 | train_steps=12, 220 | steps_per_eval=3, 221 | steps_per_loop=4, 222 | model_dir=model_dir, 223 | continuous_eval_timeout=5, 224 | ) 225 | experiment = core.Experiment(task, trainer) 226 | if mode == core.Experiment.Mode.CONTINUOUS_EVAL: 227 | # Produce one checkpoint so there is something to evaluate. 228 | core.run_experiment(experiment, core.Experiment.Mode.TRAIN) 229 | logs = core.run_experiment(experiment, mode) 230 | 231 | for key in expected_keys: 232 | self.assertIn(key, logs) 233 | self.assertIn("loss", logs[key]) 234 | 235 | if mode in [ 236 | core.Experiment.Mode.TRAIN, 237 | core.Experiment.Mode.TRAIN_AND_EVAL, 238 | ]: 239 | checkpointed_steps = ocp.utils.checkpoint_steps( 240 | os.path.join(model_dir, core.CHECKPOINT_DIR) 241 | ) 242 | self.assertEqual([3, 7, 11], sorted(checkpointed_steps)) 243 | 244 | # TODO(aahil): Check the logs for the correct summaries. 245 | # TODO(aahil): Test exporting here. 246 | 247 | def test_optimizer_metrics(self): 248 | @dataclasses.dataclass 249 | class State: 250 | step: int 251 | opt_state: optax.OptState 252 | 253 | tx = optax.chain( 254 | optax.clip_by_global_norm(1.0), 255 | optax.scale_by_adam(), 256 | optax.inject_stateful_hyperparams(optax.scale_by_learning_rate)( 257 | learning_rate=0.1 258 | ), 259 | ) 260 | state = State(step=10, opt_state=tx.init({"a": jnp.ones((10, 10))})) 261 | metrics = jax_trainer._state_metrics(state) 262 | self.assertIn("optimizer/learning_rate", metrics) 263 | self.assertEqual(metrics["optimizer/learning_rate"].compute(), 0.1) 264 | 265 | 266 | if __name__ == "__main__": 267 | absltest.main() 268 | -------------------------------------------------------------------------------- /recml/core/training/keras_trainer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Jax training library.""" 15 | 16 | from absl import flags 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import keras 20 | from recml.core.training import core 21 | from recml.core.training import keras_trainer 22 | import tensorflow as tf 23 | 24 | 25 | class _KerasTask(keras_trainer.KerasTask): 26 | 27 | def create_dataset(self, training: bool) -> tf.data.Dataset: 28 | def _map_fn(x: int): 29 | return (tf.cast(x, tf.float32), 0.1 * tf.cast(x, tf.float32) + 3) 30 | 31 | return tf.data.Dataset.range(1000).map(_map_fn).batch(2) 32 | 33 | def create_model(self) -> keras.Model: 34 | inputs = keras.Input(shape=(1,), dtype=tf.float32) 35 | outputs = keras.layers.Dense( 36 | 1, kernel_initializer=keras.initializers.constant(-1.0) 37 | )(inputs) 38 | model = keras.Model(inputs=inputs, outputs=outputs) 39 | model.compile( 40 | optimizer=keras.optimizers.Adagrad(0.1), 41 | loss=keras.losses.MeanSquaredError(), 42 | ) 43 | return model 44 | 45 | 46 | class KerasTrainerTest(parameterized.TestCase): 47 | 48 | def setUp(self): 49 | super().setUp() 50 | # Workaround to make `create_tempdir` work with pytest. 51 | if not flags.FLAGS.is_parsed(): 52 | flags.FLAGS.mark_as_parsed() 53 | 54 | @parameterized.named_parameters( 55 | {"testcase_name": "train", "mode": core.Experiment.Mode.TRAIN}, 56 | {"testcase_name": "eval", "mode": core.Experiment.Mode.EVAL}, 57 | { 58 | "testcase_name": "train_and_eval", 59 | "mode": core.Experiment.Mode.TRAIN_AND_EVAL, 60 | }, 61 | { 62 | "testcase_name": "continuous_eval", 63 | "mode": core.Experiment.Mode.CONTINUOUS_EVAL, 64 | }, 65 | ) 66 | def test_keras_task_and_trainer(self, mode: str): 67 | if keras.backend.backend() == "jax": 68 | distribution = keras.distribution.DataParallel() 69 | else: 70 | distribution = None 71 | if mode == core.Experiment.Mode.CONTINUOUS_EVAL: 72 | self.skipTest("Continuous eval is only supported on the Jax backend.") 73 | 74 | trainer = keras_trainer.KerasTrainer( 75 | distribution=distribution, 76 | train_steps=5, 77 | steps_per_eval=3, 78 | steps_per_loop=2, 79 | model_dir=self.create_tempdir().full_path, 80 | continuous_eval_timeout=5, 81 | ) 82 | experiment = core.Experiment(_KerasTask(), trainer) 83 | 84 | if mode == core.Experiment.Mode.CONTINUOUS_EVAL: 85 | # Produce one checkpoint so there is something to evaluate. 86 | core.run_experiment(experiment, core.Experiment.Mode.TRAIN) 87 | 88 | history = core.run_experiment(experiment, mode) 89 | 90 | if ( 91 | mode 92 | in [core.Experiment.Mode.TRAIN, core.Experiment.Mode.TRAIN_AND_EVAL] 93 | and keras.backend.backend() == "jax" 94 | ): 95 | self.assertEqual(history.history["num_params/trainable"][0], 2) 96 | 97 | 98 | if __name__ == "__main__": 99 | absltest.main() 100 | -------------------------------------------------------------------------------- /recml/core/training/optax_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Optax optimizer factories.""" 15 | 16 | from collections.abc import Callable 17 | import dataclasses 18 | import re 19 | from typing import Any 20 | 21 | import jax 22 | import optax 23 | from recml.core.utils import types 24 | 25 | 26 | def _default_weight_decay_mask(params: optax.Params) -> optax.Params: 27 | """Default weight decay mask that only applies to non-1D parameters.""" 28 | return jax.tree.map(lambda p: p.ndim > 1, params) 29 | 30 | 31 | def _regex_mask(regex: str) -> Callable[[optax.Params], optax.Params]: 32 | """Returns a mask that applies to parameters matching a regex.""" 33 | 34 | def _matches_regex(path: tuple[str, ...], _: Any) -> bool: 35 | key = '/'.join([jax.tree_util.keystr((k,), simple=True) for k in path]) 36 | return re.fullmatch(regex, key) is not None 37 | 38 | def _mask(params: optax.Params) -> optax.Params: 39 | return jax.tree.map_with_path(_matches_regex, params) 40 | 41 | return _mask 42 | 43 | 44 | class OptimizerFactory(types.Factory[optax.GradientTransformation]): 45 | """Standard optimizer factory for Optax optimizers. 46 | 47 | Attributes: 48 | learning_rate: The learning rate to use for the optimizer. 49 | scaling: The gradient scaling transformation to use during optimization. 50 | Defaults to identity. 51 | weight_decay: Optional weight decay to apply to variables during 52 | optimization. Defaults to None. 53 | grad_clip_norm: Optional gradient clipping norm to limit the maximum 54 | magnitude of the gradients during optimization. Defaults to None. 55 | weight_decay_mask: The weight decay mask to use when applying weight decay. 56 | Defaults applying weight decay to all non-1D parameters. 57 | freeze_mask: Optional mask to freeze parameters during optimization. 58 | Defaults to None. 59 | 60 | Example usage: 61 | 62 | ``` 63 | sgd = OptimizerFactory(learning_rate=0.001).make() 64 | 65 | adamw = OptimizerFactory( 66 | learning_rate=1e-3, 67 | scale_transform=optax.scale_by_adam(), 68 | weight_decay=1e-7, 69 | grad_clip_norm=1.0, 70 | ).make() 71 | ``` 72 | """ 73 | 74 | learning_rate: optax.ScalarOrSchedule 75 | scaling: optax.GradientTransformation = dataclasses.field( 76 | default_factory=optax.identity 77 | ) 78 | weight_decay: float | None = None 79 | grad_clip_norm: float | None = None 80 | weight_decay_mask: str | Callable[[optax.Params], optax.Params] = ( 81 | _default_weight_decay_mask 82 | ) 83 | freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None 84 | 85 | def make(self) -> optax.GradientTransformation: 86 | if self.grad_clip_norm is not None: 87 | apply_clipping = optax.clip_by_global_norm(self.grad_clip_norm) 88 | else: 89 | apply_clipping = optax.identity() 90 | 91 | # Tags the learning rate as a stateful hyperparameter so it can be logged. 92 | lr_scaling = optax.inject_stateful_hyperparams( 93 | optax.scale_by_learning_rate 94 | )(learning_rate=self.learning_rate) 95 | 96 | if self.weight_decay is not None: 97 | if isinstance(self.weight_decay_mask, str): 98 | mask = _regex_mask(self.weight_decay_mask) 99 | else: 100 | mask = self.weight_decay_mask 101 | weight_decay = optax.add_decayed_weights(self.weight_decay, mask=mask) 102 | else: 103 | weight_decay = optax.identity() 104 | 105 | tx = optax.chain(*[ 106 | apply_clipping, 107 | self.scaling, 108 | weight_decay, 109 | lr_scaling, 110 | ]) 111 | 112 | if self.freeze_mask is not None: 113 | if isinstance(self.freeze_mask, str): 114 | mask = _regex_mask(self.freeze_mask) 115 | else: 116 | mask = self.freeze_mask 117 | 118 | def _param_labels(params: optax.Params) -> optax.Params: 119 | return jax.tree.map( 120 | lambda p: 'frozen' if mask(p) else 'trainable', params 121 | ) 122 | 123 | tx = optax.multi_transform( 124 | transforms={'trainable': tx, 'frozen': optax.set_to_zero()}, 125 | param_labels=_param_labels, 126 | ) 127 | return tx 128 | 129 | 130 | class AdamFactory(types.Factory[optax.GradientTransformation]): 131 | """Adam optimizer factory. 132 | 133 | Attributes: 134 | learning_rate: The learning rate to use for the optimizer. 135 | b1: The beta1 coefficient for the Adam optimizer. Defaults to 0.9. 136 | b2: The beta2 coefficient for the Adam optimizer. Defaults to 0.999. 137 | eps: The epsilon coefficient for the Adam optimizer. Defaults to 1e-8. 138 | weight_decay: Optional weight decay to apply to variables during 139 | optimization. Defaults to None. 140 | grad_clip_norm: Optional gradient clipping norm to limit the maximum 141 | magnitude of the gradients during optimization. Defaults to None. 142 | weight_decay_mask: The weight decay mask to use when applying weight decay. 143 | Defaults applying weight decay to all non-1D parameters. 144 | freeze_mask: Optional mask to freeze parameters during optimization. 145 | Defaults to None. 146 | 147 | Example usage: 148 | ``` 149 | adam = AdamFactory(learning_rate=1e-3).make() 150 | 151 | adamw = AdamFactory( 152 | learning_rate=1e-3, 153 | weight_decay=1e-7, 154 | grad_clip_norm=1.0, 155 | ).make() 156 | ``` 157 | """ 158 | 159 | learning_rate: optax.ScalarOrSchedule 160 | b1: float = 0.9 161 | b2: float = 0.999 162 | eps: float = 1e-8 163 | weight_decay: float | None = None 164 | grad_clip_norm: float | None = None 165 | weight_decay_mask: str | Callable[[optax.Params], optax.Params] = ( 166 | _default_weight_decay_mask 167 | ) 168 | freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None 169 | 170 | def make(self) -> optax.GradientTransformation: 171 | return OptimizerFactory( 172 | learning_rate=self.learning_rate, 173 | scaling=optax.scale_by_adam(b1=self.b1, b2=self.b2, eps=self.eps), 174 | weight_decay=self.weight_decay, 175 | grad_clip_norm=self.grad_clip_norm, 176 | weight_decay_mask=self.weight_decay_mask, 177 | ).make() 178 | 179 | 180 | class AdagradFactory(types.Factory[optax.GradientTransformation]): 181 | """Adagrad optimizer factory. 182 | 183 | Attributes: 184 | learning_rate: The learning rate to use for the optimizer. 185 | initial_accumulator_value: The initial accumulator value for the Adagrad 186 | optimizer. Defaults to 0.1. 187 | eps: The epsilon coefficient for the Adagrad optimizer. Defaults to 1e-7. 188 | grad_clip_norm: Optional gradient clipping norm to limit the maximum 189 | magnitude of the gradients during optimization. Defaults to None. 190 | freeze_mask: Optional mask to freeze parameters during optimization. 191 | Defaults to None. 192 | 193 | Example usage: 194 | ``` 195 | adagrad = AdagradFactory(learning_rate=1e-3).make() 196 | ``` 197 | """ 198 | 199 | learning_rate: optax.ScalarOrSchedule 200 | initial_accumulator_value: float = 0.1 201 | eps: float = 1e-7 202 | grad_clip_norm: float | None = None 203 | freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None 204 | 205 | def make(self) -> optax.GradientTransformation: 206 | return OptimizerFactory( 207 | learning_rate=self.learning_rate, 208 | scaling=optax.scale_by_rss( 209 | initial_accumulator_value=self.initial_accumulator_value, 210 | eps=self.eps, 211 | ), 212 | grad_clip_norm=self.grad_clip_norm, 213 | ).make() 214 | -------------------------------------------------------------------------------- /recml/core/training/optax_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test for optax optimizer factories.""" 15 | 16 | from absl.testing import absltest 17 | import jax 18 | import numpy as np 19 | import optax 20 | from recml.core.training import optax_factory 21 | 22 | 23 | class OptaxFactoryTest(absltest.TestCase): 24 | 25 | def assertOptimizersEqual( 26 | self, 27 | a: optax.GradientTransformation, 28 | b: optax.GradientTransformation, 29 | steps: int = 10, 30 | ): 31 | k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) 32 | params = { 33 | "x": jax.random.uniform(k1, (128, 128)), 34 | "y": jax.random.uniform(k2, (128, 128)), 35 | "z": jax.random.uniform(k3, (128, 128)), 36 | } 37 | grads = jax.tree.map(lambda p: jax.random.uniform(k4, p.shape), params) 38 | 39 | opt_state_a = a.init(params) 40 | opt_state_b = b.init(params) 41 | 42 | for _ in range(steps): 43 | updates_a, opt_state_a = a.update(grads, opt_state_a, params) 44 | updates_b, opt_state_b = b.update(grads, opt_state_b, params) 45 | for k in params: 46 | np.testing.assert_allclose(updates_a[k], updates_b[k]) 47 | 48 | def test_optimizer_factory(self): 49 | optimizer_a = optax_factory.OptimizerFactory( 50 | learning_rate=optax.warmup_cosine_decay_schedule( 51 | init_value=0.0, 52 | peak_value=1e-3, 53 | warmup_steps=5, 54 | decay_steps=10, 55 | end_value=0, 56 | ), 57 | scaling=optax.scale_by_rms(), 58 | weight_decay=1e-4, 59 | weight_decay_mask=r"^(?!.*(?:x|y)$).*", 60 | grad_clip_norm=1.0, 61 | ).make() 62 | optimizer_b = optax.chain( 63 | optax.clip_by_global_norm(1.0), 64 | optax.scale_by_rms(), 65 | optax.add_decayed_weights( 66 | 1e-4, mask=optax_factory._regex_mask(r"^(?!.*(?:x|y)$).*") 67 | ), 68 | optax.scale_by_learning_rate( 69 | optax.warmup_cosine_decay_schedule( 70 | init_value=0.0, 71 | peak_value=1e-3, 72 | warmup_steps=5, 73 | decay_steps=10, 74 | end_value=0, 75 | ) 76 | ), 77 | ) 78 | optimizer_c = optax_factory.OptimizerFactory( 79 | learning_rate=optax.warmup_cosine_decay_schedule( 80 | init_value=0.0, 81 | peak_value=1e-3, 82 | warmup_steps=5, 83 | decay_steps=10, 84 | end_value=0, 85 | ), 86 | scaling=optax.scale_by_rms(), 87 | weight_decay=1e-4, 88 | weight_decay_mask=r"^(?!.*(?:z)$).*", 89 | grad_clip_norm=1.0, 90 | ).make() 91 | self.assertOptimizersEqual(optimizer_a, optimizer_b, steps=10) 92 | self.assertRaises( 93 | AssertionError, 94 | self.assertOptimizersEqual, 95 | optimizer_a, 96 | optimizer_c, 97 | steps=10, 98 | ) 99 | 100 | def test_adam_factory(self): 101 | optimizer_a = optax_factory.AdamFactory( 102 | learning_rate=optax.warmup_cosine_decay_schedule( 103 | init_value=0.0, 104 | peak_value=1e-3, 105 | warmup_steps=5, 106 | decay_steps=10, 107 | end_value=0, 108 | ), 109 | b1=0.9, 110 | b2=0.999, 111 | eps=1e-8, 112 | weight_decay=1e-4, 113 | grad_clip_norm=1.0, 114 | ).make() 115 | optimizer_b = optax.chain( 116 | optax.clip_by_global_norm(1.0), 117 | optax.adamw( 118 | learning_rate=optax.warmup_cosine_decay_schedule( 119 | init_value=0.0, 120 | peak_value=1e-3, 121 | warmup_steps=5, 122 | decay_steps=10, 123 | end_value=0, 124 | ), 125 | b1=0.9, 126 | b2=0.999, 127 | eps=1e-8, 128 | weight_decay=1e-4, 129 | mask=optax_factory._default_weight_decay_mask, 130 | ), 131 | ) 132 | self.assertOptimizersEqual(optimizer_a, optimizer_b, steps=10) 133 | 134 | def test_adagrad_factory(self): 135 | optimizer_a = optax_factory.AdagradFactory( 136 | learning_rate=optax.warmup_cosine_decay_schedule( 137 | init_value=0.0, 138 | peak_value=1e-3, 139 | warmup_steps=5, 140 | decay_steps=10, 141 | end_value=0, 142 | ), 143 | initial_accumulator_value=0.1, 144 | eps=1e-7, 145 | grad_clip_norm=1.0, 146 | ).make() 147 | optimizer_b = optax.chain( 148 | optax.clip_by_global_norm(1.0), 149 | optax.adagrad( 150 | learning_rate=optax.warmup_cosine_decay_schedule( 151 | init_value=0.0, 152 | peak_value=1e-3, 153 | warmup_steps=5, 154 | decay_steps=10, 155 | end_value=0, 156 | ), 157 | initial_accumulator_value=0.1, 158 | eps=1e-7, 159 | ), 160 | ) 161 | self.assertOptimizersEqual(optimizer_a, optimizer_b, steps=10) 162 | 163 | 164 | if __name__ == "__main__": 165 | absltest.main() 166 | -------------------------------------------------------------------------------- /recml/core/training/partitioning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for partitioning.""" 15 | 16 | import abc 17 | from collections.abc import Callable, Mapping, Sequence 18 | import math 19 | from typing import Any, ContextManager 20 | 21 | import flax.linen as nn 22 | import jax 23 | from jax.experimental import mesh_utils 24 | import numpy as np 25 | 26 | 27 | PyTree = Any 28 | State = Any 29 | CreateStateFn = Callable[[PyTree], State] 30 | InitFn = Callable[[PyTree, jax.Array], State] 31 | StepFn = Callable[[PyTree, State], Any] 32 | 33 | 34 | class Partitioner(abc.ABC): 35 | """An abstract class defining partitioning logic for data and computation.""" 36 | 37 | @abc.abstractmethod 38 | def shard_inputs(self, inputs: Any) -> PyTree: 39 | """Shards the input batches and put them on the device.""" 40 | 41 | @abc.abstractmethod 42 | def partition_init( 43 | self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None 44 | ) -> CreateStateFn: 45 | """Shards the initialization function.""" 46 | 47 | @abc.abstractmethod 48 | def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: 49 | """Shards the training and evaluation steps.""" 50 | 51 | 52 | class NullPartitioner(Partitioner): 53 | """A null partitioner.""" 54 | 55 | def shard_inputs(self, inputs: PyTree) -> PyTree: 56 | return inputs 57 | 58 | def partition_init( 59 | self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None 60 | ) -> CreateStateFn: 61 | return init_fn 62 | 63 | def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: 64 | return fn 65 | 66 | 67 | class DataParallelPartitioner(Partitioner): 68 | """Data parallel partitioner.""" 69 | 70 | def __init__(self, data_axis: str = "batch"): 71 | self.mesh = jax.sharding.Mesh(jax.devices(), (data_axis,)) 72 | self.data_sharding = jax.sharding.NamedSharding( 73 | self.mesh, jax.sharding.PartitionSpec(data_axis) 74 | ) 75 | self.state_sharding = jax.sharding.NamedSharding( 76 | self.mesh, jax.sharding.PartitionSpec() 77 | ) 78 | 79 | def shard_inputs(self, inputs: PyTree) -> PyTree: 80 | local_devices = self.mesh.local_devices 81 | local_device_count = len(local_devices) 82 | device_count = len(self.mesh.devices) 83 | 84 | def _shard(x: np.ndarray) -> jax.Array: 85 | per_proc_batch_size = x.shape[0] 86 | per_replica_batch_size = per_proc_batch_size // local_device_count 87 | if per_proc_batch_size % local_device_count != 0: 88 | raise ValueError( 89 | "The per process batch size must be divisible by the number of" 90 | " local devices. Got per process batch size:" 91 | f" {per_proc_batch_size} and local device count:" 92 | f" {local_device_count}." 93 | ) 94 | 95 | per_device_arrays = np.split(x, local_device_count, axis=0) 96 | device_buffers = [ 97 | jax.device_put(arr, device) 98 | for arr, device in zip(per_device_arrays, local_devices) 99 | ] 100 | 101 | global_batch_size = per_replica_batch_size * device_count 102 | return jax.make_array_from_single_device_arrays( 103 | (global_batch_size,) + x.shape[1:], self.data_sharding, device_buffers 104 | ) 105 | 106 | return jax.tree.map(_shard, inputs) 107 | 108 | def partition_init( 109 | self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None 110 | ) -> CreateStateFn: 111 | with jax.sharding.use_mesh(self.mesh): 112 | init_fn = jax.jit(init_fn, out_shardings=self.state_sharding) 113 | 114 | def _wrapped_init(batch: PyTree) -> State: 115 | with jax.sharding.use_mesh(self.mesh): 116 | state = init_fn(batch) 117 | state = _maybe_unbox_state(state) 118 | return state 119 | 120 | return _wrapped_init 121 | 122 | def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: 123 | jit_kws = {} 124 | if training: 125 | jit_kws["out_shardings"] = (self.state_sharding, None) 126 | jit_kws["donate_argnums"] = (1,) 127 | 128 | with jax.sharding.use_mesh(self.mesh): 129 | step_fn = jax.jit( 130 | fn, 131 | in_shardings=(self.data_sharding, self.state_sharding), 132 | **jit_kws, 133 | ) 134 | 135 | def _wrapped_step(batch: PyTree, state: State) -> Any: 136 | with jax.sharding.use_mesh(self.mesh): 137 | return step_fn(batch, state) 138 | 139 | return _wrapped_step 140 | 141 | 142 | class ModelParallelPartitioner(Partitioner): 143 | """Model parallel partitioner. 144 | 145 | This only works with multi-controller Jax, i.e. communications along the ICI 146 | for TPUs. For scaling beyond a single TPU slice this needs to be extended to 147 | support Megascale XLA or single-controller Pathways. Consider using T5X, Pax, 148 | or Gemax for these use cases. 149 | 150 | Note: This assumes that all axes of the inputs except the final one are used 151 | for data parallelism while the final one is used for model parallelism. 152 | This tends to work well for 2D and 3D torus topologies since network latency 153 | tends to be much higher for the leading axes. 154 | 155 | IMPORTANT: `shard_inputs` operates on a per process batch. This means that the 156 | input batch size on CPU must already be the per process batch size, 157 | i.e. global batch size // jax.process_count(). It is the responsibility of the 158 | CPU input pipeline to ensure that inputs are different across processes. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | axes: Sequence[tuple[str, int]], 164 | rules: Mapping[str, str] | None = None, 165 | aot_compile: bool = False, 166 | options: jax.stages.CompilerOptions | None = None, 167 | ): 168 | if len(axes) < 2: 169 | raise ValueError( 170 | "`axes` cannot less than 2D, use data-parallel" 171 | f" partitioner instead. Got axes: {axes}." 172 | ) 173 | 174 | mesh_devices = mesh_utils.create_device_mesh([dim for _, dim, in axes]) 175 | self.mesh = jax.sharding.Mesh(mesh_devices, [axis for axis, _ in axes]) 176 | self.rules = rules 177 | self.aot_compile = aot_compile 178 | self.options = options 179 | 180 | dp_axes, dp_dims = zip(*axes[:-1]) 181 | _, mp_dim = axes[-1] 182 | 183 | if math.prod(dp_dims) % jax.process_count() != 0: 184 | raise ValueError( 185 | "The data parallel dimensions in the mesh must be divisible by the" 186 | " number of processes as we assume data parallelism across" 187 | f" processes. Got process count: {jax.process_count()} and data" 188 | f" parallelism dimensions: {dp_dims} for axes: {axes} and mesh" 189 | f" devices: {self.mesh.devices}." 190 | ) 191 | if jax.local_device_count() % mp_dim != 0: 192 | raise ValueError( 193 | "The number of local devices on each host must be divisible by the" 194 | " model dimension as we assume model parallelism across local" 195 | f" devices. Got local device count: {jax.local_device_count()} and" 196 | f" model parallelism dimension: {mp_dim} for axes: {axes} and mesh" 197 | f" devices: {self.mesh.devices}." 198 | ) 199 | 200 | self.data_sharding = jax.sharding.NamedSharding( 201 | self.mesh, jax.sharding.PartitionSpec(dp_axes) 202 | ) 203 | self.state_sharding = None 204 | self.abstract_batch = None 205 | self.abstract_state = None 206 | 207 | @property 208 | def mesh_context_manager( 209 | self, 210 | ) -> Callable[[jax.sharding.Mesh], ContextManager[None]]: 211 | return jax.sharding.use_mesh 212 | 213 | def shard_inputs(self, inputs: PyTree) -> PyTree: 214 | def _shard(x: np.ndarray) -> jax.Array: 215 | return jax.make_array_from_process_local_data(self.data_sharding, x) 216 | 217 | return jax.tree.map(_shard, inputs) 218 | 219 | def partition_init( 220 | self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None 221 | ) -> CreateStateFn: 222 | if abstract_batch is None: 223 | raise ValueError( 224 | "An `abstract_batch` is required for partitioning `init_fn` with a" 225 | " model parallel partitioner." 226 | ) 227 | 228 | with self.mesh_context_manager(self.mesh): 229 | abstract_state = jax.eval_shape(init_fn, abstract_batch) 230 | specs = nn.get_partition_spec(abstract_state) 231 | 232 | if self.rules is not None: 233 | specs = nn.logical_to_mesh(specs, self.rules) 234 | 235 | state_sharding = jax.tree.map( 236 | lambda x: jax.sharding.NamedSharding(self.mesh, x), specs 237 | ) 238 | compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding) 239 | 240 | def _init(batch: PyTree) -> State: 241 | with self.mesh_context_manager(self.mesh): 242 | state = compiled_init_fn(batch) 243 | state = _maybe_unbox_state(state) 244 | return state 245 | 246 | self.abstract_batch = abstract_batch 247 | self.abstract_state = abstract_state 248 | self.state_sharding = state_sharding 249 | return _init 250 | 251 | def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: 252 | jit_kws = {} 253 | if training: 254 | jit_kws["out_shardings"] = (self.state_sharding, None) 255 | jit_kws["donate_argnums"] = (1,) 256 | else: 257 | jit_kws["out_shardings"] = None 258 | 259 | with self.mesh_context_manager(self.mesh): 260 | step_fn = jax.jit( 261 | fn, 262 | in_shardings=(self.data_sharding, self.state_sharding), 263 | compiler_options=(self.options if not self.aot_compile else None), 264 | **jit_kws, 265 | ) 266 | if self.aot_compile: 267 | if self.abstract_batch is None or self.abstract_state is None: 268 | raise ValueError( 269 | "An `abstract_batch` and `abstract_state` must be set on the model" 270 | " parallel partitioner when `aot_compile` is set to True in order" 271 | " to compile the step. Make sure you call" 272 | " `partitioner.partition_init(...)` first." 273 | ) 274 | 275 | step_fn = step_fn.lower(self.abstract_batch, self.abstract_state).compile( 276 | self.options 277 | ) 278 | 279 | def _step(batch: PyTree, state: State) -> Any: 280 | with self.mesh_context_manager(self.mesh): 281 | return step_fn(batch, state) 282 | 283 | return _step 284 | 285 | 286 | def _maybe_unbox_state(x: Any) -> Any: 287 | def _maybe_unbox(x: Any) -> Any: 288 | if isinstance(x, nn.Partitioned): 289 | return x.unbox() 290 | return x 291 | 292 | return jax.tree.map( 293 | _maybe_unbox, 294 | x, 295 | is_leaf=lambda k: isinstance(k, nn.Partitioned), 296 | ) 297 | -------------------------------------------------------------------------------- /recml/core/training/partitioning_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Jax partitioners.""" 15 | 16 | from collections.abc import Mapping 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import flax.linen as nn 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | from recml.core.training import partitioning 25 | 26 | 27 | class PartitioningTest(parameterized.TestCase): 28 | 29 | @parameterized.named_parameters( 30 | { 31 | "testcase_name": "data_parallel_partitioner", 32 | "partitioner_cls": partitioning.DataParallelPartitioner, 33 | }, 34 | { 35 | "testcase_name": "model_parallel_partitioner", 36 | "partitioner_cls": partitioning.ModelParallelPartitioner, 37 | }, 38 | ) 39 | def test_data_parallelism( 40 | self, partitioner_cls: type[partitioning.Partitioner] 41 | ): 42 | if partitioner_cls is partitioning.ModelParallelPartitioner: 43 | kwargs = {"axes": [("data", jax.device_count()), ("model", 1)]} 44 | else: 45 | kwargs = {} 46 | partitioner = partitioner_cls(**kwargs) 47 | 48 | inputs = np.zeros((128, 16), dtype=np.float32) 49 | sharded_inputs = partitioner.shard_inputs(inputs) 50 | 51 | self.assertIsInstance(sharded_inputs, jax.Array) 52 | self.assertSequenceEqual(sharded_inputs.shape, (128, 16)) 53 | self.assertEqual(sharded_inputs.sharding, partitioner.data_sharding) 54 | 55 | def _init(batch: jax.Array) -> jax.Array: 56 | return jnp.ones_like(batch) 57 | 58 | def _train_step( 59 | batch: jax.Array, state: jax.Array 60 | ) -> tuple[jax.Array, Mapping[str, jax.Array]]: 61 | return batch + state, { 62 | "batch_mean": jnp.mean(batch), 63 | "state_mean": jnp.mean(state), 64 | } 65 | 66 | def _eval_step( 67 | batch: jax.Array, state: jax.Array 68 | ) -> Mapping[str, jax.Array]: 69 | return {"batch_mean": jnp.mean(batch), "state_mean": jnp.mean(state)} 70 | 71 | state = partitioner.partition_init(_init, abstract_batch=sharded_inputs)( 72 | sharded_inputs 73 | ) 74 | self.assertIsInstance(state, jax.Array) 75 | self.assertSequenceEqual(state.shape, (128, 16)) 76 | self.assertEqual(state.sharding, partitioner.state_sharding) 77 | 78 | new_state, metrics = partitioner.partition_step(_train_step, training=True)( 79 | sharded_inputs, state 80 | ) 81 | self.assertTrue(state.is_deleted()) # Buffer should be donated. 82 | self.assertIsInstance(new_state, jax.Array) 83 | self.assertSequenceEqual(new_state.shape, (128, 16)) 84 | self.assertEqual(new_state.sharding, partitioner.state_sharding) 85 | for metric in jax.tree.flatten(metrics)[0]: 86 | self.assertIsInstance(metric, jax.Array) 87 | self.assertEqual( 88 | metric.sharding, 89 | jax.sharding.NamedSharding( 90 | partitioner.mesh, jax.sharding.PartitionSpec() 91 | ), 92 | ) 93 | 94 | metrics = partitioner.partition_step(_eval_step, training=False)( 95 | sharded_inputs, new_state 96 | ) 97 | self.assertFalse(new_state.is_deleted()) # Buffer should not be donated. 98 | for metric in jax.tree.flatten(metrics)[0]: 99 | self.assertIsInstance(metric, jax.Array) 100 | self.assertEqual( 101 | metric.sharding, 102 | jax.sharding.NamedSharding( 103 | partitioner.mesh, jax.sharding.PartitionSpec() 104 | ), 105 | ) 106 | 107 | self.assertEqual( 108 | partitioner.state_sharding, 109 | jax.sharding.NamedSharding( 110 | partitioner.mesh, jax.sharding.PartitionSpec() 111 | ), 112 | ) 113 | 114 | def test_model_parallelism(self): 115 | partitioner = partitioning.ModelParallelPartitioner( 116 | axes=[("data", 1), ("model", jax.device_count())] 117 | ) 118 | 119 | inputs = np.zeros((128, 16), dtype=np.float32) 120 | sharded_inputs = partitioner.shard_inputs(inputs) 121 | 122 | self.assertIsInstance(sharded_inputs, jax.Array) 123 | self.assertSequenceEqual(sharded_inputs.shape, (128, 16)) 124 | self.assertEqual( 125 | sharded_inputs.sharding, 126 | jax.sharding.NamedSharding( 127 | partitioner.mesh, jax.sharding.PartitionSpec("data") 128 | ), 129 | ) 130 | 131 | def _init(batch: jax.Array) -> jax.Array: 132 | return nn.with_partitioning( 133 | jnp.ones_like, ("data", "model"), partitioner.mesh 134 | )(batch) 135 | 136 | state = partitioner.partition_init(_init, abstract_batch=sharded_inputs)( 137 | sharded_inputs 138 | ) 139 | 140 | self.assertIsInstance(state, jax.Array) 141 | self.assertSequenceEqual(state.shape, (128, 16)) 142 | self.assertEqual(state.sharding, partitioner.state_sharding) 143 | self.assertEqual( 144 | partitioner.state_sharding, 145 | jax.sharding.NamedSharding( 146 | partitioner.mesh, 147 | jax.sharding.PartitionSpec("data", "model"), 148 | ), 149 | ) 150 | 151 | # TODO(aahil): Add tests for the steps. 152 | 153 | 154 | if __name__ == "__main__": 155 | absltest.main() 156 | -------------------------------------------------------------------------------- /recml/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Public utilities API.""" 15 | 16 | # pylint: disable=g-importing-member 17 | 18 | from recml.core.utils.config import DEFINE_fiddle_config 19 | from recml.core.utils.config import FiddleFlag 20 | from recml.core.utils.types import Dataclass 21 | from recml.core.utils.types import FrozenDataclass 22 | -------------------------------------------------------------------------------- /recml/core/utils/config_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for configuration utilities.""" 15 | 16 | from collections.abc import Sequence 17 | import dataclasses 18 | import sys 19 | 20 | from absl import flags 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | import fiddle as fdl 24 | from recml.core.utils import config as config_lib 25 | 26 | # Pytest may use the test from a different module, otherwise this should be 27 | # __main__. 28 | _TEST_MODULE_NAME = sys.modules[__name__].__name__ 29 | 30 | 31 | @dataclasses.dataclass 32 | class _X: 33 | value: int 34 | 35 | 36 | @dataclasses.dataclass 37 | class _Y: 38 | value: int 39 | 40 | 41 | @dataclasses.dataclass 42 | class _Object: 43 | x: _X 44 | y: _Y 45 | 46 | 47 | def config_1() -> fdl.Config[_Object]: 48 | return fdl.Config( 49 | _Object, 50 | x=fdl.Config(_X, value=1), 51 | y=fdl.Config(_Y, value=2), 52 | ) 53 | 54 | 55 | def fiddler_1(cfg: fdl.Config[_Object]): 56 | cfg.x.value = 3 57 | 58 | 59 | def fiddler_2(cfg: fdl.Config[_Object], value: int): 60 | cfg.y.value = value 61 | 62 | 63 | class ConfigTest(parameterized.TestCase): 64 | 65 | @parameterized.named_parameters( 66 | { 67 | 'testcase_name': 'base_config', 68 | 'args': [f'config:{_TEST_MODULE_NAME}.config_1'], 69 | 'expected_config': config_1(), 70 | }, 71 | { 72 | 'testcase_name': 'relative_fiddler', 73 | 'args': [ 74 | f'config:{_TEST_MODULE_NAME}.config_1', 75 | 'fiddler:fiddler_1', 76 | 'fiddler:fiddler_2(value=4)', 77 | ], 78 | 'expected_config': fdl.Config( 79 | _Object, 80 | x=fdl.Config(_X, value=3), 81 | y=fdl.Config(_Y, value=4), 82 | ), 83 | }, 84 | { 85 | 'testcase_name': 'absolute_fiddler', 86 | 'args': [ 87 | f'config:{_TEST_MODULE_NAME}.config_1', 88 | f'fiddler:{_TEST_MODULE_NAME}.fiddler_2(3)', 89 | ], 90 | 'expected_config': fdl.Config( 91 | _Object, 92 | x=fdl.Config(_X, value=1), 93 | y=fdl.Config(_Y, value=3), 94 | ), 95 | }, 96 | { 97 | 'testcase_name': 'set', 98 | 'args': [ 99 | f'config:{_TEST_MODULE_NAME}.config_1', 100 | 'set:x.value=0', 101 | 'set:y.value=0', 102 | ], 103 | 'expected_config': fdl.Config( 104 | _Object, 105 | x=fdl.Config(_X, value=0), 106 | y=fdl.Config(_Y, value=0), 107 | ), 108 | }, 109 | ) 110 | def test_fiddle_flag( 111 | self, args: Sequence[str], expected_config: fdl.Config[_Object] 112 | ): 113 | fdl_flag = config_lib.FiddleFlag( 114 | name='test_flag', 115 | default=None, 116 | parser=flags.ArgumentParser(), 117 | serializer=None, 118 | help_string='My fiddle flag', 119 | ) 120 | fdl_flag.parse(args) 121 | self.assertEqual(expected_config, fdl_flag.value) 122 | 123 | @parameterized.named_parameters( 124 | { 125 | 'testcase_name': 'bad_base_config', 126 | 'args': [f'config:{_TEST_MODULE_NAME}.config_3'], 127 | 'expected_error': AttributeError, 128 | 'expected_error_regex': 'Could not init a buildable from .*', 129 | }, 130 | { 131 | 'testcase_name': 'bad_fiddler', 132 | 'args': [f'config:{_TEST_MODULE_NAME}.config_1', 'fiddler:fiddler_3'], 133 | 'expected_error': ValueError, 134 | # TODO(aahil): Figure out why the error regex is different in 3P. 135 | 'expected_error_regex': '.*', 136 | }, 137 | ) 138 | def test_invalid_fiddle_flag( 139 | self, 140 | args: Sequence[str], 141 | expected_error: type[Exception], 142 | expected_error_regex: str, 143 | ): 144 | fdl_flag = config_lib.FiddleFlag( 145 | name='test_flag', 146 | default=None, 147 | parser=flags.ArgumentParser(), 148 | serializer=None, 149 | help_string='My fiddle flag', 150 | ) 151 | 152 | def _value(args: Sequence[str]): 153 | fdl_flag.parse(args) 154 | return fdl_flag.value 155 | 156 | self.assertRaisesRegex(expected_error, expected_error_regex, _value, args) 157 | 158 | 159 | if __name__ == '__main__': 160 | absltest.main() 161 | -------------------------------------------------------------------------------- /recml/core/utils/py_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Miscellaneous utilities.""" 15 | 16 | from collections.abc import Callable 17 | import inspect 18 | from typing import Any 19 | 20 | 21 | def has_argument(fn: Callable[..., Any], arg_name: str) -> bool: 22 | """Checks if a function has an argument with a given name.""" 23 | params = inspect.signature(fn).parameters.values() 24 | param_names = [v.name for v in params] 25 | has_arg = arg_name in param_names 26 | has_kw_args = any([v.kind == inspect.Parameter.VAR_KEYWORD for v in params]) 27 | return has_arg or has_kw_args 28 | -------------------------------------------------------------------------------- /recml/core/utils/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration tools for dealing with abstract types.""" 15 | 16 | import abc 17 | import dataclasses 18 | from typing import Generic, Protocol, TypeVar 19 | 20 | from typing_extensions import dataclass_transform 21 | 22 | 23 | T = TypeVar("T") 24 | 25 | 26 | @dataclass_transform(field_specifiers=dataclasses.field) # type: ignore[literal-required] 27 | class Dataclass: 28 | """A dataclass transform that converts a class to a dataclass.""" 29 | 30 | def __init_subclass__(cls, **kwargs): 31 | def replace(self, **updates): 32 | return dataclasses.replace(self, **updates) 33 | 34 | data_cls = dataclasses.dataclass(**kwargs)(cls) 35 | data_cls.replace = replace 36 | 37 | def __init__(self, *args, **kwargs): 38 | # stub for pytype 39 | raise NotImplementedError 40 | 41 | def replace(self: T, **overrides) -> T: 42 | # stub for pytype 43 | raise NotImplementedError 44 | 45 | 46 | # TODO(aahil): Share code with `Dataclass`. 47 | @dataclass_transform(field_specifiers=dataclasses.field) # type: ignore[literal-required] 48 | class FrozenDataclass: 49 | """A dataclass transform that converts a class to a frozen dataclass.""" 50 | 51 | def __init_subclass__(cls, **kwargs): 52 | if "frozen" not in kwargs: 53 | kwargs["frozen"] = True 54 | 55 | def replace(self, **updates): 56 | return dataclasses.replace(self, **updates) 57 | 58 | data_cls = dataclasses.dataclass(**kwargs)(cls) 59 | data_cls.replace = replace 60 | 61 | def __init__(self, *args, **kwargs): 62 | # stub for pytype 63 | raise NotImplementedError 64 | 65 | def replace(self: T, **overrides) -> T: 66 | # stub for pytype 67 | raise NotImplementedError 68 | 69 | 70 | class Factory(abc.ABC, Generic[T], Dataclass): 71 | """A factory interface for configuring an arbitary object via a dataclass. 72 | 73 | This is useful for creating objects that require run-time information. 74 | """ 75 | 76 | @abc.abstractmethod 77 | def make(self, *args, **kwargs) -> T: 78 | """Builds the object instance.""" 79 | 80 | 81 | class FactoryProtocol(Protocol, Generic[T]): 82 | """A protocol for typing factories.""" 83 | 84 | def make(self, *args, **kwargs) -> T: 85 | """Builds the object instance.""" 86 | 87 | def replace(self, **overrides) -> T: 88 | """Replaces the object instance.""" 89 | 90 | 91 | class ObjectFactory(Factory[T]): 92 | """A factory that wraps around the constructor of an object. 93 | 94 | This is useful when a library only accepts a factory but creating a factory 95 | introduces unnecessary boilerplate. 96 | 97 | Example usage: 98 | ``` 99 | class MyObject: 100 | def __init__(self, x: int, y: int): 101 | self._x = x 102 | self._y = y 103 | 104 | 105 | factory = ObjectFactory(MyObject, x=1, y=2) 106 | obj = factory.make() 107 | assert obj._x == 1 108 | assert obj._y == 2 109 | ``` 110 | """ 111 | 112 | def __new__(cls, *args, **kwargs) -> Factory[T]: 113 | if args: 114 | raise ValueError( 115 | "`StaticFactory` does not accept positional arguments. Got args:" 116 | f" {args}." 117 | ) 118 | if "type" not in kwargs: 119 | raise ValueError( 120 | "`StaticFactory` requires a `type` keyword argument. Got kwargs:" 121 | f" {kwargs}." 122 | ) 123 | 124 | class _ObjectFactory(Factory): 125 | 126 | def __init_subclass__(cls, **kwargs): 127 | # Override the dataclass transform from the base class. 128 | pass 129 | 130 | def make(self): 131 | return getattr(self, "type")(**{ 132 | f.name: getattr(self, f.name) 133 | for f in dataclasses.fields(self) 134 | if f.name != "type" 135 | }) 136 | 137 | sub_cls = dataclasses.make_dataclass( 138 | cls_name=cls.__name__, 139 | fields=[(k, type(v)) for k, v in kwargs.items()], 140 | bases=(_ObjectFactory,), 141 | kw_only=True, 142 | ) 143 | obj = sub_cls(**kwargs) 144 | return obj 145 | 146 | def __init__(self, *, type: type[T], **kwargs): # pylint: disable=redefined-builtin 147 | # Stub for pytype. 148 | raise NotImplementedError() 149 | 150 | @property 151 | def type(self) -> type[T]: 152 | # Stub for pytype. 153 | raise NotImplementedError() 154 | 155 | def make(self) -> T: 156 | # Stub for pytype. 157 | raise NotImplementedError() 158 | -------------------------------------------------------------------------------- /recml/core/utils/types_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for type utilities.""" 15 | 16 | import dataclasses 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from recml.core.utils import types 21 | 22 | 23 | class TypesTest(parameterized.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | {'testcase_name': 'dataclass', 'cls': types.Dataclass}, 27 | {'testcase_name': 'frozen_dataclass', 'cls': types.FrozenDataclass}, 28 | ) 29 | def test_dataclass_transform(self, cls: type[types.Dataclass]): 30 | class Foo(cls): 31 | x: int 32 | y: int 33 | z: int = dataclasses.field(default_factory=lambda: 1) 34 | 35 | class Bar(Foo): 36 | u: int = dataclasses.field(default_factory=lambda: 2) 37 | 38 | foo = Foo(x=1, y=2) 39 | self.assertEqual(foo.x, 1) 40 | self.assertEqual(foo.y, 2) 41 | self.assertEqual(foo.z, 1) 42 | self.assertTrue(dataclasses.is_dataclass(Foo)) 43 | self.assertTrue(dataclasses.is_dataclass(foo)) 44 | 45 | bar = Bar(x=1, y=2, u=3) 46 | self.assertEqual(bar.x, 1) 47 | self.assertEqual(bar.y, 2) 48 | self.assertEqual(bar.z, 1) 49 | self.assertEqual(bar.u, 3) 50 | self.assertTrue(dataclasses.is_dataclass(Bar)) 51 | self.assertTrue(dataclasses.is_dataclass(bar)) 52 | 53 | def test_frozen_dataclass(self): 54 | 55 | class Foo(types.Dataclass): 56 | x: int 57 | y: int 58 | 59 | class Bar(types.FrozenDataclass): 60 | x: int 61 | y: int 62 | 63 | def _mutate_foo_or_bar(foo_or_bar: Foo | Bar): 64 | foo_or_bar.x = 2 65 | 66 | # Mutating Foo is allowed. 67 | _mutate_foo_or_bar(Foo(x=1, y=2)) 68 | 69 | self.assertRaises( 70 | dataclasses.FrozenInstanceError, 71 | _mutate_foo_or_bar, 72 | Bar(x=1, y=2), 73 | ) 74 | 75 | def test_object_factory(self): 76 | class Foo(types.Dataclass): 77 | x: int 78 | y: int 79 | 80 | factory = types.ObjectFactory(type=Foo, x=1, y=2) 81 | self.assertEqual(factory.type, Foo) 82 | self.assertEqual(factory.x, 1) 83 | self.assertEqual(factory.y, 2) 84 | self.assertEqual(factory.make(), Foo(x=1, y=2)) 85 | 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /recml/examples/dlrm_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """DLRM experiment.""" 15 | 16 | from __future__ import annotations 17 | 18 | from collections.abc import Iterator, Mapping, Sequence 19 | import dataclasses 20 | from typing import Generic, Literal, TypeVar 21 | 22 | from etils import epy 23 | import fiddle as fdl 24 | import flax.linen as nn 25 | import jax 26 | import jax.numpy as jnp 27 | import jaxtyping as jt 28 | import numpy as np 29 | import optax 30 | import recml 31 | from recml.layers.linen import sparsecore 32 | import tensorflow as tf 33 | 34 | with epy.lazy_imports(): 35 | from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec # pylint: disable=g-import-not-at-top 36 | 37 | 38 | @dataclasses.dataclass 39 | class Feature: 40 | name: str 41 | 42 | 43 | FeatureT = TypeVar('FeatureT', bound=Feature) 44 | 45 | 46 | @dataclasses.dataclass 47 | class DenseFeature(Feature): 48 | """Dense feature.""" 49 | 50 | 51 | @dataclasses.dataclass 52 | class SparseFeature(Feature): 53 | """Sparse feature.""" 54 | 55 | vocab_size: int 56 | embedding_dim: int 57 | max_sequence_length: int | None = None 58 | combiner: Literal['mean', 'sum', 'sqrtn'] = 'mean' 59 | sparsity: float = 0.8 60 | 61 | 62 | @dataclasses.dataclass 63 | class FeatureSet(Generic[FeatureT]): 64 | """A collection of features.""" 65 | 66 | features: Sequence[FeatureT] 67 | 68 | def __post_init__(self): 69 | feature_names = [f.name for f in self.features] 70 | if len(feature_names) != len(set(feature_names)): 71 | raise ValueError( 72 | f'Feature names must be unique. Got names: {feature_names}.' 73 | ) 74 | 75 | def dense_features(self) -> FeatureSet[DenseFeature]: 76 | return FeatureSet[DenseFeature]( 77 | [f for f in self if isinstance(f, DenseFeature)] 78 | ) 79 | 80 | def sparse_features(self) -> FeatureSet[SparseFeature]: 81 | return FeatureSet[SparseFeature]( 82 | [f for f in self if isinstance(f, SparseFeature)] 83 | ) 84 | 85 | def __iter__(self) -> Iterator[FeatureT]: 86 | return iter(self.features) 87 | 88 | def __or__(self, other: FeatureSet[Feature]) -> FeatureSet[Feature]: 89 | return FeatureSet([*self.features, *other.features]) 90 | 91 | 92 | class DLRMModel(nn.Module): 93 | """DLRM DCN v2 model.""" 94 | 95 | features: FeatureSet 96 | embedding_optimizer: sparsecore.OptimizerSpec 97 | bottom_mlp_dims: Sequence[int] 98 | top_mlp_dims: Sequence[int] 99 | dcn_layers: int 100 | dcn_inner_dim: int 101 | 102 | # We need to track the embedder on the Flax module to ensure it is not 103 | # re-created on cloning. It is not possible to create an embedder inside 104 | # setup() because it is called lazily at compile time. The embedder needs 105 | # to be created before `model.init` so we can use it to create a preprocessor. 106 | # A simpler pattern that works is passing `embedder` directly to the module. 107 | _embedder: sparsecore.SparsecoreEmbedder | None = None 108 | 109 | @property 110 | def embedder(self) -> sparsecore.SparsecoreEmbedder: 111 | if self._embedder is not None: 112 | return self._embedder 113 | 114 | embedder = sparsecore.SparsecoreEmbedder( 115 | specs={ 116 | f.name: sparsecore.EmbeddingSpec( 117 | input_dim=f.vocab_size, 118 | embedding_dim=f.embedding_dim, 119 | max_sequence_length=f.max_sequence_length, 120 | combiner=f.combiner, 121 | ) 122 | for f in self.features.sparse_features() 123 | }, 124 | optimizer=self.embedding_optimizer, 125 | ) 126 | object.__setattr__(self, '_embedder', embedder) 127 | return embedder 128 | 129 | def bottom_mlp(self, inputs: Mapping[str, jt.Array]) -> jt.Array: 130 | x = jnp.concatenate( 131 | [inputs[f.name] for f in self.features.dense_features()], axis=-1 132 | ) 133 | 134 | for dim in self.bottom_mlp_dims: 135 | x = nn.Dense(dim)(x) 136 | x = nn.relu(x) 137 | return x 138 | 139 | def top_mlp(self, x: jt.Array) -> jt.Array: 140 | for dim in self.top_mlp_dims[:-1]: 141 | x = nn.Dense(dim)(x) 142 | x = nn.relu(x) 143 | 144 | x = nn.Dense(self.top_mlp_dims[-1])(x) 145 | return x 146 | 147 | def dcn(self, x0: jt.Array) -> jt.Array: 148 | xl = x0 149 | input_dim = x0.shape[-1] 150 | 151 | for i in range(self.dcn_layers): 152 | u_kernel = self.param( 153 | f'u_kernel_{i}', 154 | nn.initializers.xavier_normal(), 155 | (input_dim, self.dcn_inner_dim), 156 | ) 157 | v_kernel = self.param( 158 | f'v_kernel_{i}', 159 | nn.initializers.xavier_normal(), 160 | (self.dcn_inner_dim, input_dim), 161 | ) 162 | bias = self.param(f'bias_{i}', nn.initializers.zeros, (input_dim,)) 163 | 164 | u = jnp.matmul(xl, u_kernel) 165 | v = jnp.matmul(u, v_kernel) 166 | v += bias 167 | 168 | xl = x0 * v + xl 169 | 170 | return xl 171 | 172 | @nn.compact 173 | def __call__( 174 | self, inputs: Mapping[str, jt.Array], training: bool = False 175 | ) -> jt.Array: 176 | dense_embeddings = self.bottom_mlp(inputs) 177 | sparse_embeddings = self.embedder.make_sparsecore_module()(inputs) 178 | sparse_embeddings = jax.tree.flatten(sparse_embeddings)[0] 179 | concatenated_embeddings = jnp.concatenate( 180 | (dense_embeddings, *sparse_embeddings), axis=-1 181 | ) 182 | interaction_outputs = self.dcn(concatenated_embeddings) 183 | predictions = self.top_mlp(interaction_outputs) 184 | predictions = jnp.reshape(predictions, (-1,)) 185 | return predictions 186 | 187 | 188 | class CriteoFactory(recml.Factory[tf.data.Dataset]): 189 | """Data loader for dummy Criteo data optimized for Jax training.""" 190 | 191 | features: FeatureSet 192 | global_batch_size: int 193 | use_cached_data: bool = False 194 | 195 | def make(self) -> tf.data.Dataset: 196 | data = {} 197 | batch_size = self.global_batch_size // jax.process_count() 198 | 199 | for f in self.features.dense_features(): 200 | feature = np.random.normal(0.0, 1.0, size=(batch_size, 1)) 201 | data[f.name] = feature.astype(np.float32) 202 | 203 | for f in self.features.sparse_features(): 204 | non_zero_mask = ( 205 | np.random.normal(size=(batch_size, f.embedding_dim)) > f.sparsity 206 | ) 207 | sparse_feature = np.random.randint( 208 | low=0, 209 | high=f.vocab_size, 210 | size=(batch_size, f.embedding_dim), 211 | ) 212 | sparse_feature = np.where( 213 | non_zero_mask, sparse_feature, np.zeros_like(sparse_feature) 214 | ) 215 | data[f.name] = tf.constant(sparse_feature, dtype=tf.int64) 216 | 217 | label = np.random.randint(0, 2, size=(batch_size,)) 218 | 219 | dataset = tf.data.Dataset.from_tensors((data, label)) 220 | dataset = dataset.take(1).repeat() 221 | dataset = dataset.prefetch(buffer_size=2048) 222 | options = tf.data.Options() 223 | options.deterministic = False 224 | options.threading.private_threadpool_size = 96 225 | dataset = dataset.with_options(options) 226 | return dataset 227 | 228 | 229 | @dataclasses.dataclass 230 | class PredictionTask(recml.JaxTask): 231 | """Prediction task.""" 232 | 233 | train_data: CriteoFactory 234 | eval_data: CriteoFactory 235 | model: DLRMModel 236 | optimizer: recml.Factory[optax.GradientTransformation] 237 | 238 | def create_datasets(self) -> tuple[recml.data.Iterator, recml.data.Iterator]: 239 | global_batch_size = self.train_data.global_batch_size 240 | train_iter = recml.data.TFDatasetIterator( 241 | dataset=self.train_data.make(), 242 | postprocessor=self.model.embedder.make_preprocessor(global_batch_size), 243 | ) 244 | eval_iter = recml.data.TFDatasetIterator( 245 | dataset=self.eval_data.make(), 246 | postprocessor=self.model.embedder.make_preprocessor(global_batch_size), 247 | ) 248 | return train_iter, eval_iter 249 | 250 | def create_state(self, batch: jt.PyTree, rng: jt.Array) -> recml.JaxState: 251 | inputs, _ = batch 252 | params = self.model.init(rng, inputs) 253 | optimizer = self.optimizer.make() 254 | return recml.JaxState.create(params=params, tx=optimizer) 255 | 256 | def train_step( 257 | self, batch: jt.PyTree, state: recml.JaxState, rng: jt.Array 258 | ) -> tuple[recml.JaxState, Mapping[str, recml.Metric]]: 259 | inputs, label = batch 260 | 261 | def _loss_fn(params: jt.PyTree) -> tuple[jt.Scalar, jt.Array]: 262 | logits = self.model.apply(params, inputs, training=True) 263 | loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0) 264 | return loss, logits 265 | 266 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True, allow_int=True) 267 | (loss, logits), grads = grad_fn(state.params) 268 | state = state.update(grads=grads) 269 | 270 | metrics = { 271 | 'loss': recml.metrics.scalar(loss), 272 | 'accuracy': recml.metrics.binary_accuracy(label, logits, threshold=0.0), 273 | 'auc': recml.metrics.aucpr(label, logits, from_logits=True), 274 | 'aucroc': recml.metrics.aucroc(label, logits, from_logits=True), 275 | 'label/mean': recml.metrics.mean(label), 276 | 'prediction/mean': recml.metrics.mean(jax.nn.sigmoid(logits)), 277 | } 278 | return state, metrics 279 | 280 | def eval_step( 281 | self, batch: jt.PyTree, state: recml.JaxState 282 | ) -> Mapping[str, recml.Metric]: 283 | inputs, label = batch 284 | logits = self.model.apply(state.params, inputs, training=False) 285 | loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0) 286 | 287 | metrics = { 288 | 'loss': recml.metrics.mean(loss), 289 | 'accuracy': recml.metrics.binary_accuracy(label, logits, threshold=0.0), 290 | 'auc': recml.metrics.aucpr(label, logits, from_logits=True), 291 | 'aucroc': recml.metrics.aucroc(label, logits, from_logits=True), 292 | 'label/mean': recml.metrics.mean(label), 293 | 'prediction/mean': recml.metrics.mean(jax.nn.sigmoid(logits)), 294 | } 295 | return metrics 296 | 297 | 298 | def features() -> fdl.Config[FeatureSet]: 299 | """Creates a feature collection for the DLRM model.""" 300 | table_sizes = [ 301 | (40000000, 3), 302 | (39060, 2), 303 | (17295, 1), 304 | (7424, 2), 305 | (20265, 6), 306 | (3, 1), 307 | (7122, 1), 308 | (1543, 1), 309 | (63, 1), 310 | (40000000, 7), 311 | (3067956, 3), 312 | (405282, 8), 313 | (10, 1), 314 | (2209, 6), 315 | (11938, 9), 316 | (155, 5), 317 | (4, 1), 318 | (976, 1), 319 | (14, 1), 320 | (40000000, 12), 321 | (40000000, 100), 322 | (40000000, 27), 323 | (590152, 10), 324 | (12973, 3), 325 | (108, 1), 326 | (36, 1), 327 | ] 328 | return fdl.Config( 329 | FeatureSet, 330 | features=[ 331 | fdl.Config(DenseFeature, name=f'float-feature-{i}') for i in range(13) 332 | ] 333 | + [ 334 | fdl.Config( 335 | SparseFeature, 336 | vocab_size=vocab_size, 337 | embedding_dim=embedding_dim, 338 | name=f'categorical-feature-{i}', 339 | ) 340 | for i, (vocab_size, embedding_dim) in enumerate(table_sizes) 341 | ], 342 | ) 343 | 344 | 345 | def experiment() -> fdl.Config[recml.Experiment]: 346 | """DLRM experiment.""" 347 | 348 | feature_set = features() 349 | 350 | task = fdl.Config( 351 | PredictionTask, 352 | train_data=fdl.Config( 353 | CriteoFactory, 354 | features=feature_set, 355 | global_batch_size=131_072, 356 | ), 357 | eval_data=fdl.Config( 358 | CriteoFactory, 359 | features=feature_set, 360 | global_batch_size=131_072, 361 | use_cached_data=True, 362 | ), 363 | model=fdl.Config( 364 | DLRMModel, 365 | features=feature_set, 366 | embedding_optimizer=fdl.Config( 367 | embedding_spec.AdagradOptimizerSpec, 368 | learning_rate=0.01, 369 | ), 370 | bottom_mlp_dims=[512, 256, 128], 371 | top_mlp_dims=[1024, 1024, 512, 256, 1], 372 | dcn_layers=3, 373 | dcn_inner_dim=512, 374 | ), 375 | optimizer=fdl.Config( 376 | recml.AdagradFactory, 377 | learning_rate=0.01, 378 | # Sparsecore embedding parameters are optimized in the backward pass. 379 | freeze_mask=rf'.*{sparsecore.EMBEDDING_PARAM_NAME}.*', 380 | ), 381 | ) 382 | trainer = fdl.Config( 383 | recml.JaxTrainer, 384 | partitioner=fdl.Config(recml.DataParallelPartitioner), 385 | train_steps=1_000, 386 | steps_per_eval=100, 387 | steps_per_loop=100, 388 | ) 389 | return fdl.Config(recml.Experiment, task=task, trainer=trainer) 390 | -------------------------------------------------------------------------------- /recml/examples/dlrm_experiment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for the DLRM experiment.""" 15 | 16 | from absl.testing import absltest 17 | import fiddle as fdl 18 | from fiddle import selectors 19 | import jax 20 | import numpy as np 21 | import recml 22 | from recml.examples import dlrm_experiment 23 | 24 | 25 | class DLRMExperimentTest(absltest.TestCase): 26 | 27 | def test_dlrm_experiment(self): 28 | if jax.devices()[0].platform != "tpu": 29 | self.skipTest("Test only supported on TPUs.") 30 | 31 | np.random.seed(1337) 32 | 33 | experiment = dlrm_experiment.experiment() 34 | 35 | experiment.task.train_data.global_batch_size = 4 36 | experiment.task.eval_data.global_batch_size = 4 37 | experiment.trainer.train_steps = 12 38 | experiment.trainer.steps_per_loop = 4 39 | experiment.trainer.steps_per_eval = 4 40 | 41 | for cfg in selectors.select(experiment, dlrm_experiment.SparseFeature): 42 | cfg.vocab_size = 200 43 | cfg.embedding_dim = 8 44 | 45 | experiment = fdl.build(experiment) 46 | recml.run_experiment(experiment, recml.Experiment.Mode.TRAIN_AND_EVAL) 47 | 48 | 49 | if __name__ == "__main__": 50 | absltest.main() 51 | -------------------------------------------------------------------------------- /recml/layers/keras/README.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | ### SASRec 4 | 5 | Uses self-attention to predict a user's next action based on their past 6 | activities. It aims to understand long-term user interests while also making 7 | good predictions based on just the most recent actions. It smartly adapts which 8 | past actions to focus on depending on how much history a user has. Built 9 | entirely with efficient attention blocks, SASRec avoids the complex structures 10 | of older RNN or CNN models, leading to faster training and better performance on 11 | diverse datasets. 12 | 13 | #### Architecture Overview 14 | 15 | - **Embedding Layer** - Converts item IDs into dense vectors. Adds a learnable 16 | absolute positional embedding to the item embedding to incorporate sequence 17 | order information. Dropout is applied to the combined embedding. 18 | - **Multi-Head Self-Attention Layer** Computes attention scores between all 19 | pairs of items within the allowed sequence window. Employs causality by 20 | masking out attention to future positions to prevent information leakage 21 | when training with a causal prediction objective. 22 | - **Feed-Forward Network** Applied independently to each embedding vector 23 | output by the attention layer. Uses two linear layers with a GeLU activation 24 | in between to add non-linearity. 25 | - **Residual Connections and Pre-Layernorm** Applied around both the 26 | self-attention and feed-forward network sub-layers for stable and faster 27 | training of deeper models. Dropout is also used within the block. 28 | - **Prediction Head** Decodes the sequence embeddings into logits using the 29 | input item embedding table and computes a causal categorical cross entropy 30 | loss between the inputs and the inputs shifted right. 31 | 32 | ### BERT4Rec 33 | 34 | Models how user preferences change based on their past actions for 35 | recommendations. Unlike older methods that only look at history in chronological 36 | order, BERT4Rec uses a transformer based approach to look at the user's sequence 37 | of actions in both directions. This helps capture context better, as user 38 | behavior isn't always strictly ordered. To learn effectively, it is trained 39 | using a mask prediction objective: some items are randomly masked and the model 40 | learns to predict them based on the context.. BERT4Rec consistently performs 41 | better than many standard sequential models. 42 | 43 | #### Architecture Overview 44 | 45 | - **Embedding Layer** - Converts item IDs into dense vectors. Adds a learnable 46 | absolute positional embedding to the item embedding to incorporate sequence 47 | order information. An optional type embedding can be added to the item 48 | embedding. Embedding dropout is applied to the combined embedding. Uses a 49 | separate embedding for masked features to prevent other item tokens from 50 | attending to them. 51 | 52 | - **Multi-Head Self-Attention Layer** Computes attention scores between all 53 | pairs of items within the allowed sequence window. Uses a separate embedding 54 | for masked features to prevent other item tokens from attending to them. 55 | 56 | - **Feed-Forward Network** Applied independently to each embedding vector 57 | output by the attention layer. Uses two linear layers with a GeLU activation 58 | in between to add non-linearity. 59 | 60 | - **Residual Connections and Post-Layernorm** Applied around both the 61 | self-attention and feed-forward network sub-layers for stable and faster 62 | training of deeper models. Dropout is also used within the block. 63 | 64 | - **Masked Prediction Head** Gathers and projects the masked sequence 65 | embeddings, and decodes them using the item embedding layer. Computes a 66 | categorical cross entropy loss between the masked item ids and the predicted 67 | logits for the corresponding masked item embeddings. 68 | 69 | ### HSTU 70 | 71 | HSTU is a novel architecture designed for sequential recommendation, 72 | particularly suited for high cardinality, non-stationary streaming data. It 73 | reformulates recommendation as a sequential transduction task within a 74 | generative modeling framework -"Generative Recommenders". HSTU aims to provide 75 | state-of-the-art results while being highly scalable and efficient, capable of 76 | handling models with up to trillions of parameters. It has demonstrated 77 | significant improvements over baselines in offline benchmarks and online A/B 78 | tests, leading to deployment on large-scale internet platforms. 79 | 80 | #### Architecture Overview 81 | 82 | - **Embedding Layer** Converts various action tokens into dense vectors in the 83 | same space. Optionally, adds a learnable absolute positional embedding to 84 | incorporate sequence order information. Embedding dropout is applied to the 85 | combined embedding. 86 | - **Gated Pointwise Aggregated Attention** - Uses a multi-head gated pointwise 87 | attention mechanism with a Layernorm on the attention outputs before 88 | projecting them. This captures the intensity of interactions between 89 | actions, which is lost in softmax attention. 90 | - **Relative Attention Bias** - Uses a T5 style relative attention bias 91 | computed using the positions and timestamps of the actions to improve the 92 | position encoding. 93 | - **Residual Connections and Pre-Layernorm** Applied around both the pointwise 94 | attention blocks for stable and faster training of deeper models. 95 | - **No Feedforward Network** - The feedforward network is removed. 96 | - **Prediction Head** - Decodes the sequence embeddings into logits using 97 | separately learnt weights and computes a causal categorical cross entropy 98 | loss between the inputs and the inputs shifted right. 99 | 100 | ### Mamba4Rec 101 | 102 | A linear recurrent Mamba 2 architecture to model sequences of items for 103 | recommendations. This scales better on longer sequences than attention based 104 | methods due to its linear complexity compared to the former's quadratic 105 | complexity. Mamba4Rec performs better than RNNs and matches the quality of 106 | standard attention models while being more efficient at both training and 107 | inference time. 108 | 109 | #### Architecture Overview 110 | 111 | - **Embedding Layer** Converts item IDs into dense vectors. No position 112 | embedding is used since the recurrent nature of Mamba inherently encodes 113 | positional information as an inductive bias. 114 | - **Mamba SSD** Computes a causal interaction between different item 115 | embeddings in the sequence using the Mamba state space duality algorithm. 116 | - **Feedforward Network** Applied independently to each embedding vector 117 | output by the Mamba layer. Uses two linear layers with a GeLU activation in 118 | between to add non-linearity. 119 | - **Residual Connections and Post-Layernorm** Applied around both the Mamba 120 | and feed-forward network sub-layers for stable and faster training of deeper 121 | models. Dropout is also used within the block. 122 | - **Prediction Head** Decodes the sequence embeddings into logits using the 123 | input item embedding table and computes a causal categorical cross entropy 124 | loss between the inputs and the inputs shifted right. 125 | 126 | ## References 127 | 128 | - SASRec Paper: Kang, W. C., & McAuley, J. (2018). Self-Attentive Sequential 129 | Recommendation. arXiv preprint arXiv:1808.09781v1. 130 | https://arxiv.org/abs/1808.09781 131 | - Transformer Paper: Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., 132 | Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you 133 | need. Advances in neural information processing systems, 30. 134 | - Mamba4Rec Paper: Liu, C., Lin, J., Liu, H., Wang, J., & Caverlee, J. (2024). 135 | Mamba4Rec: Towards Efficient Sequential Recommendation with Selective State 136 | Space Models. arXiv preprint arXiv:2403.03900v2. 137 | https://arxiv.org/abs/2403.03900 138 | - Mamba Paper: Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling 139 | with Selective State Spaces. arXiv preprint arXiv:2312.00752. 140 | - BERT4Rec Paper: Sun, F., Liu, J., Wu, J., Pei, C., Lin, X., Ou, W., & Jiang, 141 | P. (2019). BERT4Rec: Sequential Recommendation with Bidirectional Encoder 142 | Representations from Transformer. arXiv preprint arXiv:1904.06690v2. 143 | https://arxiv.org/abs/1904.06690 144 | - BERT Paper: Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). BERT: 145 | Pre-training of Deep Bidirectional Transformers for Language Understanding. 146 | arXiv preprint arXiv:1810.04805. 147 | - HSTU Paper: Actions Speak Louder than Words: Trillion-Parameter Sequential 148 | Transducers for Generative Recommendations (arXiv:2402.17152) 149 | -------------------------------------------------------------------------------- /recml/layers/keras/bert4rec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Models baselined.""" 15 | 16 | from collections.abc import Mapping, Sequence 17 | from typing import Any 18 | 19 | import keras 20 | import keras_hub 21 | from recml.layers.keras import utils 22 | 23 | Tensor = Any 24 | 25 | 26 | @keras.saving.register_keras_serializable("recml") 27 | class BERT4Rec(keras.layers.Layer): 28 | """BERT4Rec architecture as in [1]. 29 | 30 | Implements the BERT4Rec model architecture as described in 'BERT4Rec: 31 | Sequential Recommendation with Bidirectional Encoder Representations from 32 | Transformer' [1]. 33 | 34 | [1] https://arxiv.org/abs/1904.06690 35 | """ 36 | 37 | def __init__( 38 | self, 39 | *, 40 | vocab_size: int, 41 | max_positions: int, 42 | num_types: int | None = None, 43 | model_dim: int, 44 | mlp_dim: int, 45 | num_heads: int, 46 | num_layers: int, 47 | dropout: float = 0.0, 48 | norm_eps: float = 1e-12, 49 | add_head: bool = True, 50 | **kwargs, 51 | ): 52 | """Initializes the instance. 53 | 54 | Args: 55 | vocab_size: The size of the item vocabulary. 56 | max_positions: The maximum number of positions in a sequence. 57 | num_types: The number of types. If None, no type embedding is used. 58 | Defaults to None. 59 | model_dim: The width of the embeddings in the model. 60 | mlp_dim: The width of the MLP in each transformer block. 61 | num_heads: The number of attention heads in each transformer block. 62 | num_layers: The number of transformer blocks in the model. 63 | dropout: The dropout rate. Defaults to 0. 64 | norm_eps: The epsilon for layer normalization. 65 | add_head: Whether to add a masked language modeling head. 66 | **kwargs: Passed through to the super class. 67 | """ 68 | 69 | super().__init__(**kwargs) 70 | 71 | self.item_embedding = keras_hub.layers.ReversibleEmbedding( 72 | input_dim=vocab_size, 73 | output_dim=model_dim, 74 | embeddings_initializer=keras.initializers.TruncatedNormal(stddev=0.02), 75 | dtype=self.dtype_policy, 76 | reverse_dtype=self.compute_dtype, 77 | name="item_embedding", 78 | ) 79 | if num_types is not None: 80 | self.type_embedding = keras.layers.Embedding( 81 | input_dim=num_types, 82 | output_dim=model_dim, 83 | embeddings_initializer=keras.initializers.TruncatedNormal( 84 | stddev=0.02 85 | ), 86 | dtype=self.dtype_policy, 87 | name="type_embedding", 88 | ) 89 | else: 90 | self.type_embedding = None 91 | 92 | self.position_embedding = keras_hub.layers.PositionEmbedding( 93 | sequence_length=max_positions, 94 | initializer=keras.initializers.TruncatedNormal(stddev=0.02), 95 | dtype=self.dtype_policy, 96 | name="position_embedding", 97 | ) 98 | 99 | self.embeddings_norm = keras.layers.LayerNormalization( 100 | epsilon=1e-12, name="embedding_norm" 101 | ) 102 | self.embeddings_dropout = keras.layers.Dropout( 103 | dropout, name="embedding_dropout" 104 | ) 105 | 106 | self.encoder_blocks = [ 107 | keras_hub.layers.TransformerEncoder( 108 | intermediate_dim=mlp_dim, 109 | num_heads=num_heads, 110 | dropout=dropout, 111 | activation=utils.gelu_approximate, 112 | layer_norm_epsilon=norm_eps, 113 | normalize_first=False, 114 | dtype=self.dtype_policy, 115 | name=f"encoder_block_{i}", 116 | ) 117 | for i in range(num_layers) 118 | ] 119 | if add_head: 120 | self.head = keras_hub.layers.MaskedLMHead( 121 | vocabulary_size=vocab_size, 122 | token_embedding=self.item_embedding, 123 | intermediate_activation=utils.gelu_approximate, 124 | kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02), 125 | dtype=self.dtype_policy, 126 | name="mlm_head", 127 | ) 128 | else: 129 | self.head = None 130 | 131 | self._vocab_size = vocab_size 132 | self._model_dim = model_dim 133 | self._config = { 134 | "vocab_size": vocab_size, 135 | "max_positions": max_positions, 136 | "num_types": num_types, 137 | "model_dim": model_dim, 138 | "mlp_dim": mlp_dim, 139 | "num_heads": num_heads, 140 | "num_layers": num_layers, 141 | "dropout": dropout, 142 | "norm_eps": norm_eps, 143 | "add_head": add_head, 144 | } 145 | 146 | def build(self, inputs_shape: Sequence[int]): 147 | self.item_embedding.build(inputs_shape) 148 | if self.type_embedding is not None: 149 | self.type_embedding.build(inputs_shape) 150 | 151 | self.position_embedding.build((*inputs_shape, self._model_dim)) 152 | self.embeddings_norm.build((*inputs_shape, self._model_dim)) 153 | 154 | for encoder_block in self.encoder_blocks: 155 | encoder_block.build((*inputs_shape, self._model_dim)) 156 | 157 | if self.head is not None: 158 | self.head.build((*inputs_shape, self._model_dim)) 159 | 160 | def call( 161 | self, 162 | inputs: Tensor, 163 | type_ids: Tensor | None = None, 164 | padding_mask: Tensor | None = None, 165 | attention_mask: Tensor | None = None, 166 | mask_positions: Tensor | None = None, 167 | training: bool = False, 168 | ) -> Tensor: 169 | embeddings = self.item_embedding(inputs) 170 | if self.type_embedding is not None: 171 | if type_ids is None: 172 | raise ValueError( 173 | "`type_ids` cannot be None when `num_types` is not None." 174 | ) 175 | embeddings += self.type_embedding(type_ids) 176 | embeddings += self.position_embedding(embeddings) 177 | 178 | embeddings = self.embeddings_norm(embeddings) 179 | embeddings = self.embeddings_dropout(embeddings, training=training) 180 | 181 | for encoder_block in self.encoder_blocks: 182 | embeddings = encoder_block( 183 | embeddings, 184 | padding_mask=padding_mask, 185 | attention_mask=attention_mask, 186 | training=training, 187 | ) 188 | 189 | if self.head is None: 190 | return embeddings 191 | 192 | return self.head(embeddings, mask_positions) 193 | 194 | def compute_output_shape( 195 | self, 196 | inputs_shape: Sequence[int], 197 | mask_positions_shape: Tensor | None = None, 198 | ) -> Sequence[int | None]: 199 | if self.head is not None: 200 | if mask_positions_shape is None: 201 | raise ValueError( 202 | "`mask_positions_shape` cannot be None when `add_head` is True." 203 | ) 204 | return (*inputs_shape[:-1], mask_positions_shape[-1], self._vocab_size) 205 | return (*inputs_shape, self._model_dim) 206 | 207 | def get_config(self) -> Mapping[str, Any]: 208 | return {**super().get_config(), **self._config} 209 | -------------------------------------------------------------------------------- /recml/layers/keras/bert4rec_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Keras architectures.""" 15 | 16 | from absl.testing import absltest 17 | import keras 18 | from keras.src import testing 19 | from recml.layers.keras import bert4rec 20 | 21 | 22 | class BERT4RecTest(testing.TestCase): 23 | 24 | def test_bert4rec(self): 25 | item_ids = keras.ops.array([[1, 2, 3], [4, 5, 0]], "int32") 26 | item_type_ids = keras.ops.array([[1, 2, 3], [4, 4, 0]], "int32") 27 | mask = keras.ops.array([[1, 1, 1], [1, 1, 0]], "int32") 28 | mask_positions = keras.ops.array([[0], [0]], "int32") 29 | init_kws = { 30 | "vocab_size": 500, 31 | "num_types": 5, 32 | "max_positions": 20, 33 | "model_dim": 32, 34 | "mlp_dim": 64, 35 | "num_heads": 4, 36 | "num_layers": 3, 37 | "dropout": 0.1, 38 | } 39 | 40 | tvars = ( 41 | (500 * 32) # Item embedding 42 | + (5 * 32) # Type embedding 43 | + (20 * 32) # Position embedding 44 | + (2 * 32) # Embedding norm 45 | + 3 # 3 encoder blocks 46 | * ( 47 | ((32 + 1) * 32 * 3 + (32 + 1) * 32) # Attention QKVO 48 | + (2 * 32) # Attention block norm 49 | + ((32 + 1) * 64) # MLP inner projection 50 | + ((64 + 1) * 32) # MLP outer projection 51 | + (2 * 32) # MLP block norm 52 | ) 53 | + (32 + 1) * 32 # Head projection 54 | + (2 * 32) # Head norm 55 | + 500 # Head bias 56 | ) 57 | seed_generators = 1 + 3 * 3 # 1 seed generator for each dropout layer. 58 | 59 | model = bert4rec.BERT4Rec(**init_kws) 60 | model.build(keras.ops.shape(item_ids)) 61 | self.assertEqual(model.count_params(), tvars) 62 | 63 | self.run_layer_test( 64 | bert4rec.BERT4Rec, 65 | init_kwargs={**init_kws, "add_head": False}, 66 | input_data=item_ids, 67 | call_kwargs={ 68 | "type_ids": item_type_ids, 69 | "padding_mask": mask, 70 | "mask_positions": mask_positions, 71 | }, 72 | expected_output_shape=(2, 3, 32), 73 | expected_output_dtype="float32", 74 | expected_num_seed_generators=seed_generators, 75 | run_training_check=False, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /recml/layers/keras/hstu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for the HSTU implementation.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import keras 19 | from keras.src import testing 20 | import numpy as np 21 | from recml.layers.keras import hstu 22 | 23 | 24 | class HSTUTest(testing.TestCase): 25 | 26 | def test_hstu(self): 27 | item_ids = keras.ops.array([[1, 2, 3], [4, 5, 0]], "int32") 28 | padding_mask = keras.ops.array([[1, 1, 1], [1, 1, 0]], "int32") 29 | init_kws = { 30 | "vocab_size": 500, 31 | "max_positions": 20, 32 | "model_dim": 32, 33 | "num_heads": 4, 34 | "num_layers": 3, 35 | "dropout": 0.1, 36 | } 37 | 38 | tvars = ( 39 | (500 * 32) # Item embedding 40 | + (20 * 32) # Position embedding 41 | + 3 # 3 decoder blocks 42 | * ( 43 | (32 * 32 * 4 + 32 * 32) # UQKV + output 44 | + (2 * 32) * 2 # Input + attention Layer norms. 45 | ) 46 | + (2 * 32) # Final norm 47 | + (500 * 32) # Output embedding 48 | ) 49 | seed_generators = 1 + 3 # 1 seed generator for each dropout layer. 50 | model = hstu.HSTU(**init_kws) 51 | model.build(keras.ops.shape(item_ids)) 52 | self.assertEqual(model.count_params(), tvars) 53 | 54 | self.run_layer_test( 55 | hstu.HSTU, 56 | init_kwargs=init_kws, 57 | input_data=item_ids, 58 | call_kwargs={"padding_mask": padding_mask}, 59 | expected_output_shape=(2, 3, 500), 60 | expected_output_dtype="float32", 61 | expected_num_seed_generators=seed_generators, 62 | run_training_check=False, 63 | ) 64 | 65 | 66 | if __name__ == "__main__": 67 | absltest.main() 68 | -------------------------------------------------------------------------------- /recml/layers/keras/mamba_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for the Mamba implementation.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import einops 19 | import keras 20 | from keras.src import testing 21 | import numpy as np 22 | from recml.layers.keras import mamba 23 | import tensorflow as tf 24 | import torch 25 | import torch.nn.functional as F 26 | 27 | 28 | # originally found here: 29 | # https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py 30 | def segsum(x): 31 | """More stable segment sum calculation.""" 32 | t = x.size(-1) 33 | x = einops.repeat(x, "... d -> ... d e", e=t) 34 | mask = torch.tril(torch.ones(t, t, device=x.device, dtype=bool), diagonal=-1) 35 | x = x.masked_fill(~mask, 0) 36 | x_segsum = torch.cumsum(x, dim=-2) 37 | mask = torch.tril(torch.ones(t, t, device=x.device, dtype=bool), diagonal=0) 38 | x_segsum = x_segsum.masked_fill(~mask, -torch.inf) 39 | return x_segsum 40 | 41 | 42 | # originally found here: 43 | # https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py 44 | def ssd_minimal_discrete(x, a, b, c, block_len, initial_states=None): 45 | """Original Pytorch implementation of Mamba2. 46 | 47 | Args: 48 | x: (batch, length, n_heads, d_head) 49 | a: (batch, length, n_heads) 50 | b: (batch, length, n_groups, d_state) 51 | c: (batch, length, n_groups, d_state) 52 | block_len: int 53 | initial_states: tensor of initial state values. 54 | 55 | Returns: 56 | Y: (batch, length, n_heads, d_head) 57 | """ 58 | assert x.dtype == a.dtype == b.dtype == c.dtype 59 | assert x.shape[1] % block_len == 0 60 | 61 | # Rearrange into blocks/chunks 62 | x, a, b, c = [ 63 | einops.rearrange(x, "b (c l) ... -> b c l ...", l=block_len) 64 | for x in (x, a, b, c) 65 | ] 66 | 67 | a = einops.rearrange(a, "b c l h -> b h c l") 68 | a_cumsum = torch.cumsum(a, dim=-1) 69 | 70 | # 1. Compute the output for each intra-chunk (diagonal blocks) 71 | length = torch.exp(segsum(a)) 72 | y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", c, b, length, x) 73 | 74 | # 2. Compute the state for each intra-chunk 75 | # (right term of low-rank factorization of off-diagonal blocks; B terms) 76 | decay_states = torch.exp((a_cumsum[:, :, :, -1:] - a_cumsum)) 77 | states = torch.einsum("bclhn,bhcl,bclhp->bchpn", b, decay_states, x) 78 | 79 | # 3. Compute the inter-chunk SSM recurrence; 80 | # produces correct SSM states at chunk boundaries 81 | # (middle term of factorization of off-diag blocks; A terms) 82 | if initial_states is None: 83 | initial_states = torch.zeros_like(states[:, :1]) 84 | states = torch.cat([initial_states, states], dim=1) 85 | decay_chunk = torch.exp(segsum(F.pad(a_cumsum[:, :, :, -1], (1, 0)))) 86 | new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) 87 | states, final_state = new_states[:, :-1], new_states[:, -1] 88 | 89 | # 4. Compute state -> output conversion per chunk 90 | # (left term of low-rank factorization of off-diagonal blocks; C terms) 91 | state_decay_out = torch.exp(a_cumsum) 92 | y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", c, states, state_decay_out) 93 | 94 | # Add output of intra-chunk any_offer-chunk terms 95 | # (diagonal and off-diagonal blocks) 96 | y = einops.rearrange(y_diag + y_off, "b c l h p -> b (c l) h p") 97 | return y, final_state 98 | 99 | 100 | class MambaSSDTest(testing.TestCase): 101 | 102 | # Simple equivalence test 103 | @parameterized.parameters(dict(seed=40), dict(seed=50), dict(seed=70)) 104 | def test_ssd_correctness(self, seed: int): 105 | keras.utils.set_random_seed(seed) 106 | 107 | ## Dimensions 108 | # Denoted (B, T, Q, D, P) in the paper 109 | batch, seqlen, chunk_size, dim, nheads = 1, 2048, 64, 2048, 32 110 | ngroups = 1 # (G) in the paper 111 | dstate = 64 # (N) in the paper 112 | 113 | dtype = "float32" 114 | x = keras.random.normal((batch, seqlen, nheads, dim // nheads), dtype=dtype) 115 | dt = keras.ops.nn.softplus( 116 | keras.random.normal((batch, seqlen, nheads), dtype=dtype) - 4 117 | ) 118 | a = keras.ops.multiply( 119 | -1, keras.ops.exp(keras.random.normal((nheads,), dtype=dtype)) 120 | ) 121 | b = keras.random.normal((batch, seqlen, ngroups, dstate), dtype=dtype) 122 | c = keras.random.normal((batch, seqlen, ngroups, dstate), dtype=dtype) 123 | 124 | torch_a = torch.tensor(np.array(a)) 125 | torch_b = torch.tensor(np.array(b)) 126 | torch_c = torch.tensor(np.array(c)) 127 | torch_dt = torch.tensor(np.array(dt)) 128 | torch_x = torch.tensor(np.array(x)) 129 | ground_truth = ssd_minimal_discrete( 130 | torch_x * torch_dt.unsqueeze(-1), 131 | torch_a * torch_dt, 132 | torch_b, 133 | torch_c, 134 | chunk_size, 135 | ) 136 | ours = mamba.ssd_minimal_discrete( 137 | keras.ops.multiply(x, keras.ops.expand_dims(dt, axis=-1)), 138 | keras.ops.multiply(a, dt), 139 | b, 140 | c, 141 | chunk_size, 142 | ) 143 | 144 | self.assertAllClose(ground_truth[0], ours, atol=1e-5, rtol=1e-5) 145 | 146 | 147 | class Mamba4RecTest(testing.TestCase): 148 | 149 | def test_mamba4rec(self): 150 | item_ids = keras.ops.array([[1, 2, 3, 4], [4, 5, 0, 0]], "int32") 151 | padding_mask = keras.ops.array([[1, 1, 1, 0], [1, 1, 0, 0]], "int32") 152 | init_kws = { 153 | "vocab_size": 500, 154 | "model_dim": 32, 155 | "mlp_expand": 4, 156 | "num_heads": 4, 157 | "num_layers": 3, 158 | "dropout": 0.1, 159 | "d_expand": 128, 160 | "d_state": 64, 161 | "d_conv": 4, 162 | "chunk_size": 2, 163 | } 164 | 165 | self.run_layer_test( 166 | mamba.Mamba4Rec, 167 | init_kwargs=init_kws, 168 | input_data=item_ids, 169 | call_kwargs={"padding_mask": padding_mask}, 170 | expected_output_shape=(2, 4, 500), 171 | expected_output_dtype="float32", 172 | expected_num_seed_generators=1 + 3 * 3, 173 | ) 174 | 175 | 176 | if __name__ == "__main__": 177 | absltest.main() 178 | -------------------------------------------------------------------------------- /recml/layers/keras/sasrec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Models baselined.""" 15 | 16 | from collections.abc import Mapping, Sequence 17 | from typing import Any 18 | 19 | import keras 20 | import keras_hub 21 | from recml.layers.keras import utils 22 | 23 | Tensor = Any 24 | 25 | 26 | @keras.saving.register_keras_serializable("recml") 27 | class SASRec(keras.layers.Layer): 28 | """SASRec architecture as in [1]. 29 | 30 | Implements the SASRec model architecture as described in 'Self-Attentive 31 | Sequential Recommendation' [1]. 32 | 33 | [1] https://arxiv.org/abs/1808.09781 34 | """ 35 | 36 | def __init__( 37 | self, 38 | *, 39 | vocab_size: int, 40 | max_positions: int, 41 | model_dim: int, 42 | mlp_dim: int, 43 | num_heads: int, 44 | num_layers: int, 45 | dropout: float = 0.0, 46 | norm_eps: float = 1e-6, 47 | scale_by_sqrt_dim: bool = False, 48 | add_head: bool = True, 49 | **kwargs, 50 | ): 51 | """Initializes the instance. 52 | 53 | Args: 54 | vocab_size: The size of the item vocabulary. 55 | max_positions: The maximum number of positions in a sequence. 56 | model_dim: The width of the embeddings in the model. 57 | mlp_dim: The width of the MLP in each transformer block. 58 | num_heads: The number of attention heads in each transformer block. 59 | num_layers: The number of transformer blocks in the model. 60 | dropout: The dropout rate. Defaults to 0. 61 | norm_eps: The epsilon for RMS normalization. 62 | scale_by_sqrt_dim: Whether to scale the item embeddings by 63 | sqrt(model_dim). Defaults to False. 64 | add_head: Whether to decode the sequence embeddings to logits. 65 | **kwargs: Passed through to the super class. 66 | """ 67 | super().__init__(**kwargs) 68 | 69 | self.item_embedding = keras_hub.layers.ReversibleEmbedding( 70 | input_dim=vocab_size, 71 | output_dim=model_dim, 72 | embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02), 73 | dtype=self.dtype_policy, 74 | reverse_dtype=self.compute_dtype, 75 | name="item_embedding", 76 | ) 77 | 78 | self.position_embedding = keras_hub.layers.PositionEmbedding( 79 | sequence_length=max_positions, 80 | initializer=keras.initializers.RandomNormal(stddev=0.02), 81 | dtype=self.dtype_policy, 82 | name="position_embedding", 83 | ) 84 | 85 | self.embeddings_dropout = keras.layers.Dropout( 86 | dropout, name="embedding_dropout" 87 | ) 88 | 89 | self.decoder_blocks = [ 90 | keras_hub.layers.TransformerDecoder( 91 | intermediate_dim=mlp_dim, 92 | num_heads=num_heads, 93 | dropout=dropout, 94 | activation=utils.gelu_approximate, 95 | layer_norm_epsilon=norm_eps, 96 | normalize_first=True, 97 | dtype=self.dtype_policy, 98 | name=f"decoder_block_{i}", 99 | ) 100 | for i in range(num_layers) 101 | ] 102 | self.final_norm = keras.layers.LayerNormalization( 103 | epsilon=norm_eps, name="final_norm" 104 | ) 105 | 106 | self._vocab_size = vocab_size 107 | self._model_dim = model_dim 108 | self._scale_by_sqrt_dim = scale_by_sqrt_dim 109 | self._add_head = add_head 110 | self._config = { 111 | "vocab_size": vocab_size, 112 | "max_positions": max_positions, 113 | "model_dim": model_dim, 114 | "mlp_dim": mlp_dim, 115 | "num_heads": num_heads, 116 | "num_layers": num_layers, 117 | "dropout": dropout, 118 | "norm_eps": norm_eps, 119 | "scale_by_sqrt_dim": scale_by_sqrt_dim, 120 | "add_head": add_head, 121 | } 122 | 123 | def build(self, inputs_shape: Sequence[int]): 124 | self.item_embedding.build(inputs_shape) 125 | self.position_embedding.build((*inputs_shape, self._model_dim)) 126 | 127 | for decoder_block in self.decoder_blocks: 128 | decoder_block.build((*inputs_shape, self._model_dim)) 129 | 130 | self.final_norm.build((*inputs_shape, self._model_dim)) 131 | 132 | def call( 133 | self, 134 | inputs: Tensor, 135 | padding_mask: Tensor | None = None, 136 | attention_mask: Tensor | None = None, 137 | mask_positions: Tensor | None = None, 138 | training: bool = False, 139 | ) -> Tensor: 140 | embeddings = self.item_embedding(inputs) 141 | if self._scale_by_sqrt_dim: 142 | embeddings *= keras.ops.cast( 143 | self._model_dim**0.5, keras.ops.dtype(embeddings) 144 | ) 145 | embeddings += self.position_embedding(embeddings) 146 | 147 | embeddings = self.final_norm(embeddings) 148 | embeddings = self.embeddings_dropout(embeddings, training=training) 149 | 150 | for decoder_block in self.decoder_blocks: 151 | embeddings = decoder_block( 152 | embeddings, 153 | decoder_padding_mask=padding_mask, 154 | decoder_attention_mask=attention_mask, 155 | training=training, 156 | ) 157 | 158 | embeddings = self.final_norm(embeddings) 159 | 160 | if not self._add_head: 161 | return embeddings 162 | 163 | return self.item_embedding(embeddings, reverse=True) 164 | 165 | def compute_output_shape(self, inputs_shape: Sequence[int]) -> Sequence[int]: 166 | output_dim = self._vocab_size if self._add_head else self._model_dim 167 | return (*inputs_shape, output_dim) 168 | 169 | def get_config(self) -> Mapping[str, Any]: 170 | return {**super().get_config(), **self._config} 171 | -------------------------------------------------------------------------------- /recml/layers/keras/sasrec_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Keras architectures.""" 15 | 16 | from absl.testing import absltest 17 | import keras 18 | from keras.src import testing 19 | from recml.layers.keras import sasrec 20 | 21 | 22 | class SASRecTest(testing.TestCase): 23 | 24 | def test_sasrec(self): 25 | item_ids = keras.ops.array([[1, 2, 3], [4, 5, 0]], "int32") 26 | padding_mask = keras.ops.array([[1, 1, 1], [1, 1, 0]], "int32") 27 | init_kws = { 28 | "vocab_size": 500, 29 | "max_positions": 20, 30 | "model_dim": 32, 31 | "mlp_dim": 64, 32 | "num_heads": 4, 33 | "num_layers": 3, 34 | "dropout": 0.1, 35 | } 36 | 37 | tvars = ( 38 | (500 * 32) # Item embedding 39 | + (20 * 32) # Position embedding 40 | + 3 # 3 decoder blocks 41 | * ( 42 | ((32 + 1) * 32 * 3 + (32 + 1) * 32) # Attention QKVO 43 | + (2 * 32) # Attention block norm 44 | + ((32 + 1) * 64) # MLP inner projection 45 | + ((64 + 1) * 32) # MLP outer projection 46 | + (2 * 32) # MLP block norm 47 | ) 48 | + (2 * 32) # Final norm 49 | ) 50 | seed_generators = 1 + 3 * 3 # 1 seed generator for each dropout layer. 51 | model = sasrec.SASRec(**init_kws) 52 | model.build(keras.ops.shape(item_ids)) 53 | self.assertEqual(model.count_params(), tvars) 54 | 55 | self.run_layer_test( 56 | sasrec.SASRec, 57 | init_kwargs=init_kws, 58 | input_data=item_ids, 59 | call_kwargs={"padding_mask": padding_mask}, 60 | expected_output_shape=(2, 3, 500), 61 | expected_output_dtype="float32", 62 | expected_num_seed_generators=seed_generators, 63 | run_training_check=False, 64 | ) 65 | 66 | 67 | if __name__ == "__main__": 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /recml/layers/keras/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Layer utilities.""" 15 | 16 | from typing import Any 17 | import keras 18 | 19 | Tensor = Any 20 | 21 | 22 | def clone_initializer(initializer: Any) -> Any: 23 | """Clones an initializer.""" 24 | if isinstance(initializer, keras.initializers.Initializer): 25 | return initializer.clone() 26 | return initializer 27 | 28 | 29 | def make_attention_mask(mask: Tensor, dtype: str = "float32") -> Tensor: 30 | """Creates a 3D self-attention mask from a padding mask.""" 31 | # Element wise pairwise function on [B, L, 1], [B, 1, L]. 32 | attention_mask = keras.ops.multiply( 33 | keras.ops.expand_dims(mask, axis=-1), 34 | keras.ops.expand_dims(mask, axis=-2), 35 | ) 36 | return keras.ops.cast(attention_mask, dtype=dtype) 37 | 38 | 39 | def make_causal_mask(mask: Tensor, dtype: str = "float32") -> Tensor: 40 | """Creates a 3D causal self-attention mask from a padding mask.""" 41 | return keras.ops.tril(make_attention_mask(mask, dtype=dtype)) 42 | 43 | 44 | @keras.saving.register_keras_serializable("recml") 45 | def gelu_approximate(x: Tensor) -> Tensor: 46 | """Approximate GELU activation function.""" 47 | return keras.activations.gelu(x, approximate=True) 48 | 49 | 50 | @keras.saving.register_keras_serializable("recml") 51 | def relu_squared(x: Tensor) -> Tensor: 52 | """RELU squared activation function.""" 53 | return keras.ops.square(keras.activations.relu(x)) 54 | 55 | 56 | @keras.saving.register_keras_serializable("recml") 57 | def norm_embedding_post_processor(inputs: Tensor, eps: float = 1e-6) -> Tensor: 58 | """L2 Normalization Post Processor for HSTU. 59 | 60 | Take output embeddings and normalize them to unit length. 61 | 62 | Args: 63 | inputs: The input sequence tensor. shape = [B, N, D] 64 | eps: Epsilon to use for division. 65 | 66 | Returns: 67 | The normalized output embeddings. 68 | """ 69 | return keras.ops.divide( 70 | inputs, 71 | keras.ops.clip( 72 | keras.ops.norm(inputs, ord=None, axis=-1, keepdims=True), 73 | x_min=eps, 74 | x_max=None, 75 | ), 76 | ) 77 | 78 | 79 | def apply_rotary_encoding( 80 | x: Tensor, *, positions: Tensor | None = None, max_wavelength: int 81 | ) -> Tensor: 82 | """Returns the rotary positional encodings. 83 | 84 | Args: 85 | x: Array of embeddings of shape [*batch_size, seq_len, num_heads, head_dim]. 86 | Where head_dim must be even. 87 | positions: Optional array of shape [*batch_size, seq_len] holding the 88 | position of each token in the sequence. If not provided, the input is 89 | assumed to be a contiguous sequence and the positions are therefore [0, 1, 90 | ..., seq_len - 1] for each example. 91 | max_wavelength: Maximum wavelength that will appear in sin / cosine 92 | waveforms. This specifies the maximum sequence length for identifying 93 | unique positions. 94 | 95 | Returns: 96 | Array of rotary encoded input of shape [batch_size, seq_len, num_heads, 97 | head_dim]. 98 | """ 99 | x_shape = keras.ops.shape(x) 100 | b = (x_shape[i] for i in range(len(x_shape) - 3)) 101 | seq_len = x_shape[-3] 102 | if x_shape[-1] % 2 != 0: 103 | raise ValueError( 104 | "Embedding dimension must be even, but got" 105 | f" {x_shape[-1]} for input of shape {x_shape}." 106 | ) 107 | if len(x_shape) < 4: 108 | raise ValueError( 109 | f"Unexpected input shape: {x_shape}. Expected shape of rank 4 or" 110 | " greater." 111 | ) 112 | if positions is None: 113 | positions = keras.ops.tile( 114 | keras.ops.arange(seq_len)[None, :], 115 | (*[d if d is not None else -1 for d in b], 1), 116 | ) 117 | # Only do shape checks on not TF backends. 118 | if keras.backend.backend() != "tensorflow": 119 | if keras.ops.shape(positions) != x_shape[:-2]: 120 | raise ValueError( 121 | f"Positions must be of shape: {(x_shape[:-2])} but got shape:" 122 | f" {keras.ops.shape(positions)}." 123 | ) 124 | freq_exponents = (2.0 / x_shape[-1]) * keras.ops.arange( 125 | x_shape[-1] // 2, dtype="float32" 126 | ) 127 | timescale = max_wavelength**freq_exponents 128 | timescale = timescale[ 129 | (*[None for _ in b], None, slice(None)) 130 | ] # timescale[None, None, :] when len(b) == 1 131 | radians = keras.ops.cast(positions[..., None], "float32") / timescale 132 | radians = radians[..., None, :] 133 | # radians.shape = [...,L,1,d=D/2] 134 | sin, cos = keras.ops.sin(radians), keras.ops.cos(radians) 135 | x1, x2 = keras.ops.split(x, 2, axis=-1) 136 | x1, x2 = keras.ops.cast(x1, "float32"), keras.ops.cast(x2, "float32") 137 | res = keras.ops.concatenate( 138 | [x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1 139 | ) 140 | return keras.ops.cast(res, keras.ops.dtype(x)) 141 | 142 | 143 | def large_negative_for_attention(dtype: Any) -> float: 144 | """Return a large negative number based on dtype.""" 145 | if keras.backend.standardize_dtype(dtype) == "float16": 146 | return -3e4 147 | return -1e9 148 | -------------------------------------------------------------------------------- /recml/layers/keras/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for layer utilities.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import keras 19 | from keras.src import testing 20 | import numpy as np 21 | from recml.layers.keras import utils 22 | 23 | 24 | class UtilsTest(testing.TestCase): 25 | 26 | def test_clone_initializer(self): 27 | random_initializer = keras.initializers.RandomNormal(stddev=1.0) 28 | random_initializer_clone = utils.clone_initializer(random_initializer) 29 | self.assertNotEqual(random_initializer.seed, random_initializer_clone.seed) 30 | 31 | lecun_initializer = keras.initializers.LecunNormal(seed=1) 32 | lecun_initializer_clone = utils.clone_initializer(lecun_initializer) 33 | self.assertEqual(lecun_initializer.seed, lecun_initializer_clone.seed) 34 | 35 | self.assertEqual(utils.clone_initializer("lecun_normal"), "lecun_normal") 36 | 37 | # Remember to read these sideways =)) 38 | @parameterized.parameters( 39 | dict( 40 | inputs=np.array([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=np.float32), 41 | expected_outputs=np.array( 42 | [ 43 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]], 44 | [[1, 0, 0], [0, 0, 0], [0, 0, 0]], 45 | [[1, 1, 0], [1, 1, 0], [0, 0, 0]], 46 | ], 47 | dtype=np.float32, 48 | ), 49 | ), 50 | ) 51 | def test_make_attention_mask( 52 | self, inputs: np.array, expected_outputs: np.array 53 | ): 54 | self.assertAllClose( 55 | utils.make_attention_mask(keras.ops.array(inputs)), 56 | keras.ops.array(expected_outputs), 57 | ) 58 | 59 | @parameterized.parameters( 60 | dict( 61 | inputs=np.array([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=np.float32), 62 | expected_outputs=np.array( 63 | [ 64 | [[1, 0, 0], [1, 1, 0], [1, 1, 1]], 65 | [[1, 0, 0], [0, 0, 0], [0, 0, 0]], 66 | [[1, 0, 0], [1, 1, 0], [0, 0, 0]], 67 | ], 68 | dtype=np.float32, 69 | ), 70 | ), 71 | ) 72 | def test_make_causal_mask(self, inputs: np.array, expected_outputs: np.array): 73 | self.assertAllClose( 74 | utils.make_causal_mask(keras.ops.array(inputs)), 75 | keras.ops.array(expected_outputs), 76 | ) 77 | 78 | @parameterized.parameters( 79 | dict( 80 | inputs=np.array([[[1.0, 1.0], [0.0, 1.0]]]), 81 | eps=1e-6, 82 | expected_output=np.array([[[0.70710678, 0.70710678], [0.0, 1.0]]]), 83 | ), 84 | dict( 85 | inputs=np.array([[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0]]]), 86 | eps=1e-06, 87 | expected_output=np.array( 88 | [[[0.0, 1.0, 0.0], [0.5773502, 0.5773502, 0.5773502]]] 89 | ), 90 | ), 91 | ) 92 | def test_l2_norm_embedding_postprocessor_output( 93 | self, inputs, eps, expected_output 94 | ): 95 | inputs = keras.ops.array(inputs) 96 | expected_output = keras.ops.array(expected_output) 97 | 98 | got = utils.norm_embedding_post_processor(inputs, eps=eps) 99 | self.assertAllClose(got, expected_output) 100 | 101 | @parameterized.named_parameters( 102 | dict( 103 | # wavelength == 1 so radians == expand_dims(positions, axes=[-1, -2]) 104 | testcase_name="with_positions", 105 | inputs=np.array([[[[1.0, 2.0]], [[3.0, 4.0]]]]), 106 | positions=np.array([[4, 1]]), 107 | max_wavelength=1, 108 | expected_outputs=np.array([[ 109 | [[ 110 | 1.0 * np.cos(4) - 2.0 * np.sin(4), 111 | 2.0 * np.cos(4) + 1.0 * np.sin(4), 112 | ]], 113 | [[ 114 | 3.0 * np.cos(1) - 4.0 * np.sin(1), 115 | 4.0 * np.cos(1) + 3.0 * np.sin(1), 116 | ]], 117 | ]]), 118 | ), 119 | dict( 120 | # wavelength == 1 so radians == expand_dims(positions, axes=[-1, -2]) 121 | testcase_name="no_positions", 122 | inputs=np.array([[[[1.0, 2.0]], [[3.0, 4.0]]]]), 123 | positions=None, # Should evaluate to [[0, 1]] 124 | max_wavelength=1, 125 | expected_outputs=np.array([[ 126 | [[1.0, 2.0]], 127 | [[ 128 | 3.0 * np.cos(1) - 4.0 * np.sin(1), 129 | 4.0 * np.cos(1) + 3.0 * np.sin(1), 130 | ]], 131 | ]]), 132 | ), 133 | ) 134 | def test_apply_rotary_embedding( 135 | self, 136 | inputs: np.ndarray, 137 | positions: np.ndarray | None, 138 | max_wavelength: int, 139 | expected_outputs: np.ndarray, 140 | ): 141 | self.assertAllClose( 142 | expected_outputs, 143 | utils.apply_rotary_encoding( 144 | inputs, 145 | positions=positions, 146 | max_wavelength=max_wavelength, 147 | ), 148 | ) 149 | 150 | 151 | if __name__ == "__main__": 152 | absltest.main() 153 | -------------------------------------------------------------------------------- /recml/layers/linen/sparsecore_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 RecML authors . 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Sparsecore tests.""" 15 | 16 | import functools 17 | 18 | from absl.testing import absltest 19 | from etils import epy 20 | import jax 21 | from recml.core.training import partitioning 22 | from recml.layers.linen import sparsecore 23 | 24 | with epy.lazy_imports(): 25 | from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec # pylint: disable=g-import-not-at-top 26 | 27 | 28 | class SparsecoreTest(absltest.TestCase): 29 | 30 | def test_sparsecore_embedder_equivalence(self): 31 | if jax.devices()[0].platform != "tpu": 32 | self.skipTest("Test only supported on TPUs.") 33 | 34 | k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) 35 | 36 | inputs = { 37 | "a": jax.random.randint(k1, (32, 16), minval=1, maxval=100), 38 | "b": jax.random.randint(k2, (32, 16), minval=1, maxval=100), 39 | "w": jax.random.normal(k3, (32, 16)), 40 | } 41 | 42 | dp_partitioner = partitioning.DataParallelPartitioner() 43 | embedder = sparsecore.SparsecoreEmbedder( 44 | specs={ 45 | "a": sparsecore.EmbeddingSpec( 46 | input_dim=100, 47 | embedding_dim=16, 48 | combiner="mean", 49 | weight_name="w", 50 | ), 51 | "b": sparsecore.EmbeddingSpec( 52 | input_dim=100, 53 | embedding_dim=16, 54 | max_sequence_length=10, 55 | ), 56 | }, 57 | optimizer=embedding_spec.AdagradOptimizerSpec(learning_rate=0.01), 58 | ) 59 | preprocessor = embedder.make_preprocessor(32) 60 | layer = embedder.make_sparsecore_module() 61 | 62 | sc_inputs = dp_partitioner.shard_inputs(preprocessor(inputs)) 63 | sc_vars = dp_partitioner.partition_init(functools.partial(layer.init, k4))( 64 | sc_inputs 65 | ) 66 | 67 | def step(inputs, params): 68 | return layer.apply(params, inputs) 69 | 70 | p_step = dp_partitioner.partition_step(step, training=False) 71 | sparsecore_activations = jax.device_get(p_step(sc_inputs, sc_vars)) 72 | 73 | self.assertEqual(sparsecore_activations["a"].shape, (32, 16)) 74 | self.assertEqual(sparsecore_activations["b"].shape, (32, 10, 16)) 75 | 76 | 77 | if __name__ == "__main__": 78 | absltest.main() 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.2.2 2 | astroid==3.3.9 3 | astunparse==1.6.3 4 | attrs==25.3.0 5 | certifi==2025.1.31 6 | cfgv==3.4.0 7 | charset-normalizer==3.4.1 8 | chex==0.1.89 9 | clu==0.0.12 10 | dill==0.4.0 11 | distlib==0.3.9 12 | dm-tree==0.1.9 13 | docstring-parser==0.16 14 | einops==0.8.1 15 | etils==1.12.2 16 | fiddle==0.3.0 17 | filelock==3.18.0 18 | flatbuffers==25.2.10 19 | flax==0.10.5 20 | fsspec==2025.3.2 21 | gast==0.6.0 22 | google-pasta==0.2.0 23 | googleapis-common-protos==1.70.0 24 | graphviz==0.20.3 25 | grpcio==1.71.0 26 | h5py==3.13.0 27 | humanize==4.12.2 28 | identify==2.6.9 29 | idna==3.10 30 | immutabledict==4.2.1 31 | importlib-resources==6.5.2 32 | iniconfig==2.1.0 33 | isort==6.0.1 34 | jax==0.6.0 35 | jaxlib==0.6.0 36 | jaxtyping==0.3.1 37 | jinja2==3.1.6 38 | kagglehub==0.3.11 39 | keras==3.9.2 40 | keras-hub==0.20.0 41 | libclang==18.1.1 42 | libcst==1.7.0 43 | markdown==3.8 44 | markdown-it-py==3.0.0 45 | markupsafe==3.0.2 46 | mccabe==0.7.0 47 | mdurl==0.1.2 48 | ml-collections==1.1.0 49 | ml-dtypes==0.5.1 50 | mpmath==1.3.0 51 | msgpack==1.1.0 52 | namex==0.0.8 53 | nest-asyncio==1.6.0 54 | networkx==3.4.2 55 | nodeenv==1.9.1 56 | numpy==2.1.3 57 | opt-einsum==3.4.0 58 | optax==0.2.4 59 | optree==0.15.0 60 | orbax-checkpoint==0.11.12 61 | packaging==24.2 62 | platformdirs==4.3.7 63 | pluggy==1.5.0 64 | pre-commit==4.2.0 65 | promise==2.3 66 | protobuf==5.29.4 67 | psutil==7.0.0 68 | pyarrow==19.0.1 69 | pygments==2.19.1 70 | pylint==3.3.6 71 | pytest==8.3.5 72 | pytest-env==1.1.5 73 | pyyaml==6.0.2 74 | regex==2024.11.6 75 | requests==2.32.3 76 | rich==14.0.0 77 | scipy==1.15.2 78 | setuptools==78.1.0 79 | simple-parsing==0.1.7 80 | simplejson==3.20.1 81 | six==1.17.0 82 | sympy==1.13.1 83 | tensorboard==2.19.0 84 | tensorboard-data-server==0.7.2 85 | tensorflow==2.19.0 86 | tensorflow-datasets==4.9.8 87 | tensorflow-metadata==1.17.1 88 | tensorflow-text==2.19.0 89 | tensorstore==0.1.73 90 | termcolor==3.0.1 91 | toml==0.10.2 92 | tomlkit==0.13.2 93 | toolz==1.0.0 94 | torch==2.6.0 95 | tqdm==4.67.1 96 | treescope==0.1.9 97 | typing-extensions==4.13.2 98 | urllib3==2.4.0 99 | virtualenv==20.30.0 100 | wadler-lindig==0.1.5 101 | werkzeug==3.1.3 102 | wheel==0.45.1 103 | wrapt==1.17.2 104 | zipp==3.21.0 105 | --------------------------------------------------------------------------------