├── .github
└── workflows
│ └── unit_tests.yaml
├── .gitignore
├── .python-version
├── LICENSE
├── README.md
├── docs
├── code-of-conduct.md
└── contributing.md
├── pytest.ini
├── recml
├── __init__.py
├── core
│ ├── data
│ │ ├── __init__.py
│ │ ├── iterator.py
│ │ ├── preprocessing.py
│ │ └── tf_dataset_factory.py
│ ├── metrics
│ │ ├── __init__.py
│ │ ├── base_metrics.py
│ │ ├── base_metrics_test.py
│ │ ├── confusion_metrics.py
│ │ ├── confusion_metrics_test.py
│ │ ├── mean_metrics.py
│ │ ├── mean_metrics_test.py
│ │ ├── reduction_metrics.py
│ │ ├── reduction_metrics_test.py
│ │ └── tools.py
│ ├── ops
│ │ └── embedding_ops.py
│ ├── training
│ │ ├── core.py
│ │ ├── jax_trainer.py
│ │ ├── jax_trainer_quality_test.py
│ │ ├── jax_trainer_test.py
│ │ ├── keras_trainer.py
│ │ ├── keras_trainer_test.py
│ │ ├── optax_factory.py
│ │ ├── optax_factory_test.py
│ │ ├── partitioning.py
│ │ └── partitioning_test.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── config_test.py
│ │ ├── keras_utils.py
│ │ ├── keras_utils_test.py
│ │ ├── py_utils.py
│ │ ├── types.py
│ │ └── types_test.py
├── examples
│ ├── dlrm_experiment.py
│ └── dlrm_experiment_test.py
└── layers
│ ├── keras
│ ├── README.md
│ ├── bert4rec.py
│ ├── bert4rec_test.py
│ ├── hstu.py
│ ├── hstu_test.py
│ ├── mamba.py
│ ├── mamba_test.py
│ ├── sasrec.py
│ ├── sasrec_test.py
│ ├── utils.py
│ └── utils_test.py
│ └── linen
│ ├── sparsecore.py
│ └── sparsecore_test.py
└── requirements.txt
/.github/workflows/unit_tests.yaml:
--------------------------------------------------------------------------------
1 | name: Run Unit Tests
2 |
3 | on:
4 | push:
5 | branches: [ "main", "master" ]
6 | pull_request:
7 | branches: [ "main", "master" ]
8 | workflow_dispatch:
9 |
10 | jobs:
11 | run_tests:
12 | name: Run Tests
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Checkout code
16 | uses: actions/checkout@v4
17 | - name: Install Dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | pip install -r requirements.txt # If you have a requirements.txt file
21 | - name: Run Pytest
22 | run: |
23 | export KERAS_BACKEND=jax
24 | python -m pytest -v
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.12
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RecML: High-Performance Recommender Library
2 |
3 | ## Vision
4 |
5 | RecML is envisioned as a high-performance, large-scale deep learning recommender
6 | system library optimized for Cloud TPUs. It aims to provide researchers and
7 | practitioners state-of-the-art reference implementations, tools, and best
8 | practice guidelines for building and deploying recommender systems.
9 |
10 | The key goals of RecML are:
11 |
12 | * **Performance & Scalability:** Leverage Cloud TPUs (including SparseCore
13 | acceleration) to deliver exceptional performance for training and serving
14 | massive models with large embeddings on datasets with millions or billions
15 | of items/users. RecML can additionally target Cloud GPUs.
16 | * **State-of-the-Art Models:** Provide production-ready, easy-to-understand
17 | reference implementations of popular and cutting-edge models, with a strong
18 | focus on LLM-based recommenders.
19 | * **Ease of Use:** Offer a user-friendly API, intuitive abstractions, and
20 | comprehensive documentation/examples for rapid prototyping and deployment.
21 | * **Flexibility:** Primarily built with Keras and JAX, but designed with
22 | potential future expansion to other frameworks like PyTorch/XLA.
23 | * **Open Source:** Foster community collaboration and provide components to
24 | help users get started with advanced recommender workloads on Google Cloud.
25 |
26 | ## Features
27 |
28 | * **High Performance:** Optimized for Cloud TPU (SparseCore) training and
29 | inference.
30 | * **Scalable Architecture:** Designed for massive datasets and models with
31 | large embedding tables. Includes support for efficient data loading
32 | (tf.data, potentially Grain) and sharding/SPMD.
33 | * **State-of-the-Art Model Implementations:** Reference implementations for
34 | various recommendation tasks (ranking, retrieval, sequential).
35 | * **Reusable Building Blocks:**
36 | * Common recommendation layers (e.g., DCN, BERT4Rec).
37 | * Specialized Embedding APIs (e.g. JAX Embedding API for SparseCore).
38 | * Standardized metrics (e.g., AUC, Accuracy, NDCG@K, MRR, Recall@K).
39 | * Common loss functions.
40 | * **Unified Trainer:** A high-level trainer abstraction capable of targeting
41 | different hardware (TPU/GPU) and frameworks. Includes customizable training
42 | and evaluation loops.
43 | * **End-to-End Support:** Covers aspects from data pipelines to training,
44 | evaluation, checkpointing, metrics logging (e.g., to BigQuery), and model
45 | export/serving considerations.
46 |
47 | ## Models Included
48 |
49 | This library aims to house implementations for a variety of recommender models,
50 | including:
51 |
52 | * **SASRec:** Self-Attention based Sequential Recommendation
53 | * **BERT4Rec:** Bidirectional Encoder Representations from Transformer for
54 | Sequential Recommendation.
55 | * **Mamba4Rec:** Efficient Sequential Recommendation with Selective State
56 | Space Models.
57 | * **HSTU:** Hierarchical Sequential Transduction Units for Generative
58 | Recommendations.
59 | * **DLRM v2:** Deep Learning Recommendation Model
60 |
61 | ## Roadmap / Future Work
62 |
63 | * Expand reference model implementations (Retrieval, Uplift, foundation user
64 | model).
65 | * Add support for optimized configurations and lower precision training
66 | (bfloat16, fp16).
67 | * Improve support for Cloud GPU training and inference
68 | * Enhance sharding and quantization support.
69 | * Improve integration with Keras (and Keras Recommenders) and potentially
70 | PyTorch/XLA.
71 | * Develop comprehensive model serving examples and integrations.
72 | * Refine data loading pipelines (e.g., Grain support).
73 | * Add more common layers, losses, and metrics.
74 |
75 | ## Responsible Use
76 |
77 | As with any machine learning model, potential risks exist. The performance and
78 | behavior depend heavily on the training data, which may contain biases reflected
79 | in the recommendations. Developers should carefully evaluate the model's
80 | fairness and potential limitations in their specific application context.
81 |
82 | ## License
83 |
84 | RecML is released under the Apache 2.0. Please see the `LICENSE` file for full
85 | details.
86 |
--------------------------------------------------------------------------------
/docs/code-of-conduct.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, gender identity and expression, level of
9 | experience, education, socio-economic status, nationality, personal appearance,
10 | race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or reject
41 | comments, commits, code, wiki edits, issues, and other contributions that are
42 | not aligned to this Code of Conduct, or to ban temporarily or permanently any
43 | contributor for other behaviors that they deem inappropriate, threatening,
44 | offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when the Project
56 | Steward has a reasonable belief that an individual's behavior may have a
57 | negative impact on the project or its community.
58 |
59 | ## Conflict Resolution
60 |
61 | We do not believe that all conflict is bad; healthy debate and disagreement
62 | often yield positive results. However, it is never okay to be disrespectful or
63 | to engage in behavior that violates the project’s code of conduct.
64 |
65 | If you see someone violating the code of conduct, you are encouraged to address
66 | the behavior directly with those involved. Many issues can be resolved quickly
67 | and easily, and this gives people more control over the outcome of their
68 | dispute. If you are unable to resolve the matter for any reason, or if the
69 | behavior is threatening or harassing, report it. We are dedicated to providing
70 | an environment where participants feel welcome and safe.
71 |
72 | Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the
73 | Project Steward(s) for *[PROJECT NAME]*. It is the Project Steward’s duty to
74 | receive and address reported violations of the code of conduct. They will then
75 | work with a committee consisting of representatives from the Open Source
76 | Programs Office and the Google Open Source Strategy team. If for any reason you
77 | are uncomfortable reaching out to the Project Steward, please email
78 | opensource@google.com.
79 |
80 | We will investigate every complaint, but you may not receive a direct response.
81 | We will use our discretion in determining when and how to follow up on reported
82 | incidents, which may range from not taking action to permanent expulsion from
83 | the project and project-sponsored spaces. We will notify the accused of the
84 | report and provide them an opportunity to discuss it before any action is taken.
85 | The identity of the reporter will be omitted from the details of the report
86 | supplied to the accused. In potentially harmful situations, such as ongoing
87 | harassment or threats to anyone's safety, we may take action without notice.
88 |
89 | ## Attribution
90 |
91 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4,
92 | available at
93 | https://www.contributor-covenant.org/version/1/4/code-of-conduct/
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We would love to accept your patches and contributions to this project.
4 |
5 | ## Before you begin
6 |
7 | ### Sign our Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a
10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11 | You (or your employer) retain the copyright to your contribution; this simply
12 | gives us permission to use and redistribute your contributions as part of the
13 | project.
14 |
15 | If you or your current employer have already signed the Google CLA (even if it
16 | was for a different project), you probably don't need to do it again.
17 |
18 | Visit to see your current agreements or to
19 | sign a new one.
20 |
21 | ### Review our Community Guidelines
22 |
23 | This project follows [Google's Open Source Community
24 | Guidelines](https://opensource.google/conduct/).
25 |
26 | ## Contribution process
27 |
28 | ### Code Reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)
32 | for this purpose.
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | required_plugins = pytest-env
3 | env =
4 | KERAS_BACKEND=jax
--------------------------------------------------------------------------------
/recml/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Public API for RecML."""
15 |
16 | # pylint: disable=g-importing-member
17 |
18 | from recml.core import data
19 | from recml.core import metrics
20 | from recml.core import utils
21 | from recml.core.metrics.base_metrics import Metric
22 | from recml.core.training.core import Experiment
23 | from recml.core.training.core import run_experiment
24 | from recml.core.training.core import Trainer
25 | from recml.core.training.jax_trainer import JaxState
26 | from recml.core.training.jax_trainer import JaxTask
27 | from recml.core.training.jax_trainer import JaxTrainer
28 | from recml.core.training.jax_trainer import KerasState
29 | from recml.core.training.keras_trainer import KerasTask
30 | from recml.core.training.keras_trainer import KerasTrainer
31 | from recml.core.training.optax_factory import AdagradFactory
32 | from recml.core.training.optax_factory import AdamFactory
33 | from recml.core.training.optax_factory import OptimizerFactory
34 | from recml.core.training.partitioning import DataParallelPartitioner
35 | from recml.core.training.partitioning import ModelParallelPartitioner
36 | from recml.core.training.partitioning import NullPartitioner
37 | from recml.core.training.partitioning import Partitioner
38 | from recml.core.utils.types import Factory
39 | from recml.core.utils.types import FactoryProtocol
40 | from recml.core.utils.types import ObjectFactory
41 |
--------------------------------------------------------------------------------
/recml/core/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Public API for RecML data."""
15 |
16 | # pylint: disable=g-importing-member
17 |
18 | from recml.core.data.iterator import Iterator
19 | from recml.core.data.iterator import TFDatasetIterator
20 | from recml.core.data.preprocessing import PreprocessingMode
21 | from recml.core.data.tf_dataset_factory import DatasetShardingInfo
22 | from recml.core.data.tf_dataset_factory import TFDatasetFactory
23 | from recml.core.data.tf_dataset_factory import TFDSMetadata
24 |
--------------------------------------------------------------------------------
/recml/core/data/iterator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data loading and preprocessing for feeding Jax models."""
15 |
16 | from collections.abc import Callable
17 | import os
18 | from typing import Any
19 |
20 | import clu.data as clu_data
21 | from etils import epath
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 |
26 | Iterator = clu_data.DatasetIterator
27 |
28 |
29 | class TFDatasetIterator(clu_data.DatasetIterator):
30 | """An iterator for TF Datasets that supports postprocessing."""
31 |
32 | def __init__(
33 | self,
34 | dataset: tf.data.Dataset,
35 | postprocessor: Callable[..., Any] | None = None,
36 | checkpoint: bool = False,
37 | ):
38 | """Initializes the iterator.
39 |
40 | Args:
41 | dataset: The TF Dataset to iterate over.
42 | postprocessor: An optional postprocessor to apply to each batch. This is
43 | useful for sending embedded ID features to a separate accelerator.
44 | checkpoint: Whether to checkpoint the iterator state.
45 | """
46 | self._dataset = dataset
47 | self._iterator = iter(dataset)
48 | self._postprocessor = postprocessor
49 | self._prefetched_batch = None
50 | self._element_spec = None
51 | self._checkpoint = None
52 | if checkpoint:
53 | self._checkpoint = tf.train.Checkpoint(ds=self._iterator)
54 |
55 | def __next__(self) -> clu_data.Element:
56 | """Returns the next batch."""
57 | if self._prefetched_batch is not None:
58 | batch = self._prefetched_batch
59 | self._prefetched_batch = None
60 | else:
61 | batch = next(self._iterator)
62 | if self._postprocessor is not None:
63 | batch = self._postprocessor(batch)
64 |
65 | def _maybe_to_numpy(
66 | x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
67 | ) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
68 | if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
69 | return x
70 | if hasattr(x, "_numpy"):
71 | numpy = x._numpy() # pylint: disable=protected-access
72 | else:
73 | numpy = x.numpy()
74 | if isinstance(numpy, np.ndarray):
75 | # `numpy` shares the same underlying buffer as the `x` Tensor.
76 | # Tensors are expected to be immutable, so we disable writes.
77 | numpy.setflags(write=False)
78 | return numpy
79 |
80 | return tf.nest.map_structure(_maybe_to_numpy, batch)
81 |
82 | @property
83 | def element_spec(self) -> clu_data.ElementSpec:
84 | if self._element_spec is not None:
85 | return self._element_spec
86 |
87 | batch = next(self._iterator)
88 | if self._postprocessor is not None:
89 | batch = self._postprocessor(batch)
90 |
91 | self._prefetched_batch = batch
92 |
93 | def _to_element_spec(
94 | x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
95 | ) -> clu_data.ArraySpec:
96 | if isinstance(x, tf.SparseTensor):
97 | return clu_data.ArraySpec(
98 | dtype=x.dtype.as_numpy_dtype,
99 | shape=tuple(x.shape[0], *[None for _ in x.shape[1:]]),
100 | )
101 | if isinstance(x, tf.RaggedTensor):
102 | return clu_data.ArraySpec(
103 | dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error
104 | shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error
105 | )
106 | if isinstance(x, tf.Tensor):
107 | return clu_data.ArraySpec(
108 | dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list())
109 | )
110 | return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
111 |
112 | element_spec = tf.nest.map_structure(_to_element_spec, batch)
113 | self._element_spec = element_spec
114 | return element_spec
115 |
116 | def reset(self):
117 | self._iterator = iter(self._dataset)
118 | if self._checkpoint is not None:
119 | self._checkpoint = tf.train.Checkpoint(ds=self._iterator)
120 |
121 | def save(self, filename: epath.Path):
122 | if self._checkpoint is not None:
123 | self._checkpoint.write(os.fspath(filename))
124 |
125 | def restore(self, filename: epath.Path):
126 | if self._checkpoint is not None:
127 | self._checkpoint.read(os.fspath(filename)).assert_consumed()
128 |
--------------------------------------------------------------------------------
/recml/core/data/preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Preprocessing utilities."""
15 |
16 | import enum
17 |
18 |
19 | class PreprocessingMode(enum.StrEnum):
20 | """Mode for data preprocessing."""
21 |
22 | TRAINING = "training"
23 | EVAL = "eval"
24 | SERVING = "serving"
25 |
--------------------------------------------------------------------------------
/recml/core/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Public API for metrics."""
15 |
16 | # pylint: disable=g-importing-member
17 |
18 | from recml.core.metrics.base_metrics import Metric
19 | from recml.core.metrics.base_metrics import scalar
20 | from recml.core.metrics.confusion_metrics import aucpr
21 | from recml.core.metrics.confusion_metrics import aucroc
22 | from recml.core.metrics.confusion_metrics import estimate_confusion_matrix
23 | from recml.core.metrics.confusion_metrics import f1_score
24 | from recml.core.metrics.confusion_metrics import fbeta_score
25 | from recml.core.metrics.confusion_metrics import precision
26 | from recml.core.metrics.confusion_metrics import precision_at_recall
27 | from recml.core.metrics.confusion_metrics import recall
28 | from recml.core.metrics.mean_metrics import accuracy
29 | from recml.core.metrics.mean_metrics import binary_accuracy
30 | from recml.core.metrics.mean_metrics import mean_squared_error
31 | from recml.core.metrics.mean_metrics import top_k_accuracy
32 | from recml.core.metrics.reduction_metrics import mean
33 | from recml.core.metrics.reduction_metrics import sum # pylint: disable=redefined-builtin
34 | from recml.core.metrics.tools import MetricAccumulator
35 |
--------------------------------------------------------------------------------
/recml/core/metrics/base_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Functional metrics inspired by the CLU interface and Keras semantics."""
15 |
16 | import abc
17 | from collections.abc import Mapping, Sequence
18 | import math
19 | from typing import Self, dataclass_transform
20 |
21 | import clu.metrics as clu_metrics
22 | from flax import struct
23 | import jax
24 | import jax.numpy as jnp
25 | import numpy as np
26 |
27 | Scalar = float | Sequence[float] | Mapping[str, float] | jax.Array | np.ndarray
28 |
29 | # TODO(aahil): Look into why pytype doesn't respect the Self type as a generic
30 | # type. We should not be violating LSP.
31 |
32 | # TODO(b/387463777): Consider removing the dependency on CLU metrics longer term
33 | # since it's just an interface.
34 | @dataclass_transform(field_specifiers=(struct.field,)) # pytype: disable=not-supported-yet
35 | class Metric(abc.ABC, clu_metrics.Metric, struct.PyTreeNode):
36 | """PyTree node representing the state of a metric.
37 |
38 | Note: This class follows the same interface as CLU metrics and can be used
39 | interchangeably.
40 |
41 | There are a few suble differences between subclasses and standard CLU metrics:
42 | 1. Inheriting from this automatically makes the metric a PyTree node.
43 | 2. `mask` has been replaced by `weights` to be consistent with Keras.
44 | 3. Subclasses do not implement methods apart from the ones listed below.
45 | 4. The `localize` method is added to specify how to localize the metric from
46 | device to host.
47 | """
48 |
49 | @classmethod
50 | def from_model_output(cls, *args, **kwargs) -> Self:
51 | """Creates a metric from observations.
52 |
53 | Args:
54 | *args: Positional arguments to pass to the metric.
55 | **kwargs: Keyword arguments to pass to the metric.
56 |
57 | Returns:
58 | A new instance of the metric.
59 |
60 | NOTE: This metric is always called on the device and should therefore use
61 | only jax ops.
62 | """
63 | raise NotImplementedError()
64 |
65 | @abc.abstractmethod
66 | def merge(self, other: Self) -> Self: # pytype: disable=signature-mismatch
67 | """Merges two metrics.
68 |
69 | Args:
70 | other: Another metric of the same type to merge with.
71 |
72 | Returns:
73 | A new instance of the same class that is the merge of the two.
74 |
75 | NOTE: This method is almost always called on the host which means that it
76 | should *never* call jax ops - this will implicitly move the state of the
77 | metric back to the device for computation. It is safest to rely on dunder
78 | methods via `+` / `-`.
79 | """
80 |
81 | @abc.abstractmethod
82 | def compute(self) -> Scalar: # pytype: disable=signature-mismatch
83 | """Computes the value of the metric.
84 |
85 | NOTE: This method is almost always called on the host which means that it
86 | should *never* call jax ops - this will implicitly move the state of the
87 | metric back to the device for computation. Use numpy ops instead.
88 | """
89 |
90 | def localize(self) -> Self:
91 | """Localizes the metric from device to host.
92 |
93 | Returns:
94 | A new instance of the same class that is localized, i.e. jax arrays on the
95 | metric are replaced by numpy arrays.
96 | """
97 |
98 | def _localize(x):
99 | x = jax.device_get(x)
100 | if isinstance(x, jax.Array) and not isinstance(x, jax.core.Tracer):
101 | return x.addressable_data(0)
102 | return x
103 |
104 | return jax.tree.map(_localize, self)
105 |
106 |
107 | class ScalarMetric(Metric):
108 | """A metric for reporting scalar values without aggregation."""
109 |
110 | value: jax.Array
111 |
112 | @classmethod
113 | def from_model_output(cls, value: jax.Array | float) -> Self:
114 | if hasattr(value, "shape") and math.prod(value.shape) != 1:
115 | raise ValueError(
116 | f"Scalar metric values must be scalars. Got shape: {value.shape}"
117 | " instead."
118 | )
119 | return cls(value=jnp.squeeze(jnp.asarray(value, dtype=jnp.float32)))
120 |
121 | def merge(self, other: Self) -> Self:
122 | return other
123 |
124 | def compute(self) -> Scalar:
125 | return self.value
126 |
127 |
128 | def scalar(value: float | jax.Array) -> ScalarMetric:
129 | """Creates a scalar metric from a scalar value at a specific step.
130 |
131 | This is useful for reporting batch metrics during training. When merged with
132 | other instances, effectively the last value observed is reported.
133 |
134 | Note that using this metric during evaluation will result in multiple values
135 | being reported for the same step, which is generally undesirable.
136 |
137 | Example usage:
138 |
139 | ```
140 | state = ...
141 | metrics = {
142 | "average_loss": mean(loss),
143 | "per_batch_loss": scalar(loss),
144 | "learning_rate": scalar(learning_rate),
145 | }
146 | ```
147 |
148 | Args:
149 | value: The scalar value to report.
150 |
151 | Returns:
152 | A scalar metric reporting the value.
153 | """
154 | return ScalarMetric.from_model_output(value)
155 |
--------------------------------------------------------------------------------
/recml/core/metrics/base_metrics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for base metrics."""
15 |
16 | from absl.testing import absltest
17 | import jax.numpy as jnp
18 | from recml.core.metrics import base_metrics
19 |
20 |
21 | class BaseMetricsTest(absltest.TestCase):
22 |
23 | def test_scalar(self):
24 | m1 = base_metrics.scalar(1.0)
25 | m2 = base_metrics.scalar(2.0)
26 | m3 = m1.merge(m2)
27 | self.assertEqual(m3.compute(), 2.0)
28 |
29 | self.assertRaises(ValueError, base_metrics.scalar, jnp.array([1.0, 2.0]))
30 |
31 |
32 | if __name__ == "__main__":
33 | absltest.main()
34 |
--------------------------------------------------------------------------------
/recml/core/metrics/mean_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Mean metrics."""
15 |
16 | import jax
17 | import keras
18 | from recml.core.metrics import reduction_metrics
19 |
20 |
21 | def accuracy(
22 | y_true: jax.Array,
23 | y_pred: jax.Array,
24 | weights: jax.Array | None = None,
25 | **_,
26 | ) -> reduction_metrics.Mean:
27 | """Computes accuracy from observations.
28 |
29 | Args:
30 | y_true: The true labels of shape [D1, ..., D_N]
31 | y_pred: The predicted logits of shape [D1, ..., D_N, num_classes].
32 | weights: Optional weights of shape broadcastable to [D1, ... D_N].
33 | **_: Unused kwargs.
34 |
35 | Returns:
36 | A metric accumulation of the accuracy.
37 | """
38 | assert keras.backend.backend() == 'jax'
39 | acc = keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
40 | return reduction_metrics.mean(acc, weights)
41 |
42 |
43 | def top_k_accuracy(
44 | y_true: jax.Array,
45 | y_pred: jax.Array,
46 | weights: jax.Array | None = None,
47 | *,
48 | k: int,
49 | **_,
50 | ) -> reduction_metrics.Mean:
51 | """Computes top-k accuracy from observations.
52 |
53 | Args:
54 | y_true: The true labels of shape [D1, ..., D_N]
55 | y_pred: The predicted logits of shape [D1, ..., D_N, num_classes].
56 | weights: Optional weights of shape broadcastable to [D1, ... D_N].
57 | k: The number of top classes to consider. Must be less than num_classes.
58 | **_: Unused kwargs.
59 |
60 | Returns:
61 | A metric accumulation of the top-k accuracy.
62 | """
63 | assert keras.backend.backend() == 'jax'
64 | acc = keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=k)
65 | return reduction_metrics.mean(acc, weights)
66 |
67 |
68 | def binary_accuracy(
69 | y_true: jax.Array,
70 | y_pred: jax.Array,
71 | weights: jax.Array | None = None,
72 | *,
73 | threshold: float = 0.5,
74 | **_,
75 | ) -> reduction_metrics.Mean:
76 | """Computes binary accuracy from observations.
77 |
78 | Args:
79 | y_true: The true labels of shape [D1, ..., D_N]
80 | y_pred: The binary predictions of shape [D1, ..., D_N].
81 | weights: Optional weights of shape broadcastable to [D1, ... D_N].
82 | threshold: The threshold to use for binary classification.
83 | **_: Unused kwargs.
84 |
85 | Returns:
86 | A metric accumulation of the binary accuracy.
87 | """
88 | assert keras.backend.backend() == 'jax'
89 | bin_acc = keras.metrics.binary_accuracy(y_true, y_pred, threshold=threshold)
90 | return reduction_metrics.mean(bin_acc, weights)
91 |
92 |
93 | def mean_squared_error(
94 | y_true: jax.Array,
95 | y_pred: jax.Array,
96 | weights: jax.Array | None = None,
97 | **_,
98 | ) -> reduction_metrics.Mean:
99 | """Computes mean squared error from observations.
100 |
101 | Args:
102 | y_true: The true labels of shape [D1, ..., D_N].
103 | y_pred: The predictions of shape [D1, ..., D_N].
104 | weights: Optional weights of shape broadcastable to [D1, ... D_N].
105 | **_: Unused kwargs.
106 |
107 | Returns:
108 | A metric accumulation of the mean squared error.
109 | """
110 | assert keras.backend.backend() == 'jax'
111 | mse = keras.metrics.mean_squared_error(y_true, y_pred)
112 | return reduction_metrics.mean(mse, weights)
113 |
--------------------------------------------------------------------------------
/recml/core/metrics/mean_metrics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for mean metrics."""
15 |
16 | from collections.abc import Sequence
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import numpy as np
21 | from recml.core.metrics import mean_metrics
22 |
23 |
24 | class MeanMetricsTest(parameterized.TestCase):
25 |
26 | @parameterized.named_parameters(
27 | {
28 | 'testcase_name': 'unweighted',
29 | 'y_true': np.array([1, 0, 1, 0]),
30 | 'y_pred': np.array([[0.2, 0.8], [0.8, 0.2], [0.1, 0.9], [0.6, 0.4]]),
31 | 'weights': None,
32 | 'expected_output': 1.0,
33 | },
34 | {
35 | 'testcase_name': 'weighted',
36 | 'y_true': np.array([[1, 0, 1, 0]]),
37 | 'y_pred': np.array(
38 | [[[0.2, 0.8], [0.1, 0.9], [0.3, 0.7], [0.4, 0.6]]]
39 | ),
40 | 'weights': np.array([[1, 2, 3, 4]]),
41 | 'expected_output': 0.4,
42 | },
43 | )
44 | def test_accuracy(
45 | self,
46 | y_true: np.ndarray,
47 | y_pred: np.ndarray,
48 | weights: np.ndarray | None,
49 | expected_output: np.ndarray,
50 | ):
51 | accuracy = mean_metrics.accuracy(y_true, y_pred, weights)
52 | np.testing.assert_allclose(expected_output, accuracy.compute())
53 | np.testing.assert_allclose(expected_output, accuracy.localize().compute())
54 |
55 | @parameterized.named_parameters(
56 | {
57 | 'testcase_name': 'unweighted',
58 | 'y_true': np.array([[1], [4], [2], [3], [3], [1], [0], [5]]),
59 | 'y_pred': np.array([
60 | [[0.1, 0.7, 0.5, 0.3, 0.2, 0.0]], # [1, 2, 3, 4, 0, 5]
61 | [[0.2, 0.8, 0.0, 0.1, 0.4, 0.3]], # [1, 4, 5, 0, 3, 2]
62 | [[0.1, 0.2, 0.4, 0.8, 0.0, 0.3]], # [3, 2, 5, 1, 0, 4]
63 | [[1.0, 0.9, 0.1, 0.3, 0.2, 0.0]], # [0, 1, 3, 4, 2, 5]
64 | [[0.1, 0.7, 0.5, 0.3, 0.2, 0.0]], # [1, 2, 3, 4, 0, 5]
65 | [[0.2, 0.8, 0.0, 0.1, 0.4, 0.3]], # [1, 4, 5, 0, 3, 2]
66 | [[0.1, 0.2, 0.4, 0.8, 0.0, 0.3]], # [3, 2, 5, 1, 0, 4]
67 | [[1.0, 0.9, 0.1, 0.3, 0.2, 0.0]], # [0, 1, 3, 4, 2, 5]
68 | ]),
69 | 'weights': None,
70 | 'ks': [1, 2, 3, 4, 5, 6],
71 | 'expected_outputs': [0.25, 0.5, 0.75, 0.75, 0.875, 1.0],
72 | },
73 | {
74 | 'testcase_name': 'weighted',
75 | 'y_true': np.array([0, 1, 1, 0]),
76 | 'y_pred': np.array([[0.1, 0.7], [0.2, 0.4], [0.1, 0.3], [0.2, 0.1]]),
77 | 'weights': np.array([0.2, 0.6, 0.1, 0.1]),
78 | 'ks': [1, 2],
79 | 'expected_outputs': np.array([0.8, 1.0]),
80 | },
81 | )
82 | def test_top_k_accuracies(
83 | self,
84 | y_true: np.ndarray,
85 | y_pred: np.ndarray,
86 | weights: np.ndarray | None,
87 | ks: Sequence[int],
88 | expected_outputs: Sequence[float],
89 | ):
90 | for k, expected_output in zip(ks, expected_outputs):
91 | accuracy = mean_metrics.top_k_accuracy(y_true, y_pred, weights, k=k)
92 | np.testing.assert_allclose(expected_output, accuracy.compute())
93 | np.testing.assert_allclose(expected_output, accuracy.localize().compute())
94 |
95 | @parameterized.named_parameters(
96 | {
97 | 'testcase_name': 'unweighted',
98 | 'y_true': np.array([1, 0, 1, 0]),
99 | 'y_pred': np.array([0.4, 0.6, 0.8, 0.2]),
100 | 'weights': None,
101 | 'threshold': 0.5,
102 | 'expected_output': 0.5,
103 | },
104 | {
105 | 'testcase_name': 'weighted',
106 | 'y_true': np.array([[1, 0, 1, 0]]),
107 | 'y_pred': np.array([[0.8, 0.6, 0.7, 0.6]]),
108 | 'weights': np.array([[1, 2, 3, 4]]),
109 | 'threshold': 0.75,
110 | 'expected_output': 0.7,
111 | },
112 | )
113 | def test_binary_accuracy(
114 | self,
115 | y_true: np.ndarray,
116 | y_pred: np.ndarray,
117 | weights: np.ndarray | None,
118 | threshold: float,
119 | expected_output: np.ndarray,
120 | ):
121 | accuracy = mean_metrics.binary_accuracy(
122 | y_true, y_pred, weights, threshold=threshold
123 | )
124 | np.testing.assert_allclose(expected_output, accuracy.compute())
125 | np.testing.assert_allclose(expected_output, accuracy.localize().compute())
126 |
127 | @parameterized.named_parameters(
128 | {
129 | 'testcase_name': 'unweighted',
130 | 'y_true': np.array([0.3, 0.5, 0.7, 0.9]),
131 | 'y_pred': np.array([0.4, 0.6, 0.8, 0.2]),
132 | 'weights': None,
133 | 'expected_output': 0.13,
134 | },
135 | {
136 | 'testcase_name': 'weighted',
137 | 'y_true': np.array([[0.3, 0.6, 0.2, 0.6]]),
138 | 'y_pred': np.array([[0.8, 0.6, 0.7, 0.6]]),
139 | 'weights': np.array([0.5]),
140 | 'expected_output': 0.125,
141 | },
142 | )
143 | def test_mean_squared_error(
144 | self,
145 | y_true: np.ndarray,
146 | y_pred: np.ndarray,
147 | weights: np.ndarray | None,
148 | expected_output: np.ndarray,
149 | ):
150 | mse = mean_metrics.mean_squared_error(y_true, y_pred, weights)
151 | np.testing.assert_allclose(expected_output, mse.compute(), rtol=1e-3)
152 | np.testing.assert_allclose(
153 | expected_output, mse.localize().compute(), rtol=1e-3
154 | )
155 |
156 |
157 | if __name__ == '__main__':
158 | absltest.main()
159 |
--------------------------------------------------------------------------------
/recml/core/metrics/reduction_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Reduction metrics."""
15 |
16 | from __future__ import annotations
17 |
18 | from collections.abc import Callable
19 | import math
20 | from typing import Any, Self
21 |
22 | import jax
23 | import jax.numpy as jnp
24 | from recml.core.metrics import base_metrics
25 |
26 |
27 | class ReductionMetric(base_metrics.Metric):
28 | """A base class for reduction metrics."""
29 |
30 | @classmethod
31 | def from_model_output(
32 | cls, values: jax.Array, weights: jax.Array | None = None, **_
33 | ) -> Self:
34 | raise NotImplementedError()
35 |
36 | @classmethod
37 | def from_fun(cls, fun: Callable[..., Any], **kwargs) -> type[Self]:
38 | """Returns a reduction metric class that is computed from a function."""
39 | base_cls = cls
40 | bound_kwargs = kwargs
41 |
42 | class _FromFun(cls):
43 | """A reduction metric that is computed from a function."""
44 |
45 | @classmethod
46 | def from_model_output(cls, *args, **kwargs) -> ReductionMetric:
47 | if "weights" in kwargs:
48 | weights = kwargs.pop("weights")
49 | else:
50 | weights = None
51 |
52 | values = fun(*args, **bound_kwargs, **kwargs)
53 |
54 | return base_cls.from_model_output(values, weights)
55 |
56 | return _FromFun
57 |
58 |
59 | class Sum(ReductionMetric):
60 | """Computes the sum of observations over multiple batches."""
61 |
62 | total: jax.Array
63 |
64 | @classmethod
65 | def from_model_output(
66 | cls, values: jax.Array, weights: jax.Array | None = None, **_
67 | ) -> Self:
68 | values = jnp.asarray(values, dtype=jnp.float32)
69 | if weights is not None:
70 | weights = jnp.asarray(weights, dtype=jnp.float32)
71 | values, weights = _maybe_reshape_or_broadcast(values, weights)
72 | total = jnp.sum(values * weights)
73 | else:
74 | total = jnp.sum(values)
75 |
76 | return cls(total=total)
77 |
78 | def merge(self, other: Self) -> Self: # pytype: disable=signature-mismatch
79 | return type(self)(total=self.total + other.total)
80 |
81 | def compute(self) -> base_metrics.Scalar:
82 | return self.total
83 |
84 |
85 | class Mean(ReductionMetric):
86 | """Computes the mean of observations over multiple batches.
87 |
88 | This is done by tracking a total and a count over multiple observations and
89 | aggregating their mean over multiple batches when `compute` is called.
90 | """
91 |
92 | total: jax.Array
93 | count: jax.Array
94 |
95 | @classmethod
96 | def from_model_output(
97 | cls, values: jax.Array, weights: jax.Array | None = None, **_
98 | ) -> Self:
99 | values = jnp.asarray(values, dtype=jnp.float32)
100 | if weights is not None:
101 | weights = jnp.asarray(weights, dtype=jnp.float32)
102 | values, weights = _maybe_reshape_or_broadcast(values, weights)
103 | total = jnp.sum(values * weights)
104 | count = jnp.sum(weights)
105 | elif values.ndim >= 1:
106 | total = jnp.sum(values)
107 | count = jnp.asarray(math.prod(values.shape), jnp.float32)
108 | else:
109 | total = values
110 | count = jnp.ones((), dtype=jnp.float32)
111 |
112 | return cls(total=total, count=count)
113 |
114 | def merge(self, other: Self) -> Self: # pytype: disable=signature-mismatch
115 | return type(self)(
116 | total=self.total + other.total,
117 | count=self.count + other.count,
118 | )
119 |
120 | def compute(self) -> base_metrics.Scalar:
121 | return self.total / self.count
122 |
123 |
124 | def mean(values: jax.Array, weights: jax.Array | None = None, **_) -> Mean:
125 | """Computes a mean metric from values and optional weights.
126 |
127 | The resulting metric instance is a reduction metric that will aggregate the
128 | mean of the values over multiple batches.
129 |
130 | The total and counts are computed as follows:
131 | weights = broadcast_to(weights, values.shape)
132 | total = sum(values * weights)
133 | count = sum(weights)
134 |
135 | Where the output of an aggregated metric is the total / count.
136 |
137 | Example usage:
138 |
139 | ```
140 | metrics = {
141 | # Reports the mean accuracy over multiple batches.
142 | 'accuracy': mean(y_true == y_pred),
143 | # Reports the mean loss over multiple batches.
144 | 'loss': mean(loss),
145 | }
146 | ```
147 |
148 | Args:
149 | values: The values to compute the mean over of shape [D1, ..., DN].
150 | weights: Optional weights to apply to the values. If provided, the shape of
151 | the weights must be broadcastable to the shape of the values. If not
152 | provided, all values will effectively have a weight of 1.0.
153 | **_: Unused keyword arguments.
154 |
155 | Returns:
156 | A mean metric accumulation.
157 | """
158 | return Mean.from_model_output(values, weights)
159 |
160 |
161 | def sum(values: jax.Array, weights: jax.Array | None = None, **_) -> Sum: # pylint: disable=redefined-builtin
162 | """Computes a sum metric from values and optional weights.
163 |
164 | The sum is computed as follows:
165 | weights = broadcast_to(weights, values.shape)
166 | total = sum(values * weights)
167 |
168 | Where total is the output of an aggregated metric.
169 |
170 | Example usage:
171 |
172 | ```
173 | metrics = {
174 | # Reports the total number of hits over multiple batches.
175 | 'number_of_hits': sum(y_true == y_pred),
176 | }
177 | ```
178 |
179 | Args:
180 | values: The values to compute the sum over of shape [D1, ..., DN].
181 | weights: Optional weights to apply to the values. If provided, the shape of
182 | the weights must be broadcastable to the shape of the values. If not
183 | provided, all values will effectively have a weight of 1.0.
184 | **_: Unused keyword arguments.
185 |
186 | Returns:
187 | A sum metric accumulation.
188 | """
189 | return Sum.from_model_output(values, weights)
190 |
191 |
192 | def _maybe_reshape_or_broadcast(
193 | values: jax.Array, weights: jax.Array
194 | ) -> tuple[jax.Array, jax.Array]:
195 | """Reshapes or broadcasts arrays to have the same shape or throws an error."""
196 | # Note that we broadcast the weights explicitly so that the sum of the weights
197 | # is not performed on the non-broadcasted array.
198 | if values.shape == weights.shape:
199 | return values, weights
200 | elif values.ndim == weights.ndim and all(
201 | v_d == w_d or w_d == 1 for v_d, w_d in zip(values.shape, weights.shape)
202 | ):
203 | return values, jnp.broadcast_to(weights, values.shape)
204 | elif values.ndim == weights.ndim:
205 | raise ValueError(
206 | f"Got incompatible shapes {values.shape} and {weights.shape}."
207 | )
208 | elif (
209 | values.ndim > weights.ndim
210 | and values.shape[: weights.ndim] == weights.shape
211 | ):
212 | weights = jax.lax.expand_dims(
213 | weights, list(range(weights.ndim, values.ndim))
214 | )
215 | return values, jnp.broadcast_to(weights, values.shape)
216 | elif (
217 | weights.ndim > values.ndim
218 | and weights.shape[: values.ndim] == values.shape
219 | ):
220 | values = jax.lax.expand_dims(values, list(range(values.ndim, weights.ndim)))
221 | return values, weights
222 |
223 | raise ValueError(
224 | "The arrays must have the same shape or the shape of one array must be"
225 | f" a broadcastable to the other. Got shapes: {values.shape} and"
226 | f" {weights.shape}."
227 | )
228 |
--------------------------------------------------------------------------------
/recml/core/metrics/reduction_metrics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for reduction metrics."""
15 |
16 | from collections.abc import Callable, Mapping, Sequence
17 | from typing import Any
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | import numpy as np
22 | from recml.core.metrics import reduction_metrics
23 |
24 |
25 | def mse(y_true, y_pred):
26 | return (y_true - y_pred) ** 2
27 |
28 |
29 | class ReductionMetricsTest(parameterized.TestCase):
30 |
31 | @parameterized.named_parameters(
32 | {
33 | 'testcase_name': 'scalar_weighted_sum',
34 | 'metric': reduction_metrics.sum,
35 | 'args': [0.5, 0.5],
36 | 'kwargs': {},
37 | 'expected_output': 0.25,
38 | },
39 | {
40 | 'testcase_name': 'unweighted_sum',
41 | 'metric': reduction_metrics.sum,
42 | 'args': [np.array([1, 3, 5, 7])],
43 | 'kwargs': {},
44 | 'expected_output': 16.0,
45 | },
46 | {
47 | 'testcase_name': 'weighted_sum',
48 | 'metric': reduction_metrics.sum,
49 | 'args': [np.array([1, 3, 5, 7]), np.array([1, 1, 0, 0])],
50 | 'kwargs': {},
51 | 'expected_output': 4.0,
52 | },
53 | {
54 | 'testcase_name': 'weighted_sum_2d',
55 | 'metric': reduction_metrics.sum,
56 | 'args': [np.array([[1, 3], [5, 7]])],
57 | 'kwargs': {'weights': np.array([[1, 1], [1, 0]])},
58 | 'expected_output': 9.0,
59 | },
60 | {
61 | 'testcase_name': 'weighted_sum_2d_broadcast',
62 | 'metric': reduction_metrics.sum,
63 | 'args': [np.array([[1, 3], [5, 7]]), np.array([[1, 0]])],
64 | 'kwargs': {},
65 | 'expected_output': 6.0,
66 | },
67 | {
68 | 'testcase_name': 'weighted_sum_3d_broadcast',
69 | 'metric': reduction_metrics.sum,
70 | 'args': [
71 | np.array([
72 | [[0.3, 0.7, 0.4, 0.6], [0.5, 0.75, 0.25, 1.5]],
73 | [[0.6, 0.3, 0.1, 1.0], [0.3, 0.7, 0.75, 0.25]],
74 | ])
75 | ],
76 | 'kwargs': {'weights': np.array([[1, 1], [1, 0]])},
77 | 'expected_output': 7.0,
78 | },
79 | {
80 | 'testcase_name': 'unweighted_sum_from_fun',
81 | 'metric': reduction_metrics.Sum.from_fun(mse).from_model_output,
82 | 'args': [
83 | np.array([
84 | [0, 1, 0, 1, 0],
85 | [0, 0, 1, 1, 1],
86 | [1, 1, 1, 1, 0],
87 | [0, 0, 0, 0, 1],
88 | ]),
89 | np.array([
90 | [0, 0, 1, 1, 0],
91 | [1, 1, 1, 1, 1],
92 | [0, 1, 0, 1, 0],
93 | [1, 1, 1, 1, 1],
94 | ]),
95 | ],
96 | 'kwargs': {},
97 | 'expected_output': 10.0,
98 | },
99 | {
100 | 'testcase_name': 'scalar_weighted_mean',
101 | 'metric': reduction_metrics.mean,
102 | 'args': [0.5, 0.5],
103 | 'kwargs': {},
104 | 'expected_output': 0.5,
105 | },
106 | {
107 | 'testcase_name': 'unweighted_mean',
108 | 'metric': reduction_metrics.mean,
109 | 'args': [np.array([1, 3, 5, 7])],
110 | 'kwargs': {},
111 | 'expected_output': 4.0,
112 | },
113 | {
114 | 'testcase_name': 'weighted_mean',
115 | 'metric': reduction_metrics.mean,
116 | 'args': [np.array([1, 3, 5, 7]), np.array([1, 1, 0, 0])],
117 | 'kwargs': {},
118 | 'expected_output': 2.0,
119 | },
120 | {
121 | 'testcase_name': 'weighted_mean_neg_weights',
122 | 'metric': reduction_metrics.mean,
123 | 'args': [np.array([1, 3, 5, 7]), np.array([-1, -1, 0, 0])],
124 | 'kwargs': {},
125 | 'expected_output': 2.0,
126 | },
127 | {
128 | 'testcase_name': 'weighted_mean_2d',
129 | 'metric': reduction_metrics.mean,
130 | 'args': [np.array([[1, 3], [5, 7]])],
131 | 'kwargs': {'weights': np.array([[1, 1], [1, 0]])},
132 | 'expected_output': 3.0,
133 | },
134 | {
135 | 'testcase_name': 'weighted_mean_2d_broadcast',
136 | 'metric': reduction_metrics.mean,
137 | 'args': [np.array([[1, 3], [5, 7]]), np.array([[1, 0]])],
138 | 'kwargs': {},
139 | 'expected_output': 3.0,
140 | },
141 | {
142 | 'testcase_name': 'weighted_mean_3d_broadcast',
143 | 'metric': reduction_metrics.mean,
144 | 'args': [
145 | np.array([
146 | [[0.3, 0.7, 0.4, 0.6], [0.5, 0.75, 0.25, 1.5]],
147 | [[0.6, 0.3, 0.1, 1.0], [0.3, 0.7, 0.75, 0.25]],
148 | ])
149 | ],
150 | 'kwargs': {'weights': np.array([[1, 1], [1, 0]])},
151 | 'expected_output': 7 / 12,
152 | },
153 | {
154 | 'testcase_name': 'unweighted_mean_from_fun',
155 | 'metric': reduction_metrics.Mean.from_fun(mse).from_model_output,
156 | 'args': [
157 | np.array([
158 | [0, 1, 0, 1, 0],
159 | [0, 0, 1, 1, 1],
160 | [1, 1, 1, 1, 0],
161 | [0, 0, 0, 0, 1],
162 | ]),
163 | np.array([
164 | [0, 0, 1, 1, 0],
165 | [1, 1, 1, 1, 1],
166 | [0, 1, 0, 1, 0],
167 | [1, 1, 1, 1, 1],
168 | ]),
169 | ],
170 | 'kwargs': {},
171 | 'expected_output': 0.5,
172 | },
173 | {
174 | 'testcase_name': 'weighted_mean_from_fun',
175 | 'metric': reduction_metrics.Mean.from_fun(mse).from_model_output,
176 | 'args': [
177 | np.array([
178 | [0, 1, 0, 1, 0],
179 | [0, 0, 1, 1, 1],
180 | [1, 1, 1, 1, 0],
181 | [0, 0, 0, 0, 1],
182 | ]),
183 | np.array([
184 | [0, 0, 1, 1, 0],
185 | [1, 1, 1, 1, 1],
186 | [0, 1, 0, 1, 0],
187 | [1, 1, 1, 1, 1],
188 | ]),
189 | ],
190 | 'kwargs': {'weights': np.array([1.0, 1.5, 2.0, 2.5])},
191 | 'expected_output': 0.542857,
192 | },
193 | )
194 | def test_reduction_metric(
195 | self,
196 | metric: Callable[..., reduction_metrics.ReductionMetric],
197 | args: Sequence[Any],
198 | kwargs: Mapping[str, Any],
199 | expected_output: float | np.ndarray,
200 | ):
201 | instance = metric(*args, **kwargs)
202 | np.testing.assert_allclose(expected_output, instance.compute(), 1e-3)
203 | np.testing.assert_allclose(
204 | expected_output, instance.localize().compute(), 1e-3
205 | )
206 |
207 |
208 | if __name__ == '__main__':
209 | absltest.main()
210 |
--------------------------------------------------------------------------------
/recml/core/metrics/tools.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tools for RecML metrics."""
15 |
16 | from collections.abc import Mapping
17 | import concurrent.futures
18 | import functools
19 | import os
20 | import re
21 | from typing import Any
22 |
23 | from absl import flags
24 | from clu import metric_writers
25 | import clu.metrics as clu_metrics
26 | import jax
27 | from recml.core.metrics import base_metrics
28 |
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | class AsyncMultiWriter(metric_writers.AsyncMultiWriter):
34 | """A multi writer that logs to a summary writer and a logging writer."""
35 |
36 | def __init__(self, *, log_dir: str, name: str):
37 | summary_writer = metric_writers.SummaryWriter(
38 | os.fspath(os.path.join(log_dir, name))
39 | )
40 | writers = [summary_writer]
41 |
42 | super().__init__(writers)
43 | self._summary_writer = summary_writer
44 |
45 | @property
46 | def summary_writer(self) -> metric_writers.SummaryWriter:
47 | return self._summary_writer
48 |
49 |
50 | class MetricAccumulator:
51 | """A utility for asynchronously accumulating metrics."""
52 |
53 | def __init__(self, writer: AsyncMultiWriter, max_workers: int = 1):
54 | if not isinstance(writer, AsyncMultiWriter):
55 | raise ValueError(
56 | "`summary_writer` must be an instance of AsyncMultiWriter, got"
57 | f" {type(writer)}."
58 | )
59 |
60 | self._writer = writer
61 | self._metrics: list[Mapping[str, clu_metrics.Metric]] = []
62 | self._scalar_log_pool = concurrent.futures.ThreadPoolExecutor(
63 | max_workers=max_workers
64 | )
65 | self._scalar_log_futures: list[concurrent.futures.Future[None]] = []
66 |
67 | def accumulate(
68 | self, metrics_accum: Mapping[str, clu_metrics.Metric], step: int
69 | ):
70 | """Asynchronously accumulates a set of metrics and logs scalars."""
71 | self._metrics.append(metrics_accum)
72 |
73 | scalar_metrics_accum = {
74 | k: v
75 | for k, v in metrics_accum.items()
76 | if isinstance(v, base_metrics.ScalarMetric)
77 | }
78 |
79 | self._scalar_log_futures.append(
80 | self._scalar_log_pool.submit(
81 | _localize_and_log_scalars,
82 | # We only want to log per-step scalars via the summary writer.
83 | # Logging per-step scalars via other writers can be expensive.
84 | self._writer.summary_writer,
85 | step,
86 | scalar_metrics_accum,
87 | )
88 | )
89 |
90 | def compute_and_log_scalars(
91 | self, step: int
92 | ) -> Mapping[str, base_metrics.Scalar]:
93 | """Computes the scalars from the accumulated metrics and logs them."""
94 |
95 | if not self._metrics:
96 | return {}
97 |
98 | for future in self._scalar_log_futures:
99 | future.result()
100 |
101 | self._scalar_log_futures.clear()
102 |
103 | metrics = functools.reduce(
104 | merge_metrics, [jax.tree.map(_localize, ms) for ms in self._metrics]
105 | )
106 | self._metrics.clear()
107 | scalars = compute_metrics(metrics)
108 |
109 | # Log only non-reported scalars but return all for tracking in checkpoints.
110 | non_reported_scalars = {
111 | k: v
112 | for k, v in scalars.items()
113 | if not isinstance(metrics[k], base_metrics.ScalarMetric)
114 | }
115 | self._writer.write_scalars(step, non_reported_scalars)
116 | self._writer.flush()
117 |
118 | return scalars
119 |
120 |
121 | def compute_metrics(
122 | metrics: Mapping[str, clu_metrics.Metric | base_metrics.Metric],
123 | ) -> Mapping[str, base_metrics.Scalar]:
124 | """Collects the merged metrics and returns the computed scalars."""
125 | return {k: m.compute() for k, m in metrics.items()}
126 |
127 |
128 | def merge_metrics(
129 | a: Mapping[str, clu_metrics.Metric | base_metrics.Metric],
130 | b: Mapping[str, clu_metrics.Metric | base_metrics.Metric],
131 | ) -> Mapping[str, clu_metrics.Metric]:
132 | """Merges two mappings of metrics."""
133 | merged_metrics = {}
134 | for k in [*a.keys(), *b.keys()]:
135 | if k in a and k in b:
136 | merged_metrics[k] = a[k].merge(b[k])
137 | elif k in a:
138 | merged_metrics[k] = a[k]
139 | elif k in b:
140 | merged_metrics[k] = b[k]
141 | return merged_metrics
142 |
143 |
144 | def _localize(x: Any) -> Any:
145 | """Returns the localized data for an object."""
146 | x = jax.device_get(x)
147 | if isinstance(x, jax.Array) and not isinstance(x, jax.core.Tracer):
148 | return x.addressable_data(0)
149 | return x
150 |
151 |
152 | def _localize_and_log_scalars(
153 | summary_writer: metric_writers.SummaryWriter,
154 | step: int,
155 | scalar_metrics: Mapping[str, base_metrics.ScalarMetric],
156 | ) -> None:
157 | """Localizes the metrics from device to host and logs scalars."""
158 | scalar_metrics = jax.tree.map(_localize, scalar_metrics)
159 | summary_writer.write_scalars(step, compute_metrics(scalar_metrics))
160 |
--------------------------------------------------------------------------------
/recml/core/ops/embedding_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Embedding lookup ops."""
15 |
16 | from collections.abc import Mapping, Sequence
17 | import dataclasses
18 | import functools
19 | from typing import Any, TypeVar
20 |
21 | from etils import epy
22 | import jax
23 | from jax.experimental import shard_map
24 |
25 | with epy.lazy_imports():
26 | # pylint: disable=g-import-not-at-top
27 | from jax_tpu_embedding.sparsecore.lib.nn import embedding
28 | # pylint: enable=g-import-not-at-top
29 |
30 |
31 | T = TypeVar("T")
32 | Nested = T | Sequence[T] | Mapping[str, T]
33 | FeatureSpec = Any
34 |
35 |
36 | @dataclasses.dataclass
37 | class SparsecoreParams:
38 | """Embedding parameters."""
39 |
40 | feature_specs: Nested[FeatureSpec]
41 | abstract_mesh: jax.sharding.AbstractMesh
42 | data_axes: Sequence[str | None]
43 | embedding_axes: Sequence[str | None]
44 | sharding_strategy: str
45 |
46 |
47 | @functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
48 | def sparsecore_lookup(
49 | sparsecore_params: SparsecoreParams,
50 | tables: Mapping[str, tuple[jax.Array, ...]],
51 | csr_inputs: tuple[jax.Array, ...],
52 | ):
53 | return shard_map.shard_map(
54 | functools.partial(
55 | embedding.tpu_sparse_dense_matmul,
56 | global_device_count=sparsecore_params.abstract_mesh.size,
57 | feature_specs=sparsecore_params.feature_specs,
58 | sharding_strategy=sparsecore_params.sharding_strategy,
59 | ),
60 | mesh=sparsecore_params.abstract_mesh,
61 | in_specs=(
62 | jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
63 | jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
64 | ),
65 | out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
66 | check_rep=False,
67 | )(csr_inputs, tables)
68 |
69 |
70 | def _emb_lookup_fwd(
71 | sparsecore_params: SparsecoreParams,
72 | tables: Mapping[str, tuple[jax.Array, ...]],
73 | csr_inputs: tuple[jax.Array, ...],
74 | ):
75 | out = sparsecore_lookup(sparsecore_params, tables, csr_inputs)
76 | return out, (tables, csr_inputs)
77 |
78 |
79 | def _emb_lookup_bwd(
80 | sparsecore_params: SparsecoreParams,
81 | res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]],
82 | gradients: Nested[jax.Array],
83 | ) -> tuple[Nested[jax.Array], None]:
84 | """Backward pass for embedding lookup."""
85 | (tables, csr_inputs) = res
86 |
87 | emb_table_grads = shard_map.shard_map(
88 | functools.partial(
89 | embedding.tpu_sparse_dense_matmul_grad,
90 | feature_specs=sparsecore_params.feature_specs,
91 | sharding_strategy=sparsecore_params.sharding_strategy,
92 | ),
93 | mesh=sparsecore_params.abstract_mesh,
94 | in_specs=(
95 | jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
96 | jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
97 | jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
98 | ),
99 | out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
100 | check_rep=False,
101 | )(gradients, csr_inputs, tables)
102 |
103 | # `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict).
104 | # It may not be the same type as the embedding table (e.g. FrozenDict).
105 | # Here we use flatten / unflatten to ensure the types are the same.
106 | emb_table_grads = jax.tree.unflatten(
107 | jax.tree.structure(tables), jax.tree.leaves(emb_table_grads)
108 | )
109 |
110 | return emb_table_grads, None
111 |
112 |
113 | sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)
114 |
--------------------------------------------------------------------------------
/recml/core/training/core.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Core training library for Jax."""
15 |
16 | import abc
17 | from collections.abc import Mapping, Sequence
18 | import dataclasses
19 | import enum
20 | from typing import Any, Generic, TypeVar
21 |
22 | import jax
23 | import jax.numpy as jnp
24 | from recml.core.data import iterator
25 | import tensorflow as tf
26 |
27 |
28 | # pylint: disable=logging-fstring-interpolation
29 |
30 | LOG_DIR = "logs"
31 | BACKUP_DIR = "backup"
32 | CHECKPOINT_DIR = "checkpoints"
33 | TRAINING_COMPLETE_MARKER_FILE = "marker.txt"
34 | TRAIN_LOG_DIRNAME = "train"
35 | EVAL_LOG_DIRNAME = "val"
36 | KERAS_MODEL_SAVEFILE = "model.keras"
37 | ORBAX_CHECKPOINT_DEFAULT_KEY = "default"
38 |
39 | DEFAULT_RNG_SEED = 0
40 | IN_TRAINER_CONTEXT = False # Set to true when run from the main trainer.
41 | STATE_CHECKPOINT_KEY = "state"
42 |
43 | TaskT = TypeVar("TaskT")
44 | DatasetT = TypeVar(
45 | "DatasetT",
46 | tf.data.Dataset,
47 | tuple[tf.data.Dataset, tf.data.Dataset],
48 | tuple[tf.data.Dataset, Mapping[str, tf.data.Dataset]],
49 | iterator.Iterator,
50 | tuple[iterator.Iterator, iterator.Iterator],
51 | tuple[iterator.Iterator, Mapping[str, iterator.Iterator]],
52 | )
53 | MetaT = TypeVar("MetaT")
54 | Logs = Any # Any metric logs returned by the training or evaluation task.
55 |
56 |
57 | class Trainer(abc.ABC, Generic[TaskT]):
58 | """A base trainer interface for training and evaluation."""
59 |
60 | @abc.abstractmethod
61 | def __init__(self, model_dir: str, *args, **kwargs):
62 | """Initializes the instance."""
63 |
64 | @abc.abstractmethod
65 | def train(self, task: TaskT, *args, **kwargs) -> Logs | None:
66 | """Performs training for a fixed number of steps."""
67 |
68 | @abc.abstractmethod
69 | def evaluate(self, task: TaskT, *args, **kwargs) -> Logs | None:
70 | """Performs evaluation for a fixed number of steps."""
71 |
72 | @abc.abstractmethod
73 | def train_and_evaluate(self, task: TaskT, *args, **kwargs) -> Logs | None:
74 | """Performs training and evaluation for a fixed number of steps."""
75 |
76 | @abc.abstractmethod
77 | def evaluate_continuously(self, task: TaskT, *args, **kwargs) -> Logs | None:
78 | """Performs continuous evaluation until a condition is met."""
79 |
80 |
81 | @dataclasses.dataclass(frozen=True)
82 | class Experiment(Generic[TaskT]):
83 | """Experiment definition.
84 |
85 | Properties:
86 | Mode: The mode to run the experiment in.
87 |
88 | Attributes:
89 | task: A user defined task that defines the training and evaluation logic.
90 | trainer: The trainer to use for the experiment.
91 | """
92 |
93 | class Mode(enum.StrEnum):
94 | """Mode to run an experiment."""
95 |
96 | TRAIN = "train"
97 | EVAL = "eval"
98 | TRAIN_AND_EVAL = "train_and_eval"
99 | CONTINUOUS_EVAL = "continuous_eval"
100 |
101 | task: TaskT
102 | trainer: Trainer[TaskT]
103 |
104 |
105 | def run_experiment(
106 | experiment: Experiment, mode: Experiment.Mode
107 | ) -> Logs | None:
108 | """Runs an experiment."""
109 | if mode == Experiment.Mode.TRAIN_AND_EVAL:
110 | return experiment.trainer.train_and_evaluate(experiment.task)
111 | elif mode == Experiment.Mode.TRAIN:
112 | return experiment.trainer.train(experiment.task)
113 | elif mode == Experiment.Mode.EVAL:
114 | return experiment.trainer.evaluate(experiment.task)
115 | elif mode == Experiment.Mode.CONTINUOUS_EVAL:
116 | return experiment.trainer.evaluate_continuously(experiment.task)
117 | else:
118 | raise ValueError(f"The job mode provided is not supported: {mode}.")
119 |
120 |
121 | def get_iterators(
122 | datasets: DatasetT,
123 | ) -> tuple[iterator.Iterator, Mapping[str, iterator.Iterator]]:
124 | """Creates and unpacks the datasets returned by the task."""
125 | if isinstance(datasets, (iterator.Iterator, tf.data.Dataset)):
126 | if isinstance(datasets, tf.data.Dataset):
127 | datasets = iterator.TFDatasetIterator(datasets)
128 | return datasets, {}
129 | elif not isinstance(datasets, tuple) and len(datasets) != 2:
130 | raise ValueError(
131 | "Expected `datasets` to be a single dataset or a tuple of training"
132 | f" and evaluation datasets, but got {type(datasets)}."
133 | )
134 |
135 | train_dataset, eval_datasets = datasets
136 | if isinstance(train_dataset, (iterator.Iterator, tf.data.Dataset)):
137 | if isinstance(train_dataset, tf.data.Dataset):
138 | train_dataset = iterator.TFDatasetIterator(train_dataset)
139 | else:
140 | raise ValueError(
141 | "Expected the training dataset in `datasets` to be a"
142 | " `tf.data.Dataset` or CLU `DatasetIterator` instance, but"
143 | f" {type(train_dataset)}."
144 | )
145 |
146 | if isinstance(eval_datasets, (iterator.Iterator, tf.data.Dataset)):
147 | if isinstance(eval_datasets, tf.data.Dataset):
148 | eval_datasets = iterator.TFDatasetIterator(eval_datasets)
149 | return train_dataset, {"": eval_datasets}
150 |
151 | if not isinstance(eval_datasets, Mapping):
152 | raise ValueError(
153 | "Expected the evaluation dataset in `datasets` to either be a"
154 | " `tf.data.Dataset` or CLU `DatasetIterator` instance or be a"
155 | " mapping of datasets keyed by name, but got"
156 | f" {type(eval_datasets)}."
157 | )
158 |
159 | if all(isinstance(v, tf.data.Dataset) for v in eval_datasets.values()):
160 | eval_datasets = {
161 | k: iterator.TFDatasetIterator(v) for k, v in eval_datasets.items()
162 | }
163 |
164 | if not all(
165 | isinstance(v, iterator.Iterator) for v in eval_datasets.values()
166 | ):
167 | raise ValueError(
168 | "Expected all values in the evaluation datasets mapping to be either"
169 | " `tf.data.Dataset` instances or CLU `DatasetIterator` instances,"
170 | f" but got {eval_datasets}. You cannot mix both."
171 | )
172 |
173 | return train_dataset, eval_datasets # pytype: disable=bad-return-type
174 |
175 |
176 | def get_shape(
177 | x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | tf.TensorSpec,
178 | ) -> Sequence[int | None]:
179 | """Gets the shape of a dense / sparse / ragged tensor or tensor spec."""
180 | if isinstance(x, tf.SparseTensor):
181 | return [x.shape[0]] + [None for _ in x.shape[1:]]
182 | return x.shape.as_list()
183 |
184 |
185 | def in_tracing_context() -> bool:
186 | """Returns whether the current context is a tracing context."""
187 | return isinstance(jnp.ones(()), jax.core.Tracer)
188 |
--------------------------------------------------------------------------------
/recml/core/training/jax_trainer_quality_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for the quality of training loops."""
15 |
16 | from collections.abc import Mapping
17 | import functools
18 |
19 | from absl import flags
20 | from absl.testing import absltest
21 | import clu.metrics as clu_metrics
22 | import flax.linen as nn
23 | from flax.training import train_state as ts
24 | import jax
25 | import jax.numpy as jnp
26 | import jaxtyping as jt
27 | import optax
28 | from recml.core.training import jax_trainer
29 | from recml.core.training import partitioning
30 | import tensorflow as tf
31 | import tensorflow_datasets as tfds
32 |
33 |
34 | class _MNISTTask(jax_trainer.JaxTask):
35 | """Task for fitting a CNN on MNIST."""
36 |
37 | def create_datasets(self) -> tuple[tf.data.Dataset, tf.data.Dataset]:
38 |
39 | def _preprocessor(batch: jt.PyTree) -> jt.PyTree:
40 | images = batch['image']
41 | labels = batch['label']
42 | images = tf.cast(images, tf.float32) / 255.0
43 | labels = tf.cast(labels, tf.int32)
44 | return images, labels
45 |
46 | def _create_dataset(training: bool) -> tf.data.Dataset:
47 | dataset = tfds.load(
48 | name='mnist',
49 | split='train' if training else 'test',
50 | batch_size=32,
51 | shuffle_files=training,
52 | )
53 | return dataset.map(_preprocessor).prefetch(buffer_size=tf.data.AUTOTUNE)
54 |
55 | return _create_dataset(training=True), _create_dataset(training=False)
56 |
57 | def create_state(self, batch: jt.PyTree, rng: jax.Array) -> ts.TrainState:
58 | images, _ = batch
59 | model = nn.Sequential([
60 | nn.Conv(32, kernel_size=(3, 3)),
61 | nn.relu,
62 | functools.partial(nn.max_pool, window_shape=(2, 2), strides=(2, 2)),
63 | nn.Conv(64, kernel_size=(3, 3)),
64 | nn.relu,
65 | functools.partial(nn.max_pool, window_shape=(2, 2), strides=(2, 2)),
66 | lambda x: x.reshape((x.shape[0], -1)),
67 | nn.Dense(256),
68 | nn.relu,
69 | nn.Dense(10),
70 | ])
71 | variables = model.init(rng, jnp.zeros_like(images))
72 | optimizer = optax.sgd(0.1)
73 | return ts.TrainState.create(
74 | apply_fn=model.apply, params=variables, tx=optimizer
75 | )
76 |
77 | def train_step(
78 | self, batch: jt.PyTree, state: ts.TrainState, rng: jax.Array
79 | ) -> tuple[ts.TrainState, Mapping[str, clu_metrics.Metric]]:
80 | images, labels = batch
81 |
82 | def _loss_fn(params):
83 | logits = state.apply_fn(params, images)
84 | loss = jnp.mean(
85 | optax.softmax_cross_entropy_with_integer_labels(logits, labels),
86 | axis=0,
87 | )
88 | return loss, (logits, labels)
89 |
90 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
91 | (loss, (logits, labels)), grads = grad_fn(state.params)
92 | state = state.apply_gradients(grads=grads)
93 | metrics = {
94 | 'loss': clu_metrics.Average.from_model_output(loss),
95 | 'accuracy': clu_metrics.Accuracy.from_model_output(
96 | logits=logits, labels=labels
97 | ),
98 | }
99 | return state, metrics
100 |
101 | def eval_step(
102 | self, batch: jt.PyTree, state: ts.TrainState
103 | ) -> Mapping[str, clu_metrics.Metric]:
104 | images, labels = batch
105 | logits = state.apply_fn(state.params, images)
106 | loss = jnp.mean(
107 | optax.softmax_cross_entropy_with_integer_labels(logits, labels)
108 | )
109 | metrics = {
110 | 'loss': clu_metrics.Average.from_model_output(loss),
111 | 'accuracy': clu_metrics.Accuracy.from_model_output(
112 | logits=logits, labels=labels
113 | ),
114 | }
115 | return metrics
116 |
117 |
118 | class JaxQualityTest(absltest.TestCase):
119 |
120 | def setUp(self):
121 | super().setUp()
122 | # Workaround to make `create_tempdir` work with pytest.
123 | if not flags.FLAGS.is_parsed():
124 | flags.FLAGS.mark_as_parsed()
125 |
126 | def test_mnist_e2e(self):
127 | model_dir = self.create_tempdir().full_path
128 | task = _MNISTTask()
129 | trainer = jax_trainer.JaxTrainer(
130 | partitioner=partitioning.DataParallelPartitioner(),
131 | train_steps=1000,
132 | steps_per_eval=50,
133 | steps_per_loop=100,
134 | continuous_eval_timeout=5,
135 | model_dir=model_dir,
136 | rng_seed=42,
137 | )
138 | logs = trainer.train_and_evaluate(task)
139 | self.assertGreater(logs['train']['accuracy'], 0.95)
140 | self.assertGreater(logs['val']['accuracy'], 0.95)
141 |
142 | self.assertTrue(tf.io.gfile.exists(model_dir))
143 | continuous_eval_logs = trainer.evaluate_continuously(task)
144 | self.assertGreater(continuous_eval_logs['val']['accuracy'], 0.95)
145 |
146 |
147 | if __name__ == '__main__':
148 | absltest.main()
149 |
--------------------------------------------------------------------------------
/recml/core/training/jax_trainer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for Jax task and trainer."""
15 |
16 | from collections.abc import Mapping, Sequence
17 | import dataclasses
18 | import os
19 |
20 | from absl import flags
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 | import clu.metrics as clu_metrics
24 | import flax.linen as nn
25 | from flax.training import train_state as ts
26 | import jax
27 | import jax.numpy as jnp
28 | import jaxtyping as jt
29 | import keras
30 | import optax
31 | import orbax.checkpoint as ocp
32 | from recml.core.training import core
33 | from recml.core.training import jax_trainer
34 | from recml.core.training import partitioning
35 | import tensorflow as tf
36 |
37 |
38 | class _DummyFlaxModel(nn.Module):
39 |
40 | @nn.compact
41 | def __call__(self, inputs: jax.Array) -> jax.Array:
42 | return nn.Dense(1, kernel_init=nn.initializers.constant(-1.0))(inputs)
43 |
44 |
45 | class _JaxTask(jax_trainer.JaxTask):
46 |
47 | def create_datasets(
48 | self,
49 | ) -> tuple[tf.data.Dataset, Mapping[str, tf.data.Dataset]]:
50 | def _map_fn(x: int):
51 | return (tf.cast(x, tf.float32), 0.1 * tf.cast(x, tf.float32) + 3)
52 |
53 | return tf.data.Dataset.range(1000).map(_map_fn).batch(2), {
54 | "eval_on_train": tf.data.Dataset.range(1000).map(_map_fn).batch(2),
55 | "eval_on_test": tf.data.Dataset.range(2000).map(_map_fn).batch(2),
56 | }
57 |
58 | def create_state(self, batch: jt.PyTree, rng: jax.Array) -> ts.TrainState:
59 | x, _ = batch
60 | model = _DummyFlaxModel()
61 | params = model.init(rng, x)
62 | optimizer = optax.adagrad(0.1)
63 | return ts.TrainState.create(
64 | apply_fn=model.apply,
65 | params=params,
66 | tx=optimizer,
67 | )
68 |
69 | def train_step(
70 | self, batch: jt.PyTree, state: ts.TrainState, rng: jax.Array
71 | ) -> tuple[ts.TrainState, Mapping[str, clu_metrics.Metric]]:
72 | x, y = batch
73 |
74 | def _loss_fn(params):
75 | y_pred = state.apply_fn(params, x)
76 | loss = keras.losses.mean_squared_error(y, y_pred)
77 | return loss
78 |
79 | grad_fn = jax.value_and_grad(_loss_fn)
80 | loss, grads = grad_fn(state.params)
81 | state = state.apply_gradients(grads=grads)
82 | return state, {"loss": clu_metrics.Average.from_model_output(loss)}
83 |
84 | def eval_step(
85 | self, batch: jt.PyTree, state: ts.TrainState
86 | ) -> Mapping[str, clu_metrics.Metric]:
87 | x, y = batch
88 | y_pred = state.apply_fn(state.params, x)
89 | loss = keras.losses.mean_squared_error(y, y_pred)
90 | return {"loss": clu_metrics.Average.from_model_output(loss)}
91 |
92 |
93 | class _KerasJaxTask(jax_trainer.JaxTask):
94 |
95 | def create_datasets(self) -> tf.data.Dataset:
96 | def _map_fn(x: int):
97 | return (
98 | tf.expand_dims(tf.cast(x, tf.float32), axis=-1),
99 | 0.1 * tf.cast(x, tf.float32) + 3,
100 | )
101 |
102 | return (
103 | tf.data.Dataset.range(1000).map(_map_fn).batch(2),
104 | tf.data.Dataset.range(2000).map(_map_fn).batch(2),
105 | )
106 |
107 | def create_state(
108 | self, batch: jt.PyTree, rng: jax.Array
109 | ) -> jax_trainer.KerasState:
110 | x, _ = batch
111 |
112 | model = keras.Sequential(
113 | [
114 | keras.layers.Dense(
115 | 1,
116 | kernel_initializer=keras.initializers.constant(-1.0),
117 | name="dense",
118 | ),
119 | ],
120 | name="model",
121 | )
122 | model.build(x.shape)
123 |
124 | optimizer = optax.adagrad(0.1)
125 | return jax_trainer.KerasState.create(model=model, tx=optimizer)
126 |
127 | def train_step(
128 | self, batch: jt.PyTree, state: jax_trainer.KerasState, rng: jax.Array
129 | ) -> tuple[jax_trainer.KerasState, Mapping[str, clu_metrics.Metric]]:
130 | x, y = batch
131 |
132 | def _loss_fn(tvars):
133 | y_pred, _ = state.model.stateless_call(tvars, state.ntvars, x)
134 | loss = keras.ops.mean(keras.losses.mean_squared_error(y, y_pred))
135 | return loss
136 |
137 | grad_fn = jax.value_and_grad(_loss_fn)
138 | loss, grads = grad_fn(state.tvars)
139 | state = state.update(grads=grads)
140 | return state, {"loss": clu_metrics.Average.from_model_output(loss)}
141 |
142 | def eval_step(
143 | self, batch: jt.PyTree, state: jax_trainer.KerasState
144 | ) -> Mapping[str, clu_metrics.Metric]:
145 | x, y = batch
146 | y_pred, _ = state.model.stateless_call(state.tvars, state.ntvars, x)
147 | loss = keras.losses.mean_squared_error(y, y_pred)
148 | return {"loss": clu_metrics.Average.from_model_output(loss)}
149 |
150 |
151 | class JaxTest(parameterized.TestCase):
152 |
153 | def setUp(self):
154 | super().setUp()
155 | # Workaround to make `create_tempdir` work with pytest.
156 | if not flags.FLAGS.is_parsed():
157 | flags.FLAGS.mark_as_parsed()
158 |
159 | @parameterized.named_parameters(
160 | {
161 | "testcase_name": "jax_task_train",
162 | "task_cls": _JaxTask,
163 | "mode": core.Experiment.Mode.TRAIN,
164 | "expected_keys": ["train"],
165 | },
166 | {
167 | "testcase_name": "keras_jax_task_train",
168 | "task_cls": _KerasJaxTask,
169 | "mode": core.Experiment.Mode.TRAIN,
170 | "expected_keys": ["train"],
171 | },
172 | {
173 | "testcase_name": "jax_task_eval",
174 | "task_cls": _JaxTask,
175 | "mode": core.Experiment.Mode.EVAL,
176 | "expected_keys": ["val_eval_on_train", "val_eval_on_test"],
177 | },
178 | {
179 | "testcase_name": "keras_jax_task_eval",
180 | "task_cls": _KerasJaxTask,
181 | "mode": core.Experiment.Mode.EVAL,
182 | "expected_keys": ["val"],
183 | },
184 | {
185 | "testcase_name": "jax_task_train_and_eval",
186 | "task_cls": _JaxTask,
187 | "mode": core.Experiment.Mode.TRAIN_AND_EVAL,
188 | "expected_keys": ["train", "val_eval_on_train", "val_eval_on_test"],
189 | },
190 | {
191 | "testcase_name": "keras_jax_task_train_and_eval",
192 | "task_cls": _KerasJaxTask,
193 | "mode": core.Experiment.Mode.TRAIN_AND_EVAL,
194 | "expected_keys": ["train", "val"],
195 | },
196 | {
197 | "testcase_name": "jax_task_continuous_eval",
198 | "task_cls": _JaxTask,
199 | "mode": core.Experiment.Mode.CONTINUOUS_EVAL,
200 | "expected_keys": ["val_eval_on_train", "val_eval_on_test"],
201 | },
202 | {
203 | "testcase_name": "keras_jax_task_continuous_eval",
204 | "task_cls": _KerasJaxTask,
205 | "mode": core.Experiment.Mode.CONTINUOUS_EVAL,
206 | "expected_keys": ["val"],
207 | },
208 | )
209 | def test_jax_trainer(
210 | self,
211 | task_cls: type[jax_trainer.JaxTask],
212 | mode: str,
213 | expected_keys: Sequence[str],
214 | ):
215 | model_dir = self.create_tempdir().full_path
216 | task = task_cls()
217 | trainer = jax_trainer.JaxTrainer(
218 | partitioner=partitioning.DataParallelPartitioner(data_axis="batch"),
219 | train_steps=12,
220 | steps_per_eval=3,
221 | steps_per_loop=4,
222 | model_dir=model_dir,
223 | continuous_eval_timeout=5,
224 | )
225 | experiment = core.Experiment(task, trainer)
226 | if mode == core.Experiment.Mode.CONTINUOUS_EVAL:
227 | # Produce one checkpoint so there is something to evaluate.
228 | core.run_experiment(experiment, core.Experiment.Mode.TRAIN)
229 | logs = core.run_experiment(experiment, mode)
230 |
231 | for key in expected_keys:
232 | self.assertIn(key, logs)
233 | self.assertIn("loss", logs[key])
234 |
235 | if mode in [
236 | core.Experiment.Mode.TRAIN,
237 | core.Experiment.Mode.TRAIN_AND_EVAL,
238 | ]:
239 | checkpointed_steps = ocp.utils.checkpoint_steps(
240 | os.path.join(model_dir, core.CHECKPOINT_DIR)
241 | )
242 | self.assertEqual([3, 7, 11], sorted(checkpointed_steps))
243 |
244 | # TODO(aahil): Check the logs for the correct summaries.
245 | # TODO(aahil): Test exporting here.
246 |
247 | def test_optimizer_metrics(self):
248 | @dataclasses.dataclass
249 | class State:
250 | step: int
251 | opt_state: optax.OptState
252 |
253 | tx = optax.chain(
254 | optax.clip_by_global_norm(1.0),
255 | optax.scale_by_adam(),
256 | optax.inject_stateful_hyperparams(optax.scale_by_learning_rate)(
257 | learning_rate=0.1
258 | ),
259 | )
260 | state = State(step=10, opt_state=tx.init({"a": jnp.ones((10, 10))}))
261 | metrics = jax_trainer._state_metrics(state)
262 | self.assertIn("optimizer/learning_rate", metrics)
263 | self.assertEqual(metrics["optimizer/learning_rate"].compute(), 0.1)
264 |
265 |
266 | if __name__ == "__main__":
267 | absltest.main()
268 |
--------------------------------------------------------------------------------
/recml/core/training/keras_trainer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for Jax training library."""
15 |
16 | from absl import flags
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import keras
20 | from recml.core.training import core
21 | from recml.core.training import keras_trainer
22 | import tensorflow as tf
23 |
24 |
25 | class _KerasTask(keras_trainer.KerasTask):
26 |
27 | def create_dataset(self, training: bool) -> tf.data.Dataset:
28 | def _map_fn(x: int):
29 | return (tf.cast(x, tf.float32), 0.1 * tf.cast(x, tf.float32) + 3)
30 |
31 | return tf.data.Dataset.range(1000).map(_map_fn).batch(2)
32 |
33 | def create_model(self) -> keras.Model:
34 | inputs = keras.Input(shape=(1,), dtype=tf.float32)
35 | outputs = keras.layers.Dense(
36 | 1, kernel_initializer=keras.initializers.constant(-1.0)
37 | )(inputs)
38 | model = keras.Model(inputs=inputs, outputs=outputs)
39 | model.compile(
40 | optimizer=keras.optimizers.Adagrad(0.1),
41 | loss=keras.losses.MeanSquaredError(),
42 | )
43 | return model
44 |
45 |
46 | class KerasTrainerTest(parameterized.TestCase):
47 |
48 | def setUp(self):
49 | super().setUp()
50 | # Workaround to make `create_tempdir` work with pytest.
51 | if not flags.FLAGS.is_parsed():
52 | flags.FLAGS.mark_as_parsed()
53 |
54 | @parameterized.named_parameters(
55 | {"testcase_name": "train", "mode": core.Experiment.Mode.TRAIN},
56 | {"testcase_name": "eval", "mode": core.Experiment.Mode.EVAL},
57 | {
58 | "testcase_name": "train_and_eval",
59 | "mode": core.Experiment.Mode.TRAIN_AND_EVAL,
60 | },
61 | {
62 | "testcase_name": "continuous_eval",
63 | "mode": core.Experiment.Mode.CONTINUOUS_EVAL,
64 | },
65 | )
66 | def test_keras_task_and_trainer(self, mode: str):
67 | if keras.backend.backend() == "jax":
68 | distribution = keras.distribution.DataParallel()
69 | else:
70 | distribution = None
71 | if mode == core.Experiment.Mode.CONTINUOUS_EVAL:
72 | self.skipTest("Continuous eval is only supported on the Jax backend.")
73 |
74 | trainer = keras_trainer.KerasTrainer(
75 | distribution=distribution,
76 | train_steps=5,
77 | steps_per_eval=3,
78 | steps_per_loop=2,
79 | model_dir=self.create_tempdir().full_path,
80 | continuous_eval_timeout=5,
81 | )
82 | experiment = core.Experiment(_KerasTask(), trainer)
83 |
84 | if mode == core.Experiment.Mode.CONTINUOUS_EVAL:
85 | # Produce one checkpoint so there is something to evaluate.
86 | core.run_experiment(experiment, core.Experiment.Mode.TRAIN)
87 |
88 | history = core.run_experiment(experiment, mode)
89 |
90 | if (
91 | mode
92 | in [core.Experiment.Mode.TRAIN, core.Experiment.Mode.TRAIN_AND_EVAL]
93 | and keras.backend.backend() == "jax"
94 | ):
95 | self.assertEqual(history.history["num_params/trainable"][0], 2)
96 |
97 |
98 | if __name__ == "__main__":
99 | absltest.main()
100 |
--------------------------------------------------------------------------------
/recml/core/training/optax_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Optax optimizer factories."""
15 |
16 | from collections.abc import Callable
17 | import dataclasses
18 | import re
19 | from typing import Any
20 |
21 | import jax
22 | import optax
23 | from recml.core.utils import types
24 |
25 |
26 | def _default_weight_decay_mask(params: optax.Params) -> optax.Params:
27 | """Default weight decay mask that only applies to non-1D parameters."""
28 | return jax.tree.map(lambda p: p.ndim > 1, params)
29 |
30 |
31 | def _regex_mask(regex: str) -> Callable[[optax.Params], optax.Params]:
32 | """Returns a mask that applies to parameters matching a regex."""
33 |
34 | def _matches_regex(path: tuple[str, ...], _: Any) -> bool:
35 | key = '/'.join([jax.tree_util.keystr((k,), simple=True) for k in path])
36 | return re.fullmatch(regex, key) is not None
37 |
38 | def _mask(params: optax.Params) -> optax.Params:
39 | return jax.tree.map_with_path(_matches_regex, params)
40 |
41 | return _mask
42 |
43 |
44 | class OptimizerFactory(types.Factory[optax.GradientTransformation]):
45 | """Standard optimizer factory for Optax optimizers.
46 |
47 | Attributes:
48 | learning_rate: The learning rate to use for the optimizer.
49 | scaling: The gradient scaling transformation to use during optimization.
50 | Defaults to identity.
51 | weight_decay: Optional weight decay to apply to variables during
52 | optimization. Defaults to None.
53 | grad_clip_norm: Optional gradient clipping norm to limit the maximum
54 | magnitude of the gradients during optimization. Defaults to None.
55 | weight_decay_mask: The weight decay mask to use when applying weight decay.
56 | Defaults applying weight decay to all non-1D parameters.
57 | freeze_mask: Optional mask to freeze parameters during optimization.
58 | Defaults to None.
59 |
60 | Example usage:
61 |
62 | ```
63 | sgd = OptimizerFactory(learning_rate=0.001).make()
64 |
65 | adamw = OptimizerFactory(
66 | learning_rate=1e-3,
67 | scale_transform=optax.scale_by_adam(),
68 | weight_decay=1e-7,
69 | grad_clip_norm=1.0,
70 | ).make()
71 | ```
72 | """
73 |
74 | learning_rate: optax.ScalarOrSchedule
75 | scaling: optax.GradientTransformation = dataclasses.field(
76 | default_factory=optax.identity
77 | )
78 | weight_decay: float | None = None
79 | grad_clip_norm: float | None = None
80 | weight_decay_mask: str | Callable[[optax.Params], optax.Params] = (
81 | _default_weight_decay_mask
82 | )
83 | freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
84 |
85 | def make(self) -> optax.GradientTransformation:
86 | if self.grad_clip_norm is not None:
87 | apply_clipping = optax.clip_by_global_norm(self.grad_clip_norm)
88 | else:
89 | apply_clipping = optax.identity()
90 |
91 | # Tags the learning rate as a stateful hyperparameter so it can be logged.
92 | lr_scaling = optax.inject_stateful_hyperparams(
93 | optax.scale_by_learning_rate
94 | )(learning_rate=self.learning_rate)
95 |
96 | if self.weight_decay is not None:
97 | if isinstance(self.weight_decay_mask, str):
98 | mask = _regex_mask(self.weight_decay_mask)
99 | else:
100 | mask = self.weight_decay_mask
101 | weight_decay = optax.add_decayed_weights(self.weight_decay, mask=mask)
102 | else:
103 | weight_decay = optax.identity()
104 |
105 | tx = optax.chain(*[
106 | apply_clipping,
107 | self.scaling,
108 | weight_decay,
109 | lr_scaling,
110 | ])
111 |
112 | if self.freeze_mask is not None:
113 | if isinstance(self.freeze_mask, str):
114 | mask = _regex_mask(self.freeze_mask)
115 | else:
116 | mask = self.freeze_mask
117 |
118 | def _param_labels(params: optax.Params) -> optax.Params:
119 | return jax.tree.map(
120 | lambda p: 'frozen' if mask(p) else 'trainable', params
121 | )
122 |
123 | tx = optax.multi_transform(
124 | transforms={'trainable': tx, 'frozen': optax.set_to_zero()},
125 | param_labels=_param_labels,
126 | )
127 | return tx
128 |
129 |
130 | class AdamFactory(types.Factory[optax.GradientTransformation]):
131 | """Adam optimizer factory.
132 |
133 | Attributes:
134 | learning_rate: The learning rate to use for the optimizer.
135 | b1: The beta1 coefficient for the Adam optimizer. Defaults to 0.9.
136 | b2: The beta2 coefficient for the Adam optimizer. Defaults to 0.999.
137 | eps: The epsilon coefficient for the Adam optimizer. Defaults to 1e-8.
138 | weight_decay: Optional weight decay to apply to variables during
139 | optimization. Defaults to None.
140 | grad_clip_norm: Optional gradient clipping norm to limit the maximum
141 | magnitude of the gradients during optimization. Defaults to None.
142 | weight_decay_mask: The weight decay mask to use when applying weight decay.
143 | Defaults applying weight decay to all non-1D parameters.
144 | freeze_mask: Optional mask to freeze parameters during optimization.
145 | Defaults to None.
146 |
147 | Example usage:
148 | ```
149 | adam = AdamFactory(learning_rate=1e-3).make()
150 |
151 | adamw = AdamFactory(
152 | learning_rate=1e-3,
153 | weight_decay=1e-7,
154 | grad_clip_norm=1.0,
155 | ).make()
156 | ```
157 | """
158 |
159 | learning_rate: optax.ScalarOrSchedule
160 | b1: float = 0.9
161 | b2: float = 0.999
162 | eps: float = 1e-8
163 | weight_decay: float | None = None
164 | grad_clip_norm: float | None = None
165 | weight_decay_mask: str | Callable[[optax.Params], optax.Params] = (
166 | _default_weight_decay_mask
167 | )
168 | freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
169 |
170 | def make(self) -> optax.GradientTransformation:
171 | return OptimizerFactory(
172 | learning_rate=self.learning_rate,
173 | scaling=optax.scale_by_adam(b1=self.b1, b2=self.b2, eps=self.eps),
174 | weight_decay=self.weight_decay,
175 | grad_clip_norm=self.grad_clip_norm,
176 | weight_decay_mask=self.weight_decay_mask,
177 | ).make()
178 |
179 |
180 | class AdagradFactory(types.Factory[optax.GradientTransformation]):
181 | """Adagrad optimizer factory.
182 |
183 | Attributes:
184 | learning_rate: The learning rate to use for the optimizer.
185 | initial_accumulator_value: The initial accumulator value for the Adagrad
186 | optimizer. Defaults to 0.1.
187 | eps: The epsilon coefficient for the Adagrad optimizer. Defaults to 1e-7.
188 | grad_clip_norm: Optional gradient clipping norm to limit the maximum
189 | magnitude of the gradients during optimization. Defaults to None.
190 | freeze_mask: Optional mask to freeze parameters during optimization.
191 | Defaults to None.
192 |
193 | Example usage:
194 | ```
195 | adagrad = AdagradFactory(learning_rate=1e-3).make()
196 | ```
197 | """
198 |
199 | learning_rate: optax.ScalarOrSchedule
200 | initial_accumulator_value: float = 0.1
201 | eps: float = 1e-7
202 | grad_clip_norm: float | None = None
203 | freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
204 |
205 | def make(self) -> optax.GradientTransformation:
206 | return OptimizerFactory(
207 | learning_rate=self.learning_rate,
208 | scaling=optax.scale_by_rss(
209 | initial_accumulator_value=self.initial_accumulator_value,
210 | eps=self.eps,
211 | ),
212 | grad_clip_norm=self.grad_clip_norm,
213 | ).make()
214 |
--------------------------------------------------------------------------------
/recml/core/training/optax_factory_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Test for optax optimizer factories."""
15 |
16 | from absl.testing import absltest
17 | import jax
18 | import numpy as np
19 | import optax
20 | from recml.core.training import optax_factory
21 |
22 |
23 | class OptaxFactoryTest(absltest.TestCase):
24 |
25 | def assertOptimizersEqual(
26 | self,
27 | a: optax.GradientTransformation,
28 | b: optax.GradientTransformation,
29 | steps: int = 10,
30 | ):
31 | k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
32 | params = {
33 | "x": jax.random.uniform(k1, (128, 128)),
34 | "y": jax.random.uniform(k2, (128, 128)),
35 | "z": jax.random.uniform(k3, (128, 128)),
36 | }
37 | grads = jax.tree.map(lambda p: jax.random.uniform(k4, p.shape), params)
38 |
39 | opt_state_a = a.init(params)
40 | opt_state_b = b.init(params)
41 |
42 | for _ in range(steps):
43 | updates_a, opt_state_a = a.update(grads, opt_state_a, params)
44 | updates_b, opt_state_b = b.update(grads, opt_state_b, params)
45 | for k in params:
46 | np.testing.assert_allclose(updates_a[k], updates_b[k])
47 |
48 | def test_optimizer_factory(self):
49 | optimizer_a = optax_factory.OptimizerFactory(
50 | learning_rate=optax.warmup_cosine_decay_schedule(
51 | init_value=0.0,
52 | peak_value=1e-3,
53 | warmup_steps=5,
54 | decay_steps=10,
55 | end_value=0,
56 | ),
57 | scaling=optax.scale_by_rms(),
58 | weight_decay=1e-4,
59 | weight_decay_mask=r"^(?!.*(?:x|y)$).*",
60 | grad_clip_norm=1.0,
61 | ).make()
62 | optimizer_b = optax.chain(
63 | optax.clip_by_global_norm(1.0),
64 | optax.scale_by_rms(),
65 | optax.add_decayed_weights(
66 | 1e-4, mask=optax_factory._regex_mask(r"^(?!.*(?:x|y)$).*")
67 | ),
68 | optax.scale_by_learning_rate(
69 | optax.warmup_cosine_decay_schedule(
70 | init_value=0.0,
71 | peak_value=1e-3,
72 | warmup_steps=5,
73 | decay_steps=10,
74 | end_value=0,
75 | )
76 | ),
77 | )
78 | optimizer_c = optax_factory.OptimizerFactory(
79 | learning_rate=optax.warmup_cosine_decay_schedule(
80 | init_value=0.0,
81 | peak_value=1e-3,
82 | warmup_steps=5,
83 | decay_steps=10,
84 | end_value=0,
85 | ),
86 | scaling=optax.scale_by_rms(),
87 | weight_decay=1e-4,
88 | weight_decay_mask=r"^(?!.*(?:z)$).*",
89 | grad_clip_norm=1.0,
90 | ).make()
91 | self.assertOptimizersEqual(optimizer_a, optimizer_b, steps=10)
92 | self.assertRaises(
93 | AssertionError,
94 | self.assertOptimizersEqual,
95 | optimizer_a,
96 | optimizer_c,
97 | steps=10,
98 | )
99 |
100 | def test_adam_factory(self):
101 | optimizer_a = optax_factory.AdamFactory(
102 | learning_rate=optax.warmup_cosine_decay_schedule(
103 | init_value=0.0,
104 | peak_value=1e-3,
105 | warmup_steps=5,
106 | decay_steps=10,
107 | end_value=0,
108 | ),
109 | b1=0.9,
110 | b2=0.999,
111 | eps=1e-8,
112 | weight_decay=1e-4,
113 | grad_clip_norm=1.0,
114 | ).make()
115 | optimizer_b = optax.chain(
116 | optax.clip_by_global_norm(1.0),
117 | optax.adamw(
118 | learning_rate=optax.warmup_cosine_decay_schedule(
119 | init_value=0.0,
120 | peak_value=1e-3,
121 | warmup_steps=5,
122 | decay_steps=10,
123 | end_value=0,
124 | ),
125 | b1=0.9,
126 | b2=0.999,
127 | eps=1e-8,
128 | weight_decay=1e-4,
129 | mask=optax_factory._default_weight_decay_mask,
130 | ),
131 | )
132 | self.assertOptimizersEqual(optimizer_a, optimizer_b, steps=10)
133 |
134 | def test_adagrad_factory(self):
135 | optimizer_a = optax_factory.AdagradFactory(
136 | learning_rate=optax.warmup_cosine_decay_schedule(
137 | init_value=0.0,
138 | peak_value=1e-3,
139 | warmup_steps=5,
140 | decay_steps=10,
141 | end_value=0,
142 | ),
143 | initial_accumulator_value=0.1,
144 | eps=1e-7,
145 | grad_clip_norm=1.0,
146 | ).make()
147 | optimizer_b = optax.chain(
148 | optax.clip_by_global_norm(1.0),
149 | optax.adagrad(
150 | learning_rate=optax.warmup_cosine_decay_schedule(
151 | init_value=0.0,
152 | peak_value=1e-3,
153 | warmup_steps=5,
154 | decay_steps=10,
155 | end_value=0,
156 | ),
157 | initial_accumulator_value=0.1,
158 | eps=1e-7,
159 | ),
160 | )
161 | self.assertOptimizersEqual(optimizer_a, optimizer_b, steps=10)
162 |
163 |
164 | if __name__ == "__main__":
165 | absltest.main()
166 |
--------------------------------------------------------------------------------
/recml/core/training/partitioning.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities for partitioning."""
15 |
16 | import abc
17 | from collections.abc import Callable, Mapping, Sequence
18 | import math
19 | from typing import Any, ContextManager
20 |
21 | import flax.linen as nn
22 | import jax
23 | from jax.experimental import mesh_utils
24 | import numpy as np
25 |
26 |
27 | PyTree = Any
28 | State = Any
29 | CreateStateFn = Callable[[PyTree], State]
30 | InitFn = Callable[[PyTree, jax.Array], State]
31 | StepFn = Callable[[PyTree, State], Any]
32 |
33 |
34 | class Partitioner(abc.ABC):
35 | """An abstract class defining partitioning logic for data and computation."""
36 |
37 | @abc.abstractmethod
38 | def shard_inputs(self, inputs: Any) -> PyTree:
39 | """Shards the input batches and put them on the device."""
40 |
41 | @abc.abstractmethod
42 | def partition_init(
43 | self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
44 | ) -> CreateStateFn:
45 | """Shards the initialization function."""
46 |
47 | @abc.abstractmethod
48 | def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
49 | """Shards the training and evaluation steps."""
50 |
51 |
52 | class NullPartitioner(Partitioner):
53 | """A null partitioner."""
54 |
55 | def shard_inputs(self, inputs: PyTree) -> PyTree:
56 | return inputs
57 |
58 | def partition_init(
59 | self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
60 | ) -> CreateStateFn:
61 | return init_fn
62 |
63 | def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
64 | return fn
65 |
66 |
67 | class DataParallelPartitioner(Partitioner):
68 | """Data parallel partitioner."""
69 |
70 | def __init__(self, data_axis: str = "batch"):
71 | self.mesh = jax.sharding.Mesh(jax.devices(), (data_axis,))
72 | self.data_sharding = jax.sharding.NamedSharding(
73 | self.mesh, jax.sharding.PartitionSpec(data_axis)
74 | )
75 | self.state_sharding = jax.sharding.NamedSharding(
76 | self.mesh, jax.sharding.PartitionSpec()
77 | )
78 |
79 | def shard_inputs(self, inputs: PyTree) -> PyTree:
80 | local_devices = self.mesh.local_devices
81 | local_device_count = len(local_devices)
82 | device_count = len(self.mesh.devices)
83 |
84 | def _shard(x: np.ndarray) -> jax.Array:
85 | per_proc_batch_size = x.shape[0]
86 | per_replica_batch_size = per_proc_batch_size // local_device_count
87 | if per_proc_batch_size % local_device_count != 0:
88 | raise ValueError(
89 | "The per process batch size must be divisible by the number of"
90 | " local devices. Got per process batch size:"
91 | f" {per_proc_batch_size} and local device count:"
92 | f" {local_device_count}."
93 | )
94 |
95 | per_device_arrays = np.split(x, local_device_count, axis=0)
96 | device_buffers = [
97 | jax.device_put(arr, device)
98 | for arr, device in zip(per_device_arrays, local_devices)
99 | ]
100 |
101 | global_batch_size = per_replica_batch_size * device_count
102 | return jax.make_array_from_single_device_arrays(
103 | (global_batch_size,) + x.shape[1:], self.data_sharding, device_buffers
104 | )
105 |
106 | return jax.tree.map(_shard, inputs)
107 |
108 | def partition_init(
109 | self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
110 | ) -> CreateStateFn:
111 | with jax.sharding.use_mesh(self.mesh):
112 | init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)
113 |
114 | def _wrapped_init(batch: PyTree) -> State:
115 | with jax.sharding.use_mesh(self.mesh):
116 | state = init_fn(batch)
117 | state = _maybe_unbox_state(state)
118 | return state
119 |
120 | return _wrapped_init
121 |
122 | def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
123 | jit_kws = {}
124 | if training:
125 | jit_kws["out_shardings"] = (self.state_sharding, None)
126 | jit_kws["donate_argnums"] = (1,)
127 |
128 | with jax.sharding.use_mesh(self.mesh):
129 | step_fn = jax.jit(
130 | fn,
131 | in_shardings=(self.data_sharding, self.state_sharding),
132 | **jit_kws,
133 | )
134 |
135 | def _wrapped_step(batch: PyTree, state: State) -> Any:
136 | with jax.sharding.use_mesh(self.mesh):
137 | return step_fn(batch, state)
138 |
139 | return _wrapped_step
140 |
141 |
142 | class ModelParallelPartitioner(Partitioner):
143 | """Model parallel partitioner.
144 |
145 | This only works with multi-controller Jax, i.e. communications along the ICI
146 | for TPUs. For scaling beyond a single TPU slice this needs to be extended to
147 | support Megascale XLA or single-controller Pathways. Consider using T5X, Pax,
148 | or Gemax for these use cases.
149 |
150 | Note: This assumes that all axes of the inputs except the final one are used
151 | for data parallelism while the final one is used for model parallelism.
152 | This tends to work well for 2D and 3D torus topologies since network latency
153 | tends to be much higher for the leading axes.
154 |
155 | IMPORTANT: `shard_inputs` operates on a per process batch. This means that the
156 | input batch size on CPU must already be the per process batch size,
157 | i.e. global batch size // jax.process_count(). It is the responsibility of the
158 | CPU input pipeline to ensure that inputs are different across processes.
159 | """
160 |
161 | def __init__(
162 | self,
163 | axes: Sequence[tuple[str, int]],
164 | rules: Mapping[str, str] | None = None,
165 | aot_compile: bool = False,
166 | options: jax.stages.CompilerOptions | None = None,
167 | ):
168 | if len(axes) < 2:
169 | raise ValueError(
170 | "`axes` cannot less than 2D, use data-parallel"
171 | f" partitioner instead. Got axes: {axes}."
172 | )
173 |
174 | mesh_devices = mesh_utils.create_device_mesh([dim for _, dim, in axes])
175 | self.mesh = jax.sharding.Mesh(mesh_devices, [axis for axis, _ in axes])
176 | self.rules = rules
177 | self.aot_compile = aot_compile
178 | self.options = options
179 |
180 | dp_axes, dp_dims = zip(*axes[:-1])
181 | _, mp_dim = axes[-1]
182 |
183 | if math.prod(dp_dims) % jax.process_count() != 0:
184 | raise ValueError(
185 | "The data parallel dimensions in the mesh must be divisible by the"
186 | " number of processes as we assume data parallelism across"
187 | f" processes. Got process count: {jax.process_count()} and data"
188 | f" parallelism dimensions: {dp_dims} for axes: {axes} and mesh"
189 | f" devices: {self.mesh.devices}."
190 | )
191 | if jax.local_device_count() % mp_dim != 0:
192 | raise ValueError(
193 | "The number of local devices on each host must be divisible by the"
194 | " model dimension as we assume model parallelism across local"
195 | f" devices. Got local device count: {jax.local_device_count()} and"
196 | f" model parallelism dimension: {mp_dim} for axes: {axes} and mesh"
197 | f" devices: {self.mesh.devices}."
198 | )
199 |
200 | self.data_sharding = jax.sharding.NamedSharding(
201 | self.mesh, jax.sharding.PartitionSpec(dp_axes)
202 | )
203 | self.state_sharding = None
204 | self.abstract_batch = None
205 | self.abstract_state = None
206 |
207 | @property
208 | def mesh_context_manager(
209 | self,
210 | ) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
211 | return jax.sharding.use_mesh
212 |
213 | def shard_inputs(self, inputs: PyTree) -> PyTree:
214 | def _shard(x: np.ndarray) -> jax.Array:
215 | return jax.make_array_from_process_local_data(self.data_sharding, x)
216 |
217 | return jax.tree.map(_shard, inputs)
218 |
219 | def partition_init(
220 | self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
221 | ) -> CreateStateFn:
222 | if abstract_batch is None:
223 | raise ValueError(
224 | "An `abstract_batch` is required for partitioning `init_fn` with a"
225 | " model parallel partitioner."
226 | )
227 |
228 | with self.mesh_context_manager(self.mesh):
229 | abstract_state = jax.eval_shape(init_fn, abstract_batch)
230 | specs = nn.get_partition_spec(abstract_state)
231 |
232 | if self.rules is not None:
233 | specs = nn.logical_to_mesh(specs, self.rules)
234 |
235 | state_sharding = jax.tree.map(
236 | lambda x: jax.sharding.NamedSharding(self.mesh, x), specs
237 | )
238 | compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding)
239 |
240 | def _init(batch: PyTree) -> State:
241 | with self.mesh_context_manager(self.mesh):
242 | state = compiled_init_fn(batch)
243 | state = _maybe_unbox_state(state)
244 | return state
245 |
246 | self.abstract_batch = abstract_batch
247 | self.abstract_state = abstract_state
248 | self.state_sharding = state_sharding
249 | return _init
250 |
251 | def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
252 | jit_kws = {}
253 | if training:
254 | jit_kws["out_shardings"] = (self.state_sharding, None)
255 | jit_kws["donate_argnums"] = (1,)
256 | else:
257 | jit_kws["out_shardings"] = None
258 |
259 | with self.mesh_context_manager(self.mesh):
260 | step_fn = jax.jit(
261 | fn,
262 | in_shardings=(self.data_sharding, self.state_sharding),
263 | compiler_options=(self.options if not self.aot_compile else None),
264 | **jit_kws,
265 | )
266 | if self.aot_compile:
267 | if self.abstract_batch is None or self.abstract_state is None:
268 | raise ValueError(
269 | "An `abstract_batch` and `abstract_state` must be set on the model"
270 | " parallel partitioner when `aot_compile` is set to True in order"
271 | " to compile the step. Make sure you call"
272 | " `partitioner.partition_init(...)` first."
273 | )
274 |
275 | step_fn = step_fn.lower(self.abstract_batch, self.abstract_state).compile(
276 | self.options
277 | )
278 |
279 | def _step(batch: PyTree, state: State) -> Any:
280 | with self.mesh_context_manager(self.mesh):
281 | return step_fn(batch, state)
282 |
283 | return _step
284 |
285 |
286 | def _maybe_unbox_state(x: Any) -> Any:
287 | def _maybe_unbox(x: Any) -> Any:
288 | if isinstance(x, nn.Partitioned):
289 | return x.unbox()
290 | return x
291 |
292 | return jax.tree.map(
293 | _maybe_unbox,
294 | x,
295 | is_leaf=lambda k: isinstance(k, nn.Partitioned),
296 | )
297 |
--------------------------------------------------------------------------------
/recml/core/training/partitioning_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for Jax partitioners."""
15 |
16 | from collections.abc import Mapping
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import flax.linen as nn
21 | import jax
22 | import jax.numpy as jnp
23 | import numpy as np
24 | from recml.core.training import partitioning
25 |
26 |
27 | class PartitioningTest(parameterized.TestCase):
28 |
29 | @parameterized.named_parameters(
30 | {
31 | "testcase_name": "data_parallel_partitioner",
32 | "partitioner_cls": partitioning.DataParallelPartitioner,
33 | },
34 | {
35 | "testcase_name": "model_parallel_partitioner",
36 | "partitioner_cls": partitioning.ModelParallelPartitioner,
37 | },
38 | )
39 | def test_data_parallelism(
40 | self, partitioner_cls: type[partitioning.Partitioner]
41 | ):
42 | if partitioner_cls is partitioning.ModelParallelPartitioner:
43 | kwargs = {"axes": [("data", jax.device_count()), ("model", 1)]}
44 | else:
45 | kwargs = {}
46 | partitioner = partitioner_cls(**kwargs)
47 |
48 | inputs = np.zeros((128, 16), dtype=np.float32)
49 | sharded_inputs = partitioner.shard_inputs(inputs)
50 |
51 | self.assertIsInstance(sharded_inputs, jax.Array)
52 | self.assertSequenceEqual(sharded_inputs.shape, (128, 16))
53 | self.assertEqual(sharded_inputs.sharding, partitioner.data_sharding)
54 |
55 | def _init(batch: jax.Array) -> jax.Array:
56 | return jnp.ones_like(batch)
57 |
58 | def _train_step(
59 | batch: jax.Array, state: jax.Array
60 | ) -> tuple[jax.Array, Mapping[str, jax.Array]]:
61 | return batch + state, {
62 | "batch_mean": jnp.mean(batch),
63 | "state_mean": jnp.mean(state),
64 | }
65 |
66 | def _eval_step(
67 | batch: jax.Array, state: jax.Array
68 | ) -> Mapping[str, jax.Array]:
69 | return {"batch_mean": jnp.mean(batch), "state_mean": jnp.mean(state)}
70 |
71 | state = partitioner.partition_init(_init, abstract_batch=sharded_inputs)(
72 | sharded_inputs
73 | )
74 | self.assertIsInstance(state, jax.Array)
75 | self.assertSequenceEqual(state.shape, (128, 16))
76 | self.assertEqual(state.sharding, partitioner.state_sharding)
77 |
78 | new_state, metrics = partitioner.partition_step(_train_step, training=True)(
79 | sharded_inputs, state
80 | )
81 | self.assertTrue(state.is_deleted()) # Buffer should be donated.
82 | self.assertIsInstance(new_state, jax.Array)
83 | self.assertSequenceEqual(new_state.shape, (128, 16))
84 | self.assertEqual(new_state.sharding, partitioner.state_sharding)
85 | for metric in jax.tree.flatten(metrics)[0]:
86 | self.assertIsInstance(metric, jax.Array)
87 | self.assertEqual(
88 | metric.sharding,
89 | jax.sharding.NamedSharding(
90 | partitioner.mesh, jax.sharding.PartitionSpec()
91 | ),
92 | )
93 |
94 | metrics = partitioner.partition_step(_eval_step, training=False)(
95 | sharded_inputs, new_state
96 | )
97 | self.assertFalse(new_state.is_deleted()) # Buffer should not be donated.
98 | for metric in jax.tree.flatten(metrics)[0]:
99 | self.assertIsInstance(metric, jax.Array)
100 | self.assertEqual(
101 | metric.sharding,
102 | jax.sharding.NamedSharding(
103 | partitioner.mesh, jax.sharding.PartitionSpec()
104 | ),
105 | )
106 |
107 | self.assertEqual(
108 | partitioner.state_sharding,
109 | jax.sharding.NamedSharding(
110 | partitioner.mesh, jax.sharding.PartitionSpec()
111 | ),
112 | )
113 |
114 | def test_model_parallelism(self):
115 | partitioner = partitioning.ModelParallelPartitioner(
116 | axes=[("data", 1), ("model", jax.device_count())]
117 | )
118 |
119 | inputs = np.zeros((128, 16), dtype=np.float32)
120 | sharded_inputs = partitioner.shard_inputs(inputs)
121 |
122 | self.assertIsInstance(sharded_inputs, jax.Array)
123 | self.assertSequenceEqual(sharded_inputs.shape, (128, 16))
124 | self.assertEqual(
125 | sharded_inputs.sharding,
126 | jax.sharding.NamedSharding(
127 | partitioner.mesh, jax.sharding.PartitionSpec("data")
128 | ),
129 | )
130 |
131 | def _init(batch: jax.Array) -> jax.Array:
132 | return nn.with_partitioning(
133 | jnp.ones_like, ("data", "model"), partitioner.mesh
134 | )(batch)
135 |
136 | state = partitioner.partition_init(_init, abstract_batch=sharded_inputs)(
137 | sharded_inputs
138 | )
139 |
140 | self.assertIsInstance(state, jax.Array)
141 | self.assertSequenceEqual(state.shape, (128, 16))
142 | self.assertEqual(state.sharding, partitioner.state_sharding)
143 | self.assertEqual(
144 | partitioner.state_sharding,
145 | jax.sharding.NamedSharding(
146 | partitioner.mesh,
147 | jax.sharding.PartitionSpec("data", "model"),
148 | ),
149 | )
150 |
151 | # TODO(aahil): Add tests for the steps.
152 |
153 |
154 | if __name__ == "__main__":
155 | absltest.main()
156 |
--------------------------------------------------------------------------------
/recml/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Public utilities API."""
15 |
16 | # pylint: disable=g-importing-member
17 |
18 | from recml.core.utils.config import DEFINE_fiddle_config
19 | from recml.core.utils.config import FiddleFlag
20 | from recml.core.utils.types import Dataclass
21 | from recml.core.utils.types import FrozenDataclass
22 |
--------------------------------------------------------------------------------
/recml/core/utils/config_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for configuration utilities."""
15 |
16 | from collections.abc import Sequence
17 | import dataclasses
18 | import sys
19 |
20 | from absl import flags
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 | import fiddle as fdl
24 | from recml.core.utils import config as config_lib
25 |
26 | # Pytest may use the test from a different module, otherwise this should be
27 | # __main__.
28 | _TEST_MODULE_NAME = sys.modules[__name__].__name__
29 |
30 |
31 | @dataclasses.dataclass
32 | class _X:
33 | value: int
34 |
35 |
36 | @dataclasses.dataclass
37 | class _Y:
38 | value: int
39 |
40 |
41 | @dataclasses.dataclass
42 | class _Object:
43 | x: _X
44 | y: _Y
45 |
46 |
47 | def config_1() -> fdl.Config[_Object]:
48 | return fdl.Config(
49 | _Object,
50 | x=fdl.Config(_X, value=1),
51 | y=fdl.Config(_Y, value=2),
52 | )
53 |
54 |
55 | def fiddler_1(cfg: fdl.Config[_Object]):
56 | cfg.x.value = 3
57 |
58 |
59 | def fiddler_2(cfg: fdl.Config[_Object], value: int):
60 | cfg.y.value = value
61 |
62 |
63 | class ConfigTest(parameterized.TestCase):
64 |
65 | @parameterized.named_parameters(
66 | {
67 | 'testcase_name': 'base_config',
68 | 'args': [f'config:{_TEST_MODULE_NAME}.config_1'],
69 | 'expected_config': config_1(),
70 | },
71 | {
72 | 'testcase_name': 'relative_fiddler',
73 | 'args': [
74 | f'config:{_TEST_MODULE_NAME}.config_1',
75 | 'fiddler:fiddler_1',
76 | 'fiddler:fiddler_2(value=4)',
77 | ],
78 | 'expected_config': fdl.Config(
79 | _Object,
80 | x=fdl.Config(_X, value=3),
81 | y=fdl.Config(_Y, value=4),
82 | ),
83 | },
84 | {
85 | 'testcase_name': 'absolute_fiddler',
86 | 'args': [
87 | f'config:{_TEST_MODULE_NAME}.config_1',
88 | f'fiddler:{_TEST_MODULE_NAME}.fiddler_2(3)',
89 | ],
90 | 'expected_config': fdl.Config(
91 | _Object,
92 | x=fdl.Config(_X, value=1),
93 | y=fdl.Config(_Y, value=3),
94 | ),
95 | },
96 | {
97 | 'testcase_name': 'set',
98 | 'args': [
99 | f'config:{_TEST_MODULE_NAME}.config_1',
100 | 'set:x.value=0',
101 | 'set:y.value=0',
102 | ],
103 | 'expected_config': fdl.Config(
104 | _Object,
105 | x=fdl.Config(_X, value=0),
106 | y=fdl.Config(_Y, value=0),
107 | ),
108 | },
109 | )
110 | def test_fiddle_flag(
111 | self, args: Sequence[str], expected_config: fdl.Config[_Object]
112 | ):
113 | fdl_flag = config_lib.FiddleFlag(
114 | name='test_flag',
115 | default=None,
116 | parser=flags.ArgumentParser(),
117 | serializer=None,
118 | help_string='My fiddle flag',
119 | )
120 | fdl_flag.parse(args)
121 | self.assertEqual(expected_config, fdl_flag.value)
122 |
123 | @parameterized.named_parameters(
124 | {
125 | 'testcase_name': 'bad_base_config',
126 | 'args': [f'config:{_TEST_MODULE_NAME}.config_3'],
127 | 'expected_error': AttributeError,
128 | 'expected_error_regex': 'Could not init a buildable from .*',
129 | },
130 | {
131 | 'testcase_name': 'bad_fiddler',
132 | 'args': [f'config:{_TEST_MODULE_NAME}.config_1', 'fiddler:fiddler_3'],
133 | 'expected_error': ValueError,
134 | # TODO(aahil): Figure out why the error regex is different in 3P.
135 | 'expected_error_regex': '.*',
136 | },
137 | )
138 | def test_invalid_fiddle_flag(
139 | self,
140 | args: Sequence[str],
141 | expected_error: type[Exception],
142 | expected_error_regex: str,
143 | ):
144 | fdl_flag = config_lib.FiddleFlag(
145 | name='test_flag',
146 | default=None,
147 | parser=flags.ArgumentParser(),
148 | serializer=None,
149 | help_string='My fiddle flag',
150 | )
151 |
152 | def _value(args: Sequence[str]):
153 | fdl_flag.parse(args)
154 | return fdl_flag.value
155 |
156 | self.assertRaisesRegex(expected_error, expected_error_regex, _value, args)
157 |
158 |
159 | if __name__ == '__main__':
160 | absltest.main()
161 |
--------------------------------------------------------------------------------
/recml/core/utils/py_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Miscellaneous utilities."""
15 |
16 | from collections.abc import Callable
17 | import inspect
18 | from typing import Any
19 |
20 |
21 | def has_argument(fn: Callable[..., Any], arg_name: str) -> bool:
22 | """Checks if a function has an argument with a given name."""
23 | params = inspect.signature(fn).parameters.values()
24 | param_names = [v.name for v in params]
25 | has_arg = arg_name in param_names
26 | has_kw_args = any([v.kind == inspect.Parameter.VAR_KEYWORD for v in params])
27 | return has_arg or has_kw_args
28 |
--------------------------------------------------------------------------------
/recml/core/utils/types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Configuration tools for dealing with abstract types."""
15 |
16 | import abc
17 | import dataclasses
18 | from typing import Generic, Protocol, TypeVar
19 |
20 | from typing_extensions import dataclass_transform
21 |
22 |
23 | T = TypeVar("T")
24 |
25 |
26 | @dataclass_transform(field_specifiers=dataclasses.field) # type: ignore[literal-required]
27 | class Dataclass:
28 | """A dataclass transform that converts a class to a dataclass."""
29 |
30 | def __init_subclass__(cls, **kwargs):
31 | def replace(self, **updates):
32 | return dataclasses.replace(self, **updates)
33 |
34 | data_cls = dataclasses.dataclass(**kwargs)(cls)
35 | data_cls.replace = replace
36 |
37 | def __init__(self, *args, **kwargs):
38 | # stub for pytype
39 | raise NotImplementedError
40 |
41 | def replace(self: T, **overrides) -> T:
42 | # stub for pytype
43 | raise NotImplementedError
44 |
45 |
46 | # TODO(aahil): Share code with `Dataclass`.
47 | @dataclass_transform(field_specifiers=dataclasses.field) # type: ignore[literal-required]
48 | class FrozenDataclass:
49 | """A dataclass transform that converts a class to a frozen dataclass."""
50 |
51 | def __init_subclass__(cls, **kwargs):
52 | if "frozen" not in kwargs:
53 | kwargs["frozen"] = True
54 |
55 | def replace(self, **updates):
56 | return dataclasses.replace(self, **updates)
57 |
58 | data_cls = dataclasses.dataclass(**kwargs)(cls)
59 | data_cls.replace = replace
60 |
61 | def __init__(self, *args, **kwargs):
62 | # stub for pytype
63 | raise NotImplementedError
64 |
65 | def replace(self: T, **overrides) -> T:
66 | # stub for pytype
67 | raise NotImplementedError
68 |
69 |
70 | class Factory(abc.ABC, Generic[T], Dataclass):
71 | """A factory interface for configuring an arbitary object via a dataclass.
72 |
73 | This is useful for creating objects that require run-time information.
74 | """
75 |
76 | @abc.abstractmethod
77 | def make(self, *args, **kwargs) -> T:
78 | """Builds the object instance."""
79 |
80 |
81 | class FactoryProtocol(Protocol, Generic[T]):
82 | """A protocol for typing factories."""
83 |
84 | def make(self, *args, **kwargs) -> T:
85 | """Builds the object instance."""
86 |
87 | def replace(self, **overrides) -> T:
88 | """Replaces the object instance."""
89 |
90 |
91 | class ObjectFactory(Factory[T]):
92 | """A factory that wraps around the constructor of an object.
93 |
94 | This is useful when a library only accepts a factory but creating a factory
95 | introduces unnecessary boilerplate.
96 |
97 | Example usage:
98 | ```
99 | class MyObject:
100 | def __init__(self, x: int, y: int):
101 | self._x = x
102 | self._y = y
103 |
104 |
105 | factory = ObjectFactory(MyObject, x=1, y=2)
106 | obj = factory.make()
107 | assert obj._x == 1
108 | assert obj._y == 2
109 | ```
110 | """
111 |
112 | def __new__(cls, *args, **kwargs) -> Factory[T]:
113 | if args:
114 | raise ValueError(
115 | "`StaticFactory` does not accept positional arguments. Got args:"
116 | f" {args}."
117 | )
118 | if "type" not in kwargs:
119 | raise ValueError(
120 | "`StaticFactory` requires a `type` keyword argument. Got kwargs:"
121 | f" {kwargs}."
122 | )
123 |
124 | class _ObjectFactory(Factory):
125 |
126 | def __init_subclass__(cls, **kwargs):
127 | # Override the dataclass transform from the base class.
128 | pass
129 |
130 | def make(self):
131 | return getattr(self, "type")(**{
132 | f.name: getattr(self, f.name)
133 | for f in dataclasses.fields(self)
134 | if f.name != "type"
135 | })
136 |
137 | sub_cls = dataclasses.make_dataclass(
138 | cls_name=cls.__name__,
139 | fields=[(k, type(v)) for k, v in kwargs.items()],
140 | bases=(_ObjectFactory,),
141 | kw_only=True,
142 | )
143 | obj = sub_cls(**kwargs)
144 | return obj
145 |
146 | def __init__(self, *, type: type[T], **kwargs): # pylint: disable=redefined-builtin
147 | # Stub for pytype.
148 | raise NotImplementedError()
149 |
150 | @property
151 | def type(self) -> type[T]:
152 | # Stub for pytype.
153 | raise NotImplementedError()
154 |
155 | def make(self) -> T:
156 | # Stub for pytype.
157 | raise NotImplementedError()
158 |
--------------------------------------------------------------------------------
/recml/core/utils/types_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for type utilities."""
15 |
16 | import dataclasses
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from recml.core.utils import types
21 |
22 |
23 | class TypesTest(parameterized.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | {'testcase_name': 'dataclass', 'cls': types.Dataclass},
27 | {'testcase_name': 'frozen_dataclass', 'cls': types.FrozenDataclass},
28 | )
29 | def test_dataclass_transform(self, cls: type[types.Dataclass]):
30 | class Foo(cls):
31 | x: int
32 | y: int
33 | z: int = dataclasses.field(default_factory=lambda: 1)
34 |
35 | class Bar(Foo):
36 | u: int = dataclasses.field(default_factory=lambda: 2)
37 |
38 | foo = Foo(x=1, y=2)
39 | self.assertEqual(foo.x, 1)
40 | self.assertEqual(foo.y, 2)
41 | self.assertEqual(foo.z, 1)
42 | self.assertTrue(dataclasses.is_dataclass(Foo))
43 | self.assertTrue(dataclasses.is_dataclass(foo))
44 |
45 | bar = Bar(x=1, y=2, u=3)
46 | self.assertEqual(bar.x, 1)
47 | self.assertEqual(bar.y, 2)
48 | self.assertEqual(bar.z, 1)
49 | self.assertEqual(bar.u, 3)
50 | self.assertTrue(dataclasses.is_dataclass(Bar))
51 | self.assertTrue(dataclasses.is_dataclass(bar))
52 |
53 | def test_frozen_dataclass(self):
54 |
55 | class Foo(types.Dataclass):
56 | x: int
57 | y: int
58 |
59 | class Bar(types.FrozenDataclass):
60 | x: int
61 | y: int
62 |
63 | def _mutate_foo_or_bar(foo_or_bar: Foo | Bar):
64 | foo_or_bar.x = 2
65 |
66 | # Mutating Foo is allowed.
67 | _mutate_foo_or_bar(Foo(x=1, y=2))
68 |
69 | self.assertRaises(
70 | dataclasses.FrozenInstanceError,
71 | _mutate_foo_or_bar,
72 | Bar(x=1, y=2),
73 | )
74 |
75 | def test_object_factory(self):
76 | class Foo(types.Dataclass):
77 | x: int
78 | y: int
79 |
80 | factory = types.ObjectFactory(type=Foo, x=1, y=2)
81 | self.assertEqual(factory.type, Foo)
82 | self.assertEqual(factory.x, 1)
83 | self.assertEqual(factory.y, 2)
84 | self.assertEqual(factory.make(), Foo(x=1, y=2))
85 |
86 |
87 | if __name__ == '__main__':
88 | absltest.main()
89 |
--------------------------------------------------------------------------------
/recml/examples/dlrm_experiment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """DLRM experiment."""
15 |
16 | from __future__ import annotations
17 |
18 | from collections.abc import Iterator, Mapping, Sequence
19 | import dataclasses
20 | from typing import Generic, Literal, TypeVar
21 |
22 | from etils import epy
23 | import fiddle as fdl
24 | import flax.linen as nn
25 | import jax
26 | import jax.numpy as jnp
27 | import jaxtyping as jt
28 | import numpy as np
29 | import optax
30 | import recml
31 | from recml.layers.linen import sparsecore
32 | import tensorflow as tf
33 |
34 | with epy.lazy_imports():
35 | from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec # pylint: disable=g-import-not-at-top
36 |
37 |
38 | @dataclasses.dataclass
39 | class Feature:
40 | name: str
41 |
42 |
43 | FeatureT = TypeVar('FeatureT', bound=Feature)
44 |
45 |
46 | @dataclasses.dataclass
47 | class DenseFeature(Feature):
48 | """Dense feature."""
49 |
50 |
51 | @dataclasses.dataclass
52 | class SparseFeature(Feature):
53 | """Sparse feature."""
54 |
55 | vocab_size: int
56 | embedding_dim: int
57 | max_sequence_length: int | None = None
58 | combiner: Literal['mean', 'sum', 'sqrtn'] = 'mean'
59 | sparsity: float = 0.8
60 |
61 |
62 | @dataclasses.dataclass
63 | class FeatureSet(Generic[FeatureT]):
64 | """A collection of features."""
65 |
66 | features: Sequence[FeatureT]
67 |
68 | def __post_init__(self):
69 | feature_names = [f.name for f in self.features]
70 | if len(feature_names) != len(set(feature_names)):
71 | raise ValueError(
72 | f'Feature names must be unique. Got names: {feature_names}.'
73 | )
74 |
75 | def dense_features(self) -> FeatureSet[DenseFeature]:
76 | return FeatureSet[DenseFeature](
77 | [f for f in self if isinstance(f, DenseFeature)]
78 | )
79 |
80 | def sparse_features(self) -> FeatureSet[SparseFeature]:
81 | return FeatureSet[SparseFeature](
82 | [f for f in self if isinstance(f, SparseFeature)]
83 | )
84 |
85 | def __iter__(self) -> Iterator[FeatureT]:
86 | return iter(self.features)
87 |
88 | def __or__(self, other: FeatureSet[Feature]) -> FeatureSet[Feature]:
89 | return FeatureSet([*self.features, *other.features])
90 |
91 |
92 | class DLRMModel(nn.Module):
93 | """DLRM DCN v2 model."""
94 |
95 | features: FeatureSet
96 | embedding_optimizer: sparsecore.OptimizerSpec
97 | bottom_mlp_dims: Sequence[int]
98 | top_mlp_dims: Sequence[int]
99 | dcn_layers: int
100 | dcn_inner_dim: int
101 |
102 | # We need to track the embedder on the Flax module to ensure it is not
103 | # re-created on cloning. It is not possible to create an embedder inside
104 | # setup() because it is called lazily at compile time. The embedder needs
105 | # to be created before `model.init` so we can use it to create a preprocessor.
106 | # A simpler pattern that works is passing `embedder` directly to the module.
107 | _embedder: sparsecore.SparsecoreEmbedder | None = None
108 |
109 | @property
110 | def embedder(self) -> sparsecore.SparsecoreEmbedder:
111 | if self._embedder is not None:
112 | return self._embedder
113 |
114 | embedder = sparsecore.SparsecoreEmbedder(
115 | specs={
116 | f.name: sparsecore.EmbeddingSpec(
117 | input_dim=f.vocab_size,
118 | embedding_dim=f.embedding_dim,
119 | max_sequence_length=f.max_sequence_length,
120 | combiner=f.combiner,
121 | )
122 | for f in self.features.sparse_features()
123 | },
124 | optimizer=self.embedding_optimizer,
125 | )
126 | object.__setattr__(self, '_embedder', embedder)
127 | return embedder
128 |
129 | def bottom_mlp(self, inputs: Mapping[str, jt.Array]) -> jt.Array:
130 | x = jnp.concatenate(
131 | [inputs[f.name] for f in self.features.dense_features()], axis=-1
132 | )
133 |
134 | for dim in self.bottom_mlp_dims:
135 | x = nn.Dense(dim)(x)
136 | x = nn.relu(x)
137 | return x
138 |
139 | def top_mlp(self, x: jt.Array) -> jt.Array:
140 | for dim in self.top_mlp_dims[:-1]:
141 | x = nn.Dense(dim)(x)
142 | x = nn.relu(x)
143 |
144 | x = nn.Dense(self.top_mlp_dims[-1])(x)
145 | return x
146 |
147 | def dcn(self, x0: jt.Array) -> jt.Array:
148 | xl = x0
149 | input_dim = x0.shape[-1]
150 |
151 | for i in range(self.dcn_layers):
152 | u_kernel = self.param(
153 | f'u_kernel_{i}',
154 | nn.initializers.xavier_normal(),
155 | (input_dim, self.dcn_inner_dim),
156 | )
157 | v_kernel = self.param(
158 | f'v_kernel_{i}',
159 | nn.initializers.xavier_normal(),
160 | (self.dcn_inner_dim, input_dim),
161 | )
162 | bias = self.param(f'bias_{i}', nn.initializers.zeros, (input_dim,))
163 |
164 | u = jnp.matmul(xl, u_kernel)
165 | v = jnp.matmul(u, v_kernel)
166 | v += bias
167 |
168 | xl = x0 * v + xl
169 |
170 | return xl
171 |
172 | @nn.compact
173 | def __call__(
174 | self, inputs: Mapping[str, jt.Array], training: bool = False
175 | ) -> jt.Array:
176 | dense_embeddings = self.bottom_mlp(inputs)
177 | sparse_embeddings = self.embedder.make_sparsecore_module()(inputs)
178 | sparse_embeddings = jax.tree.flatten(sparse_embeddings)[0]
179 | concatenated_embeddings = jnp.concatenate(
180 | (dense_embeddings, *sparse_embeddings), axis=-1
181 | )
182 | interaction_outputs = self.dcn(concatenated_embeddings)
183 | predictions = self.top_mlp(interaction_outputs)
184 | predictions = jnp.reshape(predictions, (-1,))
185 | return predictions
186 |
187 |
188 | class CriteoFactory(recml.Factory[tf.data.Dataset]):
189 | """Data loader for dummy Criteo data optimized for Jax training."""
190 |
191 | features: FeatureSet
192 | global_batch_size: int
193 | use_cached_data: bool = False
194 |
195 | def make(self) -> tf.data.Dataset:
196 | data = {}
197 | batch_size = self.global_batch_size // jax.process_count()
198 |
199 | for f in self.features.dense_features():
200 | feature = np.random.normal(0.0, 1.0, size=(batch_size, 1))
201 | data[f.name] = feature.astype(np.float32)
202 |
203 | for f in self.features.sparse_features():
204 | non_zero_mask = (
205 | np.random.normal(size=(batch_size, f.embedding_dim)) > f.sparsity
206 | )
207 | sparse_feature = np.random.randint(
208 | low=0,
209 | high=f.vocab_size,
210 | size=(batch_size, f.embedding_dim),
211 | )
212 | sparse_feature = np.where(
213 | non_zero_mask, sparse_feature, np.zeros_like(sparse_feature)
214 | )
215 | data[f.name] = tf.constant(sparse_feature, dtype=tf.int64)
216 |
217 | label = np.random.randint(0, 2, size=(batch_size,))
218 |
219 | dataset = tf.data.Dataset.from_tensors((data, label))
220 | dataset = dataset.take(1).repeat()
221 | dataset = dataset.prefetch(buffer_size=2048)
222 | options = tf.data.Options()
223 | options.deterministic = False
224 | options.threading.private_threadpool_size = 96
225 | dataset = dataset.with_options(options)
226 | return dataset
227 |
228 |
229 | @dataclasses.dataclass
230 | class PredictionTask(recml.JaxTask):
231 | """Prediction task."""
232 |
233 | train_data: CriteoFactory
234 | eval_data: CriteoFactory
235 | model: DLRMModel
236 | optimizer: recml.Factory[optax.GradientTransformation]
237 |
238 | def create_datasets(self) -> tuple[recml.data.Iterator, recml.data.Iterator]:
239 | global_batch_size = self.train_data.global_batch_size
240 | train_iter = recml.data.TFDatasetIterator(
241 | dataset=self.train_data.make(),
242 | postprocessor=self.model.embedder.make_preprocessor(global_batch_size),
243 | )
244 | eval_iter = recml.data.TFDatasetIterator(
245 | dataset=self.eval_data.make(),
246 | postprocessor=self.model.embedder.make_preprocessor(global_batch_size),
247 | )
248 | return train_iter, eval_iter
249 |
250 | def create_state(self, batch: jt.PyTree, rng: jt.Array) -> recml.JaxState:
251 | inputs, _ = batch
252 | params = self.model.init(rng, inputs)
253 | optimizer = self.optimizer.make()
254 | return recml.JaxState.create(params=params, tx=optimizer)
255 |
256 | def train_step(
257 | self, batch: jt.PyTree, state: recml.JaxState, rng: jt.Array
258 | ) -> tuple[recml.JaxState, Mapping[str, recml.Metric]]:
259 | inputs, label = batch
260 |
261 | def _loss_fn(params: jt.PyTree) -> tuple[jt.Scalar, jt.Array]:
262 | logits = self.model.apply(params, inputs, training=True)
263 | loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0)
264 | return loss, logits
265 |
266 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True, allow_int=True)
267 | (loss, logits), grads = grad_fn(state.params)
268 | state = state.update(grads=grads)
269 |
270 | metrics = {
271 | 'loss': recml.metrics.scalar(loss),
272 | 'accuracy': recml.metrics.binary_accuracy(label, logits, threshold=0.0),
273 | 'auc': recml.metrics.aucpr(label, logits, from_logits=True),
274 | 'aucroc': recml.metrics.aucroc(label, logits, from_logits=True),
275 | 'label/mean': recml.metrics.mean(label),
276 | 'prediction/mean': recml.metrics.mean(jax.nn.sigmoid(logits)),
277 | }
278 | return state, metrics
279 |
280 | def eval_step(
281 | self, batch: jt.PyTree, state: recml.JaxState
282 | ) -> Mapping[str, recml.Metric]:
283 | inputs, label = batch
284 | logits = self.model.apply(state.params, inputs, training=False)
285 | loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0)
286 |
287 | metrics = {
288 | 'loss': recml.metrics.mean(loss),
289 | 'accuracy': recml.metrics.binary_accuracy(label, logits, threshold=0.0),
290 | 'auc': recml.metrics.aucpr(label, logits, from_logits=True),
291 | 'aucroc': recml.metrics.aucroc(label, logits, from_logits=True),
292 | 'label/mean': recml.metrics.mean(label),
293 | 'prediction/mean': recml.metrics.mean(jax.nn.sigmoid(logits)),
294 | }
295 | return metrics
296 |
297 |
298 | def features() -> fdl.Config[FeatureSet]:
299 | """Creates a feature collection for the DLRM model."""
300 | table_sizes = [
301 | (40000000, 3),
302 | (39060, 2),
303 | (17295, 1),
304 | (7424, 2),
305 | (20265, 6),
306 | (3, 1),
307 | (7122, 1),
308 | (1543, 1),
309 | (63, 1),
310 | (40000000, 7),
311 | (3067956, 3),
312 | (405282, 8),
313 | (10, 1),
314 | (2209, 6),
315 | (11938, 9),
316 | (155, 5),
317 | (4, 1),
318 | (976, 1),
319 | (14, 1),
320 | (40000000, 12),
321 | (40000000, 100),
322 | (40000000, 27),
323 | (590152, 10),
324 | (12973, 3),
325 | (108, 1),
326 | (36, 1),
327 | ]
328 | return fdl.Config(
329 | FeatureSet,
330 | features=[
331 | fdl.Config(DenseFeature, name=f'float-feature-{i}') for i in range(13)
332 | ]
333 | + [
334 | fdl.Config(
335 | SparseFeature,
336 | vocab_size=vocab_size,
337 | embedding_dim=embedding_dim,
338 | name=f'categorical-feature-{i}',
339 | )
340 | for i, (vocab_size, embedding_dim) in enumerate(table_sizes)
341 | ],
342 | )
343 |
344 |
345 | def experiment() -> fdl.Config[recml.Experiment]:
346 | """DLRM experiment."""
347 |
348 | feature_set = features()
349 |
350 | task = fdl.Config(
351 | PredictionTask,
352 | train_data=fdl.Config(
353 | CriteoFactory,
354 | features=feature_set,
355 | global_batch_size=131_072,
356 | ),
357 | eval_data=fdl.Config(
358 | CriteoFactory,
359 | features=feature_set,
360 | global_batch_size=131_072,
361 | use_cached_data=True,
362 | ),
363 | model=fdl.Config(
364 | DLRMModel,
365 | features=feature_set,
366 | embedding_optimizer=fdl.Config(
367 | embedding_spec.AdagradOptimizerSpec,
368 | learning_rate=0.01,
369 | ),
370 | bottom_mlp_dims=[512, 256, 128],
371 | top_mlp_dims=[1024, 1024, 512, 256, 1],
372 | dcn_layers=3,
373 | dcn_inner_dim=512,
374 | ),
375 | optimizer=fdl.Config(
376 | recml.AdagradFactory,
377 | learning_rate=0.01,
378 | # Sparsecore embedding parameters are optimized in the backward pass.
379 | freeze_mask=rf'.*{sparsecore.EMBEDDING_PARAM_NAME}.*',
380 | ),
381 | )
382 | trainer = fdl.Config(
383 | recml.JaxTrainer,
384 | partitioner=fdl.Config(recml.DataParallelPartitioner),
385 | train_steps=1_000,
386 | steps_per_eval=100,
387 | steps_per_loop=100,
388 | )
389 | return fdl.Config(recml.Experiment, task=task, trainer=trainer)
390 |
--------------------------------------------------------------------------------
/recml/examples/dlrm_experiment_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for the DLRM experiment."""
15 |
16 | from absl.testing import absltest
17 | import fiddle as fdl
18 | from fiddle import selectors
19 | import jax
20 | import numpy as np
21 | import recml
22 | from recml.examples import dlrm_experiment
23 |
24 |
25 | class DLRMExperimentTest(absltest.TestCase):
26 |
27 | def test_dlrm_experiment(self):
28 | if jax.devices()[0].platform != "tpu":
29 | self.skipTest("Test only supported on TPUs.")
30 |
31 | np.random.seed(1337)
32 |
33 | experiment = dlrm_experiment.experiment()
34 |
35 | experiment.task.train_data.global_batch_size = 4
36 | experiment.task.eval_data.global_batch_size = 4
37 | experiment.trainer.train_steps = 12
38 | experiment.trainer.steps_per_loop = 4
39 | experiment.trainer.steps_per_eval = 4
40 |
41 | for cfg in selectors.select(experiment, dlrm_experiment.SparseFeature):
42 | cfg.vocab_size = 200
43 | cfg.embedding_dim = 8
44 |
45 | experiment = fdl.build(experiment)
46 | recml.run_experiment(experiment, recml.Experiment.Mode.TRAIN_AND_EVAL)
47 |
48 |
49 | if __name__ == "__main__":
50 | absltest.main()
51 |
--------------------------------------------------------------------------------
/recml/layers/keras/README.md:
--------------------------------------------------------------------------------
1 | ## Models
2 |
3 | ### SASRec
4 |
5 | Uses self-attention to predict a user's next action based on their past
6 | activities. It aims to understand long-term user interests while also making
7 | good predictions based on just the most recent actions. It smartly adapts which
8 | past actions to focus on depending on how much history a user has. Built
9 | entirely with efficient attention blocks, SASRec avoids the complex structures
10 | of older RNN or CNN models, leading to faster training and better performance on
11 | diverse datasets.
12 |
13 | #### Architecture Overview
14 |
15 | - **Embedding Layer** - Converts item IDs into dense vectors. Adds a learnable
16 | absolute positional embedding to the item embedding to incorporate sequence
17 | order information. Dropout is applied to the combined embedding.
18 | - **Multi-Head Self-Attention Layer** Computes attention scores between all
19 | pairs of items within the allowed sequence window. Employs causality by
20 | masking out attention to future positions to prevent information leakage
21 | when training with a causal prediction objective.
22 | - **Feed-Forward Network** Applied independently to each embedding vector
23 | output by the attention layer. Uses two linear layers with a GeLU activation
24 | in between to add non-linearity.
25 | - **Residual Connections and Pre-Layernorm** Applied around both the
26 | self-attention and feed-forward network sub-layers for stable and faster
27 | training of deeper models. Dropout is also used within the block.
28 | - **Prediction Head** Decodes the sequence embeddings into logits using the
29 | input item embedding table and computes a causal categorical cross entropy
30 | loss between the inputs and the inputs shifted right.
31 |
32 | ### BERT4Rec
33 |
34 | Models how user preferences change based on their past actions for
35 | recommendations. Unlike older methods that only look at history in chronological
36 | order, BERT4Rec uses a transformer based approach to look at the user's sequence
37 | of actions in both directions. This helps capture context better, as user
38 | behavior isn't always strictly ordered. To learn effectively, it is trained
39 | using a mask prediction objective: some items are randomly masked and the model
40 | learns to predict them based on the context.. BERT4Rec consistently performs
41 | better than many standard sequential models.
42 |
43 | #### Architecture Overview
44 |
45 | - **Embedding Layer** - Converts item IDs into dense vectors. Adds a learnable
46 | absolute positional embedding to the item embedding to incorporate sequence
47 | order information. An optional type embedding can be added to the item
48 | embedding. Embedding dropout is applied to the combined embedding. Uses a
49 | separate embedding for masked features to prevent other item tokens from
50 | attending to them.
51 |
52 | - **Multi-Head Self-Attention Layer** Computes attention scores between all
53 | pairs of items within the allowed sequence window. Uses a separate embedding
54 | for masked features to prevent other item tokens from attending to them.
55 |
56 | - **Feed-Forward Network** Applied independently to each embedding vector
57 | output by the attention layer. Uses two linear layers with a GeLU activation
58 | in between to add non-linearity.
59 |
60 | - **Residual Connections and Post-Layernorm** Applied around both the
61 | self-attention and feed-forward network sub-layers for stable and faster
62 | training of deeper models. Dropout is also used within the block.
63 |
64 | - **Masked Prediction Head** Gathers and projects the masked sequence
65 | embeddings, and decodes them using the item embedding layer. Computes a
66 | categorical cross entropy loss between the masked item ids and the predicted
67 | logits for the corresponding masked item embeddings.
68 |
69 | ### HSTU
70 |
71 | HSTU is a novel architecture designed for sequential recommendation,
72 | particularly suited for high cardinality, non-stationary streaming data. It
73 | reformulates recommendation as a sequential transduction task within a
74 | generative modeling framework -"Generative Recommenders". HSTU aims to provide
75 | state-of-the-art results while being highly scalable and efficient, capable of
76 | handling models with up to trillions of parameters. It has demonstrated
77 | significant improvements over baselines in offline benchmarks and online A/B
78 | tests, leading to deployment on large-scale internet platforms.
79 |
80 | #### Architecture Overview
81 |
82 | - **Embedding Layer** Converts various action tokens into dense vectors in the
83 | same space. Optionally, adds a learnable absolute positional embedding to
84 | incorporate sequence order information. Embedding dropout is applied to the
85 | combined embedding.
86 | - **Gated Pointwise Aggregated Attention** - Uses a multi-head gated pointwise
87 | attention mechanism with a Layernorm on the attention outputs before
88 | projecting them. This captures the intensity of interactions between
89 | actions, which is lost in softmax attention.
90 | - **Relative Attention Bias** - Uses a T5 style relative attention bias
91 | computed using the positions and timestamps of the actions to improve the
92 | position encoding.
93 | - **Residual Connections and Pre-Layernorm** Applied around both the pointwise
94 | attention blocks for stable and faster training of deeper models.
95 | - **No Feedforward Network** - The feedforward network is removed.
96 | - **Prediction Head** - Decodes the sequence embeddings into logits using
97 | separately learnt weights and computes a causal categorical cross entropy
98 | loss between the inputs and the inputs shifted right.
99 |
100 | ### Mamba4Rec
101 |
102 | A linear recurrent Mamba 2 architecture to model sequences of items for
103 | recommendations. This scales better on longer sequences than attention based
104 | methods due to its linear complexity compared to the former's quadratic
105 | complexity. Mamba4Rec performs better than RNNs and matches the quality of
106 | standard attention models while being more efficient at both training and
107 | inference time.
108 |
109 | #### Architecture Overview
110 |
111 | - **Embedding Layer** Converts item IDs into dense vectors. No position
112 | embedding is used since the recurrent nature of Mamba inherently encodes
113 | positional information as an inductive bias.
114 | - **Mamba SSD** Computes a causal interaction between different item
115 | embeddings in the sequence using the Mamba state space duality algorithm.
116 | - **Feedforward Network** Applied independently to each embedding vector
117 | output by the Mamba layer. Uses two linear layers with a GeLU activation in
118 | between to add non-linearity.
119 | - **Residual Connections and Post-Layernorm** Applied around both the Mamba
120 | and feed-forward network sub-layers for stable and faster training of deeper
121 | models. Dropout is also used within the block.
122 | - **Prediction Head** Decodes the sequence embeddings into logits using the
123 | input item embedding table and computes a causal categorical cross entropy
124 | loss between the inputs and the inputs shifted right.
125 |
126 | ## References
127 |
128 | - SASRec Paper: Kang, W. C., & McAuley, J. (2018). Self-Attentive Sequential
129 | Recommendation. arXiv preprint arXiv:1808.09781v1.
130 | https://arxiv.org/abs/1808.09781
131 | - Transformer Paper: Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J.,
132 | Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you
133 | need. Advances in neural information processing systems, 30.
134 | - Mamba4Rec Paper: Liu, C., Lin, J., Liu, H., Wang, J., & Caverlee, J. (2024).
135 | Mamba4Rec: Towards Efficient Sequential Recommendation with Selective State
136 | Space Models. arXiv preprint arXiv:2403.03900v2.
137 | https://arxiv.org/abs/2403.03900
138 | - Mamba Paper: Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling
139 | with Selective State Spaces. arXiv preprint arXiv:2312.00752.
140 | - BERT4Rec Paper: Sun, F., Liu, J., Wu, J., Pei, C., Lin, X., Ou, W., & Jiang,
141 | P. (2019). BERT4Rec: Sequential Recommendation with Bidirectional Encoder
142 | Representations from Transformer. arXiv preprint arXiv:1904.06690v2.
143 | https://arxiv.org/abs/1904.06690
144 | - BERT Paper: Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). BERT:
145 | Pre-training of Deep Bidirectional Transformers for Language Understanding.
146 | arXiv preprint arXiv:1810.04805.
147 | - HSTU Paper: Actions Speak Louder than Words: Trillion-Parameter Sequential
148 | Transducers for Generative Recommendations (arXiv:2402.17152)
149 |
--------------------------------------------------------------------------------
/recml/layers/keras/bert4rec.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Models baselined."""
15 |
16 | from collections.abc import Mapping, Sequence
17 | from typing import Any
18 |
19 | import keras
20 | import keras_hub
21 | from recml.layers.keras import utils
22 |
23 | Tensor = Any
24 |
25 |
26 | @keras.saving.register_keras_serializable("recml")
27 | class BERT4Rec(keras.layers.Layer):
28 | """BERT4Rec architecture as in [1].
29 |
30 | Implements the BERT4Rec model architecture as described in 'BERT4Rec:
31 | Sequential Recommendation with Bidirectional Encoder Representations from
32 | Transformer' [1].
33 |
34 | [1] https://arxiv.org/abs/1904.06690
35 | """
36 |
37 | def __init__(
38 | self,
39 | *,
40 | vocab_size: int,
41 | max_positions: int,
42 | num_types: int | None = None,
43 | model_dim: int,
44 | mlp_dim: int,
45 | num_heads: int,
46 | num_layers: int,
47 | dropout: float = 0.0,
48 | norm_eps: float = 1e-12,
49 | add_head: bool = True,
50 | **kwargs,
51 | ):
52 | """Initializes the instance.
53 |
54 | Args:
55 | vocab_size: The size of the item vocabulary.
56 | max_positions: The maximum number of positions in a sequence.
57 | num_types: The number of types. If None, no type embedding is used.
58 | Defaults to None.
59 | model_dim: The width of the embeddings in the model.
60 | mlp_dim: The width of the MLP in each transformer block.
61 | num_heads: The number of attention heads in each transformer block.
62 | num_layers: The number of transformer blocks in the model.
63 | dropout: The dropout rate. Defaults to 0.
64 | norm_eps: The epsilon for layer normalization.
65 | add_head: Whether to add a masked language modeling head.
66 | **kwargs: Passed through to the super class.
67 | """
68 |
69 | super().__init__(**kwargs)
70 |
71 | self.item_embedding = keras_hub.layers.ReversibleEmbedding(
72 | input_dim=vocab_size,
73 | output_dim=model_dim,
74 | embeddings_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
75 | dtype=self.dtype_policy,
76 | reverse_dtype=self.compute_dtype,
77 | name="item_embedding",
78 | )
79 | if num_types is not None:
80 | self.type_embedding = keras.layers.Embedding(
81 | input_dim=num_types,
82 | output_dim=model_dim,
83 | embeddings_initializer=keras.initializers.TruncatedNormal(
84 | stddev=0.02
85 | ),
86 | dtype=self.dtype_policy,
87 | name="type_embedding",
88 | )
89 | else:
90 | self.type_embedding = None
91 |
92 | self.position_embedding = keras_hub.layers.PositionEmbedding(
93 | sequence_length=max_positions,
94 | initializer=keras.initializers.TruncatedNormal(stddev=0.02),
95 | dtype=self.dtype_policy,
96 | name="position_embedding",
97 | )
98 |
99 | self.embeddings_norm = keras.layers.LayerNormalization(
100 | epsilon=1e-12, name="embedding_norm"
101 | )
102 | self.embeddings_dropout = keras.layers.Dropout(
103 | dropout, name="embedding_dropout"
104 | )
105 |
106 | self.encoder_blocks = [
107 | keras_hub.layers.TransformerEncoder(
108 | intermediate_dim=mlp_dim,
109 | num_heads=num_heads,
110 | dropout=dropout,
111 | activation=utils.gelu_approximate,
112 | layer_norm_epsilon=norm_eps,
113 | normalize_first=False,
114 | dtype=self.dtype_policy,
115 | name=f"encoder_block_{i}",
116 | )
117 | for i in range(num_layers)
118 | ]
119 | if add_head:
120 | self.head = keras_hub.layers.MaskedLMHead(
121 | vocabulary_size=vocab_size,
122 | token_embedding=self.item_embedding,
123 | intermediate_activation=utils.gelu_approximate,
124 | kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
125 | dtype=self.dtype_policy,
126 | name="mlm_head",
127 | )
128 | else:
129 | self.head = None
130 |
131 | self._vocab_size = vocab_size
132 | self._model_dim = model_dim
133 | self._config = {
134 | "vocab_size": vocab_size,
135 | "max_positions": max_positions,
136 | "num_types": num_types,
137 | "model_dim": model_dim,
138 | "mlp_dim": mlp_dim,
139 | "num_heads": num_heads,
140 | "num_layers": num_layers,
141 | "dropout": dropout,
142 | "norm_eps": norm_eps,
143 | "add_head": add_head,
144 | }
145 |
146 | def build(self, inputs_shape: Sequence[int]):
147 | self.item_embedding.build(inputs_shape)
148 | if self.type_embedding is not None:
149 | self.type_embedding.build(inputs_shape)
150 |
151 | self.position_embedding.build((*inputs_shape, self._model_dim))
152 | self.embeddings_norm.build((*inputs_shape, self._model_dim))
153 |
154 | for encoder_block in self.encoder_blocks:
155 | encoder_block.build((*inputs_shape, self._model_dim))
156 |
157 | if self.head is not None:
158 | self.head.build((*inputs_shape, self._model_dim))
159 |
160 | def call(
161 | self,
162 | inputs: Tensor,
163 | type_ids: Tensor | None = None,
164 | padding_mask: Tensor | None = None,
165 | attention_mask: Tensor | None = None,
166 | mask_positions: Tensor | None = None,
167 | training: bool = False,
168 | ) -> Tensor:
169 | embeddings = self.item_embedding(inputs)
170 | if self.type_embedding is not None:
171 | if type_ids is None:
172 | raise ValueError(
173 | "`type_ids` cannot be None when `num_types` is not None."
174 | )
175 | embeddings += self.type_embedding(type_ids)
176 | embeddings += self.position_embedding(embeddings)
177 |
178 | embeddings = self.embeddings_norm(embeddings)
179 | embeddings = self.embeddings_dropout(embeddings, training=training)
180 |
181 | for encoder_block in self.encoder_blocks:
182 | embeddings = encoder_block(
183 | embeddings,
184 | padding_mask=padding_mask,
185 | attention_mask=attention_mask,
186 | training=training,
187 | )
188 |
189 | if self.head is None:
190 | return embeddings
191 |
192 | return self.head(embeddings, mask_positions)
193 |
194 | def compute_output_shape(
195 | self,
196 | inputs_shape: Sequence[int],
197 | mask_positions_shape: Tensor | None = None,
198 | ) -> Sequence[int | None]:
199 | if self.head is not None:
200 | if mask_positions_shape is None:
201 | raise ValueError(
202 | "`mask_positions_shape` cannot be None when `add_head` is True."
203 | )
204 | return (*inputs_shape[:-1], mask_positions_shape[-1], self._vocab_size)
205 | return (*inputs_shape, self._model_dim)
206 |
207 | def get_config(self) -> Mapping[str, Any]:
208 | return {**super().get_config(), **self._config}
209 |
--------------------------------------------------------------------------------
/recml/layers/keras/bert4rec_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for Keras architectures."""
15 |
16 | from absl.testing import absltest
17 | import keras
18 | from keras.src import testing
19 | from recml.layers.keras import bert4rec
20 |
21 |
22 | class BERT4RecTest(testing.TestCase):
23 |
24 | def test_bert4rec(self):
25 | item_ids = keras.ops.array([[1, 2, 3], [4, 5, 0]], "int32")
26 | item_type_ids = keras.ops.array([[1, 2, 3], [4, 4, 0]], "int32")
27 | mask = keras.ops.array([[1, 1, 1], [1, 1, 0]], "int32")
28 | mask_positions = keras.ops.array([[0], [0]], "int32")
29 | init_kws = {
30 | "vocab_size": 500,
31 | "num_types": 5,
32 | "max_positions": 20,
33 | "model_dim": 32,
34 | "mlp_dim": 64,
35 | "num_heads": 4,
36 | "num_layers": 3,
37 | "dropout": 0.1,
38 | }
39 |
40 | tvars = (
41 | (500 * 32) # Item embedding
42 | + (5 * 32) # Type embedding
43 | + (20 * 32) # Position embedding
44 | + (2 * 32) # Embedding norm
45 | + 3 # 3 encoder blocks
46 | * (
47 | ((32 + 1) * 32 * 3 + (32 + 1) * 32) # Attention QKVO
48 | + (2 * 32) # Attention block norm
49 | + ((32 + 1) * 64) # MLP inner projection
50 | + ((64 + 1) * 32) # MLP outer projection
51 | + (2 * 32) # MLP block norm
52 | )
53 | + (32 + 1) * 32 # Head projection
54 | + (2 * 32) # Head norm
55 | + 500 # Head bias
56 | )
57 | seed_generators = 1 + 3 * 3 # 1 seed generator for each dropout layer.
58 |
59 | model = bert4rec.BERT4Rec(**init_kws)
60 | model.build(keras.ops.shape(item_ids))
61 | self.assertEqual(model.count_params(), tvars)
62 |
63 | self.run_layer_test(
64 | bert4rec.BERT4Rec,
65 | init_kwargs={**init_kws, "add_head": False},
66 | input_data=item_ids,
67 | call_kwargs={
68 | "type_ids": item_type_ids,
69 | "padding_mask": mask,
70 | "mask_positions": mask_positions,
71 | },
72 | expected_output_shape=(2, 3, 32),
73 | expected_output_dtype="float32",
74 | expected_num_seed_generators=seed_generators,
75 | run_training_check=False,
76 | )
77 |
78 |
79 | if __name__ == "__main__":
80 | absltest.main()
81 |
--------------------------------------------------------------------------------
/recml/layers/keras/hstu_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for the HSTU implementation."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import keras
19 | from keras.src import testing
20 | import numpy as np
21 | from recml.layers.keras import hstu
22 |
23 |
24 | class HSTUTest(testing.TestCase):
25 |
26 | def test_hstu(self):
27 | item_ids = keras.ops.array([[1, 2, 3], [4, 5, 0]], "int32")
28 | padding_mask = keras.ops.array([[1, 1, 1], [1, 1, 0]], "int32")
29 | init_kws = {
30 | "vocab_size": 500,
31 | "max_positions": 20,
32 | "model_dim": 32,
33 | "num_heads": 4,
34 | "num_layers": 3,
35 | "dropout": 0.1,
36 | }
37 |
38 | tvars = (
39 | (500 * 32) # Item embedding
40 | + (20 * 32) # Position embedding
41 | + 3 # 3 decoder blocks
42 | * (
43 | (32 * 32 * 4 + 32 * 32) # UQKV + output
44 | + (2 * 32) * 2 # Input + attention Layer norms.
45 | )
46 | + (2 * 32) # Final norm
47 | + (500 * 32) # Output embedding
48 | )
49 | seed_generators = 1 + 3 # 1 seed generator for each dropout layer.
50 | model = hstu.HSTU(**init_kws)
51 | model.build(keras.ops.shape(item_ids))
52 | self.assertEqual(model.count_params(), tvars)
53 |
54 | self.run_layer_test(
55 | hstu.HSTU,
56 | init_kwargs=init_kws,
57 | input_data=item_ids,
58 | call_kwargs={"padding_mask": padding_mask},
59 | expected_output_shape=(2, 3, 500),
60 | expected_output_dtype="float32",
61 | expected_num_seed_generators=seed_generators,
62 | run_training_check=False,
63 | )
64 |
65 |
66 | if __name__ == "__main__":
67 | absltest.main()
68 |
--------------------------------------------------------------------------------
/recml/layers/keras/mamba_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for the Mamba implementation."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import einops
19 | import keras
20 | from keras.src import testing
21 | import numpy as np
22 | from recml.layers.keras import mamba
23 | import tensorflow as tf
24 | import torch
25 | import torch.nn.functional as F
26 |
27 |
28 | # originally found here:
29 | # https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py
30 | def segsum(x):
31 | """More stable segment sum calculation."""
32 | t = x.size(-1)
33 | x = einops.repeat(x, "... d -> ... d e", e=t)
34 | mask = torch.tril(torch.ones(t, t, device=x.device, dtype=bool), diagonal=-1)
35 | x = x.masked_fill(~mask, 0)
36 | x_segsum = torch.cumsum(x, dim=-2)
37 | mask = torch.tril(torch.ones(t, t, device=x.device, dtype=bool), diagonal=0)
38 | x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
39 | return x_segsum
40 |
41 |
42 | # originally found here:
43 | # https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py
44 | def ssd_minimal_discrete(x, a, b, c, block_len, initial_states=None):
45 | """Original Pytorch implementation of Mamba2.
46 |
47 | Args:
48 | x: (batch, length, n_heads, d_head)
49 | a: (batch, length, n_heads)
50 | b: (batch, length, n_groups, d_state)
51 | c: (batch, length, n_groups, d_state)
52 | block_len: int
53 | initial_states: tensor of initial state values.
54 |
55 | Returns:
56 | Y: (batch, length, n_heads, d_head)
57 | """
58 | assert x.dtype == a.dtype == b.dtype == c.dtype
59 | assert x.shape[1] % block_len == 0
60 |
61 | # Rearrange into blocks/chunks
62 | x, a, b, c = [
63 | einops.rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
64 | for x in (x, a, b, c)
65 | ]
66 |
67 | a = einops.rearrange(a, "b c l h -> b h c l")
68 | a_cumsum = torch.cumsum(a, dim=-1)
69 |
70 | # 1. Compute the output for each intra-chunk (diagonal blocks)
71 | length = torch.exp(segsum(a))
72 | y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", c, b, length, x)
73 |
74 | # 2. Compute the state for each intra-chunk
75 | # (right term of low-rank factorization of off-diagonal blocks; B terms)
76 | decay_states = torch.exp((a_cumsum[:, :, :, -1:] - a_cumsum))
77 | states = torch.einsum("bclhn,bhcl,bclhp->bchpn", b, decay_states, x)
78 |
79 | # 3. Compute the inter-chunk SSM recurrence;
80 | # produces correct SSM states at chunk boundaries
81 | # (middle term of factorization of off-diag blocks; A terms)
82 | if initial_states is None:
83 | initial_states = torch.zeros_like(states[:, :1])
84 | states = torch.cat([initial_states, states], dim=1)
85 | decay_chunk = torch.exp(segsum(F.pad(a_cumsum[:, :, :, -1], (1, 0))))
86 | new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
87 | states, final_state = new_states[:, :-1], new_states[:, -1]
88 |
89 | # 4. Compute state -> output conversion per chunk
90 | # (left term of low-rank factorization of off-diagonal blocks; C terms)
91 | state_decay_out = torch.exp(a_cumsum)
92 | y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", c, states, state_decay_out)
93 |
94 | # Add output of intra-chunk any_offer-chunk terms
95 | # (diagonal and off-diagonal blocks)
96 | y = einops.rearrange(y_diag + y_off, "b c l h p -> b (c l) h p")
97 | return y, final_state
98 |
99 |
100 | class MambaSSDTest(testing.TestCase):
101 |
102 | # Simple equivalence test
103 | @parameterized.parameters(dict(seed=40), dict(seed=50), dict(seed=70))
104 | def test_ssd_correctness(self, seed: int):
105 | keras.utils.set_random_seed(seed)
106 |
107 | ## Dimensions
108 | # Denoted (B, T, Q, D, P) in the paper
109 | batch, seqlen, chunk_size, dim, nheads = 1, 2048, 64, 2048, 32
110 | ngroups = 1 # (G) in the paper
111 | dstate = 64 # (N) in the paper
112 |
113 | dtype = "float32"
114 | x = keras.random.normal((batch, seqlen, nheads, dim // nheads), dtype=dtype)
115 | dt = keras.ops.nn.softplus(
116 | keras.random.normal((batch, seqlen, nheads), dtype=dtype) - 4
117 | )
118 | a = keras.ops.multiply(
119 | -1, keras.ops.exp(keras.random.normal((nheads,), dtype=dtype))
120 | )
121 | b = keras.random.normal((batch, seqlen, ngroups, dstate), dtype=dtype)
122 | c = keras.random.normal((batch, seqlen, ngroups, dstate), dtype=dtype)
123 |
124 | torch_a = torch.tensor(np.array(a))
125 | torch_b = torch.tensor(np.array(b))
126 | torch_c = torch.tensor(np.array(c))
127 | torch_dt = torch.tensor(np.array(dt))
128 | torch_x = torch.tensor(np.array(x))
129 | ground_truth = ssd_minimal_discrete(
130 | torch_x * torch_dt.unsqueeze(-1),
131 | torch_a * torch_dt,
132 | torch_b,
133 | torch_c,
134 | chunk_size,
135 | )
136 | ours = mamba.ssd_minimal_discrete(
137 | keras.ops.multiply(x, keras.ops.expand_dims(dt, axis=-1)),
138 | keras.ops.multiply(a, dt),
139 | b,
140 | c,
141 | chunk_size,
142 | )
143 |
144 | self.assertAllClose(ground_truth[0], ours, atol=1e-5, rtol=1e-5)
145 |
146 |
147 | class Mamba4RecTest(testing.TestCase):
148 |
149 | def test_mamba4rec(self):
150 | item_ids = keras.ops.array([[1, 2, 3, 4], [4, 5, 0, 0]], "int32")
151 | padding_mask = keras.ops.array([[1, 1, 1, 0], [1, 1, 0, 0]], "int32")
152 | init_kws = {
153 | "vocab_size": 500,
154 | "model_dim": 32,
155 | "mlp_expand": 4,
156 | "num_heads": 4,
157 | "num_layers": 3,
158 | "dropout": 0.1,
159 | "d_expand": 128,
160 | "d_state": 64,
161 | "d_conv": 4,
162 | "chunk_size": 2,
163 | }
164 |
165 | self.run_layer_test(
166 | mamba.Mamba4Rec,
167 | init_kwargs=init_kws,
168 | input_data=item_ids,
169 | call_kwargs={"padding_mask": padding_mask},
170 | expected_output_shape=(2, 4, 500),
171 | expected_output_dtype="float32",
172 | expected_num_seed_generators=1 + 3 * 3,
173 | )
174 |
175 |
176 | if __name__ == "__main__":
177 | absltest.main()
178 |
--------------------------------------------------------------------------------
/recml/layers/keras/sasrec.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Models baselined."""
15 |
16 | from collections.abc import Mapping, Sequence
17 | from typing import Any
18 |
19 | import keras
20 | import keras_hub
21 | from recml.layers.keras import utils
22 |
23 | Tensor = Any
24 |
25 |
26 | @keras.saving.register_keras_serializable("recml")
27 | class SASRec(keras.layers.Layer):
28 | """SASRec architecture as in [1].
29 |
30 | Implements the SASRec model architecture as described in 'Self-Attentive
31 | Sequential Recommendation' [1].
32 |
33 | [1] https://arxiv.org/abs/1808.09781
34 | """
35 |
36 | def __init__(
37 | self,
38 | *,
39 | vocab_size: int,
40 | max_positions: int,
41 | model_dim: int,
42 | mlp_dim: int,
43 | num_heads: int,
44 | num_layers: int,
45 | dropout: float = 0.0,
46 | norm_eps: float = 1e-6,
47 | scale_by_sqrt_dim: bool = False,
48 | add_head: bool = True,
49 | **kwargs,
50 | ):
51 | """Initializes the instance.
52 |
53 | Args:
54 | vocab_size: The size of the item vocabulary.
55 | max_positions: The maximum number of positions in a sequence.
56 | model_dim: The width of the embeddings in the model.
57 | mlp_dim: The width of the MLP in each transformer block.
58 | num_heads: The number of attention heads in each transformer block.
59 | num_layers: The number of transformer blocks in the model.
60 | dropout: The dropout rate. Defaults to 0.
61 | norm_eps: The epsilon for RMS normalization.
62 | scale_by_sqrt_dim: Whether to scale the item embeddings by
63 | sqrt(model_dim). Defaults to False.
64 | add_head: Whether to decode the sequence embeddings to logits.
65 | **kwargs: Passed through to the super class.
66 | """
67 | super().__init__(**kwargs)
68 |
69 | self.item_embedding = keras_hub.layers.ReversibleEmbedding(
70 | input_dim=vocab_size,
71 | output_dim=model_dim,
72 | embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02),
73 | dtype=self.dtype_policy,
74 | reverse_dtype=self.compute_dtype,
75 | name="item_embedding",
76 | )
77 |
78 | self.position_embedding = keras_hub.layers.PositionEmbedding(
79 | sequence_length=max_positions,
80 | initializer=keras.initializers.RandomNormal(stddev=0.02),
81 | dtype=self.dtype_policy,
82 | name="position_embedding",
83 | )
84 |
85 | self.embeddings_dropout = keras.layers.Dropout(
86 | dropout, name="embedding_dropout"
87 | )
88 |
89 | self.decoder_blocks = [
90 | keras_hub.layers.TransformerDecoder(
91 | intermediate_dim=mlp_dim,
92 | num_heads=num_heads,
93 | dropout=dropout,
94 | activation=utils.gelu_approximate,
95 | layer_norm_epsilon=norm_eps,
96 | normalize_first=True,
97 | dtype=self.dtype_policy,
98 | name=f"decoder_block_{i}",
99 | )
100 | for i in range(num_layers)
101 | ]
102 | self.final_norm = keras.layers.LayerNormalization(
103 | epsilon=norm_eps, name="final_norm"
104 | )
105 |
106 | self._vocab_size = vocab_size
107 | self._model_dim = model_dim
108 | self._scale_by_sqrt_dim = scale_by_sqrt_dim
109 | self._add_head = add_head
110 | self._config = {
111 | "vocab_size": vocab_size,
112 | "max_positions": max_positions,
113 | "model_dim": model_dim,
114 | "mlp_dim": mlp_dim,
115 | "num_heads": num_heads,
116 | "num_layers": num_layers,
117 | "dropout": dropout,
118 | "norm_eps": norm_eps,
119 | "scale_by_sqrt_dim": scale_by_sqrt_dim,
120 | "add_head": add_head,
121 | }
122 |
123 | def build(self, inputs_shape: Sequence[int]):
124 | self.item_embedding.build(inputs_shape)
125 | self.position_embedding.build((*inputs_shape, self._model_dim))
126 |
127 | for decoder_block in self.decoder_blocks:
128 | decoder_block.build((*inputs_shape, self._model_dim))
129 |
130 | self.final_norm.build((*inputs_shape, self._model_dim))
131 |
132 | def call(
133 | self,
134 | inputs: Tensor,
135 | padding_mask: Tensor | None = None,
136 | attention_mask: Tensor | None = None,
137 | mask_positions: Tensor | None = None,
138 | training: bool = False,
139 | ) -> Tensor:
140 | embeddings = self.item_embedding(inputs)
141 | if self._scale_by_sqrt_dim:
142 | embeddings *= keras.ops.cast(
143 | self._model_dim**0.5, keras.ops.dtype(embeddings)
144 | )
145 | embeddings += self.position_embedding(embeddings)
146 |
147 | embeddings = self.final_norm(embeddings)
148 | embeddings = self.embeddings_dropout(embeddings, training=training)
149 |
150 | for decoder_block in self.decoder_blocks:
151 | embeddings = decoder_block(
152 | embeddings,
153 | decoder_padding_mask=padding_mask,
154 | decoder_attention_mask=attention_mask,
155 | training=training,
156 | )
157 |
158 | embeddings = self.final_norm(embeddings)
159 |
160 | if not self._add_head:
161 | return embeddings
162 |
163 | return self.item_embedding(embeddings, reverse=True)
164 |
165 | def compute_output_shape(self, inputs_shape: Sequence[int]) -> Sequence[int]:
166 | output_dim = self._vocab_size if self._add_head else self._model_dim
167 | return (*inputs_shape, output_dim)
168 |
169 | def get_config(self) -> Mapping[str, Any]:
170 | return {**super().get_config(), **self._config}
171 |
--------------------------------------------------------------------------------
/recml/layers/keras/sasrec_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for Keras architectures."""
15 |
16 | from absl.testing import absltest
17 | import keras
18 | from keras.src import testing
19 | from recml.layers.keras import sasrec
20 |
21 |
22 | class SASRecTest(testing.TestCase):
23 |
24 | def test_sasrec(self):
25 | item_ids = keras.ops.array([[1, 2, 3], [4, 5, 0]], "int32")
26 | padding_mask = keras.ops.array([[1, 1, 1], [1, 1, 0]], "int32")
27 | init_kws = {
28 | "vocab_size": 500,
29 | "max_positions": 20,
30 | "model_dim": 32,
31 | "mlp_dim": 64,
32 | "num_heads": 4,
33 | "num_layers": 3,
34 | "dropout": 0.1,
35 | }
36 |
37 | tvars = (
38 | (500 * 32) # Item embedding
39 | + (20 * 32) # Position embedding
40 | + 3 # 3 decoder blocks
41 | * (
42 | ((32 + 1) * 32 * 3 + (32 + 1) * 32) # Attention QKVO
43 | + (2 * 32) # Attention block norm
44 | + ((32 + 1) * 64) # MLP inner projection
45 | + ((64 + 1) * 32) # MLP outer projection
46 | + (2 * 32) # MLP block norm
47 | )
48 | + (2 * 32) # Final norm
49 | )
50 | seed_generators = 1 + 3 * 3 # 1 seed generator for each dropout layer.
51 | model = sasrec.SASRec(**init_kws)
52 | model.build(keras.ops.shape(item_ids))
53 | self.assertEqual(model.count_params(), tvars)
54 |
55 | self.run_layer_test(
56 | sasrec.SASRec,
57 | init_kwargs=init_kws,
58 | input_data=item_ids,
59 | call_kwargs={"padding_mask": padding_mask},
60 | expected_output_shape=(2, 3, 500),
61 | expected_output_dtype="float32",
62 | expected_num_seed_generators=seed_generators,
63 | run_training_check=False,
64 | )
65 |
66 |
67 | if __name__ == "__main__":
68 | absltest.main()
69 |
--------------------------------------------------------------------------------
/recml/layers/keras/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Layer utilities."""
15 |
16 | from typing import Any
17 | import keras
18 |
19 | Tensor = Any
20 |
21 |
22 | def clone_initializer(initializer: Any) -> Any:
23 | """Clones an initializer."""
24 | if isinstance(initializer, keras.initializers.Initializer):
25 | return initializer.clone()
26 | return initializer
27 |
28 |
29 | def make_attention_mask(mask: Tensor, dtype: str = "float32") -> Tensor:
30 | """Creates a 3D self-attention mask from a padding mask."""
31 | # Element wise pairwise function on [B, L, 1], [B, 1, L].
32 | attention_mask = keras.ops.multiply(
33 | keras.ops.expand_dims(mask, axis=-1),
34 | keras.ops.expand_dims(mask, axis=-2),
35 | )
36 | return keras.ops.cast(attention_mask, dtype=dtype)
37 |
38 |
39 | def make_causal_mask(mask: Tensor, dtype: str = "float32") -> Tensor:
40 | """Creates a 3D causal self-attention mask from a padding mask."""
41 | return keras.ops.tril(make_attention_mask(mask, dtype=dtype))
42 |
43 |
44 | @keras.saving.register_keras_serializable("recml")
45 | def gelu_approximate(x: Tensor) -> Tensor:
46 | """Approximate GELU activation function."""
47 | return keras.activations.gelu(x, approximate=True)
48 |
49 |
50 | @keras.saving.register_keras_serializable("recml")
51 | def relu_squared(x: Tensor) -> Tensor:
52 | """RELU squared activation function."""
53 | return keras.ops.square(keras.activations.relu(x))
54 |
55 |
56 | @keras.saving.register_keras_serializable("recml")
57 | def norm_embedding_post_processor(inputs: Tensor, eps: float = 1e-6) -> Tensor:
58 | """L2 Normalization Post Processor for HSTU.
59 |
60 | Take output embeddings and normalize them to unit length.
61 |
62 | Args:
63 | inputs: The input sequence tensor. shape = [B, N, D]
64 | eps: Epsilon to use for division.
65 |
66 | Returns:
67 | The normalized output embeddings.
68 | """
69 | return keras.ops.divide(
70 | inputs,
71 | keras.ops.clip(
72 | keras.ops.norm(inputs, ord=None, axis=-1, keepdims=True),
73 | x_min=eps,
74 | x_max=None,
75 | ),
76 | )
77 |
78 |
79 | def apply_rotary_encoding(
80 | x: Tensor, *, positions: Tensor | None = None, max_wavelength: int
81 | ) -> Tensor:
82 | """Returns the rotary positional encodings.
83 |
84 | Args:
85 | x: Array of embeddings of shape [*batch_size, seq_len, num_heads, head_dim].
86 | Where head_dim must be even.
87 | positions: Optional array of shape [*batch_size, seq_len] holding the
88 | position of each token in the sequence. If not provided, the input is
89 | assumed to be a contiguous sequence and the positions are therefore [0, 1,
90 | ..., seq_len - 1] for each example.
91 | max_wavelength: Maximum wavelength that will appear in sin / cosine
92 | waveforms. This specifies the maximum sequence length for identifying
93 | unique positions.
94 |
95 | Returns:
96 | Array of rotary encoded input of shape [batch_size, seq_len, num_heads,
97 | head_dim].
98 | """
99 | x_shape = keras.ops.shape(x)
100 | b = (x_shape[i] for i in range(len(x_shape) - 3))
101 | seq_len = x_shape[-3]
102 | if x_shape[-1] % 2 != 0:
103 | raise ValueError(
104 | "Embedding dimension must be even, but got"
105 | f" {x_shape[-1]} for input of shape {x_shape}."
106 | )
107 | if len(x_shape) < 4:
108 | raise ValueError(
109 | f"Unexpected input shape: {x_shape}. Expected shape of rank 4 or"
110 | " greater."
111 | )
112 | if positions is None:
113 | positions = keras.ops.tile(
114 | keras.ops.arange(seq_len)[None, :],
115 | (*[d if d is not None else -1 for d in b], 1),
116 | )
117 | # Only do shape checks on not TF backends.
118 | if keras.backend.backend() != "tensorflow":
119 | if keras.ops.shape(positions) != x_shape[:-2]:
120 | raise ValueError(
121 | f"Positions must be of shape: {(x_shape[:-2])} but got shape:"
122 | f" {keras.ops.shape(positions)}."
123 | )
124 | freq_exponents = (2.0 / x_shape[-1]) * keras.ops.arange(
125 | x_shape[-1] // 2, dtype="float32"
126 | )
127 | timescale = max_wavelength**freq_exponents
128 | timescale = timescale[
129 | (*[None for _ in b], None, slice(None))
130 | ] # timescale[None, None, :] when len(b) == 1
131 | radians = keras.ops.cast(positions[..., None], "float32") / timescale
132 | radians = radians[..., None, :]
133 | # radians.shape = [...,L,1,d=D/2]
134 | sin, cos = keras.ops.sin(radians), keras.ops.cos(radians)
135 | x1, x2 = keras.ops.split(x, 2, axis=-1)
136 | x1, x2 = keras.ops.cast(x1, "float32"), keras.ops.cast(x2, "float32")
137 | res = keras.ops.concatenate(
138 | [x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1
139 | )
140 | return keras.ops.cast(res, keras.ops.dtype(x))
141 |
142 |
143 | def large_negative_for_attention(dtype: Any) -> float:
144 | """Return a large negative number based on dtype."""
145 | if keras.backend.standardize_dtype(dtype) == "float16":
146 | return -3e4
147 | return -1e9
148 |
--------------------------------------------------------------------------------
/recml/layers/keras/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for layer utilities."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import keras
19 | from keras.src import testing
20 | import numpy as np
21 | from recml.layers.keras import utils
22 |
23 |
24 | class UtilsTest(testing.TestCase):
25 |
26 | def test_clone_initializer(self):
27 | random_initializer = keras.initializers.RandomNormal(stddev=1.0)
28 | random_initializer_clone = utils.clone_initializer(random_initializer)
29 | self.assertNotEqual(random_initializer.seed, random_initializer_clone.seed)
30 |
31 | lecun_initializer = keras.initializers.LecunNormal(seed=1)
32 | lecun_initializer_clone = utils.clone_initializer(lecun_initializer)
33 | self.assertEqual(lecun_initializer.seed, lecun_initializer_clone.seed)
34 |
35 | self.assertEqual(utils.clone_initializer("lecun_normal"), "lecun_normal")
36 |
37 | # Remember to read these sideways =))
38 | @parameterized.parameters(
39 | dict(
40 | inputs=np.array([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=np.float32),
41 | expected_outputs=np.array(
42 | [
43 | [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
44 | [[1, 0, 0], [0, 0, 0], [0, 0, 0]],
45 | [[1, 1, 0], [1, 1, 0], [0, 0, 0]],
46 | ],
47 | dtype=np.float32,
48 | ),
49 | ),
50 | )
51 | def test_make_attention_mask(
52 | self, inputs: np.array, expected_outputs: np.array
53 | ):
54 | self.assertAllClose(
55 | utils.make_attention_mask(keras.ops.array(inputs)),
56 | keras.ops.array(expected_outputs),
57 | )
58 |
59 | @parameterized.parameters(
60 | dict(
61 | inputs=np.array([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=np.float32),
62 | expected_outputs=np.array(
63 | [
64 | [[1, 0, 0], [1, 1, 0], [1, 1, 1]],
65 | [[1, 0, 0], [0, 0, 0], [0, 0, 0]],
66 | [[1, 0, 0], [1, 1, 0], [0, 0, 0]],
67 | ],
68 | dtype=np.float32,
69 | ),
70 | ),
71 | )
72 | def test_make_causal_mask(self, inputs: np.array, expected_outputs: np.array):
73 | self.assertAllClose(
74 | utils.make_causal_mask(keras.ops.array(inputs)),
75 | keras.ops.array(expected_outputs),
76 | )
77 |
78 | @parameterized.parameters(
79 | dict(
80 | inputs=np.array([[[1.0, 1.0], [0.0, 1.0]]]),
81 | eps=1e-6,
82 | expected_output=np.array([[[0.70710678, 0.70710678], [0.0, 1.0]]]),
83 | ),
84 | dict(
85 | inputs=np.array([[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0]]]),
86 | eps=1e-06,
87 | expected_output=np.array(
88 | [[[0.0, 1.0, 0.0], [0.5773502, 0.5773502, 0.5773502]]]
89 | ),
90 | ),
91 | )
92 | def test_l2_norm_embedding_postprocessor_output(
93 | self, inputs, eps, expected_output
94 | ):
95 | inputs = keras.ops.array(inputs)
96 | expected_output = keras.ops.array(expected_output)
97 |
98 | got = utils.norm_embedding_post_processor(inputs, eps=eps)
99 | self.assertAllClose(got, expected_output)
100 |
101 | @parameterized.named_parameters(
102 | dict(
103 | # wavelength == 1 so radians == expand_dims(positions, axes=[-1, -2])
104 | testcase_name="with_positions",
105 | inputs=np.array([[[[1.0, 2.0]], [[3.0, 4.0]]]]),
106 | positions=np.array([[4, 1]]),
107 | max_wavelength=1,
108 | expected_outputs=np.array([[
109 | [[
110 | 1.0 * np.cos(4) - 2.0 * np.sin(4),
111 | 2.0 * np.cos(4) + 1.0 * np.sin(4),
112 | ]],
113 | [[
114 | 3.0 * np.cos(1) - 4.0 * np.sin(1),
115 | 4.0 * np.cos(1) + 3.0 * np.sin(1),
116 | ]],
117 | ]]),
118 | ),
119 | dict(
120 | # wavelength == 1 so radians == expand_dims(positions, axes=[-1, -2])
121 | testcase_name="no_positions",
122 | inputs=np.array([[[[1.0, 2.0]], [[3.0, 4.0]]]]),
123 | positions=None, # Should evaluate to [[0, 1]]
124 | max_wavelength=1,
125 | expected_outputs=np.array([[
126 | [[1.0, 2.0]],
127 | [[
128 | 3.0 * np.cos(1) - 4.0 * np.sin(1),
129 | 4.0 * np.cos(1) + 3.0 * np.sin(1),
130 | ]],
131 | ]]),
132 | ),
133 | )
134 | def test_apply_rotary_embedding(
135 | self,
136 | inputs: np.ndarray,
137 | positions: np.ndarray | None,
138 | max_wavelength: int,
139 | expected_outputs: np.ndarray,
140 | ):
141 | self.assertAllClose(
142 | expected_outputs,
143 | utils.apply_rotary_encoding(
144 | inputs,
145 | positions=positions,
146 | max_wavelength=max_wavelength,
147 | ),
148 | )
149 |
150 |
151 | if __name__ == "__main__":
152 | absltest.main()
153 |
--------------------------------------------------------------------------------
/recml/layers/linen/sparsecore_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 RecML authors .
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Sparsecore tests."""
15 |
16 | import functools
17 |
18 | from absl.testing import absltest
19 | from etils import epy
20 | import jax
21 | from recml.core.training import partitioning
22 | from recml.layers.linen import sparsecore
23 |
24 | with epy.lazy_imports():
25 | from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec # pylint: disable=g-import-not-at-top
26 |
27 |
28 | class SparsecoreTest(absltest.TestCase):
29 |
30 | def test_sparsecore_embedder_equivalence(self):
31 | if jax.devices()[0].platform != "tpu":
32 | self.skipTest("Test only supported on TPUs.")
33 |
34 | k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
35 |
36 | inputs = {
37 | "a": jax.random.randint(k1, (32, 16), minval=1, maxval=100),
38 | "b": jax.random.randint(k2, (32, 16), minval=1, maxval=100),
39 | "w": jax.random.normal(k3, (32, 16)),
40 | }
41 |
42 | dp_partitioner = partitioning.DataParallelPartitioner()
43 | embedder = sparsecore.SparsecoreEmbedder(
44 | specs={
45 | "a": sparsecore.EmbeddingSpec(
46 | input_dim=100,
47 | embedding_dim=16,
48 | combiner="mean",
49 | weight_name="w",
50 | ),
51 | "b": sparsecore.EmbeddingSpec(
52 | input_dim=100,
53 | embedding_dim=16,
54 | max_sequence_length=10,
55 | ),
56 | },
57 | optimizer=embedding_spec.AdagradOptimizerSpec(learning_rate=0.01),
58 | )
59 | preprocessor = embedder.make_preprocessor(32)
60 | layer = embedder.make_sparsecore_module()
61 |
62 | sc_inputs = dp_partitioner.shard_inputs(preprocessor(inputs))
63 | sc_vars = dp_partitioner.partition_init(functools.partial(layer.init, k4))(
64 | sc_inputs
65 | )
66 |
67 | def step(inputs, params):
68 | return layer.apply(params, inputs)
69 |
70 | p_step = dp_partitioner.partition_step(step, training=False)
71 | sparsecore_activations = jax.device_get(p_step(sc_inputs, sc_vars))
72 |
73 | self.assertEqual(sparsecore_activations["a"].shape, (32, 16))
74 | self.assertEqual(sparsecore_activations["b"].shape, (32, 10, 16))
75 |
76 |
77 | if __name__ == "__main__":
78 | absltest.main()
79 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.2.2
2 | astroid==3.3.9
3 | astunparse==1.6.3
4 | attrs==25.3.0
5 | certifi==2025.1.31
6 | cfgv==3.4.0
7 | charset-normalizer==3.4.1
8 | chex==0.1.89
9 | clu==0.0.12
10 | dill==0.4.0
11 | distlib==0.3.9
12 | dm-tree==0.1.9
13 | docstring-parser==0.16
14 | einops==0.8.1
15 | etils==1.12.2
16 | fiddle==0.3.0
17 | filelock==3.18.0
18 | flatbuffers==25.2.10
19 | flax==0.10.5
20 | fsspec==2025.3.2
21 | gast==0.6.0
22 | google-pasta==0.2.0
23 | googleapis-common-protos==1.70.0
24 | graphviz==0.20.3
25 | grpcio==1.71.0
26 | h5py==3.13.0
27 | humanize==4.12.2
28 | identify==2.6.9
29 | idna==3.10
30 | immutabledict==4.2.1
31 | importlib-resources==6.5.2
32 | iniconfig==2.1.0
33 | isort==6.0.1
34 | jax==0.6.0
35 | jaxlib==0.6.0
36 | jaxtyping==0.3.1
37 | jinja2==3.1.6
38 | kagglehub==0.3.11
39 | keras==3.9.2
40 | keras-hub==0.20.0
41 | libclang==18.1.1
42 | libcst==1.7.0
43 | markdown==3.8
44 | markdown-it-py==3.0.0
45 | markupsafe==3.0.2
46 | mccabe==0.7.0
47 | mdurl==0.1.2
48 | ml-collections==1.1.0
49 | ml-dtypes==0.5.1
50 | mpmath==1.3.0
51 | msgpack==1.1.0
52 | namex==0.0.8
53 | nest-asyncio==1.6.0
54 | networkx==3.4.2
55 | nodeenv==1.9.1
56 | numpy==2.1.3
57 | opt-einsum==3.4.0
58 | optax==0.2.4
59 | optree==0.15.0
60 | orbax-checkpoint==0.11.12
61 | packaging==24.2
62 | platformdirs==4.3.7
63 | pluggy==1.5.0
64 | pre-commit==4.2.0
65 | promise==2.3
66 | protobuf==5.29.4
67 | psutil==7.0.0
68 | pyarrow==19.0.1
69 | pygments==2.19.1
70 | pylint==3.3.6
71 | pytest==8.3.5
72 | pytest-env==1.1.5
73 | pyyaml==6.0.2
74 | regex==2024.11.6
75 | requests==2.32.3
76 | rich==14.0.0
77 | scipy==1.15.2
78 | setuptools==78.1.0
79 | simple-parsing==0.1.7
80 | simplejson==3.20.1
81 | six==1.17.0
82 | sympy==1.13.1
83 | tensorboard==2.19.0
84 | tensorboard-data-server==0.7.2
85 | tensorflow==2.19.0
86 | tensorflow-datasets==4.9.8
87 | tensorflow-metadata==1.17.1
88 | tensorflow-text==2.19.0
89 | tensorstore==0.1.73
90 | termcolor==3.0.1
91 | toml==0.10.2
92 | tomlkit==0.13.2
93 | toolz==1.0.0
94 | torch==2.6.0
95 | tqdm==4.67.1
96 | treescope==0.1.9
97 | typing-extensions==4.13.2
98 | urllib3==2.4.0
99 | virtualenv==20.30.0
100 | wadler-lindig==0.1.5
101 | werkzeug==3.1.3
102 | wheel==0.45.1
103 | wrapt==1.17.2
104 | zipp==3.21.0
105 |
--------------------------------------------------------------------------------