├── .github └── workflows │ └── pytest_and_autopublish.yml ├── .gitignore ├── .pylintrc ├── .style.yapf ├── .vscode └── settings.json ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── gcp-finetune.md ├── gcp-pretrain.md ├── screenshot01-gcpdashboard.jpg ├── screenshot02-vminstances.jpg ├── screenshot03-vmcreate.jpg ├── screenshot04-ssh.jpg ├── screenshot05-config.jpg ├── screenshot06-buckets.jpg ├── screenshot07-createbucket.jpg ├── screenshot08-tpunode.jpg └── screenshot09-finetunetpus.jpg ├── jestimator ├── __init__.py ├── amos.py ├── amos_helper.py ├── amos_helper_test.py ├── amos_test.py ├── checkpoint_utils.py ├── data │ ├── pipeline_lm.py │ ├── pipeline_rec.py │ ├── pipeline_seqio.py │ └── reader.py ├── data_utils.py ├── estimator.py ├── modeling.py ├── models │ ├── bert │ │ ├── finetune.py │ │ ├── modeling.py │ │ └── pretrain.py │ ├── bert_rpe │ │ ├── finetune.py │ │ ├── modeling.py │ │ └── pretrain.py │ ├── linear_regression │ │ ├── linear_regression.py │ │ └── linear_regression_test.py │ ├── lstm │ │ ├── lm.py │ │ ├── modeling.py │ │ └── ptb │ │ │ ├── ptb.test.txt │ │ │ ├── ptb.train.txt │ │ │ ├── ptb.valid.txt │ │ │ └── vocab.txt │ ├── mnist │ │ ├── mnist.ipynb │ │ └── mnist.py │ └── rope │ │ ├── finetune.py │ │ ├── modeling.py │ │ └── pretrain.py └── states.py └── pyproject.toml /.github/workflows/pytest_and_autopublish.yml: -------------------------------------------------------------------------------- 1 | name: Unittests & Auto-publish 2 | 3 | # Allow to trigger the workflow manually (e.g. when deps changes) 4 | on: [push, workflow_dispatch] 5 | 6 | jobs: 7 | test-job: 8 | runs-on: ubuntu-latest 9 | timeout-minutes: 10 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | 18 | # Install deps 19 | - uses: actions/setup-python@v4 20 | with: 21 | python-version: 3.8 22 | # Uncomment to cache of pip dependencies (if tests too slow) 23 | # cache: pip 24 | # cache-dependency-path: '**/pyproject.toml' 25 | 26 | - run: pip --version 27 | - run: pip install -e .[test] 28 | - run: pip freeze 29 | 30 | # Run tests (in parallel) 31 | - name: amos_test 32 | run: python3 jestimator/amos_test.py 33 | 34 | - name: amos_helper_test 35 | run: python3 jestimator/amos_helper_test.py 36 | 37 | # Auto-publish when version is increased 38 | publish-job: 39 | # Only try to publish if: 40 | # * Repo is self (prevents running from forks) 41 | # * Branch is `main` 42 | if: | 43 | github.repository == 'google-research/jestimator' 44 | && github.ref == 'refs/heads/main' 45 | needs: test-job # Only publish after tests are successful 46 | runs-on: ubuntu-latest 47 | permissions: 48 | contents: write 49 | timeout-minutes: 10 50 | 51 | steps: 52 | # Publish the package (if local `__version__` > pip version) 53 | - uses: etils-actions/pypi-auto-publish@v1 54 | with: 55 | pypi-token: ${{ secrets.PYPI_API_TOKEN }} 56 | gh-token: ${{ secrets.FOR_JESTIMATOR_RELEASE }} 57 | parse-changelog: true 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Poetry, setuptools, PyPI distribution artifacts. 9 | /*.egg-info 10 | .eggs/ 11 | build/ 12 | dist/ 13 | poetry.lock 14 | 15 | # Tests 16 | .pytest_cache/ 17 | 18 | # Type checking 19 | .pytype/ 20 | 21 | # Other 22 | *.DS_Store 23 | 24 | # PyCharm 25 | .idea 26 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | indent_width = 2 4 | dedent_closing_brackets = True 5 | split_before_dot = True 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.insertFinalNewline": true, 3 | "files.trimFinalNewlines": true, 4 | "files.trimTrailingWhitespace": true, 5 | "files.associations": { 6 | ".pylintrc": "ini", 7 | ".style.yapf": "ini" 8 | }, 9 | "python.testing.unittestEnabled": false, 10 | "python.testing.nosetestsEnabled": false, 11 | "python.testing.pytestEnabled": true, 12 | "python.linting.pylintUseMinimalCheckers": false, 13 | "[python]": { 14 | "editor.rulers": [ 15 | 80 16 | ], 17 | "editor.tabSize": 2, 18 | "editor.detectIndentation": false 19 | }, 20 | "python.formatting.provider": "yapf", 21 | "files.watcherExclude": { 22 | "**/.git/**": true 23 | }, 24 | "files.exclude": { 25 | "**/__pycache__": true, 26 | "**/.pytest_cache": true, 27 | "**/*.egg-info": true 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | 23 | 24 | ## [Unreleased] 25 | 26 | * Multi-pod debug for estimator. 27 | * Show variable shape in tensorboard. 28 | * `save_mutable` defaults to False for LSTM debug. 29 | * Debug for checkpointing. 30 | * Simplified pre-training code (remove seqio). 31 | 32 | ## [0.3.3] - 2022-12-01 33 | 34 | * Python 3.7 compatibility. 35 | * Add `d_coef` and `c_coef` to Amos hyper-parameter. 36 | * Support for flax_mutables in checkpointing. 37 | * Bug fix data/pipeline_rec and simplified data_utils code. 38 | 39 | ## [0.3.2] - 2022-11-01 40 | 41 | * The Amos optimizer implementation stick to the paper. 42 | * Initial Pypi package. 43 | * Setup Github workflow for unit test and auto-publish. 44 | * MNIST examples. 45 | 46 | [Unreleased]: https://github.com/google-research/jestimator/compare/v0.3.3...HEAD 47 | [0.3.3]: https://github.com/google-research/jestimator/releases/tag/v0.3.3 48 | [0.3.2]: https://github.com/google-research/jestimator/releases/tag/v0.3.2 49 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Amos and JEstimator 2 | 3 | [![Unittests & Auto-publish](https://github.com/google-research/jestimator/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-research/jestimator/actions/workflows/pytest_and_autopublish.yml) 4 | [![PyPI version](https://badge.fury.io/py/jestimator.svg)](https://badge.fury.io/py/jestimator) 5 | 6 | *This is not an officially supported Google product.* 7 | 8 | This is the source code for the paper "[Amos: An Adam-style Optimizer with 9 | Adaptive Weight Decay towards Model-Oriented 10 | Scale](https://arxiv.org/abs/2210.11693)". 11 | 12 | It implements **Amos**, an optimizer compatible with the 13 | [optax](https://github.com/deepmind/optax) library, and **JEstimator**, a 14 | light-weight library with a `tf.Estimator`-like interface to manage 15 | [T5X](https://github.com/google-research/t5x)-compatible checkpoints for machine 16 | learning programs in [JAX](https://github.com/google/jax), which we use to run 17 | experiments in the paper. 18 | 19 | ## Quickstart 20 | 21 | ``` 22 | pip install jestimator 23 | ``` 24 | 25 | It will install the Amos optimizer implemented in the jestimator lib. 26 | 27 | ## Usage of Amos 28 | 29 | This implementation of Amos is used with [JAX](https://github.com/google/jax), a 30 | high-performance numerical computing library with automatic differentiation, for 31 | machine learning research. The API of Amos is compatible with 32 | [optax](https://github.com/deepmind/optax), a library of JAX optimizers 33 | (hopefully Amos will be integrated into optax in the near future). 34 | 35 | In order to demonstrate the usage, we will apply Amos to MNIST. It is based on 36 | Flax's official 37 | [MNIST Example](https://github.com/google/flax/tree/main/examples/mnist), and 38 | you can find the code in a jupyter notebook 39 | [here](https://github.com/google-research/jestimator/tree/main/jestimator/models/mnist/mnist.ipynb). 40 | 41 | ### 1. Imports 42 | 43 | ``` 44 | import jax 45 | import jax.numpy as jnp # JAX NumPy 46 | from jestimator import amos # The Amos optimizer implementation 47 | from jestimator import amos_helper # Helper module for Amos 48 | 49 | from flax import linen as nn # The Linen API 50 | from flax.training import train_state # Useful dataclass to keep train state 51 | 52 | import math 53 | import tensorflow_datasets as tfds # TFDS for MNIST 54 | from sklearn.metrics import accuracy_score 55 | ``` 56 | 57 | ### 2. Load data 58 | 59 | ``` 60 | def get_datasets(): 61 | """Load MNIST train and test datasets into memory.""" 62 | 63 | ds_builder = tfds.builder('mnist') 64 | ds_builder.download_and_prepare() 65 | train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) 66 | test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) 67 | train_ds['image'] = jnp.float32(train_ds['image']) / 255. 68 | test_ds['image'] = jnp.float32(test_ds['image']) / 255. 69 | return train_ds, test_ds 70 | ``` 71 | 72 | ### 3. Build model 73 | 74 | ``` 75 | class CNN(nn.Module): 76 | """A simple CNN model.""" 77 | 78 | @nn.compact 79 | def __call__(self, x): 80 | x = nn.Conv(features=32, kernel_size=(3, 3))(x) 81 | x = nn.relu(x) 82 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 83 | x = nn.Conv(features=64, kernel_size=(3, 3))(x) 84 | x = nn.relu(x) 85 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 86 | x = x.reshape((x.shape[0], -1)) # flatten 87 | x = nn.Dense(features=256)(x) 88 | x = nn.relu(x) 89 | x = nn.Dense(features=10)(x) 90 | return x 91 | 92 | def classify_xe_loss(self, x, labels): 93 | # Labels read from the tfds MNIST are integers from 0 to 9. 94 | # Logits are arrays of size 10. 95 | logits = self(x) 96 | logits = jax.nn.log_softmax(logits) 97 | labels_ = jnp.expand_dims(labels, -1) 98 | llh_ = jnp.take_along_axis(logits, labels_, axis=-1) 99 | loss = -jnp.sum(llh_) 100 | return loss 101 | ``` 102 | 103 | ### 4. Create train state 104 | 105 | A `TrainState` object keeps the model parameters and optimizer states, and can 106 | be checkpointed into files. 107 | 108 | We create the model and optimizer in this function. 109 | 110 | **For the optimizer, we use Amos here.** The following hyper-parameters are set: 111 | 112 | * *learning_rate*:       The global learning rate. 113 | * *eta_fn*:              The model-specific 'eta'. 114 | * *shape_fn*:            Memory reduction setting. 115 | * *beta*:                Rate for running average of gradient squares. 116 | * *clip_value*:          Gradient clipping for stable training. 117 | 118 | The global learning rate is usually set to the 1/sqrt(N), where N is the number 119 | of batches in the training data. For MNIST, we have 60k training examples and 120 | batch size is 32. So learning_rate=1/sqrt(60000/32). 121 | 122 | The model-specific 'eta_fn' requires a function that, given a variable name and 123 | shape, returns a float indicating the expected scale of that variable. Hopefully 124 | in the near future we will have libraries that can automatically calculate this 125 | 'eta_fn' from the modeling code; but for now we have to specify it manually. 126 | 127 | One can use the amos_helper.params_fn_from_assign_map() helper function to 128 | create 'eta_fn' from an assign_map. An assign_map is a dict which maps regex 129 | rules to a value or simple Python expression. It will find the first regex rule 130 | which matches the name of a variable, and evaluate the Python expression if 131 | necessary to return the value. See our example below. 132 | 133 | The 'shape_fn' similarly requires a function that, given a variable name and 134 | shape, returns a reduced shape for the corresponding slot variables. We can use 135 | the amos_helper.params_fn_from_assign_map() helper function to create 'shape_fn' 136 | from an assign_map as well. 137 | 138 | 'beta' is the exponential decay rate for running average of gradient squares. We 139 | set it to 0.98 here. 140 | 141 | 'clip_value' is the gradient clipping value, which should match the magnitude of 142 | the loss function. If the loss function is a sum of cross-entropy, then we 143 | should set 'clip_value' to the sqrt of the number of labels. 144 | 145 | Please refer to our [paper](https://arxiv.org/abs/2210.11693) for more details 146 | of the hyper-parameters. 147 | 148 | ``` 149 | def get_train_state(rng): 150 | model = CNN() 151 | dummy_x = jnp.ones([1, 28, 28, 1]) 152 | params = model.init(rng, dummy_x) 153 | 154 | eta_fn = amos_helper.params_fn_from_assign_map( 155 | { 156 | '.*/bias': 0.5, 157 | '.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))', 158 | '.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))', 159 | '.*Dense_0/kernel': 'sqrt(2/SHAPE[0])', 160 | '.*Dense_1/kernel': 'sqrt(1/SHAPE[0])', 161 | }, 162 | eval_str_value=True, 163 | ) 164 | shape_fn = amos_helper.params_fn_from_assign_map( 165 | { 166 | '.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])', 167 | '.*Dense_0/kernel': '(1, SHAPE[1])', 168 | '.*': (), 169 | }, 170 | eval_str_value=True, 171 | ) 172 | optimizer = amos.amos( 173 | learning_rate=1/math.sqrt(60000/32), 174 | eta_fn=eta_fn, 175 | shape_fn=shape_fn, 176 | beta=0.98, 177 | clip_value=math.sqrt(32), 178 | ) 179 | return train_state.TrainState.create( 180 | apply_fn=model.apply, params=params, tx=optimizer) 181 | ``` 182 | 183 | ### 5. Train step 184 | 185 | Use JAX’s @jit decorator to just-in-time compile the function for better 186 | performance. 187 | 188 | ``` 189 | @jax.jit 190 | def train_step(batch, state): 191 | grad_fn = jax.grad(state.apply_fn) 192 | grads = grad_fn( 193 | state.params, 194 | batch['image'], 195 | batch['label'], 196 | method=CNN.classify_xe_loss) 197 | return state.apply_gradients(grads=grads) 198 | ``` 199 | 200 | ### 6. Infer step 201 | 202 | Use JAX’s @jit decorator to just-in-time compile the function for better 203 | performance. 204 | 205 | ``` 206 | @jax.jit 207 | def infer_step(batch, state): 208 | logits = state.apply_fn(state.params, batch['image']) 209 | return jnp.argmax(logits, -1) 210 | ``` 211 | 212 | ### 7. Main 213 | 214 | Run the training loop and evaluate on test set. 215 | 216 | ``` 217 | train_ds, test_ds = get_datasets() 218 | 219 | rng = jax.random.PRNGKey(0) 220 | rng, init_rng = jax.random.split(rng) 221 | state = get_train_state(init_rng) 222 | del init_rng # Must not be used anymore. 223 | 224 | num_epochs = 9 225 | for epoch in range(1, num_epochs + 1): 226 | # Use a separate PRNG key to permute image data during shuffling 227 | rng, input_rng = jax.random.split(rng) 228 | perms = jax.random.permutation(input_rng, 60000) 229 | del input_rng 230 | perms = perms.reshape((60000 // 32, 32)) 231 | for perm in perms: 232 | batch = {k: v[perm, ...] for k, v in train_ds.items()} 233 | state = train_step(batch, state) 234 | 235 | pred = jax.device_get(infer_step(test_ds, state)) 236 | accuracy = accuracy_score(test_ds['label'], pred) 237 | print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100)) 238 | ``` 239 | 240 | After 9 epochs, we should get 99.26 test accuracy. If you made it, congrats! 241 | 242 | ## JEstimator 243 | 244 | With JEstimator, you can build your model mostly similar to the MNIST example 245 | above, but without writing code for the "Main" section; JEstimator will serve as 246 | the entry point for your model, automatically handle checkpointing in a 247 | train/eval-once/eval-while-training-and-save-the-best/predict mode, and set up 248 | profiling, tensorboard, and logging. 249 | 250 | In addition, JEstimator supports model partitioning which is required for 251 | training very large models across multiple TPU pods. It supports a 252 | [T5X](https://github.com/google-research/t5x)-compatible checkpoint format that 253 | saves and restores checkpoints in a distributed manner, which is suitable for 254 | large multi-pod models. 255 | 256 | In order to run models with JEstimator, we need to install 257 | [T5X](https://github.com/google-research/t5x#installation) and 258 | [FlaxFormer](https://github.com/google/flaxformer): 259 | 260 | ``` 261 | git clone --branch=main https://github.com/google-research/t5x 262 | cd t5x 263 | python3 -m pip install -e . 264 | cd .. 265 | 266 | git clone --branch=main https://github.com/google/flaxformer 267 | cd flaxformer 268 | pip3 install . 269 | cd .. 270 | ``` 271 | 272 | Then, clone this repo to get the JEstimator code: 273 | 274 | ``` 275 | git clone --branch=main https://github.com/google-research/jestimator 276 | cd jestimator 277 | ``` 278 | 279 | Now, we can test a toy linear regression model: 280 | 281 | ``` 282 | PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py 283 | ``` 284 | 285 | ## MNIST Example in JEstimator 286 | 287 | We provide this 288 | [MNIST Example](https://github.com/google-research/jestimator/tree/main/jestimator/models/mnist/mnist.py) 289 | to demonstrate how to write modeling code with JEstimator. It is much like the 290 | example above, but with a big advantage that, a config object is passed around 291 | to collect information from global flags and the dataset, in order to 292 | dynamically setup modeling. This makes it easier to apply the model to different datasets; for example, one can immediately try the [emnist](https://www.tensorflow.org/datasets/catalog/emnist) or [eurosat](https://www.tensorflow.org/datasets/catalog/eurosat) datasets simply by changing a command-line argument, without modifying the code. 293 | 294 | With the following command, we can start a job to train on MNIST, log every 100 295 | steps, and save the checkpoints to $HOME/experiments/mnist/models: 296 | 297 | ``` 298 | PYTHONPATH=. python3 jestimator/estimator.py \ 299 | --module_imp="jestimator.models.mnist.mnist" \ 300 | --module_config="jestimator/models/mnist/mnist.py" \ 301 | --train_pattern="tfds://mnist/split=train" \ 302 | --model_dir="$HOME/experiments/mnist/models" \ 303 | --train_batch_size=32 \ 304 | --train_shuffle_buf=4096 \ 305 | --train_epochs=9 \ 306 | --check_every_steps=100 \ 307 | --max_ckpt=20 \ 308 | --save_every_steps=1000 \ 309 | --module_config.warmup=2000 \ 310 | --module_config.amos_beta=0.98 311 | ``` 312 | 313 | Meanwhile, we can start a job to monitor the $HOME/experiments/mnist/models 314 | folder, evaluate on MNIST test set, and save the model with the highest 315 | accuracy: 316 | 317 | ``` 318 | PYTHONPATH=. python3 jestimator/estimator.py \ 319 | --module_imp="jestimator.models.mnist.mnist" \ 320 | --module_config="jestimator/models/mnist/mnist.py" \ 321 | --eval_pattern="tfds://mnist/split=test" \ 322 | --model_dir="$HOME/experiments/mnist/models" \ 323 | --eval_batch_size=32 \ 324 | --mode="eval_wait" \ 325 | --check_ckpt_every_secs=1 \ 326 | --save_high="test_accuracy" 327 | ``` 328 | 329 | At the same time, we can start a tensorboard to monitor the process: 330 | 331 | ``` 332 | tensorboard --logdir $HOME/experiments/mnist/models 333 | ``` 334 | 335 | ## LSTM on PTB 336 | 337 | We can use the following command to train a single layer LSTM on PTB: 338 | 339 | ``` 340 | PYTHONPATH=. python3 jestimator/estimator.py \ 341 | --module_imp="jestimator.models.lstm.lm" \ 342 | --module_config="jestimator/models/lstm/lm.py" \ 343 | --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \ 344 | --train_pattern="jestimator/models/lstm/ptb/ptb.train.txt" \ 345 | --model_dir="$HOME/models/ptb_lstm" \ 346 | --train_batch_size=64 \ 347 | --train_consecutive=113 \ 348 | --train_shuffle_buf=4096 \ 349 | --max_train_steps=200000 \ 350 | --check_every_steps=1000 \ 351 | --max_ckpt=20 \ 352 | --module_config.opt_config.optimizer="amos" \ 353 | --module_config.opt_config.learning_rate=0.01 \ 354 | --module_config.opt_config.beta=0.98 \ 355 | --module_config.opt_config.momentum=0.0 356 | ``` 357 | 358 | and evaluate: 359 | 360 | ``` 361 | PYTHONPATH=. python3 jestimator/estimator.py \ 362 | --module_imp="jestimator.models.lstm.lm" \ 363 | --module_config="jestimator/models/lstm/lm.py" \ 364 | --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \ 365 | --eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt" \ 366 | --model_dir="$HOME/models/ptb_lstm" \ 367 | --eval_batch_size=1 368 | ``` 369 | 370 | It is suitable for running on single-GPU machine. 371 | 372 | ## More JEstimator Models 373 | 374 | Here are some simple guides to pre-train and fine-tune BERT-like models, using TPUs on Google Cloud Platform (GCP). One can start with a Web browser with zero setup, by connecting to a Virtual Machine via Google Cloud console, without installing anything locally. If this is the first time, one is covered by enough credits to try the commands by free. 375 | 376 | * [My experience of pre-training a BERT-base model on GCP](docs/gcp-pretrain.md) 377 | * [My experience of fine-tuning MNLI on GCP](docs/gcp-finetune.md) 378 | -------------------------------------------------------------------------------- /docs/gcp-finetune.md: -------------------------------------------------------------------------------- 1 | # Using TPUs on Google Cloud to fine-tune MNLI 2 | 3 | Firstly, we do the same as *Login into a VM instance* and *Create a storage 4 | bucket* in the [pre-training](gcp-pretrain.md) job. If these are already done, 5 | we can use the same login VM and storage bucket. 6 | 7 | ## Create TPU nodes 8 | 9 | Then, we create TPU nodes for fine-tuning. We use one 'v3-8' node to train and 10 | one 'v2-8' node to evaluate, as below. 11 | 12 | ![TPU Nodes](screenshot09-finetunetpus.jpg) 13 | 14 | Now, we can login into each TPU node and setup by the following commands: 15 | 16 | ``` 17 | gcloud auth application-default login 18 | gcloud auth login 19 | 20 | pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 21 | 22 | git clone --branch=main https://github.com/google-research/t5x 23 | cd t5x 24 | python3 -m pip install -e . 25 | cd .. 26 | 27 | git clone --branch=main https://github.com/google/flaxformer 28 | cd flaxformer 29 | pip3 install . 30 | cd .. 31 | 32 | git clone --branch=main https://github.com/google-research/jestimator 33 | pip install -U scikit-learn 34 | 35 | wget http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/archives/sentence_piece.tar.gz 36 | tar xvfz sentence_piece.tar.gz 37 | ``` 38 | 39 | ## Launch jobs 40 | 41 | We will fine-tune a BERT-base checkpoint pre-trained by Amos. 42 | 43 | At the 'v3-8' train node, we launch the training job as below: 44 | 45 | ``` 46 | cd jestimator 47 | PYTHONPATH=. python3 jestimator/estimator.py \ 48 | --module_imp="jestimator.models.bert.finetune" \ 49 | --module_config="jestimator/models/bert/finetune.py" \ 50 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 51 | --module_config.segment_names="premise,hypothesis" \ 52 | --module_config.model_config.num_labels=3 \ 53 | --train_pattern="tfds://glue/mnli/split=train" \ 54 | --valid_pattern="tfds://glue/mnli/split=validation_matched" \ 55 | --model_dir="gs://jestimator_example/experiments/finetune/amos-bert-base/\ 56 | mnli/models" \ 57 | --checkpoint_path="gs://gresearch/checkpoints_in_amos_paper/\ 58 | amos-bert-base/checkpoint_300000" \ 59 | --train_shuffle_buf=65536 --max_train_steps=100000 \ 60 | --train_batch_size=64 --valid_batch_size=64 --num_valid_examples=256 \ 61 | --module_config.opt_config.optimizer="adam" \ 62 | --module_config.opt_config.learning_rate=1e-5 \ 63 | --check_every_steps=500 --logtostderr 64 | ``` 65 | 66 | At the 'v2-8' eval node, we launch the eval job as below: 67 | 68 | ``` 69 | cd jestimator 70 | PYTHONPATH=. python3 jestimator/estimator.py \ 71 | --module_imp="jestimator.models.bert.finetune" \ 72 | --module_config="jestimator/models/bert/finetune.py" \ 73 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 74 | --module_config.segment_names="premise,hypothesis" \ 75 | --module_config.model_config.num_labels=3 \ 76 | --eval_pattern="tfds://glue/mnli/split=validation_matched" \ 77 | --model_dir="gs://jestimator_example/experiments/finetune/amos-bert-base/\ 78 | mnli/models" \ 79 | --mode="eval_wait" --check_ckpt_every_secs=10 --max_train_steps=100000 \ 80 | --eval_batch_size=64 --module_config.eval_metric="accuracy" \ 81 | --logtostderr 82 | ``` 83 | 84 | ## Start TensorBoard 85 | 86 | The setup is the same as in [pre-training](gcp-pretrain.md#start-tensorboard). We use 87 | the following command to start a TensorBoard. 88 | 89 | ``` 90 | .local/bin/tensorboard dev upload \ 91 | --logdir mnt/jestimator_example/experiments/finetune/amos-bert-base \ 92 | --name "Finetune BERT-base" \ 93 | --description "Example of using TPUs on Google Cloud to run JEstimator." 94 | ``` 95 | 96 | ## Pre-trained checkpoints in the Amos paper 97 | 98 | We have released the following checkpoints from the Amos paper. 99 | 100 | Model | Google Storage | Download 101 | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------- 102 | BERT-base | `gs://gresearch/checkpoints_in_amos_paper/adamw-bert-base/checkpoint_300000`
`gs://gresearch/checkpoints_in_amos_paper/amos-bert-base/checkpoint_300000` | [adamw-bert-base.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-bert-base.tar.gz)
[amos-bert-base.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-bert-base.tar.gz) 103 | BERT-large | `gs://gresearch/checkpoints_in_amos_paper/adamw-bert-large/checkpoint_250000`
`gs://gresearch/checkpoints_in_amos_paper/amos-bert-large/checkpoint_250000` | [adamw-bert-large.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-bert-large.tar.gz)
[amos-bert-large.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-bert-large.tar.gz) 104 | RoPE-base | `gs://gresearch/checkpoints_in_amos_paper/adamw-rope-base/checkpoint_300000`
`gs://gresearch/checkpoints_in_amos_paper/amos-rope-base/checkpoint_300000` | [adamw-rope-base.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-rope-base.tar.gz)
[amos-rope-base.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-rope-base.tar.gz) 105 | RoPE-large | `gs://gresearch/checkpoints_in_amos_paper/adamw-rope-large/checkpoint_1000000`
`gs://gresearch/checkpoints_in_amos_paper/amos-rope-large/checkpoint_1000000` | [adamw-rope-large.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-rope-large.tar.gz)
[amos-rope-large.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-rope-large.tar.gz) 106 | RPE-base | `gs://gresearch/checkpoints_in_amos_paper/adamw-rpe/checkpoint_300000`
`gs://gresearch/checkpoints_in_amos_paper/amos-rpe/checkpoint_300000` | [adamw-rpe.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-rpe.tar.gz)
[amos-rpe.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-rpe.tar.gz) 107 | -------------------------------------------------------------------------------- /docs/gcp-pretrain.md: -------------------------------------------------------------------------------- 1 | # Using TPUs on Google Cloud to pre-train BERT 2 | 3 | ## Login into a VM instance 4 | 5 | It's not difficult to find the [Google Cloud page](https://cloud.google.com/) 6 | and start with a Google account. It requires payment information if this is the 7 | first time, but one gets automatic free trial credit and will not be charged. 8 | 9 | One should land on the dashboard as below, and let's hit the "GO TO COMPUTE 10 | ENGINE" button. 11 | 12 | ![GO TO COMPUTE ENGINE](screenshot01-gcpdashboard.jpg) 13 | 14 | We should be able to see VM instances after installing the Compute Engine App on 15 | Web browser. Then hit the "CREATE INSTANCE" button, as below. 16 | 17 | ![CREATE INSTANCE](screenshot02-vminstances.jpg) 18 | 19 | We can choose the smallest e2-micro (2 vCPU, 1 GB memory) machine type, since it 20 | will only be used to connect to TPU nodes and run TensorBoard. We may want to 21 | choose a region where we will find our TPUs, e.g. us-central (Iowa), as shown 22 | below. 23 | 24 | ![Region, Machine Type](screenshot03-vmcreate.jpg) 25 | 26 | After the Virtual Machine is created and appears on the dashboard, we hit the 27 | "SSH" button to connect to the VM, as below. 28 | 29 | ![SSH](screenshot04-ssh.jpg) 30 | 31 | Now we have got a Linux terminal! We can, for example, install TensorBoard using 32 | the following command. 33 | 34 | ``` 35 | sudo apt update 36 | sudo apt install python3-pip 37 | pip install -U tensorboard 38 | ``` 39 | 40 | The terminal is also a Google Cloud console. We can run GCP-specific commands, 41 | such as listing the pre-training corpus that we are going to use, already stored 42 | in a Cloud Storage Bucket: 43 | 44 | ``` 45 | gsutil ls gs://gresearch/checkpoints_in_amos_paper/data 46 | ``` 47 | 48 | In order to connect to TPUs, we should configure the `gcloud` command in our 49 | terminal, as below: 50 | 51 | ``` 52 | gcloud config set account your-google-account 53 | gcloud config set project your-project-id 54 | gcloud auth application-default login 55 | gcloud auth login 56 | ``` 57 | 58 | In which, the google-account (e.g. xyz@gmail.com) and project-id (e.g. 59 | graphical-quest-123456) can be found by clicking the dashboard as below. 60 | 61 | ![gcloud config](screenshot05-config.jpg) 62 | 63 | Now, we can list the TPU nodes we have created in any zone, but currently there 64 | is none: 65 | 66 | ``` 67 | gcloud compute tpus tpu-vm list --zone us-central1-a 68 | ``` 69 | 70 | ## Create a storage bucket 71 | 72 | We also need a storage bucket, so that the TPU node can save model checkpoints 73 | in it, and the login VM can access the TensorBoard log as well. We hit the 74 | dashboard menu and select 'Cloud Storage > Buckets' as below. 75 | 76 | ![Buckets](screenshot06-buckets.jpg) 77 | 78 | Then, we have to figure out a unique name for the bucket. We use 79 | 'jestimator_example' in this example. If we are sure about the region of TPUs we 80 | are going to use, we may restrict the location of the bucket. As below. 81 | 82 | ![Create a bucket](screenshot07-createbucket.jpg) 83 | 84 | ## Create a TPU node 85 | 86 | Now, we can create a TPU node for pre-training BERT. We hit the dashboard menu 87 | again and select 'Compute Engine > TPUs'. For pre-training BERT, we select the 88 | TPU type 'v2-32'. Since we are using JAX, the TPU software version should be 89 | 'tpu-vm-base'. As below. 90 | 91 | ![Create TPU Node](screenshot08-tpunode.jpg) 92 | 93 | ## Login into the TPU node, install JEstimator, and launch the pre-training job 94 | 95 | A TPU node of type 'v2-32' consists of 4 workers, while each worker has an 96 | 8-core TPUv2 board. We have to setup each TPU worker manually. To start, we have 97 | to login into each worker once. In our terminal, we use the following command to 98 | login into worker 0: 99 | 100 | ``` 101 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=0 102 | ``` 103 | 104 | And we should do the same for worker 1, 2, 3. In addition, we should do `gcloud 105 | auth` at least on worker 0, in order to access our storage bucket during 106 | pre-training: 107 | 108 | ``` 109 | gcloud auth application-default login 110 | gcloud auth login 111 | ``` 112 | 113 | Then, we can use the following commands to setup all the workers at the same 114 | time: 115 | 116 | ``` 117 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \ 118 | --command='pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html' 119 | 120 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \ 121 | --command='git clone --branch=main https://github.com/google-research/t5x' 122 | 123 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \ 124 | --command='cd t5x; python3 -m pip install -e .' 125 | 126 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \ 127 | --command='git clone --branch=main https://github.com/google/flaxformer' 128 | 129 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \ 130 | --command='cd flaxformer; pip3 install .' 131 | 132 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \ 133 | --command='git clone --branch=main https://github.com/google-research/jestimator' 134 | ``` 135 | 136 | Now, we are ready to launch the pre-training job! The following command will 137 | pre-train a BERT-base model on the Wikipedia+Books corpus, with training 138 | batch-size 256 and 300k training steps. It will use the Amos optimizer. Since 139 | the job will run for a long time, it is better to start a terminal multiplexer 140 | (e.g. GNU screen) and launch the command in it. 141 | 142 | ``` 143 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \ 144 | --command='cd jestimator; PYTHONPATH=. python3 jestimator/estimator.py --module_imp="jestimator.models.bert.pretrain" --module_config="jestimator/models/bert/pretrain.py" --module_config.model_config.vocab_size=32000 --module_config.mask_token_id=4 --train_pattern="gs://gresearch/checkpoints_in_amos_paper/data/wikipedia-00???-of-00500,gs://gresearch/checkpoints_in_amos_paper/data/books-00???-of-00500" --valid_pattern="gs://gresearch/checkpoints_in_amos_paper/data/ptb" --train_shuffle_buf=65536 --max_train_steps=300000 --train_batch_size=256 --valid_batch_size=512 --num_valid_examples=512 --check_every_steps=5000 --module_config.opt_config.optimizer="amos" --module_config.opt_config.learning_rate=0.005 --model_dir="gs://jestimator_example/experiments/pretrain/bert-base/amos/models" --logtostderr' 145 | ``` 146 | 147 | **Caution**: One wants to check the billing information now because the 148 | pre-training job may quickly exhaust the credit! 149 | 150 | ## Start TensorBoard 151 | 152 | After the pre-training job starts, we can start a TensorBoard to monitor the 153 | pre-training job. We will use 154 | [gcsfuse](https://github.com/GoogleCloudPlatform/gcsfuse) to mount our storage 155 | bucket to the login VM, and upload the pre-training logs to 156 | [TensorBoard.dev](https://tensorboard.dev/). 157 | 158 | In order to install gcsfuse: 159 | 161 | 162 | ``` 163 | export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` 164 | echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list 165 | curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - 166 | 167 | sudo apt-get update 168 | sudo apt-get install gcsfuse 169 | ``` 170 | 171 | 173 | In order to mount the storage bucket and start TensorBoard: 174 | 175 | ``` 176 | mkdir -p mnt/jestimator_example 177 | gcsfuse jestimator_example mnt/jestimator_example 178 | 179 | .local/bin/tensorboard dev upload \ 180 | --logdir mnt/jestimator_example/experiments/pretrain/bert-base \ 181 | --name "Pretrain BERT-base" \ 182 | --description "Example of using TPUs on Google Cloud to run JEstimator." 183 | ``` 184 | -------------------------------------------------------------------------------- /docs/screenshot01-gcpdashboard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot01-gcpdashboard.jpg -------------------------------------------------------------------------------- /docs/screenshot02-vminstances.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot02-vminstances.jpg -------------------------------------------------------------------------------- /docs/screenshot03-vmcreate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot03-vmcreate.jpg -------------------------------------------------------------------------------- /docs/screenshot04-ssh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot04-ssh.jpg -------------------------------------------------------------------------------- /docs/screenshot05-config.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot05-config.jpg -------------------------------------------------------------------------------- /docs/screenshot06-buckets.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot06-buckets.jpg -------------------------------------------------------------------------------- /docs/screenshot07-createbucket.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot07-createbucket.jpg -------------------------------------------------------------------------------- /docs/screenshot08-tpunode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot08-tpunode.jpg -------------------------------------------------------------------------------- /docs/screenshot09-finetunetpus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot09-finetunetpus.jpg -------------------------------------------------------------------------------- /jestimator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """JEstimator package with the Amos optimizer.""" 16 | 17 | __version__ = '0.3.3' 18 | 19 | from . import amos 20 | from . import amos_helper 21 | -------------------------------------------------------------------------------- /jestimator/amos.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements the AMOS optimizer. 16 | 17 | AMOS stands for 'Adaptive weight-decay towards Model-Oriented Scale'. It 18 | combines Adam-like gradient scaling with a theoretically proven scheme for 19 | adaptive weight-decay and learning-rate decay. 20 | 21 | In order to be effective, AMOS requires each trainable variable to provide an 22 | `eta` hyper-parameter, indicating the target scale that the entries of the 23 | trained variable converge to. `eta` is used in a variable-specific learning-rate 24 | schedule. 25 | """ 26 | from typing import Any, Callable, NamedTuple, Optional, Tuple, Union 27 | 28 | from flax.serialization import from_state_dict, to_state_dict # pylint: disable=g-multiple-import 29 | from flax.traverse_util import empty_node, flatten_dict, unflatten_dict # pylint: disable=g-multiple-import 30 | import jax 31 | import jax.numpy as jnp 32 | from jax.typing import ArrayLike 33 | import optax 34 | 35 | Shape = Tuple[int, ...] 36 | ScalarOrSchedule = Union[float, optax.Schedule] 37 | ParamsFn = Callable[[Tuple[str, ...], Shape], Any] 38 | 39 | 40 | class ScaleByAmosState(NamedTuple): 41 | """State for the Amos algorithm.""" 42 | count: Optional[ArrayLike] # shape=(), dtype=jnp.int32. 43 | v: optax.Updates 44 | b: optax.Updates 45 | 46 | 47 | def scale_by_amos( 48 | learning_rate: ScalarOrSchedule, 49 | eta_fn: ParamsFn, 50 | shape_fn: Optional[ParamsFn] = None, 51 | beta: float = 0.999, 52 | extra_l2: float = 0., 53 | d_coef: float = 0.25, 54 | c_coef: float = 0.25, 55 | epsilon: float = 1. / (1 << 125), 56 | ) -> optax.GradientTransformation: 57 | """Rescale updates according to the Amos algorithm.""" 58 | 59 | def init_fn(params): 60 | flat_v = {} 61 | flat_b = {} 62 | flat_params = _flatten(to_state_dict(params), keep_empty_nodes=True) 63 | for name, theta in flat_params.items(): 64 | if theta == empty_node: 65 | flat_v[name] = empty_node 66 | flat_b[name] = empty_node 67 | continue 68 | 69 | if shape_fn is None: 70 | v = jnp.zeros_like(theta) 71 | else: 72 | v = jnp.zeros(shape_fn(name, theta.shape), dtype=theta.dtype) 73 | flat_v[name] = v 74 | flat_b[name] = jnp.zeros_like(v) 75 | 76 | v = from_state_dict(params, _unflatten(flat_v)) 77 | b = from_state_dict(params, _unflatten(flat_b)) 78 | return ScaleByAmosState(count=jnp.array(0), v=v, b=b) 79 | 80 | def update_fn(updates, state, params): 81 | count = optax.safe_int32_increment(state.count) 82 | if callable(learning_rate): 83 | xi = learning_rate(count) 84 | else: 85 | xi = learning_rate 86 | bias_correction = 1. - beta**count 87 | xi2 = jnp.square(xi) 88 | c_coef_sqrt_xi = c_coef * jnp.sqrt(xi) 89 | 90 | flat_grad = _flatten(to_state_dict(updates), keep_empty_nodes=True) 91 | flat_v = _flatten(to_state_dict(state.v), keep_empty_nodes=True) 92 | flat_b = _flatten(to_state_dict(state.b), keep_empty_nodes=True) 93 | flat_params = _flatten(to_state_dict(params)) 94 | for name, theta in flat_params.items(): 95 | grad = flat_grad[name] 96 | v = flat_v[name] 97 | if v.shape: 98 | reduced = [i for i, k in enumerate(grad.shape) if v.shape[i] < k] 99 | g2 = jnp.mean(jnp.square(grad), axis=reduced, keepdims=True) 100 | else: 101 | g2 = jnp.mean(jnp.square(grad)) 102 | v = v * beta + g2 * (1. - beta) 103 | flat_v[name] = v 104 | rcpl_v_hat = bias_correction / jnp.maximum(v, epsilon) 105 | 106 | b = flat_b[name] 107 | decay_factor_c = jax.lax.rsqrt(1. + c_coef_sqrt_xi * b) 108 | gamma = decay_factor_c * xi2 * rcpl_v_hat * g2 109 | 110 | init_lr = xi * eta_fn(name, theta.shape) 111 | decay_factor_d = jnp.reciprocal(1. + d_coef * jnp.sqrt(init_lr) * b) 112 | l2_regularization = (-0.5 * gamma - extra_l2) * theta 113 | flat_grad[name] = decay_factor_d * ( 114 | l2_regularization - init_lr * jnp.sqrt(rcpl_v_hat) * grad) 115 | flat_b[name] = b + gamma * (1. + b) 116 | 117 | updates = from_state_dict(updates, _unflatten(flat_grad)) 118 | v = from_state_dict(state.v, _unflatten(flat_v)) 119 | b = from_state_dict(state.b, _unflatten(flat_b)) 120 | return updates, ScaleByAmosState(count=count, v=v, b=b) 121 | 122 | return optax.GradientTransformation(init_fn, update_fn) 123 | 124 | 125 | def _flatten(x, keep_empty_nodes=False): 126 | if not isinstance(x, dict): 127 | return {(): x} 128 | 129 | return flatten_dict(x, keep_empty_nodes=keep_empty_nodes) 130 | 131 | 132 | def _unflatten(x): 133 | if tuple(x.keys()) == ((),): 134 | return x[()] 135 | 136 | return unflatten_dict(x) 137 | 138 | 139 | def amos( 140 | learning_rate: ScalarOrSchedule, 141 | eta_fn: ParamsFn, 142 | shape_fn: Optional[ParamsFn] = None, 143 | beta: float = 0.999, 144 | momentum: Optional[float] = None, 145 | clip_value: Optional[float] = None, 146 | extra_l2: float = 0., 147 | d_coef: float = 0.25, 148 | c_coef: float = 0.25, 149 | epsilon: float = 1. / (1 << 125), 150 | ) -> optax.GradientTransformation: 151 | """The full Amos optimizer with optional gradient clipping and momentum. 152 | 153 | References: 154 | [The Amos Paper](https://arxiv.org/abs/2210.11693) 155 | 156 | Args: 157 | learning_rate: A float or callable for learning rate. When it is callable, 158 | the `leaning_rate` takes step count as input and returns a float scalar. 159 | Let N be the number of independent batches in the training data. It is 160 | recommended to set the learning rate to about 1/sqrt(N). 161 | eta_fn: A function that maps a variable name and shape to the variable- 162 | specific hyper-parameter 'eta' indicating the expected scale of entries. 163 | shape_fn: A function that maps a variable name and shape to the shape of the 164 | corresponding slot variables `v` and `b`. The returned shape should be 165 | broadcastable to the varialbe, while some axes might be reduced to 1 to 166 | save memory. 167 | beta: A float slightly < 1. We recommend setting `1 - beta` to the same 168 | order of magnitude as the learning rate. Defaults to 0.999. 169 | momentum: Exponential decay rate for optional moving average of updates. 170 | clip_value: Optional gradient clipping value. 171 | extra_l2: Addional L2 regularization (experimental). Defaults to 0. 172 | d_coef: Coefficient for decay_factor_d. Defaults to 0.25. 173 | c_coef: Coefficient for decay_factor_c. Defaults to 0.25. 174 | epsilon: The smallest positive normal to prevent division by 0. 175 | 176 | Returns: 177 | An (init_fn, update_fn) tuple. 178 | """ 179 | tx = [] 180 | if clip_value is not None and clip_value > 0.: 181 | tx.append(optax.clip(clip_value)) 182 | tx.append( 183 | scale_by_amos( 184 | learning_rate, 185 | eta_fn, 186 | shape_fn=shape_fn, 187 | beta=beta, 188 | extra_l2=extra_l2, 189 | d_coef=d_coef, 190 | c_coef=c_coef, 191 | epsilon=epsilon)) 192 | if momentum is not None and momentum > 0.: 193 | tx.append(optax.ema(momentum, debias=False)) 194 | 195 | if len(tx) >= 2: 196 | return optax.chain(*tx) 197 | return tx[0] 198 | -------------------------------------------------------------------------------- /jestimator/amos_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper utilities for the Amos optimizer.""" 16 | import ast 17 | import math 18 | import operator as op 19 | import re 20 | from typing import Any, Dict, Tuple 21 | 22 | from absl import logging 23 | import jax 24 | from jax.sharding import PartitionSpec # pylint: disable=g-importing-member 25 | from jestimator.amos import ParamsFn, ScaleByAmosState, Shape # pylint: disable=g-multiple-import,g-importing-member 26 | import numpy 27 | 28 | _BIN_OP_MAP = { 29 | ast.Add: op.add, 30 | ast.Sub: op.sub, 31 | ast.Mult: op.mul, 32 | ast.Div: op.truediv, 33 | ast.FloorDiv: op.floordiv, 34 | ast.Mod: op.mod, 35 | ast.Pow: op.pow, 36 | } 37 | 38 | 39 | def evaluate(s: str, shape: Shape): 40 | """Evaluate simple expression. Allow 'SHAPE' referring to variable shape.""" 41 | 42 | def _evaluate(node): 43 | if node is None: 44 | return None 45 | 46 | if isinstance(node, ast.BinOp): 47 | left = _evaluate(node.left) 48 | right = _evaluate(node.right) 49 | return _BIN_OP_MAP[type(node.op)](left, right) 50 | 51 | if isinstance(node, ast.Call): 52 | func_name = node.func 53 | assert isinstance(func_name, ast.Name) 54 | func = getattr(math, func_name.id, None) 55 | if func is None: 56 | func = getattr(numpy, func_name.id) 57 | 58 | assert not node.keywords 59 | args = [_evaluate(x) for x in node.args] 60 | return func(*args) 61 | 62 | if isinstance(node, ast.Constant): 63 | return node.value 64 | 65 | if isinstance(node, ast.Name): 66 | assert node.id == 'SHAPE' 67 | return shape 68 | 69 | if isinstance(node, ast.Num): # Python 3.7 compatibility 70 | return node.n 71 | 72 | if isinstance(node, ast.Index): # Python 3.8 compatibility 73 | return _evaluate(node.value) 74 | 75 | if isinstance(node, ast.Slice): 76 | return slice( 77 | _evaluate(node.lower), _evaluate(node.upper), _evaluate(node.step)) 78 | 79 | if isinstance(node, ast.Subscript): 80 | return _evaluate(node.value)[_evaluate(node.slice)] 81 | 82 | if isinstance(node, ast.Tuple): 83 | return tuple([_evaluate(x) for x in node.elts]) 84 | 85 | if isinstance(node, ast.UnaryOp): 86 | assert isinstance(node.op, ast.USub) 87 | return -_evaluate(node.operand) # pylint: disable=invalid-unary-operand-type 88 | 89 | raise TypeError(f'Cannot handle node type: {type(node).__name__}') 90 | 91 | node = ast.parse(s, mode='eval').body 92 | return _evaluate(node) 93 | 94 | 95 | def params_fn_from_assign_map(assign_map: Dict[str, Any], 96 | name_sep: str = '/', 97 | eval_str_value: bool = False) -> ParamsFn: 98 | """Creates a params_fn from assign_map. 99 | 100 | A params_fn maps each variable name and shape to some value. The variable name 101 | is a tuple of str, and shape is a tuple of int. An assign_map is a sequence of 102 | rules, where each rule maps a regex of variable names to a value. 103 | 104 | Args: 105 | assign_map: A dictionary mapping 'regex' to 'value'. Given a variable name, 106 | the returned params_fn will find the first matching 'regex' and return the 107 | corresponding 'value'. 108 | name_sep: Join the the variable name (tuple of str) by this separator before 109 | regex matching. Defaults to '/'. 110 | eval_str_value: If True, value can be str of simple expressions, which will 111 | be evaluated. 112 | 113 | Returns: 114 | params_fn: A function that maps each variable name and shape to a value. 115 | """ 116 | 117 | def params_fn(name: Tuple[str, ...], shape: Shape): 118 | name_str = name_sep.join(name) 119 | for regex, value in assign_map.items(): 120 | if re.match(regex, name_str): 121 | logging.info('Matched rule (%s -> %s) to variable %s of shape %s.', 122 | regex, value, name, shape) 123 | if eval_str_value and isinstance(value, str): 124 | return evaluate(value, shape) 125 | return value 126 | raise ValueError(f'No matching rule for variable {name} of shape {shape}.') 127 | 128 | return params_fn 129 | 130 | 131 | def maybe_reduce_axis_names(var, axes): 132 | """Prepend 'reduced_' to the axis name if a dimension is 1.""" 133 | if not var.shape: # Scalar. 134 | return None 135 | 136 | if axes is None: # No axes info. 137 | return None 138 | 139 | assert len(var.shape) == len(axes), f'shape: {var.shape} axis: {axes}' 140 | names = [(f'reduced_{x}' if d == 1 else x) for d, x in zip(var.shape, axes)] 141 | return PartitionSpec(*names) 142 | 143 | 144 | def state_partition_rule(state: ScaleByAmosState, params_axes): 145 | """Creates partition for Amos states from partition of parameters.""" 146 | return ScaleByAmosState( 147 | count=None, 148 | v=jax.tree.map(maybe_reduce_axis_names, state.v, params_axes), 149 | b=jax.tree.map(maybe_reduce_axis_names, state.b, params_axes)) 150 | -------------------------------------------------------------------------------- /jestimator/amos_helper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for amos_helper.""" 16 | import math 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from jestimator import amos_helper 21 | 22 | 23 | class AmosHelperTest(absltest.TestCase): 24 | 25 | def test_evaluate(self): 26 | shape = (7, 8, 9) 27 | x = amos_helper.evaluate('sqrt(1 / prod(SHAPE[:-1]))', shape) 28 | y = math.sqrt(1 / (shape[0] * shape[1])) 29 | self.assertEqual(x, y) 30 | 31 | x = amos_helper.evaluate('(1, 1, SHAPE[2])', shape) 32 | y = (1, 1, shape[2]) 33 | self.assertSequenceEqual(x, y) 34 | 35 | def test_params_fn_from_assign_map(self): 36 | assign_map = { 37 | 'init_bn/scale': 'sqrt(1 / SHAPE[-1])', 38 | r'.*bn.?/scale$': 1.0, 39 | } 40 | fn = amos_helper.params_fn_from_assign_map(assign_map, eval_str_value=True) 41 | self.assertEqual(fn(('init_bn', 'scale'), (7, 256)), math.sqrt(1 / 256)) 42 | self.assertEqual(fn(('decoder', 'layer_0', 'bn1', 'scale'), (256,)), 1.0) 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /jestimator/amos_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for amos.""" 16 | from absl.testing import absltest 17 | import jax 18 | import jax.numpy as jnp 19 | from jestimator import amos 20 | 21 | 22 | def _setup_parabola(dtype): 23 | """Quadratic function as an optimization target.""" 24 | initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) 25 | final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) 26 | 27 | @jax.grad 28 | def get_updates(params): 29 | return jnp.sum(jnp.square(params - final_params)) 30 | 31 | return initial_params, final_params, get_updates 32 | 33 | 34 | def _setup_rosenbrock(dtype): 35 | """Rosenbrock function as an optimization target.""" 36 | a = 1.0 37 | b = 100.0 38 | 39 | initial_params = jnp.array([0.0, 0.0], dtype=dtype) 40 | final_params = jnp.array([a, a**2], dtype=dtype) 41 | 42 | @jax.grad 43 | def get_updates(params): 44 | return (jnp.square(a - params[0]) + 45 | b * jnp.square(params[1] - params[0]**2)) 46 | 47 | return initial_params, final_params, get_updates 48 | 49 | 50 | class AmosTest(absltest.TestCase): 51 | 52 | def test_parabola(self): 53 | opt = amos.amos( 54 | learning_rate=1.0, 55 | eta_fn=lambda name, shape: 1.0, 56 | shape_fn=lambda name, shape: ()) 57 | initial_params, final_params, get_updates = _setup_parabola(jnp.float32) 58 | 59 | @jax.jit 60 | def step(params, state): 61 | updates = get_updates(params) 62 | updates, state = opt.update(updates, state, params) 63 | params = jax.tree_util.tree_map( 64 | lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype), params, 65 | updates) 66 | return params, state 67 | 68 | params = initial_params 69 | state = opt.init(params) 70 | for _ in range(10000): 71 | params, state = step(params, state) 72 | 73 | self.assertSequenceAlmostEqual(params, final_params) 74 | 75 | def test_rosenbrock(self): 76 | opt = amos.amos( 77 | learning_rate=0.5, 78 | eta_fn=lambda name, shape: 1.0, 79 | shape_fn=lambda name, shape: (), 80 | beta=0.5, 81 | clip_value=1.0, 82 | momentum=0.9) 83 | initial_params, final_params, get_updates = _setup_rosenbrock(jnp.float32) 84 | 85 | @jax.jit 86 | def step(params, state): 87 | updates = get_updates(params) 88 | updates, state = opt.update(updates, state, params) 89 | params = jax.tree_util.tree_map( 90 | lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype), params, 91 | updates) 92 | return params, state 93 | 94 | params = initial_params 95 | state = opt.init(params) 96 | for _ in range(10000): 97 | params, state = step(params, state) 98 | 99 | self.assertSequenceAlmostEqual(params, final_params) 100 | 101 | 102 | if __name__ == '__main__': 103 | absltest.main() 104 | -------------------------------------------------------------------------------- /jestimator/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Checkpointing utilities.""" 16 | import os 17 | import time 18 | from typing import Any, List, Mapping, Optional, Tuple, Union 19 | 20 | from absl import logging 21 | from flax.core.frozen_dict import freeze 22 | from flax.training import checkpoints 23 | from flax.traverse_util import flatten_dict, unflatten_dict # pylint: disable=g-multiple-import 24 | import jax 25 | from jax.experimental import multihost_utils 26 | from t5x.utils import get_local_data 27 | from tensorflow import errors 28 | from tensorflow.io import gfile 29 | 30 | 31 | def partial_restore(state, 32 | ckpt_dict: Mapping[str, Any], 33 | load_step: bool = False): 34 | """Partially restore from checkpoint.""" 35 | flat_x = flatten_dict(state.params) 36 | flat_y = flatten_dict(ckpt_dict['target']) 37 | flat_ret = {} 38 | for k, v in flat_x.items(): 39 | u = flat_y.get(k) 40 | if u is None: 41 | logging.warning('Key %s not found in ckpt.', '/'.join(k)) 42 | flat_ret[k] = v 43 | elif u.shape == v.shape: 44 | flat_ret[k] = u 45 | else: 46 | logging.warning('Shape %s != in ckpt %s', v.shape, u.shape) 47 | if u.shape[1:] == v.shape[1:]: 48 | if u.shape[0] < v.shape[0]: 49 | flat_ret[k] = v.at[:u.shape[0]].set(u) 50 | else: 51 | flat_ret[k] = u[:v.shape[0]] 52 | else: 53 | logging.warning('Ignored checkpoint due to shape discrepancy.') 54 | flat_ret[k] = v 55 | 56 | params = freeze(unflatten_dict(flat_ret)) 57 | state = state.replace(params=params) 58 | if 'flax_mutables' in ckpt_dict: 59 | state = state.replace(_vars=freeze(ckpt_dict['flax_mutables'])) 60 | if load_step: 61 | state = state.replace(step=get_local_data(ckpt_dict['state']['step'])) 62 | return state 63 | 64 | 65 | def latest_ckpt_path(model_dir: Optional[str] = None, 66 | init_ckpt_path: Optional[str] = None, 67 | prefix: str = 'checkpoint_') -> Tuple[Optional[str], bool]: 68 | """Get path of the latest checkpoint. 69 | 70 | If `init_ckpt_path` comes from `model_dir`, then it overrides other 71 | checkpoints in `model_dir`; 72 | Else, if `model_dir` is not empty, load the latest checkpoint in it; 73 | Otherwise, load the checkpoint specified by `init_ckpt_path`. 74 | 75 | Args: 76 | model_dir: Dir to store model checkpoints. 77 | init_ckpt_path: An optional checkpoint to initialize the model. 78 | prefix: str: name prefix of checkpoint files. 79 | 80 | Returns: 81 | (ckpt_path, same_dir). 82 | ckpt_path: The latest or init checkpoint path. 83 | same_dir: Whether the checkpoint is in the same `model_dir`. 84 | """ 85 | if model_dir is not None: 86 | if model_dir.startswith('gs://'): 87 | model_dir = model_dir.rstrip('/') + '/' 88 | else: 89 | model_dir = os.path.abspath(model_dir) + os.sep 90 | if init_ckpt_path is not None: 91 | ckpt_dir = os.path.abspath(os.path.dirname(init_ckpt_path)) + os.sep 92 | if ckpt_dir.startswith(model_dir): 93 | logging.info( 94 | 'Use checkpoint specified by `checkpoint_path` (%s),' 95 | ' since it comes from specified `model_dir` (%s) as well,' 96 | ' and hence overrides other checkpoints in the dir.', 97 | init_ckpt_path, model_dir) 98 | return init_ckpt_path, True 99 | 100 | if jax.process_index() == 0: 101 | for tmp in gfile.glob(os.path.join(model_dir, f'{prefix}*tmp*')): 102 | try: 103 | gfile.rmtree(tmp) 104 | except errors.NotFoundError: 105 | pass 106 | multihost_utils.sync_global_devices( 107 | f'jestimator:latest_ckpt_path:remove_tmp_ckpts:{model_dir}') 108 | latest = checkpoints.latest_checkpoint(model_dir, prefix=prefix) 109 | if latest is not None: 110 | logging.info( 111 | 'Use the latest checkpoint (%s) from `model_dir`,' 112 | ' and ignores `checkpoint_path`.', latest) 113 | return latest, True 114 | 115 | logging.info( 116 | 'Use checkpoint specified by `checkpoint_path` (%s),' 117 | ' since `model_dir` (%s) is empty.', init_ckpt_path, model_dir) 118 | return init_ckpt_path, False 119 | 120 | 121 | def last_evaluated_ckpt(last_eval_path: str) -> Optional[str]: 122 | """Get the last evaluated checkpoint path.""" 123 | if gfile.exists(last_eval_path): 124 | logging.info('Reading last_evaluated from: %s', last_eval_path) 125 | with gfile.GFile(last_eval_path, 'rb') as f: 126 | last_evaluated = f.read().decode('utf-8') 127 | logging.info('Found last_evaluated: %s', last_evaluated) 128 | else: 129 | last_evaluated = None 130 | return last_evaluated 131 | 132 | 133 | def sorted_checkpoints( 134 | ckpt_dir: Union[str, os.PathLike], # pylint: disable=g-bare-generic 135 | prefix: str = 'checkpoint_') -> List[str]: 136 | """Retrieve the path of all checkpoints in a directory. 137 | 138 | Args: 139 | ckpt_dir: str: directory of checkpoints to restore from. 140 | prefix: str: name prefix of checkpoint files. 141 | 142 | Returns: 143 | A list of checkpoint paths. 144 | """ 145 | ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str 146 | glob_path = gfile.glob(os.path.join(ckpt_dir, f'{prefix}*')) 147 | glob_tmp = frozenset(gfile.glob(os.path.join(ckpt_dir, f'{prefix}*tmp*'))) 148 | glob_path = [f for f in glob_path if f not in glob_tmp] 149 | checkpoint_files = checkpoints.natural_sort(glob_path) 150 | return checkpoint_files 151 | 152 | 153 | def checkpoints_iterator_from_oldest(model_dir: str, 154 | last_eval_path: str, 155 | min_interval_secs: float = 0., 156 | last_evaluated: Optional[str] = None): 157 | """Iterate checkpoints in a dir from the oldest, and wait for new.""" 158 | logging.info('Monitoring checkpoints in dir: %s', model_dir) 159 | while True: 160 | ckpts = sorted_checkpoints(model_dir)[1:] 161 | if last_evaluated is not None: 162 | for i, x in enumerate(ckpts): 163 | if x == last_evaluated: 164 | ckpts = ckpts[i + 1:] 165 | break 166 | 167 | for x in ckpts: 168 | if gfile.exists(x): 169 | last_evaluated = x 170 | if jax.process_index() == 0: 171 | with gfile.GFile(last_eval_path, 'w') as f: 172 | f.write(x) 173 | yield x 174 | 175 | if not ckpts: 176 | time.sleep(min_interval_secs) 177 | 178 | 179 | def last_score(last_score_path: str) -> float: 180 | """Get the last evaluated score.""" 181 | score = None 182 | if gfile.exists(last_score_path): 183 | logging.info('Reading last score from: %s', last_score_path) 184 | with gfile.GFile(last_score_path, 'rb') as f: 185 | score = f.read().decode('utf-8') 186 | logging.info('Found last score: %s', score) 187 | if not score: 188 | score = '-inf' 189 | return float(score) 190 | -------------------------------------------------------------------------------- /jestimator/data/pipeline_lm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Data pipeline for language models: fixed-length sequences of token ids.""" 16 | from typing import Callable, Optional, Sequence 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def pipeline_from_filenames( 22 | filenames: Sequence[str], 23 | seq_length: int, 24 | allow_remainder: bool = False, 25 | dataset_fn: Optional[Callable[[str], tf.data.Dataset]] = None, 26 | cache: bool = False, 27 | random_skip: bool = False, 28 | feature_fn: Optional[Callable] = None, # pylint: disable=g-bare-generic 29 | interleave: bool = False, 30 | shard_num: int = 1, 31 | shard_index: int = 0, 32 | epochs: Optional[int] = 1, 33 | num_take: int = -1): 34 | r"""Creates a tensorflow dataset from filenames. 35 | 36 | Args: 37 | filenames: A list of file name strings. 38 | seq_length: Length of token-id sequences in the output dtaset. 39 | allow_remainder: bool. Whether to allow the last sequence to be shorter. 40 | dataset_fn: A function that maps a path str to a tf.data.Dataset. If None, 41 | defaults to tf.data.Dataset.load(). 42 | cache: Whether to cache the constructed datasets in memory. 43 | random_skip: bool. Whether to randomly skip some tokens in the beginning of 44 | each dataset. This is used to increase randomness in training. 45 | feature_fn: A function that maps a token-id sequence to model-specific 46 | features. This is called before batching. 47 | interleave: bool. Whether to randomly interleave multiple files. 48 | shard_num: int. Number of shards. 49 | shard_index: int. Worker index. 50 | epochs: Number of epochs to repeat. If None, repeat forever. 51 | num_take: int. If not -1, take the first n examples. 52 | 53 | Returns: 54 | An instance of tf.data.Dataset. 55 | """ 56 | num_files = len(filenames) 57 | 58 | shard_data = True 59 | if num_files % shard_num == 0 or num_files / shard_num > 9: 60 | filenames = filenames[shard_index::shard_num] 61 | shard_data = False 62 | 63 | ds = [] 64 | for path in filenames: 65 | if dataset_fn is None: 66 | d = tf.data.Dataset.load(path) 67 | else: 68 | d = dataset_fn(path) 69 | if cache and num_take == -1: 70 | d = d.cache() 71 | ds.append(d) 72 | fd = tf.data.Dataset.from_tensor_slices(ds) 73 | 74 | if interleave and num_files > 1: 75 | fd = fd.shuffle(num_files, seed=num_files + 37) 76 | fd = fd.repeat(epochs) 77 | if random_skip: 78 | fd = tf.data.Dataset.zip((fd, tf.data.Dataset.random(seed=num_files + 19))) 79 | 80 | def seq_fn(d, *rnd_): 81 | if random_skip: 82 | (rnd,) = rnd_ 83 | d = d.skip(rnd % seq_length) 84 | d = d.batch(seq_length, drop_remainder=(not allow_remainder)) 85 | if shard_data: 86 | d = d.shard(shard_num, shard_index) 87 | if feature_fn is not None: 88 | d = d.map(feature_fn, num_parallel_calls=tf.data.AUTOTUNE) 89 | return d 90 | 91 | if interleave and num_files > 1: 92 | ret = fd.interleave( 93 | seq_fn, deterministic=False, num_parallel_calls=tf.data.AUTOTUNE) 94 | else: 95 | ret = fd.flat_map(seq_fn) 96 | ret = ret.take(num_take) 97 | if num_take >= 0 and cache: 98 | ret = ret.cache() 99 | return ret 100 | 101 | 102 | def lm_data( 103 | seq_length: int, 104 | allow_remainder: bool = False, 105 | dataset_fn: Optional[Callable[[str], tf.data.Dataset]] = None, 106 | cache: bool = False, 107 | random_skip: bool = False, 108 | feature_fn: Optional[Callable] = None, # pylint: disable=g-bare-generic 109 | interleave: bool = False): 110 | """Builds a data pipeline for language modeling. 111 | 112 | Args: 113 | seq_length: Length of token-id sequences in the output dtaset. 114 | allow_remainder: bool. Whether to allow the last sequence to be shorter. 115 | dataset_fn: A function that maps a path str to a tf.data.Dataset. If None, 116 | defaults to tf.data.Dataset.load(). 117 | cache: Whether to cache the constructed datasets in memory. 118 | random_skip: bool. Whether to randomly skip some tokens in the beginning of 119 | each dataset. This is used to increase randomness in training. 120 | feature_fn: A function that maps a token-id sequence to model-specific 121 | features. This is called before batching. 122 | interleave: bool. Whether to randomly interleave multiple files. 123 | 124 | Returns: 125 | A `data_fn` to be used by jestimator. 126 | """ 127 | 128 | def data_fn(filenames, **kwargs): 129 | return pipeline_from_filenames( 130 | filenames, 131 | seq_length, 132 | allow_remainder=allow_remainder, 133 | dataset_fn=dataset_fn, 134 | cache=cache, 135 | random_skip=random_skip, 136 | feature_fn=feature_fn, 137 | interleave=interleave, 138 | **kwargs) 139 | 140 | return data_fn 141 | -------------------------------------------------------------------------------- /jestimator/data/pipeline_rec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Data pipeline for record datasets.""" 16 | from typing import Callable, Optional, Sequence 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def pipeline_from_filenames( 22 | filenames: Sequence[str], 23 | dataset_fn: Optional[Callable[[str], tf.data.Dataset]] = None, 24 | cache: bool = False, 25 | feature_fn: Optional[Callable] = None, # pylint: disable=g-bare-generic 26 | interleave: bool = False, 27 | shard_num: int = 1, 28 | shard_index: int = 0, 29 | epochs: Optional[int] = 1, 30 | num_take: int = -1): 31 | r"""Creates a tensorflow dataset from filenames. 32 | 33 | Args: 34 | filenames: A list of file name strings. 35 | dataset_fn: A function that maps a path str to a tf.data.Dataset. If None, 36 | defaults to tf.data.Dataset.load(). 37 | cache: Whether to cache the constructed datasets in memory. 38 | feature_fn: A function that maps a token-id sequence to model-specific 39 | features. This is called before batching. 40 | interleave: bool. Whether to randomly interleave multiple files. 41 | shard_num: int. Number of shards. 42 | shard_index: int. Worker index. 43 | epochs: Number of epochs to repeat. If None, repeat forever. 44 | num_take: int. If not -1, take the first n examples. 45 | 46 | Returns: 47 | An instance of tf.data.Dataset. 48 | """ 49 | num_files = len(filenames) 50 | 51 | shard_data = True 52 | if num_files % shard_num == 0 or num_files / shard_num > 9: 53 | filenames = filenames[shard_index::shard_num] 54 | shard_data = False 55 | 56 | ds = [] 57 | for path in filenames: 58 | d = dataset_fn(path) 59 | if cache and num_take == -1: 60 | if feature_fn is not None: 61 | d = d.map(feature_fn, num_parallel_calls=tf.data.AUTOTUNE) 62 | d = d.cache() 63 | ds.append(d) 64 | 65 | if interleave and num_files > 1: 66 | rnd = tf.data.Dataset.random(seed=num_files + 11) 67 | if epochs is not None: 68 | rnd = rnd.take(epochs) 69 | ret = rnd.flat_map( 70 | lambda x: tf.data.Dataset.sample_from_datasets(ds, seed=x)) 71 | else: 72 | ret = ds[0] 73 | for d in ds[1:]: 74 | ret = ret.concatenate(d) 75 | ret = ret.repeat(epochs) 76 | if shard_data: 77 | ret = ret.shard(shard_num, shard_index) 78 | 79 | ret = ret.take(num_take) 80 | if not cache or num_take != -1: 81 | if feature_fn is not None: 82 | ret = ret.map(feature_fn, num_parallel_calls=tf.data.AUTOTUNE) 83 | if cache and num_take != -1: 84 | ret = ret.cache() 85 | return ret 86 | 87 | 88 | def rec_data( 89 | dataset_fn: Optional[Callable[[str], tf.data.Dataset]] = None, 90 | cache: bool = False, 91 | feature_fn: Optional[Callable] = None, # pylint: disable=g-bare-generic 92 | interleave: bool = False): 93 | """Builds a data pipeline for records. 94 | 95 | Args: 96 | dataset_fn: A function that maps a path str to a tf.data.Dataset. If None, 97 | defaults to tf.data.Dataset.load(). 98 | cache: Whether to cache the constructed datasets in memory. 99 | feature_fn: A function that maps a token-id sequence to model-specific 100 | features. This is called before batching. 101 | interleave: bool. Whether to randomly interleave multiple files. 102 | 103 | Returns: 104 | A `data_fn` to be used by jestimator. 105 | """ 106 | 107 | def data_fn(filenames, **kwargs): 108 | return pipeline_from_filenames( 109 | filenames, 110 | dataset_fn=dataset_fn, 111 | cache=cache, 112 | feature_fn=feature_fn, 113 | interleave=interleave, 114 | **kwargs) 115 | 116 | return data_fn 117 | -------------------------------------------------------------------------------- /jestimator/data/pipeline_seqio.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Data pipeline that wraps the seqio lib.""" 16 | from typing import Mapping, Optional, Sequence 17 | 18 | import seqio 19 | 20 | SEQIO_PREFIX = "seqio://" 21 | 22 | 23 | def is_seqio(filenames: Optional[Sequence[str]]) -> bool: 24 | return bool(filenames and filenames[0].startswith(SEQIO_PREFIX)) 25 | 26 | 27 | def pipeline_from_mixture_or_task_name( 28 | mixture_or_task_name: str, 29 | task_feature_lengths: Mapping[str, int], 30 | feature_converter: seqio.FeatureConverter, 31 | use_cached: bool = False, 32 | shuffle: bool = False, 33 | seed: Optional[int] = None, 34 | shard_num: int = 1, 35 | shard_index: int = 0, 36 | epochs: Optional[int] = 1, 37 | num_take: int = -1): 38 | r"""Creates a tensorflow dataset from filenames. 39 | 40 | Args: 41 | mixture_or_task_name: str. Name of task or mixture. 42 | task_feature_lengths: Dict of sequence lengths of features. 43 | feature_converter: Model-specific feature converter. 44 | use_cached: bool. Whether to use a precomputed version of the dataset from a 45 | cache dir. Defaults to False. 46 | shuffle: bool. Whether to shuffle data. Defaults to False. 47 | seed: int. Random seed. Defaults to None. 48 | shard_num: int. Number of shards. 49 | shard_index: int. Worker index. 50 | epochs: Number of epochs to repeat. If None, repeat forever. 51 | num_take: int. If not -1, take the first n examples. 52 | 53 | Returns: 54 | An instance of tf.data.Dataset. 55 | """ 56 | if mixture_or_task_name.startswith(SEQIO_PREFIX): 57 | mixture_or_task_name = mixture_or_task_name[len(SEQIO_PREFIX):] 58 | 59 | sp = mixture_or_task_name.split("/") 60 | if sp[-1].startswith("split="): 61 | mixture_or_task_name = "/".join(sp[:-1]) 62 | split = sp[-1][len("split="):] 63 | else: 64 | split = None 65 | 66 | ret = seqio.get_dataset( 67 | mixture_or_task_name=mixture_or_task_name, 68 | task_feature_lengths=task_feature_lengths, 69 | dataset_split=split, 70 | shuffle=shuffle, 71 | num_epochs=epochs, 72 | feature_converter=feature_converter, 73 | shard_info=seqio.ShardInfo(shard_index, shard_num), 74 | use_cached=use_cached, 75 | seed=seed) 76 | ret = ret.take(num_take) 77 | return ret 78 | 79 | 80 | def seqio_data(task_feature_lengths: Mapping[str, int], 81 | feature_converter: seqio.FeatureConverter, 82 | use_cached: bool = False, 83 | shuffle: bool = False, 84 | seed: Optional[int] = None): 85 | """Wraps a seqio data pipeline.""" 86 | 87 | def data_fn(filenames, **kwargs): 88 | (mixture_or_task_name,) = filenames 89 | return pipeline_from_mixture_or_task_name( 90 | mixture_or_task_name, 91 | task_feature_lengths, 92 | feature_converter, 93 | use_cached=use_cached, 94 | shuffle=shuffle, 95 | seed=seed, 96 | **kwargs) 97 | 98 | return data_fn 99 | -------------------------------------------------------------------------------- /jestimator/data/reader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Data file readers for different formats.""" 16 | import dataclasses 17 | import enum 18 | from typing import Mapping, Optional, Tuple, Union 19 | 20 | import tensorflow as tf 21 | import tensorflow_datasets as tfds 22 | 23 | TFDS_PREFIX = 'tfds://' 24 | 25 | 26 | # Supported file format for record data. 27 | @enum.unique 28 | class RecordFormat(enum.Enum): 29 | TFRECORD = 'tfrecord' 30 | 31 | 32 | def get_format(path: str) -> RecordFormat: 33 | """Returns the type of record file, 'sstable, 'recordio', or 'tfrecord'.""" 34 | 35 | # Try as a TFRecord. 36 | try: 37 | next(tf.data.TFRecordDataset(path).as_numpy_iterator()) 38 | return RecordFormat.TFRECORD 39 | except (IOError, StopIteration, tf.errors.DataLossError, AttributeError): 40 | pass 41 | 42 | raise TypeError(f'Invalid file format: {path}') 43 | 44 | 45 | def get_record_dataset( 46 | path: str, 47 | file_format: Optional[Union[str, RecordFormat]] = None) -> tf.data.Dataset: 48 | r"""Creates a tensorflow dataset from path. 49 | 50 | Args: 51 | path: Path to the dataset. 52 | file_format: The format of dataset files. 53 | 54 | Returns: 55 | An instance of tf.data.Dataset. 56 | """ 57 | if file_format is None: 58 | file_format = get_format(path) 59 | elif not isinstance(file_format, RecordFormat): 60 | file_format = RecordFormat(file_format) 61 | 62 | if file_format == RecordFormat.TFRECORD: 63 | d = tf.data.TFRecordDataset(path) 64 | return d 65 | 66 | raise TypeError(f'Unknown file format: {path}') 67 | 68 | 69 | def get_tfds_dataset(path: str) -> tf.data.Dataset: 70 | """Creates a tensorflow dataset from path.""" 71 | if path.startswith(TFDS_PREFIX): 72 | path = path[len(TFDS_PREFIX):] 73 | sp = path.split(':', 1) 74 | if len(sp) == 2: 75 | data_dir, path = sp 76 | else: 77 | data_dir = None 78 | 79 | sp = path.split('/') 80 | if sp[-1].startswith('split='): 81 | path = '/'.join(sp[:-1]) 82 | split = sp[-1][len('split='):] 83 | else: 84 | split = None 85 | 86 | d = tfds.load( 87 | path, 88 | split=split, 89 | as_supervised=False, 90 | shuffle_files=False, 91 | data_dir=data_dir) 92 | return d 93 | 94 | 95 | def serialize_tensor_dict(data: Mapping[str, tf.Tensor]) -> bytes: 96 | """Converts a tensor dict to bytes, via tf.train.Example.""" 97 | feature = {} 98 | for k, v in data.items(): 99 | bytes_list = tf.train.BytesList(value=[tf.io.serialize_tensor(v).numpy()]) 100 | feature[k] = tf.train.Feature(bytes_list=bytes_list) 101 | ex = tf.train.Example(features=tf.train.Features(feature=feature)) 102 | return ex.SerializeToString() 103 | 104 | 105 | def parse_tensor_dict( 106 | x, elem_spec: Mapping[str, tf.TensorSpec]) -> Mapping[str, tf.Tensor]: 107 | """Creates a tensor dict from serialized bytes.""" 108 | features = { 109 | k: tf.io.FixedLenFeature([], v.dtype) for k, v in elem_spec.items() 110 | } 111 | x = tf.io.parse_single_example(x, features) 112 | x = {k: tf.ensure_shape(v, elem_spec[k].shape) for k, v in x.items()} 113 | return x 114 | 115 | 116 | def to_str(x: tf.Tensor, encoding: str = 'utf-8') -> str: 117 | """Converts an eager tf.string tensor to str. Fail-safe.""" 118 | try: 119 | ret = x.numpy().decode(encoding) 120 | except UnicodeDecodeError: 121 | ret = '' 122 | return ret 123 | 124 | 125 | @dataclasses.dataclass 126 | class PyOutSpec: 127 | """Specifies the shape and type of a value returned by a py_func.""" 128 | shape: Tuple[int] 129 | dtype: tf.DType 130 | 131 | 132 | def apply_py_fn(py_fn, data, out_spec): 133 | """Applies a python function to graph-mode data. 134 | 135 | Args: 136 | py_fn: A python function. 137 | data: A nested structure of graph-mode data. 138 | out_spec: A nested structure of PyOutSpec of the returned values. 139 | 140 | Returns: 141 | A nested structure of graph-mode values. 142 | """ 143 | flat_data = tf.nest.flatten(data, expand_composites=True) 144 | 145 | def fn(*flat): 146 | data_eager = tf.nest.pack_sequence_as(data, flat, expand_composites=True) 147 | ret = py_fn(data_eager) 148 | return tf.nest.flatten(ret) 149 | 150 | flat_out_spec = tf.nest.flatten(out_spec) 151 | ret = tf.py_function(fn, flat_data, [x.dtype for x in flat_out_spec]) 152 | ret = [tf.reshape(y, x.shape) for x, y in zip(flat_out_spec, ret)] 153 | return tf.nest.pack_sequence_as(out_spec, ret) 154 | 155 | 156 | def lines_iterator(path: str, 157 | encoding: str = 'utf-8', 158 | split: bool = False, 159 | allow_empty: bool = False): 160 | """Line iterator from a file.""" 161 | with tf.io.gfile.GFile(path, 'rb') as f: 162 | for line in f: 163 | try: 164 | line = line.decode(encoding) 165 | except UnicodeDecodeError: 166 | continue 167 | ret = line.split() if split else line.rstrip('\r\n') 168 | if allow_empty or ret: 169 | yield ret 170 | -------------------------------------------------------------------------------- /jestimator/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions for building data pipelines.""" 16 | from typing import Callable, List, Sequence, Tuple, Union 17 | 18 | import jax 19 | from jax.experimental.multihost_utils import host_local_array_to_global_array 20 | import tensorflow as tf 21 | from tensorflow.io import gfile 22 | 23 | TFDS_PREFIX = 'tfds://' 24 | SEQIO_PREFIX = 'seqio://' 25 | DUMMY_PREFIX = 'dummy://' 26 | 27 | 28 | class StringIterable(object): 29 | """Converts `x` to iterable of strings. 30 | 31 | `x` can be a string, list of strings, or a tf.data.Dataset of tf.string. 32 | This prevents a Python string to be converted to iterable of characters. 33 | """ 34 | 35 | def __init__(self, x: Union[str, Sequence[str], tf.data.Dataset]): 36 | self.x = x 37 | 38 | def __iter__(self): 39 | if isinstance(self.x, str): 40 | return iter([self.x]) 41 | if isinstance(self.x, tf.data.Dataset): 42 | return self.x.as_numpy_iterator() 43 | return iter(self.x) 44 | 45 | 46 | def get_dataset_filenames(pattern: Union[str, Sequence[str], tf.data.Dataset], 47 | do_glob: bool = True) -> List[str]: 48 | """Returns a list of file names for a given sharded/glob file pattern. 49 | 50 | Args: 51 | pattern: A string, list of strings, or tf.data.Dataset containing file path, 52 | sharded path, tfds address, or glob_pattern. 53 | do_glob: Whether to do gfile glob. 54 | 55 | Returns: 56 | A list of filenames. 57 | 58 | Raises: 59 | ValueError: When some filename pattern does not match any file. 60 | """ 61 | paths = [] 62 | for pat in StringIterable(pattern): 63 | if pat.startswith(DUMMY_PREFIX): 64 | continue 65 | if pat.startswith(SEQIO_PREFIX): 66 | paths.append(pat) 67 | continue 68 | if pat.startswith(TFDS_PREFIX): 69 | paths.append(pat) 70 | continue 71 | 72 | if do_glob: 73 | glob = gfile.glob(pat) 74 | if not glob: 75 | raise ValueError(f'File pattern {pat} has no match.') from None 76 | paths.extend(glob) 77 | else: 78 | paths.append(pat) 79 | return paths 80 | 81 | 82 | def count_dataset(d: tf.data.Dataset, batch_size: int) -> Tuple[int, int]: 83 | """Count the iterator length of a dataset, and the last batch size. 84 | 85 | Args: 86 | d: tf.data.Dataset. 87 | batch_size: int. Batch size. 88 | 89 | Returns: 90 | (dataset_length, last_batch_size). 91 | """ 92 | d = d.batch(batch_size) 93 | d = d.prefetch(tf.data.AUTOTUNE) 94 | dataset_length = 0 95 | x = None 96 | for x in d: 97 | dataset_length += 1 98 | last_batch_size = tf.shape(tf.nest.flatten(x)[0])[0].numpy() 99 | return dataset_length, last_batch_size 100 | 101 | 102 | def transpose_dataset(d: tf.data.Dataset, size_d: int, size_per_elem: int, 103 | bs: int) -> tf.data.Dataset: 104 | """Transpose a dataset of datasets.""" 105 | flat = d.flat_map(lambda x: x) 106 | ret = flat.window(size_d, 1, size_per_elem, drop_remainder=True) 107 | ret = ret.flat_map(lambda x: x.batch(bs, drop_remainder=True)) 108 | return ret 109 | 110 | 111 | def create_data_pipeline(filenames: List[str], 112 | data_fn: Callable[..., tf.data.Dataset], 113 | data_layout, 114 | shuffle_buf=None, 115 | consecutive=None, 116 | shard_source=False, 117 | **kwargs): 118 | """Builds a data pipeline with partitioning, batching and shuffle. 119 | 120 | When `consecutive` is not None, this pipeline produces consecutive batches, 121 | which can be used to divide very long sequences into multiple batches. 122 | 123 | Args: 124 | filenames: List of data file names. 125 | data_fn: A function that returns a tf dataset. 126 | data_layout: Partitioning data layout. 127 | shuffle_buf: int. Buffer size for shuffling. Do not shuffle if None. 128 | consecutive: int. If not None, every n batches are consecutive. 129 | shard_source: bool. For multiple workers, whether to shard the data source 130 | instead of sharding at the end of data pipeline. Defaults to False. 131 | **kwargs: Other kwargs passed to `data_fn`. 132 | 133 | Returns: 134 | A tf dataset instance. 135 | """ 136 | shard_id = data_layout.shard_id 137 | num_shards = data_layout.num_shards 138 | batch_size = data_layout.batch_size 139 | assert num_shards == 1 or batch_size is not None 140 | if batch_size is not None: 141 | bs, r = divmod(batch_size, num_shards) 142 | assert r == 0, f'{batch_size} % {num_shards} != 0' 143 | 144 | if consecutive is None: 145 | if shard_source: 146 | d = data_fn( 147 | filenames, shard_num=num_shards, shard_index=shard_id, **kwargs) 148 | else: 149 | d = data_fn(filenames, **kwargs) 150 | 151 | if shuffle_buf is not None: 152 | d = d.shuffle(shuffle_buf, seed=shuffle_buf) 153 | if batch_size is not None: 154 | d = d.batch(bs, drop_remainder=True) 155 | if num_shards > 1 and not shard_source: 156 | d = d.shard(num_shards, shard_id) 157 | 158 | else: 159 | epochs = kwargs.pop('epochs', 1) 160 | d = data_fn(filenames, **kwargs) 161 | d = d.window(consecutive, drop_remainder=True).repeat(epochs) 162 | if num_shards > 1: 163 | d = d.shard(num_shards, shard_id) 164 | if shuffle_buf is not None: 165 | d = d.shuffle(shuffle_buf, seed=shuffle_buf) 166 | if batch_size is None: 167 | d = d.flat_map(lambda x: x) 168 | else: 169 | d = d.window(bs, drop_remainder=True) 170 | d = d.flat_map(lambda x: transpose_dataset(x, bs, consecutive, bs)) 171 | 172 | d = d.prefetch(tf.data.AUTOTUNE) 173 | return d 174 | 175 | 176 | class DataIterable(object): 177 | """Converts tf.data.Dataset to iterable of numpy arrays.""" 178 | 179 | def __init__(self, x: tf.data.Dataset, partitioner): 180 | self.x = x 181 | self.partitioner = partitioner 182 | 183 | def __iter__(self): 184 | ret = self.x.as_numpy_iterator() 185 | mesh = self.partitioner.mesh 186 | spec = self.partitioner.data_partition_spec 187 | it = (host_local_array_to_global_array(batch, mesh, spec) for batch in ret) 188 | return it 189 | -------------------------------------------------------------------------------- /jestimator/modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Modeling utilities.""" 16 | import inspect 17 | import threading 18 | from typing import Callable, Optional, Tuple 19 | 20 | from flax import linen as nn 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | from jax.typing import ArrayLike 25 | from flaxformer.types import DType, Initializer, PRNGKey, Shape # pylint: disable=g-multiple-import 26 | 27 | 28 | def sparse_xe_with_logits(labels: ArrayLike, 29 | logits: ArrayLike, 30 | mask: Optional[ArrayLike] = None, 31 | normalized: bool = False, 32 | reduce_all: bool = True): 33 | """Sparse cross entropy from softmax logits. 34 | 35 | Args: 36 | labels: int tensor. 37 | logits: float tensor of shape (labels.shape + [num_labels]). 38 | mask: 0/1 float tensor, the same shape as `labels`. 39 | normalized: Whether `logits` is normalized. 40 | reduce_all: Whether to reduce_sum the loss tensor. 41 | 42 | Returns: 43 | Cross-entropy loss. If `reduce_all` is True, returns a scalar tensor. 44 | Otherwise returns a float tensor of the same shape as `labels`. 45 | """ 46 | if not normalized: 47 | logits = jax.nn.log_softmax(logits) 48 | 49 | labels_ = jnp.expand_dims(labels, -1) 50 | llh_ = jnp.take_along_axis(logits, labels_, axis=-1) 51 | llh = jnp.squeeze(llh_, -1) 52 | 53 | if mask is not None: 54 | llh = jnp.where(mask, llh, 0.) 55 | if reduce_all: 56 | loss = -jnp.sum(llh) 57 | else: 58 | loss = -llh 59 | return loss 60 | 61 | 62 | def normalize_loss_by_size( 63 | loss: ArrayLike, size: ArrayLike 64 | ) -> Tuple[ArrayLike, ArrayLike]: 65 | """Normalize a loss value by size of labels.""" 66 | loss = jnp.asarray(loss) 67 | size = jnp.asarray(size, loss.dtype) 68 | loss = loss * jax.lax.rsqrt(size) 69 | size = jnp.sqrt(size) 70 | return loss, size 71 | 72 | 73 | def unstack(x, axis): 74 | """Unstack a tensor along axis.""" 75 | return [ 76 | jax.lax.index_in_dim(x, i, axis, keepdims=False) 77 | for i in range(x.shape[axis]) 78 | ] 79 | 80 | 81 | _thread_local = threading.local() 82 | 83 | 84 | def global_kwargs(*inherits: str, pass_down: bool = False): 85 | """Function decorator to use global kwargs. 86 | 87 | A utility for passing keyword arguments down to nested sub-calls. 88 | 89 | # Example 90 | 91 | ``` 92 | @global_kwargs('attention_mask', 'training', 'dropout_rate') 93 | def func1(x, attention_mask=None, training=False, dropout_rate=0.1): 94 | # calculation code... 95 | if attention_mask is not None: 96 | att += attention_mask 97 | if training: 98 | attention = tf.nn.dropout(attention, dropout_rate) 99 | # ... 100 | 101 | @global_kwargs(pass_down=True) 102 | def func2(hidden): 103 | # call `func1` 104 | func1(hidden) 105 | # ... 106 | 107 | # Then, one can pass arguments 'attention_mask', 'training', 'dropout_rate' 108 | # from `func2` down to `func1`, without explicitly declaring those arguments 109 | # in `func2`: 110 | 111 | func2(a, attention_mask=b, training=True, dropout_rate=0.5) # It works! 112 | ``` 113 | 114 | Args: 115 | *inherits: Keys to be inherited from the global context. 116 | pass_down: bool. If True, unrecognized keys will be passed to sub-routines. 117 | 118 | Returns: 119 | The function wrapper. 120 | """ 121 | 122 | def wrap(func: Callable) -> Callable: # pylint: disable=g-bare-generic 123 | func_signature = inspect.signature(func, follow_wrapped=False) 124 | func_params = func_signature.parameters 125 | for v in func_params.values(): 126 | assert v.kind != inspect.Parameter.VAR_KEYWORD, ( 127 | '`func` should not have VAR_KEYWORD parameter.') 128 | for k in inherits: 129 | assert k in func_params, ( 130 | f'The inherit key ({k}) is not an argument of `func`.') 131 | 132 | def wrapped(*args, **kwargs): 133 | current = getattr(_thread_local, 'current_inherit_kwargs', {}) 134 | func_kwargs = {k: current[k] for k in inherits if k in current} 135 | if pass_down: 136 | subrout = {**current} 137 | 138 | for k, v in kwargs.items(): 139 | if k in func_params: 140 | func_kwargs[k] = v 141 | else: 142 | assert pass_down, f'Unrecognized kwarg ({k}).' 143 | subrout[k] = v 144 | 145 | if pass_down: 146 | _thread_local.current_inherit_kwargs = subrout 147 | ret = func(*args, **func_kwargs) 148 | if pass_down: 149 | _thread_local.current_inherit_kwargs = current 150 | return ret 151 | 152 | return wrapped 153 | 154 | return wrap 155 | 156 | 157 | def truncated_normal_initializer(stddev: ArrayLike) -> Initializer: 158 | """Truncated normal initializer.""" 159 | 160 | def init(key: PRNGKey, shape: Shape, dtype: DType) -> ArrayLike: 161 | return jax.random.truncated_normal( 162 | key=key, lower=-2., upper=2., shape=shape, dtype=dtype) * stddev 163 | 164 | return init 165 | 166 | 167 | class Dropout(nn.Module): 168 | """Dropout layer with fast random generator.""" 169 | rate: float 170 | 171 | @global_kwargs('enable_dropout') 172 | def __call__(self, inputs: ArrayLike, enable_dropout: bool = False): 173 | """Applies a random dropout mask to the input.""" 174 | if not enable_dropout: 175 | return inputs 176 | if self.rate == 0.: 177 | return inputs 178 | # Prevent gradient NaNs in 1.0 edge-case. 179 | if self.rate == 1.0: 180 | return jnp.zeros_like(inputs) 181 | 182 | inputs = jnp.asarray(inputs) 183 | mask = jax.lax.rng_uniform(0., 1., inputs.shape) < self.rate 184 | return jnp.where(mask, 0., inputs / (1. - self.rate)) 185 | -------------------------------------------------------------------------------- /jestimator/models/bert/finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Sequence classification. 16 | 17 | # For debug run locally: 18 | 19 | ## Train: 20 | 21 | ``` 22 | PYTHONPATH=. python3 \ 23 | jestimator/estimator.py \ 24 | --module_imp="jestimator.models.bert.finetune" \ 25 | --module_config="jestimator/models/bert/finetune.py" \ 26 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 27 | --module_config.segment_names="sentence1,sentence2" \ 28 | --module_config.model_config.num_labels=2 \ 29 | --train_pattern="tfds://glue/rte/split=train" \ 30 | --valid_pattern="tfds://glue/rte/split=validation" \ 31 | --model_dir="$HOME/models/bert_rte" \ 32 | --checkpoint_path="gs://gresearch/checkpoints_in_amos_paper/\ 33 | adamw/bert-base/checkpoint_300000" \ 34 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \ 35 | --check_every_steps=10 --logtostderr 36 | ``` 37 | 38 | ## Eval: 39 | 40 | ``` 41 | PYTHONPATH=. python3 \ 42 | jestimator/estimator.py \ 43 | --module_imp="jestimator.models.bert.finetune" \ 44 | --module_config="jestimator/models/bert/finetune.py" \ 45 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 46 | --module_config.segment_names="sentence1,sentence2" \ 47 | --module_config.model_config.num_labels=2 \ 48 | --module_config.eval_metric="accuracy" \ 49 | --eval_pattern="tfds://glue/rte/split=validation" \ 50 | --model_dir="$HOME/models/bert_rte" \ 51 | --eval_batch_size=4 --num_eval_examples=4 \ 52 | --logtostderr 53 | ``` 54 | 55 | ## Predict: 56 | 57 | ``` 58 | PYTHONPATH=. python3 \ 59 | jestimator/estimator.py \ 60 | --module_imp="jestimator.models.bert.finetune" \ 61 | --module_config="jestimator/models/bert/finetune.py" \ 62 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 63 | --module_config.segment_names="sentence1,sentence2" \ 64 | --module_config.model_config.num_labels=2 \ 65 | --module_config.label_names="entailment,not_entailment" \ 66 | --pred_pattern="tfds://glue/rte/split=test" \ 67 | --model_dir="$HOME/models/bert_rte" \ 68 | --pred_batch_size=4 --num_pred_examples=4 \ 69 | --logtostderr 70 | ``` 71 | """ 72 | 73 | import dataclasses 74 | 75 | import jax 76 | import jax.numpy as jnp 77 | from jestimator.data import reader 78 | from jestimator.data.pipeline_rec import rec_data 79 | from jestimator.data.reader import PyOutSpec 80 | from jestimator.models.bert import modeling 81 | from jestimator.states import Evaluator, InferState, MeanMetrics, Predictor, TrainState # pylint: disable=g-multiple-import 82 | import ml_collections 83 | from ml_collections.config_dict import config_dict 84 | import optax 85 | from scipy import stats as scipy_stats 86 | from sklearn import metrics as sklearn_metrics 87 | import tensorflow as tf 88 | 89 | import sentencepiece as spm 90 | 91 | 92 | def get_config(): 93 | """Returns a config object for modeling flags.""" 94 | module_config = ml_collections.ConfigDict() 95 | 96 | # Model config. 97 | model_config = modeling.ModelConfig() 98 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config)) 99 | module_config.model_config = model_config 100 | 101 | # Optimizer config. 102 | opt_config = ml_collections.ConfigDict() 103 | opt_config.optimizer = 'adam' 104 | opt_config.learning_rate = 5e-6 105 | module_config.opt_config = opt_config 106 | 107 | # Other config. 108 | module_config.vocab_path = config_dict.placeholder(str) 109 | module_config.segment_names = config_dict.placeholder(str) 110 | module_config.eval_metric = config_dict.placeholder(str) 111 | module_config.output_path = config_dict.placeholder(str) 112 | module_config.label_names = config_dict.placeholder(str) 113 | module_config.stsb = False 114 | return module_config 115 | 116 | 117 | def load_config(global_flags): 118 | """Init config data from global flags.""" 119 | config = ml_collections.ConfigDict() 120 | config.update(global_flags.module_config) 121 | 122 | tokenizer = spm.SentencePieceProcessor() 123 | tokenizer.Load(config.vocab_path) 124 | config.model_config.vocab_size = tokenizer.GetPieceSize() 125 | 126 | segment_names = config.segment_names.split(',') 127 | num_segments = len(segment_names) 128 | config.model_config.num_segments = num_segments + 1 129 | 130 | # Only a frozen config (hashable object) can be passed to jit functions 131 | # (i.e. train_step/valid_step/infer_step). 132 | config.frozen = ml_collections.FrozenConfigDict(config) 133 | 134 | # Construct data pipelines in the following (using TensorFLow): 135 | max_length = config.model_config.max_length 136 | max_len_1 = (max_length - 1) // num_segments 137 | cls_token_id = tokenizer.PieceToId('') 138 | sep_token_id = tokenizer.PieceToId('') 139 | eos_token_id = tokenizer.PieceToId('') 140 | data_keys = ['idx', 'label'] + config.segment_names.split(',') 141 | mode = global_flags.mode 142 | 143 | def tokenize_fn(texts): 144 | ids = [] 145 | for s in texts: 146 | s = tf.strings.lower(s).numpy() 147 | ids.append(tf.convert_to_tensor(tokenizer.EncodeAsIds(s), tf.int32)) 148 | return ids 149 | 150 | def example_fn(data): 151 | data = {k: data[k] for k in data_keys if k in data} 152 | texts = [data[k] for k in segment_names] 153 | out_spec = [PyOutSpec((-1,), tf.int32)] * num_segments 154 | tokenized = reader.apply_py_fn(tokenize_fn, texts, out_spec) 155 | 156 | max_len_0 = max_length - 1 157 | input_ids = [tf.concat([[cls_token_id], tokenized[0]], 0)] 158 | for x in tokenized[1:]: 159 | x = tf.concat([[sep_token_id], x], 0)[:max_len_1] 160 | input_ids.append(x) 161 | max_len_0 = max_len_0 - tf.shape(x)[0] 162 | input_ids[0] = input_ids[0][:max_len_0] 163 | input_ids.append([eos_token_id]) 164 | 165 | segment_ids = [tf.ones_like(x) * i for i, x in enumerate(input_ids)] 166 | input_ids = tf.concat(input_ids, 0) 167 | input_mask = tf.ones_like(input_ids) 168 | segment_ids = tf.concat(segment_ids, 0) 169 | 170 | pad_len = max_length - tf.shape(input_ids)[0] 171 | input_ids = tf.pad(input_ids, [[0, pad_len]]) 172 | input_mask = tf.pad(input_mask, [[0, pad_len]]) 173 | segment_ids = tf.pad(segment_ids, [[0, pad_len]]) 174 | 175 | ret = { 176 | 'input_ids': tf.ensure_shape(input_ids, (max_length,)), 177 | 'input_mask': tf.ensure_shape(input_mask, (max_length,)), 178 | 'segment_ids': tf.ensure_shape(segment_ids, (max_length,)), 179 | } 180 | if mode == 'train': 181 | ret['label'] = data['label'] 182 | if config.stsb: 183 | ret['label'] /= 5.0 184 | else: 185 | ret['idx'] = data['idx'] 186 | if mode.startswith('eval'): 187 | ret = (data['label'], ret) 188 | return ret 189 | 190 | def dataset_fn(path: str) -> tf.data.Dataset: 191 | d = reader.get_tfds_dataset(path) 192 | d = d.map(example_fn, tf.data.AUTOTUNE) 193 | return d 194 | 195 | config.train_data_fn = rec_data( 196 | dataset_fn=dataset_fn, cache=True, interleave=True) 197 | config.eval_data_fn = config.valid_data_fn = rec_data( 198 | dataset_fn=dataset_fn, cache=True) 199 | config.pred_data_fn = rec_data(dataset_fn=dataset_fn) 200 | return config 201 | 202 | 203 | def get_train_state(config, rng): 204 | """Create train state.""" 205 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 206 | model = modeling.ModelForSeqCls(model_config) 207 | 208 | opt_config = config.opt_config 209 | if opt_config.optimizer == 'adam': 210 | optimizer = optax.adam(learning_rate=opt_config.learning_rate) 211 | 212 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss') 213 | return TrainState.create(metrics_mod, optimizer, model, rng, jnp.array([[0]])) 214 | 215 | 216 | def train_step(config, train_batch, state: TrainState, metrics): 217 | """Training step.""" 218 | loss_fn = ( 219 | modeling.ModelForSeqCls.mse_loss 220 | if config.stsb else modeling.ModelForSeqCls.xe_loss) 221 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)( 222 | state.params, 223 | train_batch['label'], 224 | train_batch['input_ids'], 225 | segment_ids=train_batch['segment_ids'], 226 | input_mask=train_batch['input_mask'], 227 | enable_dropout=True, 228 | method=loss_fn) 229 | _, metrics = state.metrics_mod.apply( 230 | metrics, 231 | 'train_loss', 232 | loss, 233 | size, 234 | method=MeanMetrics.update, 235 | mutable=['metrics']) 236 | return state.apply_gradients(grads=grads), metrics 237 | 238 | 239 | def valid_step(config, valid_batch, state: TrainState, metrics): 240 | """Validation step.""" 241 | loss_fn = ( 242 | modeling.ModelForSeqCls.mse_loss 243 | if config.stsb else modeling.ModelForSeqCls.xe_loss) 244 | loss, size = state.apply_fn( 245 | state.variables(), 246 | valid_batch['label'], 247 | valid_batch['input_ids'], 248 | segment_ids=valid_batch['segment_ids'], 249 | input_mask=valid_batch['input_mask'], 250 | method=loss_fn) 251 | _, metrics = state.metrics_mod.apply( 252 | metrics, 253 | 'valid_loss', 254 | loss, 255 | size, 256 | method=MeanMetrics.update, 257 | mutable=['metrics']) 258 | return metrics 259 | 260 | 261 | def get_infer_state(config): 262 | """Create infer state.""" 263 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 264 | model = modeling.ModelForSeqCls(model_config) 265 | return InferState.create(model, jnp.array([[0]])) 266 | 267 | 268 | def infer_step(config, batch, state: InferState) -> InferState: 269 | """Infer step.""" 270 | logits = state.apply_fn( 271 | state.variables(), 272 | batch['input_ids'], 273 | segment_ids=batch['segment_ids'], 274 | input_mask=batch['input_mask']) 275 | if config.stsb: 276 | pred = jax.nn.softmax(logits)[..., 0] * 5.0 277 | else: 278 | pred = jnp.argmax(logits, axis=-1) 279 | return state.replace(ret={ 280 | 'idx': batch['idx'], 281 | 'prediction': pred, 282 | }) 283 | 284 | 285 | def get_evaluator(config) -> Evaluator: 286 | """Create evaluator.""" 287 | eval_fns = { 288 | 'accuracy': sklearn_metrics.accuracy_score, 289 | 'f1': sklearn_metrics.f1_score, 290 | 'spearmanr': lambda x, y: scipy_stats.spearmanr(x, y)[0], 291 | } 292 | 293 | def proc_fn(infer): 294 | return infer['prediction'] 295 | 296 | metric = config.eval_metric 297 | return Evaluator({metric: (proc_fn, eval_fns[metric])}) 298 | 299 | 300 | def get_predictor(config) -> Predictor: 301 | """Create predictor.""" 302 | pre_str = 'index\tprediction' 303 | label_names = (None if config.label_names is None else 304 | config.label_names.split(',')) 305 | 306 | def proc_fn(infer): 307 | ret = [] 308 | for x, y in zip(infer['idx'], infer['prediction']): 309 | z = y if label_names is None else label_names[y] 310 | ret.append(f'{x}\t{z}') 311 | return ret 312 | 313 | return Predictor(proc_fn, config.output_path, pre_str=pre_str) 314 | -------------------------------------------------------------------------------- /jestimator/models/bert/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Pretrain. 16 | 17 | # For debug run locally: 18 | 19 | ``` 20 | PYTHONPATH=. python3 \ 21 | jestimator/estimator.py \ 22 | --module_imp="jestimator.models.bert.pretrain" \ 23 | --module_config="jestimator/models/bert/pretrain.py" \ 24 | --module_config.model_config.vocab_size=32000 \ 25 | --module_config.mask_token_id=4 \ 26 | --train_pattern="gs://gresearch/checkpoints_in_amos_paper/data/\ 27 | books-00000-of-00500" \ 28 | --valid_pattern="gs://gresearch/checkpoints_in_amos_paper/data/ptb" \ 29 | --model_dir="$HOME/models/bert_pretrain" \ 30 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \ 31 | --check_every_steps=10 --logtostderr 32 | ``` 33 | """ 34 | 35 | import dataclasses 36 | from typing import Mapping 37 | 38 | import jax 39 | import jax.numpy as jnp 40 | from jestimator import amos 41 | from jestimator.data.pipeline_lm import lm_data 42 | from jestimator.models.bert import modeling 43 | from jestimator.states import TrainState, MeanMetrics # pylint: disable=g-multiple-import 44 | import ml_collections 45 | from ml_collections.config_dict import config_dict 46 | import optax 47 | import tensorflow as tf 48 | 49 | 50 | def get_config(): 51 | """Returns a config object for modeling flags.""" 52 | module_config = ml_collections.ConfigDict() 53 | 54 | # Model config. 55 | model_config = modeling.ModelConfig() 56 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config)) 57 | module_config.model_config = model_config 58 | 59 | # Optimizer config. 60 | opt_config = ml_collections.ConfigDict() 61 | opt_config.optimizer = 'adamw' 62 | opt_config.learning_rate = 1e-4 63 | opt_config.warmup_steps = 10000 64 | opt_config.linear_decay_to_step = config_dict.placeholder(int) 65 | opt_config.momentum = 0.9 66 | opt_config.beta = 0.999 67 | opt_config.weight_decay = 0.01 68 | module_config.opt_config = opt_config 69 | 70 | # Other config. 71 | module_config.mask_token_id = config_dict.placeholder(int) 72 | module_config.mask_rate = 0.15 73 | return module_config 74 | 75 | 76 | def load_config(global_flags): 77 | """Init config data from global flags.""" 78 | config = ml_collections.ConfigDict() 79 | config.update(global_flags.module_config) 80 | 81 | # Only a frozen config (hashable object) can be passed to jit functions 82 | # (i.e. train_step/valid_step/infer_step). 83 | config.frozen = ml_collections.FrozenConfigDict(config) 84 | 85 | # Construct data pipelines in the following (using TensorFLow): 86 | seq_length = config.model_config.max_length 87 | 88 | def feature_fn(token_ids: tf.Tensor) -> Mapping[str, tf.Tensor]: 89 | """Builds a feature dict to be compatible with seqio.""" 90 | return {'targets': tf.ensure_shape(token_ids, (seq_length,))} 91 | 92 | config.train_data_fn = lm_data( 93 | seq_length, random_skip=True, feature_fn=feature_fn, interleave=True) 94 | config.valid_data_fn = lm_data(seq_length, feature_fn=feature_fn) 95 | return config 96 | 97 | 98 | def get_train_state(config, rng) -> TrainState: 99 | """Create train state.""" 100 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 101 | model = modeling.ModelForPretrain(model_config) 102 | opt_config = config.opt_config 103 | warmup = opt_config.warmup_steps 104 | decay = opt_config.linear_decay_to_step 105 | 106 | def lr_schedule(step): 107 | lr = opt_config.learning_rate 108 | if warmup is not None: 109 | lr *= jnp.minimum(1., step / warmup) 110 | if decay is not None: 111 | lr *= 1. - jnp.maximum(0., step - warmup) / (decay - warmup) 112 | elif decay is not None: 113 | lr *= 1. - step / decay 114 | return lr 115 | 116 | if opt_config.optimizer == 'adamw': 117 | optimizer = optax.adamw( 118 | learning_rate=lr_schedule, 119 | b1=opt_config.momentum, 120 | b2=opt_config.beta, 121 | weight_decay=opt_config.weight_decay) 122 | elif opt_config.optimizer == 'amos': 123 | optimizer = amos.amos( 124 | lr_schedule, 125 | modeling.get_eta_fn(model_config), 126 | shape_fn=modeling.get_shape_fn(model_config), 127 | beta=opt_config.beta, 128 | momentum=opt_config.momentum, 129 | clip_value=1.) 130 | 131 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss', 'valid_mrr') 132 | return TrainState.create(metrics_mod, optimizer, model, rng, jnp.array([[0]])) 133 | 134 | 135 | def train_step(config, train_batch, state: TrainState, metrics): 136 | """Training step.""" 137 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)( 138 | state.params, 139 | train_batch['targets'], 140 | config.mask_token_id, 141 | mask_rate=config.mask_rate, 142 | input_mask=train_batch.get('input_mask'), 143 | enable_dropout=True, 144 | method=modeling.ModelForPretrain.mlm_train_loss) 145 | _, metrics = state.metrics_mod.apply( 146 | metrics, 147 | 'train_loss', 148 | loss, 149 | size, 150 | method=MeanMetrics.update, 151 | mutable=['metrics']) 152 | return state.apply_gradients(grads=grads), metrics 153 | 154 | 155 | def valid_step(config, valid_batch, state: TrainState, metrics): 156 | """Validation step.""" 157 | 158 | def body(i, metrics): 159 | del i # Unused. 160 | loss, mrr, size = state.apply_fn( 161 | state.variables(), 162 | valid_batch['targets'], 163 | config.mask_token_id, 164 | mask_rate=config.mask_rate, 165 | input_mask=valid_batch.get('input_mask'), 166 | method=modeling.ModelForPretrain.mlm_valid_metrics) 167 | _, metrics = state.metrics_mod.apply( 168 | metrics, 169 | 'valid_loss', 170 | loss, 171 | size, 172 | method=MeanMetrics.update, 173 | mutable=['metrics']) 174 | _, metrics = state.metrics_mod.apply( 175 | metrics, 176 | 'valid_mrr', 177 | mrr, 178 | size, 179 | method=MeanMetrics.update, 180 | mutable=['metrics']) 181 | return metrics 182 | 183 | return jax.lax.fori_loop(0, 20, body, metrics) 184 | -------------------------------------------------------------------------------- /jestimator/models/bert_rpe/finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Sequence classification. 16 | 17 | # For debug run locally: 18 | 19 | ## Train: 20 | 21 | ``` 22 | PYTHONPATH=. python3 \ 23 | jestimator/estimator.py \ 24 | --module_imp="jestimator.models.bert_rpe.finetune" \ 25 | --module_config="jestimator/models/bert_rpe/finetune.py" \ 26 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 27 | --module_config.segment_names="sentence1,sentence2" \ 28 | --module_config.model_config.num_labels=2 \ 29 | --train_pattern="tfds://glue/rte/split=train" \ 30 | --valid_pattern="tfds://glue/rte/split=validation" \ 31 | --model_dir="$HOME/models/bert_rte" \ 32 | --checkpoint_path="gs://gresearch/checkpoints_in_amos_paper/\ 33 | adamw/rpe/checkpoint_300000" \ 34 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \ 35 | --check_every_steps=10 --logtostderr 36 | ``` 37 | 38 | ## Eval: 39 | 40 | ``` 41 | PYTHONPATH=. python3 \ 42 | jestimator/estimator.py \ 43 | --module_imp="jestimator.models.bert_rpe.finetune" \ 44 | --module_config="jestimator/models/bert_rpe/finetune.py" \ 45 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 46 | --module_config.segment_names="sentence1,sentence2" \ 47 | --module_config.model_config.num_labels=2 \ 48 | --module_config.eval_metric="accuracy" \ 49 | --eval_pattern="tfds://glue/rte/split=validation" \ 50 | --model_dir="$HOME/models/bert_rte" \ 51 | --eval_batch_size=4 --num_eval_examples=4 \ 52 | --logtostderr 53 | ``` 54 | 55 | ## Predict: 56 | 57 | ``` 58 | PYTHONPATH=. python3 \ 59 | jestimator/estimator.py \ 60 | --module_imp="jestimator.models.bert_rpe.finetune" \ 61 | --module_config="jestimator/models/bert_rpe/finetune.py" \ 62 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 63 | --module_config.segment_names="sentence1,sentence2" \ 64 | --module_config.model_config.num_labels=2 \ 65 | --module_config.label_names="entailment,not_entailment" \ 66 | --pred_pattern="tfds://glue/rte/split=test" \ 67 | --model_dir="$HOME/models/bert_rte" \ 68 | --pred_batch_size=4 --num_pred_examples=4 \ 69 | --logtostderr 70 | ``` 71 | """ 72 | 73 | import dataclasses 74 | 75 | import jax 76 | import jax.numpy as jnp 77 | from jestimator.data import reader 78 | from jestimator.data.pipeline_rec import rec_data 79 | from jestimator.data.reader import PyOutSpec 80 | from jestimator.models.bert_rpe import modeling 81 | from jestimator.states import Evaluator, InferState, MeanMetrics, Predictor, TrainState # pylint: disable=g-multiple-import 82 | import ml_collections 83 | from ml_collections.config_dict import config_dict 84 | import optax 85 | from scipy import stats as scipy_stats 86 | from sklearn import metrics as sklearn_metrics 87 | import tensorflow as tf 88 | 89 | import sentencepiece as spm 90 | 91 | 92 | def get_config(): 93 | """Returns a config object for modeling flags.""" 94 | module_config = ml_collections.ConfigDict() 95 | 96 | # Model config. 97 | model_config = modeling.ModelConfig() 98 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config)) 99 | module_config.model_config = model_config 100 | 101 | # Optimizer config. 102 | opt_config = ml_collections.ConfigDict() 103 | opt_config.optimizer = 'adam' 104 | opt_config.learning_rate = 5e-6 105 | module_config.opt_config = opt_config 106 | 107 | # Other config. 108 | module_config.vocab_path = config_dict.placeholder(str) 109 | module_config.segment_names = config_dict.placeholder(str) 110 | module_config.eval_metric = config_dict.placeholder(str) 111 | module_config.output_path = config_dict.placeholder(str) 112 | module_config.label_names = config_dict.placeholder(str) 113 | module_config.stsb = False 114 | return module_config 115 | 116 | 117 | def load_config(global_flags): 118 | """Init config data from global flags.""" 119 | config = ml_collections.ConfigDict() 120 | config.update(global_flags.module_config) 121 | 122 | tokenizer = spm.SentencePieceProcessor() 123 | tokenizer.Load(config.vocab_path) 124 | config.model_config.vocab_size = tokenizer.GetPieceSize() 125 | 126 | segment_names = config.segment_names.split(',') 127 | num_segments = len(segment_names) 128 | config.model_config.num_segments = num_segments + 1 129 | 130 | # Only a frozen config (hashable object) can be passed to jit functions 131 | # (i.e. train_step/valid_step/infer_step). 132 | config.frozen = ml_collections.FrozenConfigDict(config) 133 | 134 | # Construct data pipelines in the following (using TensorFLow): 135 | max_length = config.model_config.max_length 136 | max_len_1 = (max_length - 1) // num_segments 137 | cls_token_id = tokenizer.PieceToId('') 138 | sep_token_id = tokenizer.PieceToId('') 139 | eos_token_id = tokenizer.PieceToId('') 140 | data_keys = ['idx', 'label'] + config.segment_names.split(',') 141 | mode = global_flags.mode 142 | 143 | def tokenize_fn(texts): 144 | ids = [] 145 | for s in texts: 146 | s = tf.strings.lower(s).numpy() 147 | ids.append(tf.convert_to_tensor(tokenizer.EncodeAsIds(s), tf.int32)) 148 | return ids 149 | 150 | def example_fn(data): 151 | data = {k: data[k] for k in data_keys if k in data} 152 | texts = [data[k] for k in segment_names] 153 | out_spec = [PyOutSpec((-1,), tf.int32)] * num_segments 154 | tokenized = reader.apply_py_fn(tokenize_fn, texts, out_spec) 155 | 156 | max_len_0 = max_length - 1 157 | input_ids = [tf.concat([[cls_token_id], tokenized[0]], 0)] 158 | for x in tokenized[1:]: 159 | x = tf.concat([[sep_token_id], x], 0)[:max_len_1] 160 | input_ids.append(x) 161 | max_len_0 = max_len_0 - tf.shape(x)[0] 162 | input_ids[0] = input_ids[0][:max_len_0] 163 | input_ids.append([eos_token_id]) 164 | 165 | segment_ids = [tf.ones_like(x) * i for i, x in enumerate(input_ids)] 166 | input_ids = tf.concat(input_ids, 0) 167 | input_mask = tf.ones_like(input_ids) 168 | segment_ids = tf.concat(segment_ids, 0) 169 | 170 | pad_len = max_length - tf.shape(input_ids)[0] 171 | input_ids = tf.pad(input_ids, [[0, pad_len]]) 172 | input_mask = tf.pad(input_mask, [[0, pad_len]]) 173 | segment_ids = tf.pad(segment_ids, [[0, pad_len]]) 174 | 175 | ret = { 176 | 'input_ids': tf.ensure_shape(input_ids, (max_length,)), 177 | 'input_mask': tf.ensure_shape(input_mask, (max_length,)), 178 | 'segment_ids': tf.ensure_shape(segment_ids, (max_length,)), 179 | } 180 | if mode == 'train': 181 | ret['label'] = data['label'] 182 | if config.stsb: 183 | ret['label'] /= 5.0 184 | else: 185 | ret['idx'] = data['idx'] 186 | if mode.startswith('eval'): 187 | ret = (data['label'], ret) 188 | return ret 189 | 190 | def dataset_fn(path: str) -> tf.data.Dataset: 191 | d = reader.get_tfds_dataset(path) 192 | d = d.map(example_fn, tf.data.AUTOTUNE) 193 | return d 194 | 195 | config.train_data_fn = rec_data( 196 | dataset_fn=dataset_fn, cache=True, interleave=True) 197 | config.eval_data_fn = config.valid_data_fn = rec_data( 198 | dataset_fn=dataset_fn, cache=True) 199 | config.pred_data_fn = rec_data(dataset_fn=dataset_fn) 200 | return config 201 | 202 | 203 | def get_train_state(config, rng): 204 | """Create train state.""" 205 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 206 | model = modeling.ModelForSeqCls(model_config) 207 | 208 | opt_config = config.opt_config 209 | if opt_config.optimizer == 'adam': 210 | optimizer = optax.adam(learning_rate=opt_config.learning_rate) 211 | 212 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss') 213 | dummy_input = jnp.array([[0] * config.model_config.max_length]) 214 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy_input) 215 | 216 | 217 | def train_step(config, train_batch, state: TrainState, metrics): 218 | """Training step.""" 219 | loss_fn = ( 220 | modeling.ModelForSeqCls.mse_loss 221 | if config.stsb else modeling.ModelForSeqCls.xe_loss) 222 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)( 223 | state.params, 224 | train_batch['label'], 225 | train_batch['input_ids'], 226 | segment_ids=train_batch['segment_ids'], 227 | input_mask=train_batch['input_mask'], 228 | enable_dropout=True, 229 | method=loss_fn) 230 | _, metrics = state.metrics_mod.apply( 231 | metrics, 232 | 'train_loss', 233 | loss, 234 | size, 235 | method=MeanMetrics.update, 236 | mutable=['metrics']) 237 | return state.apply_gradients(grads=grads), metrics 238 | 239 | 240 | def valid_step(config, valid_batch, state: TrainState, metrics): 241 | """Validation step.""" 242 | loss_fn = ( 243 | modeling.ModelForSeqCls.mse_loss 244 | if config.stsb else modeling.ModelForSeqCls.xe_loss) 245 | loss, size = state.apply_fn( 246 | state.variables(), 247 | valid_batch['label'], 248 | valid_batch['input_ids'], 249 | segment_ids=valid_batch['segment_ids'], 250 | input_mask=valid_batch['input_mask'], 251 | method=loss_fn) 252 | _, metrics = state.metrics_mod.apply( 253 | metrics, 254 | 'valid_loss', 255 | loss, 256 | size, 257 | method=MeanMetrics.update, 258 | mutable=['metrics']) 259 | return metrics 260 | 261 | 262 | def get_infer_state(config): 263 | """Create infer state.""" 264 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 265 | model = modeling.ModelForSeqCls(model_config) 266 | dummy_input = jnp.array([[0] * config.model_config.max_length]) 267 | return InferState.create(model, dummy_input) 268 | 269 | 270 | def infer_step(config, batch, state: InferState) -> InferState: 271 | """Infer step.""" 272 | logits = state.apply_fn( 273 | state.variables(), 274 | batch['input_ids'], 275 | segment_ids=batch['segment_ids'], 276 | input_mask=batch['input_mask']) 277 | if config.stsb: 278 | pred = jax.nn.softmax(logits)[..., 0] * 5.0 279 | else: 280 | pred = jnp.argmax(logits, axis=-1) 281 | return state.replace(ret={ 282 | 'idx': batch['idx'], 283 | 'prediction': pred, 284 | }) 285 | 286 | 287 | def get_evaluator(config) -> Evaluator: 288 | """Create evaluator.""" 289 | eval_fns = { 290 | 'accuracy': sklearn_metrics.accuracy_score, 291 | 'f1': sklearn_metrics.f1_score, 292 | 'spearmanr': lambda x, y: scipy_stats.spearmanr(x, y)[0], 293 | } 294 | 295 | def proc_fn(infer): 296 | return infer['prediction'] 297 | 298 | metric = config.eval_metric 299 | return Evaluator({metric: (proc_fn, eval_fns[metric])}) 300 | 301 | 302 | def get_predictor(config) -> Predictor: 303 | """Create predictor.""" 304 | pre_str = 'index\tprediction' 305 | label_names = (None if config.label_names is None else 306 | config.label_names.split(',')) 307 | 308 | def proc_fn(infer): 309 | ret = [] 310 | for x, y in zip(infer['idx'], infer['prediction']): 311 | z = y if label_names is None else label_names[y] 312 | ret.append(f'{x}\t{z}') 313 | return ret 314 | 315 | return Predictor(proc_fn, config.output_path, pre_str=pre_str) 316 | -------------------------------------------------------------------------------- /jestimator/models/bert_rpe/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Pretrain. 16 | 17 | # For debug run locally: 18 | 19 | ``` 20 | PYTHONPATH=. python3 \ 21 | jestimator/estimator.py \ 22 | --module_imp="jestimator.models.bert_rpe.pretrain" \ 23 | --module_config="jestimator/models/bert_rpe/pretrain.py" \ 24 | --module_config.model_config.vocab_size=32000 \ 25 | --module_config.mask_token_id=4 \ 26 | --train_pattern="gs://gresearch/checkpoints_in_amos_paper/data/\ 27 | books-00000-of-00500" \ 28 | --valid_pattern="gs://gresearch/checkpoints_in_amos_paper/data/ptb" \ 29 | --model_dir="$HOME/models/bert_rpe_pretrain" \ 30 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \ 31 | --check_every_steps=10 --logtostderr 32 | ``` 33 | """ 34 | 35 | import dataclasses 36 | from typing import Mapping 37 | 38 | import jax 39 | import jax.numpy as jnp 40 | from jestimator import amos 41 | from jestimator.data.pipeline_lm import lm_data 42 | from jestimator.models.bert_rpe import modeling 43 | from jestimator.states import TrainState, MeanMetrics # pylint: disable=g-multiple-import 44 | import ml_collections 45 | from ml_collections.config_dict import config_dict 46 | import optax 47 | import tensorflow as tf 48 | 49 | 50 | def get_config(): 51 | """Returns a config object for modeling flags.""" 52 | module_config = ml_collections.ConfigDict() 53 | 54 | # Model config. 55 | model_config = modeling.ModelConfig() 56 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config)) 57 | module_config.model_config = model_config 58 | 59 | # Optimizer config. 60 | opt_config = ml_collections.ConfigDict() 61 | opt_config.optimizer = 'adamw' 62 | opt_config.learning_rate = 1e-4 63 | opt_config.warmup_steps = 10000 64 | opt_config.linear_decay_to_step = config_dict.placeholder(int) 65 | opt_config.momentum = 0.9 66 | opt_config.beta = 0.999 67 | opt_config.weight_decay = 0.01 68 | module_config.opt_config = opt_config 69 | 70 | # Other config. 71 | module_config.mask_token_id = config_dict.placeholder(int) 72 | module_config.mask_rate = 0.15 73 | return module_config 74 | 75 | 76 | def load_config(global_flags): 77 | """Init config data from global flags.""" 78 | config = ml_collections.ConfigDict() 79 | config.update(global_flags.module_config) 80 | 81 | # Only a frozen config (hashable object) can be passed to jit functions 82 | # (i.e. train_step/valid_step/infer_step). 83 | config.frozen = ml_collections.FrozenConfigDict(config) 84 | 85 | # Construct data pipelines in the following (using TensorFLow): 86 | seq_length = config.model_config.max_length 87 | 88 | def feature_fn(token_ids: tf.Tensor) -> Mapping[str, tf.Tensor]: 89 | """Builds a feature dict to be compatible with seqio.""" 90 | return {'targets': tf.ensure_shape(token_ids, (seq_length,))} 91 | 92 | config.train_data_fn = lm_data( 93 | seq_length, random_skip=True, feature_fn=feature_fn, interleave=True) 94 | config.valid_data_fn = lm_data(seq_length, feature_fn=feature_fn) 95 | return config 96 | 97 | 98 | def get_train_state(config, rng) -> TrainState: 99 | """Create train state.""" 100 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 101 | model = modeling.ModelForPretrain(model_config) 102 | opt_config = config.opt_config 103 | warmup = opt_config.warmup_steps 104 | decay = opt_config.linear_decay_to_step 105 | 106 | def lr_schedule(step): 107 | lr = opt_config.learning_rate 108 | if warmup is not None: 109 | lr *= jnp.minimum(1., step / warmup) 110 | if decay is not None: 111 | lr *= 1. - jnp.maximum(0., step - warmup) / (decay - warmup) 112 | elif decay is not None: 113 | lr *= 1. - step / decay 114 | return lr 115 | 116 | if opt_config.optimizer == 'adamw': 117 | optimizer = optax.adamw( 118 | learning_rate=lr_schedule, 119 | b1=opt_config.momentum, 120 | b2=opt_config.beta, 121 | weight_decay=opt_config.weight_decay) 122 | elif opt_config.optimizer == 'amos': 123 | optimizer = amos.amos( 124 | lr_schedule, 125 | modeling.get_eta_fn(model_config), 126 | shape_fn=modeling.get_shape_fn(model_config), 127 | beta=opt_config.beta, 128 | momentum=opt_config.momentum, 129 | clip_value=1.) 130 | 131 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss', 'valid_mrr') 132 | dummy_input = jnp.array([[0] * config.model_config.max_length]) 133 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy_input) 134 | 135 | 136 | def train_step(config, train_batch, state: TrainState, metrics): 137 | """Training step.""" 138 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)( 139 | state.params, 140 | train_batch['targets'], 141 | config.mask_token_id, 142 | mask_rate=config.mask_rate, 143 | input_mask=train_batch.get('input_mask'), 144 | enable_dropout=True, 145 | method=modeling.ModelForPretrain.mlm_train_loss) 146 | _, metrics = state.metrics_mod.apply( 147 | metrics, 148 | 'train_loss', 149 | loss, 150 | size, 151 | method=MeanMetrics.update, 152 | mutable=['metrics']) 153 | return state.apply_gradients(grads=grads), metrics 154 | 155 | 156 | def valid_step(config, valid_batch, state: TrainState, metrics): 157 | """Validation step.""" 158 | 159 | def body(i, metrics): 160 | del i # Unused. 161 | loss, mrr, size = state.apply_fn( 162 | state.variables(), 163 | valid_batch['targets'], 164 | config.mask_token_id, 165 | mask_rate=config.mask_rate, 166 | input_mask=valid_batch.get('input_mask'), 167 | method=modeling.ModelForPretrain.mlm_valid_metrics) 168 | _, metrics = state.metrics_mod.apply( 169 | metrics, 170 | 'valid_loss', 171 | loss, 172 | size, 173 | method=MeanMetrics.update, 174 | mutable=['metrics']) 175 | _, metrics = state.metrics_mod.apply( 176 | metrics, 177 | 'valid_mrr', 178 | mrr, 179 | size, 180 | method=MeanMetrics.update, 181 | mutable=['metrics']) 182 | return metrics 183 | 184 | return jax.lax.fori_loop(0, 20, body, metrics) 185 | -------------------------------------------------------------------------------- /jestimator/models/linear_regression/linear_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""A toy linear regression model. 16 | 17 | Using jestimator as the entry point. 18 | 19 | # For debug run locally: 20 | 21 | ## Train: 22 | 23 | ``` 24 | PYTHONPATH=. python3 \ 25 | jestimator/estimator.py \ 26 | --module_imp="jestimator.models.linear_regression.linear_regression" \ 27 | --module_config="jestimator/models/linear_regression/\ 28 | linear_regression.py" \ 29 | --train_pattern="dummy://" \ 30 | --valid_pattern="dummy://" \ 31 | --model_dir="$HOME/experiments/linear_regression/models" \ 32 | --train_batch_size=4 --valid_batch_size=4 \ 33 | --max_train_steps=200 --train_shuffle_buf=32 \ 34 | --check_every_steps=10 --logtostderr 35 | ``` 36 | 37 | ## Eval: 38 | 39 | ``` 40 | PYTHONPATH=. python3 \ 41 | jestimator/estimator.py \ 42 | --module_imp="jestimator.models.linear_regression.linear_regression" \ 43 | --module_config="jestimator/models/linear_regression/\ 44 | linear_regression.py" \ 45 | --eval_pattern="dummy://" \ 46 | --model_dir="$HOME/experiments/linear_regression/models" \ 47 | --eval_batch_size=4 \ 48 | --logtostderr 49 | ``` 50 | 51 | ## Predict: 52 | 53 | ``` 54 | PYTHONPATH=. python3 \ 55 | jestimator/estimator.py \ 56 | --module_imp="jestimator.models.linear_regression.linear_regression" \ 57 | --module_config="jestimator/models/linear_regression/\ 58 | linear_regression.py" \ 59 | --pred_pattern="dummy://" \ 60 | --model_dir="$HOME/experiments/linear_regression/models" \ 61 | --pred_batch_size=4 \ 62 | --logtostderr 63 | ``` 64 | """ 65 | from typing import Tuple 66 | 67 | from flax import linen as nn 68 | import jax 69 | import jax.numpy as jnp 70 | from jax.typing import ArrayLike 71 | from jestimator.states import Evaluator, InferState, MeanMetrics, Predictor, TrainState # pylint: disable=g-multiple-import 72 | import ml_collections 73 | import optax 74 | from sklearn.metrics import mean_squared_error 75 | import tensorflow as tf 76 | 77 | from flaxformer.components.dense import DenseGeneral 78 | 79 | 80 | def get_config(): 81 | """Returns a config object for modeling flags.""" 82 | module_config = ml_collections.ConfigDict() 83 | module_config.num_train = 20 84 | module_config.num_eval = 20 85 | module_config.x_dim = 10 86 | module_config.y_dim = 5 87 | return module_config 88 | 89 | 90 | def load_config(global_flags): 91 | """Init config data from global flags.""" 92 | config = ml_collections.ConfigDict() 93 | config.update(global_flags.module_config) 94 | 95 | # Only a frozen config (hashable object) can be passed to jit functions 96 | # (i.e. train_step/valid_step/infer_step). 97 | config.frozen = ml_collections.FrozenConfigDict(config) 98 | 99 | # Construct data pipelines in the following (using TensorFLow): 100 | def get_data_fn(ds): 101 | 102 | def data_fn(filenames, shard_num=1, shard_index=0, epochs=1, num_take=-1): 103 | del filenames # Unused. 104 | return ds.repeat(epochs).shard(shard_num, shard_index).take(num_take) 105 | 106 | return data_fn 107 | 108 | # Generate random ground truth W and b. 109 | W = tf.random.normal((config.x_dim, config.y_dim), seed=11) # pylint: disable=invalid-name 110 | b = tf.random.normal((config.y_dim,), seed=12) 111 | 112 | # Generate samples with additional noise. 113 | x_train = tf.random.normal((config.num_train, config.x_dim), seed=13) 114 | y_train = tf.matmul(x_train, W) + b + 0.1 * tf.random.normal( 115 | (config.num_train, config.y_dim), seed=14) 116 | ds_train = tf.data.Dataset.from_tensor_slices({'x': x_train, 'y': y_train}) 117 | 118 | x_eval = tf.random.normal((config.num_eval, config.x_dim), seed=15) 119 | y_eval = tf.matmul(x_eval, W) + b + 0.1 * tf.random.normal( 120 | (config.num_eval, config.y_dim), seed=16) 121 | ds_valid = tf.data.Dataset.from_tensor_slices({'x': x_eval, 'y': y_eval}) 122 | ds_eval = tf.data.Dataset.from_tensor_slices((y_eval, x_eval)) 123 | ds_pred = tf.data.Dataset.from_tensor_slices(x_eval) 124 | 125 | config.train_data_fn = get_data_fn(ds_train) 126 | config.valid_data_fn = get_data_fn(ds_valid) 127 | config.eval_data_fn = get_data_fn(ds_eval) 128 | config.pred_data_fn = get_data_fn(ds_pred) 129 | return config 130 | 131 | 132 | class LinearRegression(nn.Module): 133 | """A simple linear regression module.""" 134 | y_dim: int 135 | 136 | @nn.compact 137 | def __call__(self, x: ArrayLike) -> ArrayLike: 138 | """Applies linear on the input.""" 139 | linear = DenseGeneral( 140 | features=self.y_dim, 141 | use_bias=True, 142 | kernel_init=nn.zeros, 143 | kernel_axis_names=('x', 'y')) 144 | return linear(x) 145 | 146 | def mse(self, x: ArrayLike, y: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: 147 | """Mean squared error.""" 148 | loss = jnp.mean(jnp.square(self(x) - y), axis=-1) 149 | size = jnp.asarray(loss.size, loss.dtype) 150 | num_hosts = jnp.asarray(jax.host_count(), loss.dtype) 151 | loss = jnp.sum(loss) * jax.lax.rsqrt(size * num_hosts) 152 | size = jnp.sqrt(size / num_hosts) 153 | return loss, size 154 | 155 | 156 | def get_train_state(config, rng): 157 | """Create train state.""" 158 | model = LinearRegression(y_dim=config.y_dim) 159 | 160 | def lr_schedule(step): 161 | return 0.5 / (1. + 0.1 * step) 162 | 163 | optimizer = optax.sgd(learning_rate=lr_schedule) 164 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss') 165 | dummy_x = jnp.zeros((config.x_dim,), jnp.float32) 166 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy_x) 167 | 168 | 169 | def train_step(config, train_batch, state: TrainState, metrics): 170 | """Training step.""" 171 | del config # Unused. 172 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)( 173 | state.params, 174 | train_batch['x'], 175 | train_batch['y'], 176 | method=LinearRegression.mse) 177 | _, metrics = state.metrics_mod.apply( 178 | metrics, 179 | 'train_loss', 180 | loss, 181 | size, 182 | method=MeanMetrics.update, 183 | mutable=['metrics']) 184 | return state.apply_gradients(grads=grads), metrics 185 | 186 | 187 | def valid_step(config, valid_batch, state: TrainState, metrics): 188 | """Validation step.""" 189 | del config # Unused. 190 | loss, size = state.apply_fn( 191 | state.variables(), 192 | valid_batch['x'], 193 | valid_batch['y'], 194 | method=LinearRegression.mse) 195 | _, metrics = state.metrics_mod.apply( 196 | metrics, 197 | 'valid_loss', 198 | loss, 199 | size, 200 | method=MeanMetrics.update, 201 | mutable=['metrics']) 202 | return metrics 203 | 204 | 205 | def get_infer_state(config): 206 | """Create infer state.""" 207 | model = LinearRegression(y_dim=config.y_dim) 208 | dummy_x = jnp.zeros((config.x_dim,), jnp.float32) 209 | return InferState.create(model, dummy_x) 210 | 211 | 212 | def infer_step(config, batch, state: InferState): 213 | """Infer step.""" 214 | del config # Unused. 215 | return state.replace(ret=state.apply_fn(state.variables(), batch)) 216 | 217 | 218 | def get_evaluator(config) -> Evaluator: 219 | """Create evaluator.""" 220 | del config # Unused. 221 | return Evaluator({'mse': (lambda y: y, mean_squared_error)}) 222 | 223 | 224 | def get_predictor(config) -> Predictor: 225 | """Create predictor.""" 226 | del config # Unused. 227 | 228 | def proc_fn(y_batched): 229 | return [str(y) for y in y_batched] 230 | 231 | return Predictor(proc_fn) 232 | -------------------------------------------------------------------------------- /jestimator/models/linear_regression/linear_regression_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for linear_regression.""" 16 | import os 17 | 18 | from absl import flags 19 | from absl.testing import absltest 20 | import jax 21 | from jestimator import checkpoint_utils 22 | from jestimator import estimator 23 | from jestimator.models.linear_regression import linear_regression 24 | import tensorflow as tf 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | 29 | class LinearRegressionTest(absltest.TestCase): 30 | 31 | def setUp(self): 32 | super(LinearRegressionTest, self).setUp() 33 | tmp_model_dir = self.create_tempdir('tmp_model_dir') 34 | model_dir = os.fspath(tmp_model_dir) 35 | FLAGS.model_dir = model_dir 36 | 37 | FLAGS.module_config = linear_regression.get_config() 38 | self.config = linear_regression.load_config(FLAGS) 39 | self.partitioner = estimator.get_partitioner(self.config) 40 | 41 | def test_train_eval(self): 42 | FLAGS.train_pattern = 'dummy://' 43 | FLAGS.valid_pattern = 'dummy://' 44 | FLAGS.train_batch_size = 4 45 | FLAGS.valid_batch_size = 4 46 | FLAGS.train_shuffle_buf = 32 47 | FLAGS.check_every_steps = 10 48 | 49 | FLAGS.max_train_steps = 100 50 | seed = 100 51 | tf.random.set_seed(seed) 52 | rng = jax.random.PRNGKey(seed) 53 | estimator.train(None, False, rng, linear_regression, self.config, 54 | self.partitioner) 55 | middle_ckpt_path = os.path.join(FLAGS.model_dir, 'checkpoint_100') 56 | self.assertTrue(os.path.exists(middle_ckpt_path)) 57 | ckpt_path, same_dir = checkpoint_utils.latest_ckpt_path(FLAGS.model_dir) 58 | self.assertTrue(os.path.samefile(middle_ckpt_path, ckpt_path)) 59 | self.assertTrue(same_dir) 60 | 61 | FLAGS.max_train_steps = 200 62 | final_metrics = estimator.train(ckpt_path, same_dir, rng, linear_regression, 63 | self.config, self.partitioner) 64 | self.assertLess(final_metrics['train_loss'], 0.01) 65 | valid_loss = final_metrics['valid_loss'] 66 | self.assertLess(valid_loss, 0.05) 67 | final_ckpt_path = os.path.join(FLAGS.model_dir, 'checkpoint_200') 68 | self.assertTrue(os.path.exists(final_ckpt_path)) 69 | ckpt_path, same_dir = checkpoint_utils.latest_ckpt_path(FLAGS.model_dir) 70 | self.assertTrue(os.path.samefile(final_ckpt_path, ckpt_path)) 71 | self.assertTrue(same_dir) 72 | 73 | FLAGS.eval_pattern = 'dummy://' 74 | FLAGS.eval_batch_size = 4 75 | mode = estimator.RunMode.EVAL_ONCE 76 | eval_metrics = estimator.eval_or_predict(ckpt_path, mode, linear_regression, 77 | self.config, self.partitioner) 78 | self.assertAlmostEqual(eval_metrics['mse'], valid_loss) 79 | 80 | FLAGS.save_low = ['mse'] 81 | mode = estimator.RunMode.EVAL_WAIT 82 | estimator.eval_or_predict(None, mode, linear_regression, self.config, 83 | self.partitioner) 84 | ckpt_dir = os.path.join(FLAGS.model_dir, 'eval', 'low_mse') 85 | ckpt_path, _ = checkpoint_utils.latest_ckpt_path(ckpt_dir) 86 | self.assertTrue(ckpt_path) 87 | mode = estimator.RunMode.EVAL_ONCE 88 | eval_metrics = estimator.eval_or_predict(ckpt_path, mode, linear_regression, 89 | self.config, self.partitioner) 90 | self.assertAlmostEqual(eval_metrics['mse'], valid_loss) 91 | 92 | def test_predict(self): 93 | FLAGS.pred_pattern = 'dummy://' 94 | FLAGS.pred_batch_size = 4 95 | mode = estimator.RunMode.PREDICT 96 | estimator.eval_or_predict(None, mode, linear_regression, self.config, 97 | self.partitioner) 98 | 99 | 100 | if __name__ == '__main__': 101 | absltest.main() 102 | -------------------------------------------------------------------------------- /jestimator/models/lstm/lm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Language modeling on PTB-like datasets. 16 | 17 | Using jestimator as the entry point. 18 | 19 | # For debug run locally: 20 | 21 | ## Train: 22 | 23 | ``` 24 | PYTHONPATH=. python3 jestimator/estimator.py \ 25 | --module_imp="jestimator.models.lstm.lm" \ 26 | --module_config="jestimator/models/lstm/lm.py" \ 27 | --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \ 28 | --train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"\ 29 | --model_dir="$HOME/experiments/ptb_lstm/models" \ 30 | --train_batch_size=64 --train_consecutive=113 \ 31 | --check_every_steps=10 --logtostderr 32 | ``` 33 | 34 | ## Eval: 35 | 36 | ``` 37 | PYTHONPATH=. python3 jestimator/estimator.py \ 38 | --module_imp="jestimator.models.lstm.lm" \ 39 | --module_config="jestimator/models/lstm/lm.py" \ 40 | --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \ 41 | --eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt" \ 42 | --model_dir="$HOME/experiments/ptb_lstm/models" \ 43 | --eval_batch_size=1 --logtostderr 44 | ``` 45 | """ 46 | import dataclasses 47 | import math 48 | 49 | import jax 50 | import jax.numpy as jnp 51 | from jestimator import amos 52 | from jestimator.data.pipeline_lm import lm_data 53 | from jestimator.data.reader import lines_iterator 54 | from jestimator.models.lstm import modeling 55 | from jestimator.states import TrainState, MeanMetrics, InferState # pylint: disable=g-multiple-import 56 | import ml_collections 57 | from ml_collections.config_dict import config_dict 58 | import optax 59 | import tensorflow as tf 60 | 61 | 62 | def get_config(): 63 | """Returns a config object for modeling flags.""" 64 | module_config = ml_collections.ConfigDict() 65 | 66 | # Model config. 67 | model_config = modeling.ModelConfig() 68 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config)) 69 | module_config.model_config = model_config 70 | 71 | # Optimizer config. 72 | opt_config = ml_collections.ConfigDict() 73 | opt_config.optimizer = 'adam' 74 | opt_config.learning_rate = 5e-4 75 | opt_config.momentum = 0.9 76 | opt_config.beta = 0.99 77 | opt_config.weight_decay = 0.01 78 | module_config.opt_config = opt_config 79 | 80 | # Other config. 81 | module_config.seq_length = 64 82 | module_config.vocab_path = config_dict.placeholder(str) 83 | return module_config 84 | 85 | 86 | def load_config(global_flags): 87 | """Init config data from global flags.""" 88 | config = ml_collections.ConfigDict() 89 | config.update(global_flags.module_config) 90 | 91 | mode = global_flags.mode 92 | if mode == 'train': 93 | batch_size = global_flags.train_batch_size 94 | train_consecutive = global_flags.train_consecutive 95 | assert train_consecutive is not None, ( 96 | 'Should set --train_consecutive for training LSTM language models.') 97 | config.train_consecutive = train_consecutive 98 | elif mode.startswith('eval'): 99 | batch_size = global_flags.eval_batch_size 100 | assert batch_size == 1, 'Should set --eval_batch_size to 1 for LSTM.' 101 | assert jax.process_count() == 1, 'Should evaluate on single process.' 102 | else: 103 | batch_size = global_flags.pred_batch_size 104 | config.mode = mode 105 | config.batch_size = batch_size 106 | 107 | # Read vocab file. 108 | count = 0 109 | word_dict = {} 110 | for w, _ in lines_iterator(config.vocab_path, split=True): 111 | word_dict[w] = count 112 | count += 1 113 | 114 | config.model_config.vocab_size = count 115 | config.model_config.start_token_id = word_dict[''] 116 | eos_token_id = word_dict[''] 117 | 118 | # Only a frozen config (hashable object) can be passed to jit functions 119 | # (i.e. train_step/valid_step/infer_step). 120 | config.frozen = ml_collections.FrozenConfigDict(config) 121 | 122 | # Construct data pipelines in the following (using TensorFLow): 123 | def corpus_fn(path: str) -> tf.data.Dataset: 124 | 125 | def gen(): 126 | for tokens in lines_iterator(path, split=True): 127 | ids = [word_dict[w] for w in tokens] + [eos_token_id] 128 | for x in ids: 129 | yield x 130 | 131 | return tf.data.Dataset.from_generator( 132 | gen, output_signature=tf.TensorSpec(shape=(), dtype=tf.int32)) 133 | 134 | seq_length = config.seq_length 135 | 136 | def eval_feature_fn(x): 137 | length = tf.shape(x)[0] # The last sequence might be shorter. 138 | y = tf.pad(x, [(0, seq_length - length)]) 139 | 140 | # Eval dataset requires the gold label to be returned as first arg. 141 | # For language modeling, gold is not used. We return length as gold. 142 | return length, {'y': y, 'length': length} 143 | 144 | config.train_data_fn = lm_data( 145 | seq_length, dataset_fn=corpus_fn, cache=True, random_skip=True) 146 | config.eval_data_fn = lm_data( 147 | seq_length, 148 | allow_remainder=True, 149 | dataset_fn=corpus_fn, 150 | cache=True, 151 | feature_fn=eval_feature_fn) 152 | return config 153 | 154 | 155 | def get_train_state(config, rng) -> TrainState: 156 | """Create train state.""" 157 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 158 | model = modeling.SingleLstmLM(model_config, config.batch_size) 159 | 160 | opt_config = config.opt_config 161 | if opt_config.optimizer == 'adam': 162 | optimizer = optax.adam( 163 | learning_rate=opt_config.learning_rate, 164 | b1=opt_config.momentum, 165 | b2=opt_config.beta) 166 | elif opt_config.optimizer == 'adamw': 167 | optimizer = optax.adamw( 168 | learning_rate=opt_config.learning_rate, 169 | b1=opt_config.momentum, 170 | b2=opt_config.beta, 171 | weight_decay=opt_config.weight_decay) 172 | elif opt_config.optimizer == 'amos': 173 | optimizer = amos.amos( 174 | opt_config.learning_rate, 175 | modeling.get_eta_fn(model_config), 176 | shape_fn=modeling.get_shape_fn(model_config), 177 | beta=opt_config.beta, 178 | momentum=opt_config.momentum, 179 | clip_value=1.) 180 | 181 | metrics_mod = MeanMetrics.create('train_loss') 182 | dummy = jnp.zeros((config.batch_size, config.seq_length), jnp.int32) 183 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy, False) 184 | 185 | 186 | def train_step(config, train_batch, state: TrainState, metrics): 187 | """Training step.""" 188 | (loss, (size, vars_)), grads = state.value_and_grad_apply_fn(has_aux=True)( 189 | state.params, 190 | train_batch, 191 | state.step % config.train_consecutive != 0, 192 | enable_dropout=True) 193 | _, metrics = state.metrics_mod.apply( 194 | metrics, 195 | 'train_loss', 196 | loss, 197 | size, 198 | method=MeanMetrics.update, 199 | mutable=['metrics']) 200 | state = state.apply_gradients(grads=grads) 201 | state = state.replace(_vars=vars_) 202 | return state, metrics 203 | 204 | 205 | def get_infer_state(config): 206 | """Create infer state.""" 207 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 208 | model = modeling.SingleLstmLM(model_config, config.batch_size) 209 | dummy = jnp.zeros((config.batch_size, config.seq_length), jnp.int32) 210 | return InferState.create(model, dummy, True, mode=config.mode) 211 | 212 | 213 | def infer_step(config, batch, state: InferState) -> InferState: 214 | """Infer step.""" 215 | if config.mode.startswith('eval'): 216 | (loss, mrr, size), vars_ = state.apply_fn( 217 | state.variables(), 218 | batch['y'], 219 | True, 220 | mode=config.mode, 221 | length=batch['length'], 222 | mutable=state.mutable()) 223 | 224 | return state.replace( 225 | _vars=vars_, 226 | ret={ 227 | 'loss': jnp.expand_dims(loss, 0), 228 | 'mrr': jnp.expand_dims(mrr, 0), 229 | 'size': jnp.expand_dims(size, 0), 230 | }) 231 | 232 | raise NotImplementedError(f'Infer-step for {config.mode} not implemented.') 233 | 234 | 235 | class Evaluator(object): 236 | """Evaluator class for language modeling.""" 237 | 238 | def reset_states(self): 239 | self._total_loss = 0. 240 | self._total_mrr = 0. 241 | self._total_size = 0. 242 | 243 | def update_state(self, gold, infer): 244 | del gold # Unused. 245 | self._total_loss += infer['loss'].sum() 246 | self._total_mrr += infer['mrr'].sum() 247 | self._total_size += infer['size'].sum() 248 | 249 | def result(self): 250 | cost = self._total_loss / self._total_size 251 | return { 252 | 'cost': cost, 253 | 'perplexity': math.exp(cost), 254 | 'mrr': self._total_mrr / self._total_size, 255 | } 256 | 257 | 258 | def get_evaluator(config) -> Evaluator: 259 | del config # Unused. 260 | return Evaluator() 261 | -------------------------------------------------------------------------------- /jestimator/models/lstm/modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """LSTM model implemented in Flax.""" 16 | import dataclasses 17 | import math 18 | from typing import Optional, Tuple 19 | 20 | from flax import linen as nn 21 | from flax.linen.partitioning import variable_with_axes 22 | import jax 23 | import jax.numpy as jnp 24 | from jax.typing import ArrayLike 25 | from jestimator.modeling import global_kwargs, sparse_xe_with_logits, normalize_loss_by_size, unstack, truncated_normal_initializer, Dropout # pylint: disable=g-multiple-import 26 | 27 | DType = jnp.dtype 28 | Shape = Tuple[int, ...] 29 | 30 | 31 | @dataclasses.dataclass 32 | class ModelConfig: 33 | """Config object.""" 34 | hidden_size: int = 256 35 | memory_size: int = 1024 36 | forget_gate_bias: float = 1.0 37 | 38 | hidden_dropout_rate: float = 0.55 39 | memory_dropout_rate: float = 0.1 40 | 41 | vocab_size: int = -1 42 | start_token_id: int = -1 43 | 44 | 45 | class ShiftNorm(nn.Module): 46 | """Shifted normalization.""" 47 | 48 | @nn.compact 49 | def __call__(self, x: ArrayLike) -> ArrayLike: 50 | x = jnp.asarray(x) 51 | shift = self.param('shift', nn.zeros, x.shape[-1], x.dtype) 52 | x = x - shift 53 | x = x - jnp.mean(x, axis=-1, keepdims=True) 54 | mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True) 55 | 56 | # Instead of normalize to 1, we normalize to d**(-0.25). 57 | x = x * jax.lax.rsqrt(jnp.maximum(mean2 * math.sqrt(x.shape[-1]), 1.)) 58 | return x 59 | 60 | 61 | class LstmCell(nn.Module): 62 | """LSTM cell with some working modifications.""" 63 | config: ModelConfig 64 | 65 | def setup(self): 66 | config = self.config 67 | inputs_size = 2 * config.hidden_size 68 | core_init = truncated_normal_initializer(math.sqrt(1 / inputs_size)) 69 | self.core = nn.DenseGeneral( 70 | features=(4, config.memory_size), use_bias=True, kernel_init=core_init) 71 | self.normalize = ShiftNorm() 72 | self.out = nn.Dense( 73 | features=config.hidden_size, use_bias=True, kernel_init=nn.zeros) 74 | 75 | def __call__(self, 76 | inputs: ArrayLike, 77 | memory: ArrayLike, 78 | memory_mask: Optional[ArrayLike] = None): 79 | """Call LSTM cell with some working modifications. 80 | 81 | Args: 82 | inputs: Inputs for the current step. 83 | memory: Long-term memory. 84 | memory_mask: Optional memory mask. 85 | 86 | Returns: 87 | (out, next_memory). 88 | """ 89 | xa, xb, xg, xo = unstack(self.core(inputs), -2) 90 | fb = self.config.forget_gate_bias 91 | gate = nn.sigmoid(fb - xg) 92 | xb = jnp.clip(1 / (1 + math.exp(fb)) * jnp.tanh(xb), -gate, gate) 93 | next_memory = memory * (1. - gate) + xb * 2. * nn.silu(xa) 94 | memory_out = self.normalize(next_memory) * 2. * nn.sigmoid(xo) 95 | 96 | if memory_mask is not None: 97 | memory_out *= memory_mask 98 | out = self.out(memory_out) 99 | return out, next_memory 100 | 101 | 102 | class LstmLayer(nn.Module): 103 | """LSTM layer which encodes a sequence.""" 104 | config: ModelConfig 105 | 106 | def setup(self): 107 | config = self.config 108 | self.hidden_dropout = Dropout(config.hidden_dropout_rate) 109 | self.memory_dropout = Dropout(config.memory_dropout_rate) 110 | self.normalize = ShiftNorm() 111 | self.cell = LstmCell(config) 112 | 113 | @global_kwargs('enable_dropout') 114 | def __call__(self, 115 | xs: ArrayLike, 116 | seq_axis: int = -2, 117 | init_carry: Optional[Tuple[ArrayLike, ArrayLike]] = None, 118 | enable_dropout: bool = False): 119 | """Encode a sequence with LSTM. 120 | 121 | Args: 122 | xs: Input sequence tensor. 123 | seq_axis: int. The sequence axis. 124 | init_carry: Initial state. 125 | enable_dropout: Whether to enable dropout. 126 | 127 | Returns: 128 | Encoded sequence of the same shape as `xs`. 129 | """ 130 | xs = jnp.asarray(xs) 131 | if init_carry is None: 132 | batch_shape = xs.shape[:seq_axis] + xs.shape[seq_axis + 1:-1] 133 | init_carry = self.zero_carry(batch_shape, xs.dtype) 134 | 135 | memory_mask = None 136 | if enable_dropout: 137 | memory_mask = self.memory_dropout(jnp.ones_like(init_carry[1])) 138 | 139 | def body_fn(self, carry: Tuple[ArrayLike, ArrayLike], x: ArrayLike): 140 | recur, memory = carry 141 | inputs = jnp.concatenate((recur, x), -1) 142 | inputs = self.hidden_dropout(self.normalize(inputs)) 143 | out, next_memory = self.cell(inputs, memory, memory_mask=memory_mask) 144 | return (out, next_memory), out 145 | 146 | last_carry, outs = nn.scan( 147 | body_fn, 148 | in_axes=seq_axis, 149 | out_axes=seq_axis, 150 | variable_broadcast='params', 151 | split_rngs={'params': False})(self, init_carry, xs) 152 | outs = self.hidden_dropout(outs) 153 | return outs, last_carry 154 | 155 | def zero_carry(self, batch_shape: Shape, dtype: DType = jnp.float32): 156 | """Creates a zero state.""" 157 | recur = jnp.zeros(batch_shape + (self.config.hidden_size,), dtype) 158 | memory = jnp.zeros(batch_shape + (self.config.memory_size,), dtype) 159 | return (recur, memory) 160 | 161 | 162 | class SingleLstmLM(nn.Module): 163 | """Single layer LSTM language model.""" 164 | config: ModelConfig 165 | batch_size: int 166 | 167 | def setup(self): 168 | config = self.config 169 | embed_init = truncated_normal_initializer(math.sqrt(1 / config.hidden_size)) 170 | self.embed = nn.Embed( 171 | config.vocab_size, config.hidden_size, embedding_init=embed_init 172 | ) 173 | self.lstm = LstmLayer(config) 174 | self.bias = self.param('bias', nn.zeros, config.vocab_size) 175 | 176 | # Variables to keep context from previous batch. 177 | self.ctx_prev = variable_with_axes( 178 | 'context', 179 | 'prev', 180 | jnp.full, 181 | (self.batch_size,), 182 | config.start_token_id, 183 | axes=('data',), 184 | ) 185 | self.ctx_recur = variable_with_axes( 186 | 'context', 187 | 'recur', 188 | jnp.zeros, 189 | (self.batch_size, config.hidden_size), 190 | axes=('data', 'model'), 191 | ) 192 | self.ctx_memory = variable_with_axes( 193 | 'context', 194 | 'memory', 195 | jnp.zeros, 196 | (self.batch_size, config.memory_size), 197 | axes=('data', 'model'), 198 | ) 199 | 200 | @global_kwargs(pass_down=True) 201 | def __call__(self, 202 | y: ArrayLike, 203 | carry_mask: ArrayLike, 204 | mode: str = 'train', 205 | length: Optional[ArrayLike] = None): 206 | """Generation logits/loss for batch-major sequence `y`.""" 207 | y = jnp.asarray(y) 208 | _, seq_length = y.shape 209 | ty = jnp.transpose(y) # `ty` is time-major. 210 | 211 | if mode == 'predict': 212 | x = ty 213 | else: # `y` is label. Shift one position to create input ids. 214 | s = self.config.start_token_id 215 | prev_ids = jnp.expand_dims( 216 | jnp.where(carry_mask, self.ctx_prev.value, s), 0) 217 | x = jnp.concatenate((prev_ids, ty[:-1]), 0) 218 | self.ctx_prev.value = y[:, -1] 219 | 220 | x = self.embed(x) 221 | carry_mask = jnp.asarray(carry_mask, x.dtype) 222 | carry = (self.ctx_recur.value * carry_mask, 223 | self.ctx_memory.value * carry_mask) 224 | x, (last_recur, last_memory) = self.lstm(x, seq_axis=0, init_carry=carry) 225 | self.ctx_recur.value = last_recur 226 | self.ctx_memory.value = last_memory 227 | 228 | if mode == 'predict': 229 | if length is None: 230 | x = x[-1] 231 | else: 232 | x = jnp.swapaxes(x, 0, 1) 233 | x = jnp.take_along_axis(x, jnp.expand_dims(length - 1, 1), 1) 234 | 235 | logits = self.embed.attend(x) + self.bias 236 | if mode == 'predict': 237 | return logits 238 | 239 | if length is None: 240 | size = jnp.asarray(y.size, x.dtype) 241 | mask = None 242 | else: 243 | size = jnp.sum(jnp.asarray(length, x.dtype)) 244 | mask = (jnp.expand_dims(jnp.arange(seq_length), 1) < length) 245 | 246 | if mode == 'train': 247 | loss = sparse_xe_with_logits(ty, logits, mask=mask) 248 | return normalize_loss_by_size(loss, size) 249 | 250 | # Evaluation with loss and Mean Reciprocal Rank (MRR). 251 | logits = nn.log_softmax(logits) 252 | gold = sparse_xe_with_logits( 253 | ty, logits, mask=mask, normalized=True, reduce_all=False) 254 | loss = jnp.sum(gold) 255 | higher = (logits + jnp.expand_dims(gold, -1) >= 0) 256 | ranks = jnp.sum(jnp.asarray(higher, x.dtype), axis=-1) 257 | rcpl_ranks = jnp.reciprocal(ranks) 258 | if mask is not None: 259 | rcpl_ranks = jnp.where(mask, rcpl_ranks, 0.) 260 | mrr = jnp.sum(rcpl_ranks) 261 | return loss, mrr, size 262 | 263 | 264 | def get_eta_fn(config: ModelConfig): 265 | """Get the `eta_fn` function for Amos optimizer.""" 266 | hidden_size = config.hidden_size 267 | memory_size = config.memory_size 268 | 269 | def eta_fn(name: Tuple[str, ...], shape: Shape) -> ArrayLike: 270 | del shape # Unused. 271 | if name[-4:] == ('lstm', 'cell', 'core', 'kernel'): 272 | return math.pow(2 * hidden_size, -0.25) 273 | 274 | if name[-4:] == ('lstm', 'cell', 'normalize', 'shift'): 275 | return 0.5 276 | 277 | if name[-4:] == ('lstm', 'cell', 'out', 'kernel'): 278 | return math.pow(memory_size * hidden_size, -0.25) 279 | 280 | if name[-4:] == ('lstm', 'cell', 'out', 'bias'): 281 | return 0.5 * math.pow(hidden_size, -0.25) 282 | 283 | if name[-3:] == ('lstm', 'normalize', 'shift'): 284 | return 0.5 * math.pow(hidden_size, -0.25) 285 | 286 | if name[-2:] == ('embed', 'embedding'): 287 | return math.pow(hidden_size, -0.25) 288 | 289 | if name[-1] == 'bias': 290 | return 0.5 291 | 292 | raise ValueError(f'`eta_fn` for {name} not defined.') 293 | 294 | return eta_fn 295 | 296 | 297 | def get_shape_fn(config): 298 | """Get the `shape_fn` function for Amos optimizer.""" 299 | del config # Unused. 300 | 301 | def shape_fn(name: Tuple[str, ...], shape: Shape) -> Shape: 302 | if name[-1] == 'kernel': 303 | assert len(shape) >= 2 304 | return (1,) + shape[1:] 305 | 306 | if name[-1] == 'embedding': 307 | assert len(shape) == 2 308 | return (shape[0], 1) 309 | 310 | return () 311 | 312 | return shape_fn 313 | -------------------------------------------------------------------------------- /jestimator/models/mnist/mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b841868d-4cf1-4fcf-9e93-9ddee582b82e", 6 | "metadata": {}, 7 | "source": [ 8 | "### 1. Imports" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "d3218139-816a-412d-8a19-1bbfb219ad40", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import jax\n", 19 | "import jax.numpy as jnp # JAX NumPy\n", 20 | "from jestimator import amos # The Amos optimizer implementation\n", 21 | "from jestimator import amos_helper # Helper module for Amos\n", 22 | "\n", 23 | "from flax import linen as nn # The Linen API\n", 24 | "from flax.training import train_state # Useful dataclass to keep train state\n", 25 | "\n", 26 | "import math\n", 27 | "import tensorflow_datasets as tfds # TFDS for MNIST\n", 28 | "from sklearn.metrics import accuracy_score" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "ba3ea98c-080c-4ca8-bbc5-276c95d8196e", 34 | "metadata": {}, 35 | "source": [ 36 | "### 2. Load data" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "f11a52ce-e9e4-490a-94d6-d76758305729", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def get_datasets():\n", 47 | " \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n", 48 | "\n", 49 | " ds_builder = tfds.builder('mnist')\n", 50 | " ds_builder.download_and_prepare()\n", 51 | " train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n", 52 | " test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n", 53 | " train_ds['image'] = jnp.float32(train_ds['image']) / 255.\n", 54 | " test_ds['image'] = jnp.float32(test_ds['image']) / 255.\n", 55 | " return train_ds, test_ds" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "9996a480-b5dc-48b1-8fc1-d65c38aa9100", 61 | "metadata": {}, 62 | "source": [ 63 | "### 3. Build model" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "id": "0e81e1c5-30db-4ad3-a67d-ebab67a5c27f", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "class CNN(nn.Module):\n", 74 | " \"\"\"A simple CNN model.\"\"\"\n", 75 | "\n", 76 | " @nn.compact\n", 77 | " def __call__(self, x):\n", 78 | " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", 79 | " x = nn.relu(x)\n", 80 | " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", 81 | " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", 82 | " x = nn.relu(x)\n", 83 | " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", 84 | " x = x.reshape((x.shape[0], -1)) # flatten\n", 85 | " x = nn.Dense(features=256)(x)\n", 86 | " x = nn.relu(x)\n", 87 | " x = nn.Dense(features=10)(x)\n", 88 | " return x\n", 89 | "\n", 90 | " def classify_xe_loss(self, x, labels):\n", 91 | " # Labels read from the tfds MNIST are integers from 0 to 9. \n", 92 | " # Logits are arrays of size 10.\n", 93 | " logits = self(x)\n", 94 | " logits = jax.nn.log_softmax(logits)\n", 95 | " labels_ = jnp.expand_dims(labels, -1)\n", 96 | " llh_ = jnp.take_along_axis(logits, labels_, axis=-1)\n", 97 | " loss = -jnp.sum(llh_)\n", 98 | " return loss" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "71146ddd-6e10-4bb8-9e7f-4fc8f952a6e9", 104 | "metadata": {}, 105 | "source": [ 106 | "### 4. Create train state\n", 107 | "\n", 108 | "A `TrainState` object keeps the model parameters and optimizer states, and can be checkpointed into files.\n", 109 | "\n", 110 | "We create the model and optimizer in this function.\n", 111 | "\n", 112 | "**For the optimizer, we use Amos here.** The following hyper-parameters are set:\n", 113 | "\n", 114 | " * *learning_rate*:       The global learning rate.\n", 115 | " * *eta_fn*:              The model-specific 'eta'.\n", 116 | " * *shape_fn*:            Memory reduction setting.\n", 117 | " * *beta*:                Rate for running average of gradient squares.\n", 118 | " * *clip_value*:          Gradient clipping for stable training.\n", 119 | "\n", 120 | "The global learning rate is usually set to the 1/sqrt(N), where N is the number of batches in the training data. For MNIST, we have 60k training examples and batch size is 32. So learning_rate=1/sqrt(60000/32).\n", 121 | "\n", 122 | "The model-specific 'eta_fn' requires a function that, given a variable name and shape, returns a float indicating the expected scale of that variable. Hopefully in the near future we will have libraries that can automatically calculate this 'eta_fn' from the modeling code; but for now we have to specify it manually.\n", 123 | "\n", 124 | "One can use the amos_helper.params_fn_from_assign_map() helper function to create 'eta_fn' from an assign_map. An assign_map is a dict which maps regex rules to a value or simple Python expressions. It will find the first regex rule which matches the name of a variable, and evaluate the Python expression if necessary to return the value. See our example below.\n", 125 | "\n", 126 | "The 'shape_fn' similarly requires a function that, given a variable name and shape, returns a reduced shape for the corresponding slot variables. We can use the amos_helper.params_fn_from_assign_map() helper function to create 'shape_fn' from an assign_map as well.\n", 127 | "\n", 128 | "'beta' is the exponential decay rate for running average of gradient squares. We set it to 0.98 here.\n", 129 | "\n", 130 | "'clip_value' is the gradient clipping value, which should match the magnitude of the loss function. If the loss function is a sum of cross-entropy, then we should set 'clip_value' to the sqrt of the number of labels.\n", 131 | "\n", 132 | "Please refer to our [paper](https://arxiv.org/abs/2210.11693) for more details of the hyper-parameters." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "id": "eb049df8-70dc-447c-9a11-7166feb12d25", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "def get_train_state(rng):\n", 143 | " model = CNN()\n", 144 | " dummy_x = jnp.ones([1, 28, 28, 1])\n", 145 | " params = model.init(rng, dummy_x)\n", 146 | "\n", 147 | " eta_fn = amos_helper.params_fn_from_assign_map(\n", 148 | " {\n", 149 | " '.*/bias': 0.5,\n", 150 | " '.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',\n", 151 | " '.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',\n", 152 | " '.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',\n", 153 | " '.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',\n", 154 | " },\n", 155 | " eval_str_value=True,\n", 156 | " )\n", 157 | " shape_fn = amos_helper.params_fn_from_assign_map(\n", 158 | " {\n", 159 | " '.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',\n", 160 | " '.*Dense_0/kernel': '(1, SHAPE[1])',\n", 161 | " '.*': (),\n", 162 | " },\n", 163 | " eval_str_value=True,\n", 164 | " )\n", 165 | " optimizer = amos.amos(\n", 166 | " learning_rate=1/math.sqrt(60000/32),\n", 167 | " eta_fn=eta_fn,\n", 168 | " shape_fn=shape_fn,\n", 169 | " beta=0.98,\n", 170 | " clip_value=math.sqrt(32),\n", 171 | " )\n", 172 | " return train_state.TrainState.create(\n", 173 | " apply_fn=model.apply, params=params, tx=optimizer)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "2b00564e-504e-4275-b80a-8deed0fde177", 179 | "metadata": {}, 180 | "source": [ 181 | "### 5. Training step\n", 182 | "\n", 183 | "Use JAX’s @jit decorator to just-in-time compile the function for better performance." 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 5, 189 | "id": "ca15ee35-eb1d-4685-8e98-568c1cafc08c", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "@jax.jit\n", 194 | "def train_step(batch, state):\n", 195 | " grad_fn = jax.grad(state.apply_fn)\n", 196 | " grads = grad_fn(\n", 197 | " state.params,\n", 198 | " batch['image'],\n", 199 | " batch['label'],\n", 200 | " method=CNN.classify_xe_loss)\n", 201 | " return state.apply_gradients(grads=grads)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "id": "7f0a64e7-5507-49aa-8399-c3f20aa72a0e", 207 | "metadata": {}, 208 | "source": [ 209 | "### 6. Infer step\n", 210 | "\n", 211 | "Use JAX’s @jit decorator to just-in-time compile the function for better performance." 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 6, 217 | "id": "ad5ef40d-7fac-479f-ad5f-8bda7f839b66", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "@jax.jit\n", 222 | "def infer_step(batch, state):\n", 223 | " logits = state.apply_fn(state.params, batch['image'])\n", 224 | " return jnp.argmax(logits, -1)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "5429829e-2f7a-4fb6-acff-d4caaa3a20f6", 230 | "metadata": {}, 231 | "source": [ 232 | "### 7. Main\n", 233 | "\n", 234 | "Run the training loop and evaluate on test set." 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 7, 240 | "id": "2b181696-9268-4f8e-b359-249a544c53b1", 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "name": "stderr", 245 | "output_type": "stream", 246 | "text": [ 247 | "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" 248 | ] 249 | }, 250 | { 251 | "name": "stdout", 252 | "output_type": "stream", 253 | "text": [ 254 | "epoch: 1, test accuracy: 97.28\n", 255 | "epoch: 2, test accuracy: 98.46\n", 256 | "epoch: 3, test accuracy: 98.63\n", 257 | "epoch: 4, test accuracy: 97.91\n", 258 | "epoch: 5, test accuracy: 98.59\n", 259 | "epoch: 6, test accuracy: 99.05\n", 260 | "epoch: 7, test accuracy: 99.15\n", 261 | "epoch: 8, test accuracy: 99.21\n", 262 | "epoch: 9, test accuracy: 99.26\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "train_ds, test_ds = get_datasets()\n", 268 | "\n", 269 | "rng = jax.random.PRNGKey(0)\n", 270 | "rng, init_rng = jax.random.split(rng)\n", 271 | "state = get_train_state(init_rng)\n", 272 | "del init_rng # Must not be used anymore.\n", 273 | "\n", 274 | "num_epochs = 9\n", 275 | "for epoch in range(1, num_epochs + 1):\n", 276 | " # Use a separate PRNG key to permute image data during shuffling\n", 277 | " rng, input_rng = jax.random.split(rng)\n", 278 | " perms = jax.random.permutation(input_rng, 60000)\n", 279 | " del input_rng\n", 280 | " perms = perms.reshape((60000 // 32, 32))\n", 281 | " for perm in perms:\n", 282 | " batch = {k: v[perm, ...] for k, v in train_ds.items()}\n", 283 | " state = train_step(batch, state)\n", 284 | "\n", 285 | " pred = jax.device_get(infer_step(test_ds, state))\n", 286 | " accuracy = accuracy_score(test_ds['label'], pred)\n", 287 | " print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "id": "9236e19d-3fa1-43b0-90d4-a29df8613deb", 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3 (ipykernel)", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.10.7" 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 5 320 | } 321 | -------------------------------------------------------------------------------- /jestimator/models/mnist/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""MNIST Example using JEstimator. 16 | 17 | # For debug run locally: 18 | 19 | ## Train: 20 | 21 | ``` 22 | PYTHONPATH=. python3 jestimator/estimator.py \ 23 | --module_imp="jestimator.models.mnist.mnist" \ 24 | --module_config="jestimator/models/mnist/mnist.py" \ 25 | --train_pattern="tfds://mnist/split=train" \ 26 | --model_dir="$HOME/experiments/mnist/models" \ 27 | --train_batch_size=32 \ 28 | --train_shuffle_buf=4096 \ 29 | --train_epochs=9 \ 30 | --check_every_steps=100 \ 31 | --logtostderr 32 | ``` 33 | 34 | ## Eval continuously: 35 | 36 | ``` 37 | PYTHONPATH=. python3 jestimator/estimator.py \ 38 | --module_imp="jestimator.models.mnist.mnist" \ 39 | --module_config="jestimator/models/mnist/mnist.py" \ 40 | --eval_pattern="tfds://mnist/split=test" \ 41 | --model_dir="$HOME/experiments/mnist/models" \ 42 | --eval_batch_size=32 \ 43 | --mode="eval_wait" \ 44 | --check_ckpt_every_secs=1 \ 45 | --save_high="test_accuracy" \ 46 | --logtostderr 47 | ``` 48 | """ 49 | import math 50 | import re 51 | 52 | from flax import linen as nn 53 | import jax 54 | import jax.numpy as jnp 55 | from jestimator import amos 56 | from jestimator import amos_helper 57 | from jestimator.data.pipeline_rec import rec_data 58 | from jestimator.states import Evaluator, InferState, MeanMetrics, TrainState # pylint: disable=g-multiple-import 59 | import ml_collections 60 | from sklearn.metrics import accuracy_score 61 | import tensorflow as tf 62 | import tensorflow_datasets as tfds 63 | 64 | 65 | def get_config(): 66 | """Returns a config object for modeling flags.""" 67 | module_config = ml_collections.ConfigDict() 68 | module_config.warmup = 2000 69 | module_config.amos_beta = 0.98 70 | return module_config 71 | 72 | 73 | def load_config(global_flags): 74 | """Init config data from global flags.""" 75 | config = ml_collections.ConfigDict() 76 | config.update(global_flags.module_config) 77 | config.train_batch_size = global_flags.train_batch_size 78 | 79 | # Only a frozen config (hashable object) can be passed to jit functions 80 | # (i.e. train_step/valid_step/infer_step). 81 | config.frozen = ml_collections.FrozenConfigDict(config) 82 | 83 | # Construct data pipelines in the following (using TensorFLow): 84 | def dataset_fn(path: str) -> tf.data.Dataset: 85 | # Assum `path` is of the form 'tfds://{name}/split={split}' 86 | m = re.match('tfds://(.*)/split=(.*)', path) 87 | assert m is not None, (f'Cannot parse "{path}" (should be of the form ' 88 | '"tfds://{name}/split={split}").') 89 | name = m.group(1) 90 | split = m.group(2) 91 | builder = tfds.builder(name) 92 | 93 | # Use dataset info to setup model. 94 | info = builder.info 95 | config.image_shape = info.features['image'].shape 96 | config.num_classes = info.features['label'].num_classes 97 | config.num_examples = info.splits[split].num_examples 98 | 99 | builder.download_and_prepare() 100 | return builder.as_dataset(split=split) 101 | 102 | def eval_feature_fn(x): 103 | # For evaluation, we should return a (gold, data) tuple. 104 | label = x.pop('label') 105 | return label, x 106 | 107 | # `pipeline_rec.rec_data` wraps dataset of record type 108 | # (i.e. each record is a single data point) 109 | config.train_data_fn = rec_data(dataset_fn, interleave=True) 110 | config.eval_data_fn = rec_data(dataset_fn, feature_fn=eval_feature_fn) 111 | return config 112 | 113 | 114 | class CNN(nn.Module): 115 | """A simple CNN model.""" 116 | num_classes: int 117 | 118 | @nn.compact 119 | def __call__(self, x): 120 | x /= 255. # Each value (pixel) in input image is 0~255. 121 | x = nn.Conv(features=32, kernel_size=(3, 3))(x) 122 | x = nn.relu(x) 123 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 124 | x = nn.Conv(features=64, kernel_size=(3, 3))(x) 125 | x = nn.relu(x) 126 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 127 | x = x.reshape((x.shape[0], -1)) # flatten 128 | x = nn.Dense(features=256)(x) 129 | x = nn.relu(x) 130 | x = nn.Dense(features=self.num_classes)(x) 131 | return x 132 | 133 | def classify_xe_loss(self, x, labels): 134 | # Labels read from the tfds MNIST are integers from 0 to 9. 135 | # Logits are arrays of size 10. 136 | logits = self(x) 137 | logits = jax.nn.log_softmax(logits) 138 | labels_ = jnp.expand_dims(labels, -1) 139 | llh_ = jnp.take_along_axis(logits, labels_, axis=-1) 140 | loss = -jnp.sum(llh_) 141 | return loss 142 | 143 | 144 | def get_train_state(config, rng): 145 | """Create train state.""" 146 | model = CNN(num_classes=config.num_classes) 147 | # Shape of input is (batch_size,) + image_shape. 148 | dummy_x = jnp.zeros((1,) + config.image_shape) 149 | # Use a separate module to record training time metrics. 150 | metrics_mod = MeanMetrics.create('train_loss') 151 | 152 | def lr_schedule(step): 153 | # Set up a warm-up schedule. 154 | lr = math.sqrt(config.train_batch_size / config.num_examples) 155 | lr *= jnp.minimum(1., step / config.warmup) 156 | return lr 157 | 158 | eta_fn = amos_helper.params_fn_from_assign_map( 159 | { 160 | '.*/bias': 0.5, 161 | '.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-2]))', 162 | '.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))', 163 | '.*Dense_0/kernel': 'sqrt(2/SHAPE[0])', 164 | '.*Dense_1/kernel': 'sqrt(1/SHAPE[0])', 165 | }, 166 | eval_str_value=True, 167 | ) 168 | shape_fn = amos_helper.params_fn_from_assign_map( 169 | { 170 | '.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])', 171 | '.*Dense_0/kernel': '(1, SHAPE[1])', 172 | '.*': (), 173 | }, 174 | eval_str_value=True, 175 | ) 176 | optimizer = amos.amos( 177 | learning_rate=lr_schedule, 178 | eta_fn=eta_fn, 179 | shape_fn=shape_fn, 180 | beta=config.amos_beta, 181 | clip_value=math.sqrt(config.train_batch_size), 182 | ) 183 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy_x) 184 | 185 | 186 | def train_step(config, train_batch, state: TrainState, metrics): 187 | """Training step.""" 188 | loss, grads = state.value_and_grad_apply_fn()( 189 | state.params, 190 | train_batch['image'], 191 | train_batch['label'], 192 | method=CNN.classify_xe_loss) 193 | _, metrics = state.metrics_mod.apply( 194 | metrics, 195 | 'train_loss', 196 | loss, 197 | config.train_batch_size, 198 | method=MeanMetrics.update, 199 | mutable=['metrics']) 200 | return state.apply_gradients(grads=grads), metrics 201 | 202 | 203 | def get_infer_state(config): 204 | """Create infer state.""" 205 | model = CNN(num_classes=config.num_classes) 206 | dummy_x = jnp.zeros((1,) + config.image_shape) 207 | return InferState.create(model, dummy_x) 208 | 209 | 210 | def infer_step(config, batch, state: InferState): 211 | """Infer step.""" 212 | del config # Unused. 213 | logits = state.apply_fn(state.variables(), batch['image']) 214 | return state.replace(ret=jnp.argmax(logits, -1)) 215 | 216 | 217 | def get_evaluator(config) -> Evaluator: 218 | """Create evaluator.""" 219 | del config # Unused. 220 | return Evaluator({'test_accuracy': (lambda y: y, accuracy_score)}) 221 | -------------------------------------------------------------------------------- /jestimator/models/rope/finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Sequence classification. 16 | 17 | # For debug run locally: 18 | 19 | ## Train: 20 | 21 | ``` 22 | PYTHONPATH=. python3 \ 23 | jestimator/estimator.py \ 24 | --module_imp="jestimator.models.rope.finetune" \ 25 | --module_config="jestimator/models/rope/finetune.py" \ 26 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 27 | --module_config.segment_names="sentence1,sentence2" \ 28 | --module_config.model_config.num_labels=2 \ 29 | --train_pattern="tfds://glue/rte/split=train" \ 30 | --valid_pattern="tfds://glue/rte/split=validation" \ 31 | --model_dir="$HOME/models/rope_rte" \ 32 | --checkpoint_path="gs://gresearch/checkpoints_in_amos_paper/\ 33 | adamw/rope-base/checkpoint_300000" \ 34 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \ 35 | --check_every_steps=10 --logtostderr 36 | ``` 37 | 38 | ## Eval: 39 | 40 | ``` 41 | PYTHONPATH=. python3 \ 42 | jestimator/estimator.py \ 43 | --module_imp="jestimator.models.rope.finetune" \ 44 | --module_config="jestimator/models/rope/finetune.py" \ 45 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 46 | --module_config.segment_names="sentence1,sentence2" \ 47 | --module_config.model_config.num_labels=2 \ 48 | --module_config.eval_metric="accuracy" \ 49 | --eval_pattern="tfds://glue/rte/split=validation" \ 50 | --model_dir="$HOME/models/rope_rte" \ 51 | --eval_batch_size=4 --num_eval_examples=4 \ 52 | --logtostderr 53 | ``` 54 | 55 | ## Predict: 56 | 57 | ``` 58 | PYTHONPATH=. python3 \ 59 | jestimator/estimator.py \ 60 | --module_imp="jestimator.models.rope.finetune" \ 61 | --module_config="jestimator/models/rope/finetune.py" \ 62 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \ 63 | --module_config.segment_names="sentence1,sentence2" \ 64 | --module_config.model_config.num_labels=2 \ 65 | --module_config.label_names="entailment,not_entailment" \ 66 | --pred_pattern="tfds://glue/rte/split=test" \ 67 | --model_dir="$HOME/models/rope_rte" \ 68 | --pred_batch_size=4 --num_pred_examples=4 \ 69 | --logtostderr 70 | ``` 71 | """ 72 | 73 | import dataclasses 74 | 75 | import jax 76 | import jax.numpy as jnp 77 | from jestimator.data import reader 78 | from jestimator.data.pipeline_rec import rec_data 79 | from jestimator.data.reader import PyOutSpec 80 | from jestimator.models.rope import modeling 81 | from jestimator.states import Evaluator, InferState, MeanMetrics, Predictor, TrainState # pylint: disable=g-multiple-import 82 | import ml_collections 83 | from ml_collections.config_dict import config_dict 84 | import optax 85 | from scipy import stats as scipy_stats 86 | from sklearn import metrics as sklearn_metrics 87 | import tensorflow as tf 88 | 89 | import sentencepiece as spm 90 | 91 | 92 | def get_config(): 93 | """Returns a config object for modeling flags.""" 94 | module_config = ml_collections.ConfigDict() 95 | 96 | # Model config. 97 | model_config = modeling.ModelConfig() 98 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config)) 99 | module_config.model_config = model_config 100 | 101 | # Optimizer config. 102 | opt_config = ml_collections.ConfigDict() 103 | opt_config.optimizer = 'adam' 104 | opt_config.learning_rate = 5e-6 105 | module_config.opt_config = opt_config 106 | 107 | # Other config. 108 | module_config.vocab_path = config_dict.placeholder(str) 109 | module_config.segment_names = config_dict.placeholder(str) 110 | module_config.eval_metric = config_dict.placeholder(str) 111 | module_config.output_path = config_dict.placeholder(str) 112 | module_config.label_names = config_dict.placeholder(str) 113 | module_config.stsb = False 114 | return module_config 115 | 116 | 117 | def load_config(global_flags): 118 | """Init config data from global flags.""" 119 | config = ml_collections.ConfigDict() 120 | config.update(global_flags.module_config) 121 | 122 | tokenizer = spm.SentencePieceProcessor() 123 | tokenizer.Load(config.vocab_path) 124 | config.model_config.vocab_size = tokenizer.GetPieceSize() 125 | 126 | segment_names = config.segment_names.split(',') 127 | num_segments = len(segment_names) 128 | config.model_config.num_segments = num_segments + 1 129 | 130 | # Only a frozen config (hashable object) can be passed to jit functions 131 | # (i.e. train_step/valid_step/infer_step). 132 | config.frozen = ml_collections.FrozenConfigDict(config) 133 | 134 | # Construct data pipelines in the following (using TensorFLow): 135 | max_length = config.model_config.max_length 136 | max_len_1 = (max_length - 1) // num_segments 137 | cls_token_id = tokenizer.PieceToId('') 138 | sep_token_id = tokenizer.PieceToId('') 139 | eos_token_id = tokenizer.PieceToId('') 140 | data_keys = ['idx', 'label'] + config.segment_names.split(',') 141 | mode = global_flags.mode 142 | 143 | def tokenize_fn(texts): 144 | ids = [] 145 | for s in texts: 146 | s = tf.strings.lower(s).numpy() 147 | ids.append(tf.convert_to_tensor(tokenizer.EncodeAsIds(s), tf.int32)) 148 | return ids 149 | 150 | def example_fn(data): 151 | data = {k: data[k] for k in data_keys if k in data} 152 | texts = [data[k] for k in segment_names] 153 | out_spec = [PyOutSpec((-1,), tf.int32)] * num_segments 154 | tokenized = reader.apply_py_fn(tokenize_fn, texts, out_spec) 155 | 156 | max_len_0 = max_length - 1 157 | input_ids = [tf.concat([[cls_token_id], tokenized[0]], 0)] 158 | for x in tokenized[1:]: 159 | x = tf.concat([[sep_token_id], x], 0)[:max_len_1] 160 | input_ids.append(x) 161 | max_len_0 = max_len_0 - tf.shape(x)[0] 162 | input_ids[0] = input_ids[0][:max_len_0] 163 | input_ids.append([eos_token_id]) 164 | 165 | segment_ids = [tf.ones_like(x) * i for i, x in enumerate(input_ids)] 166 | input_ids = tf.concat(input_ids, 0) 167 | input_mask = tf.ones_like(input_ids) 168 | segment_ids = tf.concat(segment_ids, 0) 169 | 170 | pad_len = max_length - tf.shape(input_ids)[0] 171 | input_ids = tf.pad(input_ids, [[0, pad_len]]) 172 | input_mask = tf.pad(input_mask, [[0, pad_len]]) 173 | segment_ids = tf.pad(segment_ids, [[0, pad_len]]) 174 | 175 | ret = { 176 | 'input_ids': tf.ensure_shape(input_ids, (max_length,)), 177 | 'input_mask': tf.ensure_shape(input_mask, (max_length,)), 178 | 'segment_ids': tf.ensure_shape(segment_ids, (max_length,)), 179 | } 180 | if mode == 'train': 181 | ret['label'] = data['label'] 182 | if config.stsb: 183 | ret['label'] /= 5.0 184 | else: 185 | ret['idx'] = data['idx'] 186 | if mode.startswith('eval'): 187 | ret = (data['label'], ret) 188 | return ret 189 | 190 | def dataset_fn(path: str) -> tf.data.Dataset: 191 | d = reader.get_tfds_dataset(path) 192 | d = d.map(example_fn, tf.data.AUTOTUNE) 193 | return d 194 | 195 | config.train_data_fn = rec_data( 196 | dataset_fn=dataset_fn, cache=True, interleave=True) 197 | config.eval_data_fn = config.valid_data_fn = rec_data( 198 | dataset_fn=dataset_fn, cache=True) 199 | config.pred_data_fn = rec_data(dataset_fn=dataset_fn) 200 | return config 201 | 202 | 203 | def get_train_state(config, rng): 204 | """Create train state.""" 205 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 206 | model = modeling.ModelForSeqCls(model_config) 207 | 208 | opt_config = config.opt_config 209 | if opt_config.optimizer == 'adam': 210 | optimizer = optax.adam(learning_rate=opt_config.learning_rate) 211 | 212 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss') 213 | return TrainState.create(metrics_mod, optimizer, model, rng, jnp.array([[0]])) 214 | 215 | 216 | def train_step(config, train_batch, state: TrainState, metrics): 217 | """Training step.""" 218 | loss_fn = ( 219 | modeling.ModelForSeqCls.mse_loss 220 | if config.stsb else modeling.ModelForSeqCls.xe_loss) 221 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)( 222 | state.params, 223 | train_batch['label'], 224 | train_batch['input_ids'], 225 | segment_ids=train_batch['segment_ids'], 226 | input_mask=train_batch['input_mask'], 227 | enable_dropout=True, 228 | method=loss_fn) 229 | _, metrics = state.metrics_mod.apply( 230 | metrics, 231 | 'train_loss', 232 | loss, 233 | size, 234 | method=MeanMetrics.update, 235 | mutable=['metrics']) 236 | return state.apply_gradients(grads=grads), metrics 237 | 238 | 239 | def valid_step(config, valid_batch, state: TrainState, metrics): 240 | """Validation step.""" 241 | loss_fn = ( 242 | modeling.ModelForSeqCls.mse_loss 243 | if config.stsb else modeling.ModelForSeqCls.xe_loss) 244 | loss, size = state.apply_fn( 245 | state.variables(), 246 | valid_batch['label'], 247 | valid_batch['input_ids'], 248 | segment_ids=valid_batch['segment_ids'], 249 | input_mask=valid_batch['input_mask'], 250 | method=loss_fn) 251 | _, metrics = state.metrics_mod.apply( 252 | metrics, 253 | 'valid_loss', 254 | loss, 255 | size, 256 | method=MeanMetrics.update, 257 | mutable=['metrics']) 258 | return metrics 259 | 260 | 261 | def get_infer_state(config): 262 | """Create infer state.""" 263 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 264 | model = modeling.ModelForSeqCls(model_config) 265 | return InferState.create(model, jnp.array([[0]])) 266 | 267 | 268 | def infer_step(config, batch, state: InferState) -> InferState: 269 | """Infer step.""" 270 | logits = state.apply_fn( 271 | state.variables(), 272 | batch['input_ids'], 273 | segment_ids=batch['segment_ids'], 274 | input_mask=batch['input_mask']) 275 | if config.stsb: 276 | pred = jax.nn.softmax(logits)[..., 0] * 5.0 277 | else: 278 | pred = jnp.argmax(logits, axis=-1) 279 | return state.replace(ret={ 280 | 'idx': batch['idx'], 281 | 'prediction': pred, 282 | }) 283 | 284 | 285 | def get_evaluator(config) -> Evaluator: 286 | """Create evaluator.""" 287 | eval_fns = { 288 | 'accuracy': sklearn_metrics.accuracy_score, 289 | 'f1': sklearn_metrics.f1_score, 290 | 'spearmanr': lambda x, y: scipy_stats.spearmanr(x, y)[0], 291 | } 292 | 293 | def proc_fn(infer): 294 | return infer['prediction'] 295 | 296 | metric = config.eval_metric 297 | return Evaluator({metric: (proc_fn, eval_fns[metric])}) 298 | 299 | 300 | def get_predictor(config) -> Predictor: 301 | """Create predictor.""" 302 | pre_str = 'index\tprediction' 303 | label_names = (None if config.label_names is None else 304 | config.label_names.split(',')) 305 | 306 | def proc_fn(infer): 307 | ret = [] 308 | for x, y in zip(infer['idx'], infer['prediction']): 309 | z = y if label_names is None else label_names[y] 310 | ret.append(f'{x}\t{z}') 311 | return ret 312 | 313 | return Predictor(proc_fn, config.output_path, pre_str=pre_str) 314 | -------------------------------------------------------------------------------- /jestimator/models/rope/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The jestimator Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Pretrain. 16 | 17 | # For debug run locally: 18 | 19 | ``` 20 | PYTHONPATH=. python3 \ 21 | jestimator/estimator.py \ 22 | --module_imp="jestimator.models.rope.pretrain" \ 23 | --module_config="jestimator/models/rope/pretrain.py" \ 24 | --module_config.model_config.vocab_size=32000 \ 25 | --module_config.mask_token_id=4 \ 26 | --train_pattern="gs://gresearch/checkpoints_in_amos_paper/data/\ 27 | books-00000-of-00500" \ 28 | --valid_pattern="gs://gresearch/checkpoints_in_amos_paper/data/ptb" \ 29 | --model_dir="$HOME/models/rope_pretrain" \ 30 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \ 31 | --check_every_steps=10 --logtostderr 32 | ``` 33 | """ 34 | 35 | import dataclasses 36 | from typing import Mapping 37 | 38 | import jax 39 | import jax.numpy as jnp 40 | from jestimator import amos 41 | from jestimator.data.pipeline_lm import lm_data 42 | from jestimator.models.rope import modeling 43 | from jestimator.states import TrainState, MeanMetrics # pylint: disable=g-multiple-import 44 | import ml_collections 45 | from ml_collections.config_dict import config_dict 46 | import optax 47 | import tensorflow as tf 48 | 49 | 50 | def get_config(): 51 | """Returns a config object for modeling flags.""" 52 | module_config = ml_collections.ConfigDict() 53 | 54 | # Model config. 55 | model_config = modeling.ModelConfig() 56 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config)) 57 | module_config.model_config = model_config 58 | 59 | # Optimizer config. 60 | opt_config = ml_collections.ConfigDict() 61 | opt_config.optimizer = 'adamw' 62 | opt_config.learning_rate = 1e-4 63 | opt_config.warmup_steps = 10000 64 | opt_config.linear_decay_to_step = config_dict.placeholder(int) 65 | opt_config.momentum = 0.9 66 | opt_config.beta = 0.999 67 | opt_config.weight_decay = 0.01 68 | module_config.opt_config = opt_config 69 | 70 | # Other config. 71 | module_config.mask_token_id = config_dict.placeholder(int) 72 | module_config.mask_rate = 0.15 73 | return module_config 74 | 75 | 76 | def load_config(global_flags): 77 | """Init config data from global flags.""" 78 | config = ml_collections.ConfigDict() 79 | config.update(global_flags.module_config) 80 | 81 | # Only a frozen config (hashable object) can be passed to jit functions 82 | # (i.e. train_step/valid_step/infer_step). 83 | config.frozen = ml_collections.FrozenConfigDict(config) 84 | 85 | # Construct data pipelines in the following (using TensorFLow): 86 | seq_length = config.model_config.max_length 87 | 88 | def feature_fn(token_ids: tf.Tensor) -> Mapping[str, tf.Tensor]: 89 | """Builds a feature dict to be compatible with seqio.""" 90 | return {'targets': tf.ensure_shape(token_ids, (seq_length,))} 91 | 92 | config.train_data_fn = lm_data( 93 | seq_length, random_skip=True, feature_fn=feature_fn, interleave=True) 94 | config.valid_data_fn = lm_data(seq_length, feature_fn=feature_fn) 95 | return config 96 | 97 | 98 | def get_train_state(config, rng) -> TrainState: 99 | """Create train state.""" 100 | model_config = modeling.ModelConfig(**config.model_config.to_dict()) 101 | model = modeling.ModelForPretrain(model_config) 102 | opt_config = config.opt_config 103 | warmup = opt_config.warmup_steps 104 | decay = opt_config.linear_decay_to_step 105 | 106 | def lr_schedule(step): 107 | lr = opt_config.learning_rate 108 | if warmup is not None: 109 | lr *= jnp.minimum(1., step / warmup) 110 | if decay is not None: 111 | lr *= 1. - jnp.maximum(0., step - warmup) / (decay - warmup) 112 | elif decay is not None: 113 | lr *= 1. - step / decay 114 | return lr 115 | 116 | if opt_config.optimizer == 'adamw': 117 | optimizer = optax.adamw( 118 | learning_rate=lr_schedule, 119 | b1=opt_config.momentum, 120 | b2=opt_config.beta, 121 | weight_decay=opt_config.weight_decay) 122 | elif opt_config.optimizer == 'amos': 123 | optimizer = amos.amos( 124 | lr_schedule, 125 | modeling.get_eta_fn(model_config), 126 | shape_fn=modeling.get_shape_fn(model_config), 127 | beta=opt_config.beta, 128 | momentum=opt_config.momentum, 129 | clip_value=1.) 130 | 131 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss', 'valid_mrr') 132 | return TrainState.create(metrics_mod, optimizer, model, rng, jnp.array([[0]])) 133 | 134 | 135 | def train_step(config, train_batch, state: TrainState, metrics): 136 | """Training step.""" 137 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)( 138 | state.params, 139 | train_batch['targets'], 140 | config.mask_token_id, 141 | mask_rate=config.mask_rate, 142 | input_mask=train_batch.get('input_mask'), 143 | enable_dropout=True, 144 | method=modeling.ModelForPretrain.mlm_train_loss) 145 | _, metrics = state.metrics_mod.apply( 146 | metrics, 147 | 'train_loss', 148 | loss, 149 | size, 150 | method=MeanMetrics.update, 151 | mutable=['metrics']) 152 | return state.apply_gradients(grads=grads), metrics 153 | 154 | 155 | def valid_step(config, valid_batch, state: TrainState, metrics): 156 | """Validation step.""" 157 | 158 | def body(i, metrics): 159 | del i # Unused. 160 | loss, mrr, size = state.apply_fn( 161 | state.variables(), 162 | valid_batch['targets'], 163 | config.mask_token_id, 164 | mask_rate=config.mask_rate, 165 | input_mask=valid_batch.get('input_mask'), 166 | method=modeling.ModelForPretrain.mlm_valid_metrics) 167 | _, metrics = state.metrics_mod.apply( 168 | metrics, 169 | 'valid_loss', 170 | loss, 171 | size, 172 | method=MeanMetrics.update, 173 | mutable=['metrics']) 174 | _, metrics = state.metrics_mod.apply( 175 | metrics, 176 | 'valid_mrr', 177 | mrr, 178 | size, 179 | method=MeanMetrics.update, 180 | mutable=['metrics']) 181 | return metrics 182 | 183 | return jax.lax.fori_loop(0, 20, body, metrics) 184 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "jestimator" 3 | description = "Implementation of the Amos optimizer from the JEstimator lib." 4 | readme = "README.md" 5 | requires-python = ">=3.8" 6 | license = {file = "LICENSE"} 7 | authors = [{name = "jestimator authors", email="no-reply@google.com"}] 8 | classifiers = [ 9 | "Programming Language :: Python :: 3", 10 | "Programming Language :: Python :: 3 :: Only", 11 | "License :: OSI Approved :: Apache Software License", 12 | "Intended Audience :: Science/Research", 13 | ] 14 | keywords = [] 15 | 16 | # pip dependencies of the project 17 | dependencies = [ 18 | "jax>=0.4.3", 19 | "flax", 20 | ] 21 | 22 | # This is set automatically by flit using `jestimator.__version__` 23 | dynamic = ["version"] 24 | 25 | [project.urls] 26 | homepage = "https://github.com/google-research/jestimator" 27 | repository = "https://github.com/google-research/jestimator" 28 | # Other: `documentation`, `changelog` 29 | 30 | [project.optional-dependencies] 31 | # Installed through `pip install .[test]` 32 | test = [ 33 | "absl-py", 34 | ] 35 | 36 | [build-system] 37 | requires = ["flit_core >=3.5,<4"] 38 | build-backend = "flit_core.buildapi" 39 | 40 | # Only publish the Amos optimizer to pip. 41 | [tool.flit.sdist] 42 | exclude = [ 43 | "jestimator/amos_helper_test.py", 44 | "jestimator/amos_test.py", 45 | "jestimator/checkpoint_utils.py", 46 | "jestimator/data/", 47 | "jestimator/data_utils.py", 48 | "jestimator/estimator.py", 49 | "jestimator/modeling.py", 50 | "jestimator/models/", 51 | "jestimator/states.py", 52 | ] 53 | --------------------------------------------------------------------------------