├── .github
└── workflows
│ └── pytest_and_autopublish.yml
├── .gitignore
├── .pylintrc
├── .style.yapf
├── .vscode
└── settings.json
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── gcp-finetune.md
├── gcp-pretrain.md
├── screenshot01-gcpdashboard.jpg
├── screenshot02-vminstances.jpg
├── screenshot03-vmcreate.jpg
├── screenshot04-ssh.jpg
├── screenshot05-config.jpg
├── screenshot06-buckets.jpg
├── screenshot07-createbucket.jpg
├── screenshot08-tpunode.jpg
└── screenshot09-finetunetpus.jpg
├── jestimator
├── __init__.py
├── amos.py
├── amos_helper.py
├── amos_helper_test.py
├── amos_test.py
├── checkpoint_utils.py
├── data
│ ├── pipeline_lm.py
│ ├── pipeline_rec.py
│ ├── pipeline_seqio.py
│ └── reader.py
├── data_utils.py
├── estimator.py
├── modeling.py
├── models
│ ├── bert
│ │ ├── finetune.py
│ │ ├── modeling.py
│ │ └── pretrain.py
│ ├── bert_rpe
│ │ ├── finetune.py
│ │ ├── modeling.py
│ │ └── pretrain.py
│ ├── linear_regression
│ │ ├── linear_regression.py
│ │ └── linear_regression_test.py
│ ├── lstm
│ │ ├── lm.py
│ │ ├── modeling.py
│ │ └── ptb
│ │ │ ├── ptb.test.txt
│ │ │ ├── ptb.train.txt
│ │ │ ├── ptb.valid.txt
│ │ │ └── vocab.txt
│ ├── mnist
│ │ ├── mnist.ipynb
│ │ └── mnist.py
│ └── rope
│ │ ├── finetune.py
│ │ ├── modeling.py
│ │ └── pretrain.py
└── states.py
└── pyproject.toml
/.github/workflows/pytest_and_autopublish.yml:
--------------------------------------------------------------------------------
1 | name: Unittests & Auto-publish
2 |
3 | # Allow to trigger the workflow manually (e.g. when deps changes)
4 | on: [push, workflow_dispatch]
5 |
6 | jobs:
7 | test-job:
8 | runs-on: ubuntu-latest
9 | timeout-minutes: 10
10 |
11 | concurrency:
12 | group: ${{ github.workflow }}-${{ github.ref }}
13 | cancel-in-progress: true
14 |
15 | steps:
16 | - uses: actions/checkout@v3
17 |
18 | # Install deps
19 | - uses: actions/setup-python@v4
20 | with:
21 | python-version: 3.8
22 | # Uncomment to cache of pip dependencies (if tests too slow)
23 | # cache: pip
24 | # cache-dependency-path: '**/pyproject.toml'
25 |
26 | - run: pip --version
27 | - run: pip install -e .[test]
28 | - run: pip freeze
29 |
30 | # Run tests (in parallel)
31 | - name: amos_test
32 | run: python3 jestimator/amos_test.py
33 |
34 | - name: amos_helper_test
35 | run: python3 jestimator/amos_helper_test.py
36 |
37 | # Auto-publish when version is increased
38 | publish-job:
39 | # Only try to publish if:
40 | # * Repo is self (prevents running from forks)
41 | # * Branch is `main`
42 | if: |
43 | github.repository == 'google-research/jestimator'
44 | && github.ref == 'refs/heads/main'
45 | needs: test-job # Only publish after tests are successful
46 | runs-on: ubuntu-latest
47 | permissions:
48 | contents: write
49 | timeout-minutes: 10
50 |
51 | steps:
52 | # Publish the package (if local `__version__` > pip version)
53 | - uses: etils-actions/pypi-auto-publish@v1
54 | with:
55 | pypi-token: ${{ secrets.PYPI_API_TOKEN }}
56 | gh-token: ${{ secrets.FOR_JESTIMATOR_RELEASE }}
57 | parse-changelog: true
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled python modules.
2 | *.pyc
3 |
4 | # Byte-compiled
5 | _pycache__/
6 | .cache/
7 |
8 | # Poetry, setuptools, PyPI distribution artifacts.
9 | /*.egg-info
10 | .eggs/
11 | build/
12 | dist/
13 | poetry.lock
14 |
15 | # Tests
16 | .pytest_cache/
17 |
18 | # Type checking
19 | .pytype/
20 |
21 | # Other
22 | *.DS_Store
23 |
24 | # PyCharm
25 | .idea
26 |
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | based_on_style = google
3 | indent_width = 2
4 | dedent_closing_brackets = True
5 | split_before_dot = True
6 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "files.insertFinalNewline": true,
3 | "files.trimFinalNewlines": true,
4 | "files.trimTrailingWhitespace": true,
5 | "files.associations": {
6 | ".pylintrc": "ini",
7 | ".style.yapf": "ini"
8 | },
9 | "python.testing.unittestEnabled": false,
10 | "python.testing.nosetestsEnabled": false,
11 | "python.testing.pytestEnabled": true,
12 | "python.linting.pylintUseMinimalCheckers": false,
13 | "[python]": {
14 | "editor.rulers": [
15 | 80
16 | ],
17 | "editor.tabSize": 2,
18 | "editor.detectIndentation": false
19 | },
20 | "python.formatting.provider": "yapf",
21 | "files.watcherExclude": {
22 | "**/.git/**": true
23 | },
24 | "files.exclude": {
25 | "**/__pycache__": true,
26 | "**/.pytest_cache": true,
27 | "**/*.egg-info": true
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 |
23 |
24 | ## [Unreleased]
25 |
26 | * Multi-pod debug for estimator.
27 | * Show variable shape in tensorboard.
28 | * `save_mutable` defaults to False for LSTM debug.
29 | * Debug for checkpointing.
30 | * Simplified pre-training code (remove seqio).
31 |
32 | ## [0.3.3] - 2022-12-01
33 |
34 | * Python 3.7 compatibility.
35 | * Add `d_coef` and `c_coef` to Amos hyper-parameter.
36 | * Support for flax_mutables in checkpointing.
37 | * Bug fix data/pipeline_rec and simplified data_utils code.
38 |
39 | ## [0.3.2] - 2022-11-01
40 |
41 | * The Amos optimizer implementation stick to the paper.
42 | * Initial Pypi package.
43 | * Setup Github workflow for unit test and auto-publish.
44 | * MNIST examples.
45 |
46 | [Unreleased]: https://github.com/google-research/jestimator/compare/v0.3.3...HEAD
47 | [0.3.3]: https://github.com/google-research/jestimator/releases/tag/v0.3.3
48 | [0.3.2]: https://github.com/google-research/jestimator/releases/tag/v0.3.2
49 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Amos and JEstimator
2 |
3 | [](https://github.com/google-research/jestimator/actions/workflows/pytest_and_autopublish.yml)
4 | [](https://badge.fury.io/py/jestimator)
5 |
6 | *This is not an officially supported Google product.*
7 |
8 | This is the source code for the paper "[Amos: An Adam-style Optimizer with
9 | Adaptive Weight Decay towards Model-Oriented
10 | Scale](https://arxiv.org/abs/2210.11693)".
11 |
12 | It implements **Amos**, an optimizer compatible with the
13 | [optax](https://github.com/deepmind/optax) library, and **JEstimator**, a
14 | light-weight library with a `tf.Estimator`-like interface to manage
15 | [T5X](https://github.com/google-research/t5x)-compatible checkpoints for machine
16 | learning programs in [JAX](https://github.com/google/jax), which we use to run
17 | experiments in the paper.
18 |
19 | ## Quickstart
20 |
21 | ```
22 | pip install jestimator
23 | ```
24 |
25 | It will install the Amos optimizer implemented in the jestimator lib.
26 |
27 | ## Usage of Amos
28 |
29 | This implementation of Amos is used with [JAX](https://github.com/google/jax), a
30 | high-performance numerical computing library with automatic differentiation, for
31 | machine learning research. The API of Amos is compatible with
32 | [optax](https://github.com/deepmind/optax), a library of JAX optimizers
33 | (hopefully Amos will be integrated into optax in the near future).
34 |
35 | In order to demonstrate the usage, we will apply Amos to MNIST. It is based on
36 | Flax's official
37 | [MNIST Example](https://github.com/google/flax/tree/main/examples/mnist), and
38 | you can find the code in a jupyter notebook
39 | [here](https://github.com/google-research/jestimator/tree/main/jestimator/models/mnist/mnist.ipynb).
40 |
41 | ### 1. Imports
42 |
43 | ```
44 | import jax
45 | import jax.numpy as jnp # JAX NumPy
46 | from jestimator import amos # The Amos optimizer implementation
47 | from jestimator import amos_helper # Helper module for Amos
48 |
49 | from flax import linen as nn # The Linen API
50 | from flax.training import train_state # Useful dataclass to keep train state
51 |
52 | import math
53 | import tensorflow_datasets as tfds # TFDS for MNIST
54 | from sklearn.metrics import accuracy_score
55 | ```
56 |
57 | ### 2. Load data
58 |
59 | ```
60 | def get_datasets():
61 | """Load MNIST train and test datasets into memory."""
62 |
63 | ds_builder = tfds.builder('mnist')
64 | ds_builder.download_and_prepare()
65 | train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
66 | test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
67 | train_ds['image'] = jnp.float32(train_ds['image']) / 255.
68 | test_ds['image'] = jnp.float32(test_ds['image']) / 255.
69 | return train_ds, test_ds
70 | ```
71 |
72 | ### 3. Build model
73 |
74 | ```
75 | class CNN(nn.Module):
76 | """A simple CNN model."""
77 |
78 | @nn.compact
79 | def __call__(self, x):
80 | x = nn.Conv(features=32, kernel_size=(3, 3))(x)
81 | x = nn.relu(x)
82 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
83 | x = nn.Conv(features=64, kernel_size=(3, 3))(x)
84 | x = nn.relu(x)
85 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
86 | x = x.reshape((x.shape[0], -1)) # flatten
87 | x = nn.Dense(features=256)(x)
88 | x = nn.relu(x)
89 | x = nn.Dense(features=10)(x)
90 | return x
91 |
92 | def classify_xe_loss(self, x, labels):
93 | # Labels read from the tfds MNIST are integers from 0 to 9.
94 | # Logits are arrays of size 10.
95 | logits = self(x)
96 | logits = jax.nn.log_softmax(logits)
97 | labels_ = jnp.expand_dims(labels, -1)
98 | llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
99 | loss = -jnp.sum(llh_)
100 | return loss
101 | ```
102 |
103 | ### 4. Create train state
104 |
105 | A `TrainState` object keeps the model parameters and optimizer states, and can
106 | be checkpointed into files.
107 |
108 | We create the model and optimizer in this function.
109 |
110 | **For the optimizer, we use Amos here.** The following hyper-parameters are set:
111 |
112 | * *learning_rate*: The global learning rate.
113 | * *eta_fn*: The model-specific 'eta'.
114 | * *shape_fn*: Memory reduction setting.
115 | * *beta*: Rate for running average of gradient squares.
116 | * *clip_value*: Gradient clipping for stable training.
117 |
118 | The global learning rate is usually set to the 1/sqrt(N), where N is the number
119 | of batches in the training data. For MNIST, we have 60k training examples and
120 | batch size is 32. So learning_rate=1/sqrt(60000/32).
121 |
122 | The model-specific 'eta_fn' requires a function that, given a variable name and
123 | shape, returns a float indicating the expected scale of that variable. Hopefully
124 | in the near future we will have libraries that can automatically calculate this
125 | 'eta_fn' from the modeling code; but for now we have to specify it manually.
126 |
127 | One can use the amos_helper.params_fn_from_assign_map() helper function to
128 | create 'eta_fn' from an assign_map. An assign_map is a dict which maps regex
129 | rules to a value or simple Python expression. It will find the first regex rule
130 | which matches the name of a variable, and evaluate the Python expression if
131 | necessary to return the value. See our example below.
132 |
133 | The 'shape_fn' similarly requires a function that, given a variable name and
134 | shape, returns a reduced shape for the corresponding slot variables. We can use
135 | the amos_helper.params_fn_from_assign_map() helper function to create 'shape_fn'
136 | from an assign_map as well.
137 |
138 | 'beta' is the exponential decay rate for running average of gradient squares. We
139 | set it to 0.98 here.
140 |
141 | 'clip_value' is the gradient clipping value, which should match the magnitude of
142 | the loss function. If the loss function is a sum of cross-entropy, then we
143 | should set 'clip_value' to the sqrt of the number of labels.
144 |
145 | Please refer to our [paper](https://arxiv.org/abs/2210.11693) for more details
146 | of the hyper-parameters.
147 |
148 | ```
149 | def get_train_state(rng):
150 | model = CNN()
151 | dummy_x = jnp.ones([1, 28, 28, 1])
152 | params = model.init(rng, dummy_x)
153 |
154 | eta_fn = amos_helper.params_fn_from_assign_map(
155 | {
156 | '.*/bias': 0.5,
157 | '.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
158 | '.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
159 | '.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
160 | '.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
161 | },
162 | eval_str_value=True,
163 | )
164 | shape_fn = amos_helper.params_fn_from_assign_map(
165 | {
166 | '.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
167 | '.*Dense_0/kernel': '(1, SHAPE[1])',
168 | '.*': (),
169 | },
170 | eval_str_value=True,
171 | )
172 | optimizer = amos.amos(
173 | learning_rate=1/math.sqrt(60000/32),
174 | eta_fn=eta_fn,
175 | shape_fn=shape_fn,
176 | beta=0.98,
177 | clip_value=math.sqrt(32),
178 | )
179 | return train_state.TrainState.create(
180 | apply_fn=model.apply, params=params, tx=optimizer)
181 | ```
182 |
183 | ### 5. Train step
184 |
185 | Use JAX’s @jit decorator to just-in-time compile the function for better
186 | performance.
187 |
188 | ```
189 | @jax.jit
190 | def train_step(batch, state):
191 | grad_fn = jax.grad(state.apply_fn)
192 | grads = grad_fn(
193 | state.params,
194 | batch['image'],
195 | batch['label'],
196 | method=CNN.classify_xe_loss)
197 | return state.apply_gradients(grads=grads)
198 | ```
199 |
200 | ### 6. Infer step
201 |
202 | Use JAX’s @jit decorator to just-in-time compile the function for better
203 | performance.
204 |
205 | ```
206 | @jax.jit
207 | def infer_step(batch, state):
208 | logits = state.apply_fn(state.params, batch['image'])
209 | return jnp.argmax(logits, -1)
210 | ```
211 |
212 | ### 7. Main
213 |
214 | Run the training loop and evaluate on test set.
215 |
216 | ```
217 | train_ds, test_ds = get_datasets()
218 |
219 | rng = jax.random.PRNGKey(0)
220 | rng, init_rng = jax.random.split(rng)
221 | state = get_train_state(init_rng)
222 | del init_rng # Must not be used anymore.
223 |
224 | num_epochs = 9
225 | for epoch in range(1, num_epochs + 1):
226 | # Use a separate PRNG key to permute image data during shuffling
227 | rng, input_rng = jax.random.split(rng)
228 | perms = jax.random.permutation(input_rng, 60000)
229 | del input_rng
230 | perms = perms.reshape((60000 // 32, 32))
231 | for perm in perms:
232 | batch = {k: v[perm, ...] for k, v in train_ds.items()}
233 | state = train_step(batch, state)
234 |
235 | pred = jax.device_get(infer_step(test_ds, state))
236 | accuracy = accuracy_score(test_ds['label'], pred)
237 | print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
238 | ```
239 |
240 | After 9 epochs, we should get 99.26 test accuracy. If you made it, congrats!
241 |
242 | ## JEstimator
243 |
244 | With JEstimator, you can build your model mostly similar to the MNIST example
245 | above, but without writing code for the "Main" section; JEstimator will serve as
246 | the entry point for your model, automatically handle checkpointing in a
247 | train/eval-once/eval-while-training-and-save-the-best/predict mode, and set up
248 | profiling, tensorboard, and logging.
249 |
250 | In addition, JEstimator supports model partitioning which is required for
251 | training very large models across multiple TPU pods. It supports a
252 | [T5X](https://github.com/google-research/t5x)-compatible checkpoint format that
253 | saves and restores checkpoints in a distributed manner, which is suitable for
254 | large multi-pod models.
255 |
256 | In order to run models with JEstimator, we need to install
257 | [T5X](https://github.com/google-research/t5x#installation) and
258 | [FlaxFormer](https://github.com/google/flaxformer):
259 |
260 | ```
261 | git clone --branch=main https://github.com/google-research/t5x
262 | cd t5x
263 | python3 -m pip install -e .
264 | cd ..
265 |
266 | git clone --branch=main https://github.com/google/flaxformer
267 | cd flaxformer
268 | pip3 install .
269 | cd ..
270 | ```
271 |
272 | Then, clone this repo to get the JEstimator code:
273 |
274 | ```
275 | git clone --branch=main https://github.com/google-research/jestimator
276 | cd jestimator
277 | ```
278 |
279 | Now, we can test a toy linear regression model:
280 |
281 | ```
282 | PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
283 | ```
284 |
285 | ## MNIST Example in JEstimator
286 |
287 | We provide this
288 | [MNIST Example](https://github.com/google-research/jestimator/tree/main/jestimator/models/mnist/mnist.py)
289 | to demonstrate how to write modeling code with JEstimator. It is much like the
290 | example above, but with a big advantage that, a config object is passed around
291 | to collect information from global flags and the dataset, in order to
292 | dynamically setup modeling. This makes it easier to apply the model to different datasets; for example, one can immediately try the [emnist](https://www.tensorflow.org/datasets/catalog/emnist) or [eurosat](https://www.tensorflow.org/datasets/catalog/eurosat) datasets simply by changing a command-line argument, without modifying the code.
293 |
294 | With the following command, we can start a job to train on MNIST, log every 100
295 | steps, and save the checkpoints to $HOME/experiments/mnist/models:
296 |
297 | ```
298 | PYTHONPATH=. python3 jestimator/estimator.py \
299 | --module_imp="jestimator.models.mnist.mnist" \
300 | --module_config="jestimator/models/mnist/mnist.py" \
301 | --train_pattern="tfds://mnist/split=train" \
302 | --model_dir="$HOME/experiments/mnist/models" \
303 | --train_batch_size=32 \
304 | --train_shuffle_buf=4096 \
305 | --train_epochs=9 \
306 | --check_every_steps=100 \
307 | --max_ckpt=20 \
308 | --save_every_steps=1000 \
309 | --module_config.warmup=2000 \
310 | --module_config.amos_beta=0.98
311 | ```
312 |
313 | Meanwhile, we can start a job to monitor the $HOME/experiments/mnist/models
314 | folder, evaluate on MNIST test set, and save the model with the highest
315 | accuracy:
316 |
317 | ```
318 | PYTHONPATH=. python3 jestimator/estimator.py \
319 | --module_imp="jestimator.models.mnist.mnist" \
320 | --module_config="jestimator/models/mnist/mnist.py" \
321 | --eval_pattern="tfds://mnist/split=test" \
322 | --model_dir="$HOME/experiments/mnist/models" \
323 | --eval_batch_size=32 \
324 | --mode="eval_wait" \
325 | --check_ckpt_every_secs=1 \
326 | --save_high="test_accuracy"
327 | ```
328 |
329 | At the same time, we can start a tensorboard to monitor the process:
330 |
331 | ```
332 | tensorboard --logdir $HOME/experiments/mnist/models
333 | ```
334 |
335 | ## LSTM on PTB
336 |
337 | We can use the following command to train a single layer LSTM on PTB:
338 |
339 | ```
340 | PYTHONPATH=. python3 jestimator/estimator.py \
341 | --module_imp="jestimator.models.lstm.lm" \
342 | --module_config="jestimator/models/lstm/lm.py" \
343 | --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \
344 | --train_pattern="jestimator/models/lstm/ptb/ptb.train.txt" \
345 | --model_dir="$HOME/models/ptb_lstm" \
346 | --train_batch_size=64 \
347 | --train_consecutive=113 \
348 | --train_shuffle_buf=4096 \
349 | --max_train_steps=200000 \
350 | --check_every_steps=1000 \
351 | --max_ckpt=20 \
352 | --module_config.opt_config.optimizer="amos" \
353 | --module_config.opt_config.learning_rate=0.01 \
354 | --module_config.opt_config.beta=0.98 \
355 | --module_config.opt_config.momentum=0.0
356 | ```
357 |
358 | and evaluate:
359 |
360 | ```
361 | PYTHONPATH=. python3 jestimator/estimator.py \
362 | --module_imp="jestimator.models.lstm.lm" \
363 | --module_config="jestimator/models/lstm/lm.py" \
364 | --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \
365 | --eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt" \
366 | --model_dir="$HOME/models/ptb_lstm" \
367 | --eval_batch_size=1
368 | ```
369 |
370 | It is suitable for running on single-GPU machine.
371 |
372 | ## More JEstimator Models
373 |
374 | Here are some simple guides to pre-train and fine-tune BERT-like models, using TPUs on Google Cloud Platform (GCP). One can start with a Web browser with zero setup, by connecting to a Virtual Machine via Google Cloud console, without installing anything locally. If this is the first time, one is covered by enough credits to try the commands by free.
375 |
376 | * [My experience of pre-training a BERT-base model on GCP](docs/gcp-pretrain.md)
377 | * [My experience of fine-tuning MNLI on GCP](docs/gcp-finetune.md)
378 |
--------------------------------------------------------------------------------
/docs/gcp-finetune.md:
--------------------------------------------------------------------------------
1 | # Using TPUs on Google Cloud to fine-tune MNLI
2 |
3 | Firstly, we do the same as *Login into a VM instance* and *Create a storage
4 | bucket* in the [pre-training](gcp-pretrain.md) job. If these are already done,
5 | we can use the same login VM and storage bucket.
6 |
7 | ## Create TPU nodes
8 |
9 | Then, we create TPU nodes for fine-tuning. We use one 'v3-8' node to train and
10 | one 'v2-8' node to evaluate, as below.
11 |
12 | 
13 |
14 | Now, we can login into each TPU node and setup by the following commands:
15 |
16 | ```
17 | gcloud auth application-default login
18 | gcloud auth login
19 |
20 | pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
21 |
22 | git clone --branch=main https://github.com/google-research/t5x
23 | cd t5x
24 | python3 -m pip install -e .
25 | cd ..
26 |
27 | git clone --branch=main https://github.com/google/flaxformer
28 | cd flaxformer
29 | pip3 install .
30 | cd ..
31 |
32 | git clone --branch=main https://github.com/google-research/jestimator
33 | pip install -U scikit-learn
34 |
35 | wget http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/archives/sentence_piece.tar.gz
36 | tar xvfz sentence_piece.tar.gz
37 | ```
38 |
39 | ## Launch jobs
40 |
41 | We will fine-tune a BERT-base checkpoint pre-trained by Amos.
42 |
43 | At the 'v3-8' train node, we launch the training job as below:
44 |
45 | ```
46 | cd jestimator
47 | PYTHONPATH=. python3 jestimator/estimator.py \
48 | --module_imp="jestimator.models.bert.finetune" \
49 | --module_config="jestimator/models/bert/finetune.py" \
50 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
51 | --module_config.segment_names="premise,hypothesis" \
52 | --module_config.model_config.num_labels=3 \
53 | --train_pattern="tfds://glue/mnli/split=train" \
54 | --valid_pattern="tfds://glue/mnli/split=validation_matched" \
55 | --model_dir="gs://jestimator_example/experiments/finetune/amos-bert-base/\
56 | mnli/models" \
57 | --checkpoint_path="gs://gresearch/checkpoints_in_amos_paper/\
58 | amos-bert-base/checkpoint_300000" \
59 | --train_shuffle_buf=65536 --max_train_steps=100000 \
60 | --train_batch_size=64 --valid_batch_size=64 --num_valid_examples=256 \
61 | --module_config.opt_config.optimizer="adam" \
62 | --module_config.opt_config.learning_rate=1e-5 \
63 | --check_every_steps=500 --logtostderr
64 | ```
65 |
66 | At the 'v2-8' eval node, we launch the eval job as below:
67 |
68 | ```
69 | cd jestimator
70 | PYTHONPATH=. python3 jestimator/estimator.py \
71 | --module_imp="jestimator.models.bert.finetune" \
72 | --module_config="jestimator/models/bert/finetune.py" \
73 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
74 | --module_config.segment_names="premise,hypothesis" \
75 | --module_config.model_config.num_labels=3 \
76 | --eval_pattern="tfds://glue/mnli/split=validation_matched" \
77 | --model_dir="gs://jestimator_example/experiments/finetune/amos-bert-base/\
78 | mnli/models" \
79 | --mode="eval_wait" --check_ckpt_every_secs=10 --max_train_steps=100000 \
80 | --eval_batch_size=64 --module_config.eval_metric="accuracy" \
81 | --logtostderr
82 | ```
83 |
84 | ## Start TensorBoard
85 |
86 | The setup is the same as in [pre-training](gcp-pretrain.md#start-tensorboard). We use
87 | the following command to start a TensorBoard.
88 |
89 | ```
90 | .local/bin/tensorboard dev upload \
91 | --logdir mnt/jestimator_example/experiments/finetune/amos-bert-base \
92 | --name "Finetune BERT-base" \
93 | --description "Example of using TPUs on Google Cloud to run JEstimator."
94 | ```
95 |
96 | ## Pre-trained checkpoints in the Amos paper
97 |
98 | We have released the following checkpoints from the Amos paper.
99 |
100 | Model | Google Storage | Download
101 | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------
102 | BERT-base | `gs://gresearch/checkpoints_in_amos_paper/adamw-bert-base/checkpoint_300000`
`gs://gresearch/checkpoints_in_amos_paper/amos-bert-base/checkpoint_300000` | [adamw-bert-base.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-bert-base.tar.gz)
[amos-bert-base.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-bert-base.tar.gz)
103 | BERT-large | `gs://gresearch/checkpoints_in_amos_paper/adamw-bert-large/checkpoint_250000`
`gs://gresearch/checkpoints_in_amos_paper/amos-bert-large/checkpoint_250000` | [adamw-bert-large.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-bert-large.tar.gz)
[amos-bert-large.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-bert-large.tar.gz)
104 | RoPE-base | `gs://gresearch/checkpoints_in_amos_paper/adamw-rope-base/checkpoint_300000`
`gs://gresearch/checkpoints_in_amos_paper/amos-rope-base/checkpoint_300000` | [adamw-rope-base.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-rope-base.tar.gz)
[amos-rope-base.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-rope-base.tar.gz)
105 | RoPE-large | `gs://gresearch/checkpoints_in_amos_paper/adamw-rope-large/checkpoint_1000000`
`gs://gresearch/checkpoints_in_amos_paper/amos-rope-large/checkpoint_1000000` | [adamw-rope-large.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-rope-large.tar.gz)
[amos-rope-large.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-rope-large.tar.gz)
106 | RPE-base | `gs://gresearch/checkpoints_in_amos_paper/adamw-rpe/checkpoint_300000`
`gs://gresearch/checkpoints_in_amos_paper/amos-rpe/checkpoint_300000` | [adamw-rpe.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/adamw-rpe.tar.gz)
[amos-rpe.tar.gz](http://storage.googleapis.com/gresearch/checkpoints_in_amos_paper/tgz/amos-rpe.tar.gz)
107 |
--------------------------------------------------------------------------------
/docs/gcp-pretrain.md:
--------------------------------------------------------------------------------
1 | # Using TPUs on Google Cloud to pre-train BERT
2 |
3 | ## Login into a VM instance
4 |
5 | It's not difficult to find the [Google Cloud page](https://cloud.google.com/)
6 | and start with a Google account. It requires payment information if this is the
7 | first time, but one gets automatic free trial credit and will not be charged.
8 |
9 | One should land on the dashboard as below, and let's hit the "GO TO COMPUTE
10 | ENGINE" button.
11 |
12 | 
13 |
14 | We should be able to see VM instances after installing the Compute Engine App on
15 | Web browser. Then hit the "CREATE INSTANCE" button, as below.
16 |
17 | 
18 |
19 | We can choose the smallest e2-micro (2 vCPU, 1 GB memory) machine type, since it
20 | will only be used to connect to TPU nodes and run TensorBoard. We may want to
21 | choose a region where we will find our TPUs, e.g. us-central (Iowa), as shown
22 | below.
23 |
24 | 
25 |
26 | After the Virtual Machine is created and appears on the dashboard, we hit the
27 | "SSH" button to connect to the VM, as below.
28 |
29 | 
30 |
31 | Now we have got a Linux terminal! We can, for example, install TensorBoard using
32 | the following command.
33 |
34 | ```
35 | sudo apt update
36 | sudo apt install python3-pip
37 | pip install -U tensorboard
38 | ```
39 |
40 | The terminal is also a Google Cloud console. We can run GCP-specific commands,
41 | such as listing the pre-training corpus that we are going to use, already stored
42 | in a Cloud Storage Bucket:
43 |
44 | ```
45 | gsutil ls gs://gresearch/checkpoints_in_amos_paper/data
46 | ```
47 |
48 | In order to connect to TPUs, we should configure the `gcloud` command in our
49 | terminal, as below:
50 |
51 | ```
52 | gcloud config set account your-google-account
53 | gcloud config set project your-project-id
54 | gcloud auth application-default login
55 | gcloud auth login
56 | ```
57 |
58 | In which, the google-account (e.g. xyz@gmail.com) and project-id (e.g.
59 | graphical-quest-123456) can be found by clicking the dashboard as below.
60 |
61 | 
62 |
63 | Now, we can list the TPU nodes we have created in any zone, but currently there
64 | is none:
65 |
66 | ```
67 | gcloud compute tpus tpu-vm list --zone us-central1-a
68 | ```
69 |
70 | ## Create a storage bucket
71 |
72 | We also need a storage bucket, so that the TPU node can save model checkpoints
73 | in it, and the login VM can access the TensorBoard log as well. We hit the
74 | dashboard menu and select 'Cloud Storage > Buckets' as below.
75 |
76 | 
77 |
78 | Then, we have to figure out a unique name for the bucket. We use
79 | 'jestimator_example' in this example. If we are sure about the region of TPUs we
80 | are going to use, we may restrict the location of the bucket. As below.
81 |
82 | 
83 |
84 | ## Create a TPU node
85 |
86 | Now, we can create a TPU node for pre-training BERT. We hit the dashboard menu
87 | again and select 'Compute Engine > TPUs'. For pre-training BERT, we select the
88 | TPU type 'v2-32'. Since we are using JAX, the TPU software version should be
89 | 'tpu-vm-base'. As below.
90 |
91 | 
92 |
93 | ## Login into the TPU node, install JEstimator, and launch the pre-training job
94 |
95 | A TPU node of type 'v2-32' consists of 4 workers, while each worker has an
96 | 8-core TPUv2 board. We have to setup each TPU worker manually. To start, we have
97 | to login into each worker once. In our terminal, we use the following command to
98 | login into worker 0:
99 |
100 | ```
101 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=0
102 | ```
103 |
104 | And we should do the same for worker 1, 2, 3. In addition, we should do `gcloud
105 | auth` at least on worker 0, in order to access our storage bucket during
106 | pre-training:
107 |
108 | ```
109 | gcloud auth application-default login
110 | gcloud auth login
111 | ```
112 |
113 | Then, we can use the following commands to setup all the workers at the same
114 | time:
115 |
116 | ```
117 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \
118 | --command='pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
119 |
120 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \
121 | --command='git clone --branch=main https://github.com/google-research/t5x'
122 |
123 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \
124 | --command='cd t5x; python3 -m pip install -e .'
125 |
126 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \
127 | --command='git clone --branch=main https://github.com/google/flaxformer'
128 |
129 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \
130 | --command='cd flaxformer; pip3 install .'
131 |
132 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \
133 | --command='git clone --branch=main https://github.com/google-research/jestimator'
134 | ```
135 |
136 | Now, we are ready to launch the pre-training job! The following command will
137 | pre-train a BERT-base model on the Wikipedia+Books corpus, with training
138 | batch-size 256 and 300k training steps. It will use the Amos optimizer. Since
139 | the job will run for a long time, it is better to start a terminal multiplexer
140 | (e.g. GNU screen) and launch the command in it.
141 |
142 | ```
143 | gcloud compute tpus tpu-vm ssh node-1 --zone us-central1-a --worker=all \
144 | --command='cd jestimator; PYTHONPATH=. python3 jestimator/estimator.py --module_imp="jestimator.models.bert.pretrain" --module_config="jestimator/models/bert/pretrain.py" --module_config.model_config.vocab_size=32000 --module_config.mask_token_id=4 --train_pattern="gs://gresearch/checkpoints_in_amos_paper/data/wikipedia-00???-of-00500,gs://gresearch/checkpoints_in_amos_paper/data/books-00???-of-00500" --valid_pattern="gs://gresearch/checkpoints_in_amos_paper/data/ptb" --train_shuffle_buf=65536 --max_train_steps=300000 --train_batch_size=256 --valid_batch_size=512 --num_valid_examples=512 --check_every_steps=5000 --module_config.opt_config.optimizer="amos" --module_config.opt_config.learning_rate=0.005 --model_dir="gs://jestimator_example/experiments/pretrain/bert-base/amos/models" --logtostderr'
145 | ```
146 |
147 | **Caution**: One wants to check the billing information now because the
148 | pre-training job may quickly exhaust the credit!
149 |
150 | ## Start TensorBoard
151 |
152 | After the pre-training job starts, we can start a TensorBoard to monitor the
153 | pre-training job. We will use
154 | [gcsfuse](https://github.com/GoogleCloudPlatform/gcsfuse) to mount our storage
155 | bucket to the login VM, and upload the pre-training logs to
156 | [TensorBoard.dev](https://tensorboard.dev/).
157 |
158 | In order to install gcsfuse:
159 |
161 |
162 | ```
163 | export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
164 | echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
165 | curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
166 |
167 | sudo apt-get update
168 | sudo apt-get install gcsfuse
169 | ```
170 |
171 |
173 | In order to mount the storage bucket and start TensorBoard:
174 |
175 | ```
176 | mkdir -p mnt/jestimator_example
177 | gcsfuse jestimator_example mnt/jestimator_example
178 |
179 | .local/bin/tensorboard dev upload \
180 | --logdir mnt/jestimator_example/experiments/pretrain/bert-base \
181 | --name "Pretrain BERT-base" \
182 | --description "Example of using TPUs on Google Cloud to run JEstimator."
183 | ```
184 |
--------------------------------------------------------------------------------
/docs/screenshot01-gcpdashboard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot01-gcpdashboard.jpg
--------------------------------------------------------------------------------
/docs/screenshot02-vminstances.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot02-vminstances.jpg
--------------------------------------------------------------------------------
/docs/screenshot03-vmcreate.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot03-vmcreate.jpg
--------------------------------------------------------------------------------
/docs/screenshot04-ssh.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot04-ssh.jpg
--------------------------------------------------------------------------------
/docs/screenshot05-config.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot05-config.jpg
--------------------------------------------------------------------------------
/docs/screenshot06-buckets.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot06-buckets.jpg
--------------------------------------------------------------------------------
/docs/screenshot07-createbucket.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot07-createbucket.jpg
--------------------------------------------------------------------------------
/docs/screenshot08-tpunode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot08-tpunode.jpg
--------------------------------------------------------------------------------
/docs/screenshot09-finetunetpus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/jestimator/326824a477feef6d92e6f44175000907a72899ca/docs/screenshot09-finetunetpus.jpg
--------------------------------------------------------------------------------
/jestimator/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """JEstimator package with the Amos optimizer."""
16 |
17 | __version__ = '0.3.3'
18 |
19 | from . import amos
20 | from . import amos_helper
21 |
--------------------------------------------------------------------------------
/jestimator/amos.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implements the AMOS optimizer.
16 |
17 | AMOS stands for 'Adaptive weight-decay towards Model-Oriented Scale'. It
18 | combines Adam-like gradient scaling with a theoretically proven scheme for
19 | adaptive weight-decay and learning-rate decay.
20 |
21 | In order to be effective, AMOS requires each trainable variable to provide an
22 | `eta` hyper-parameter, indicating the target scale that the entries of the
23 | trained variable converge to. `eta` is used in a variable-specific learning-rate
24 | schedule.
25 | """
26 | from typing import Any, Callable, NamedTuple, Optional, Tuple, Union
27 |
28 | from flax.serialization import from_state_dict, to_state_dict # pylint: disable=g-multiple-import
29 | from flax.traverse_util import empty_node, flatten_dict, unflatten_dict # pylint: disable=g-multiple-import
30 | import jax
31 | import jax.numpy as jnp
32 | from jax.typing import ArrayLike
33 | import optax
34 |
35 | Shape = Tuple[int, ...]
36 | ScalarOrSchedule = Union[float, optax.Schedule]
37 | ParamsFn = Callable[[Tuple[str, ...], Shape], Any]
38 |
39 |
40 | class ScaleByAmosState(NamedTuple):
41 | """State for the Amos algorithm."""
42 | count: Optional[ArrayLike] # shape=(), dtype=jnp.int32.
43 | v: optax.Updates
44 | b: optax.Updates
45 |
46 |
47 | def scale_by_amos(
48 | learning_rate: ScalarOrSchedule,
49 | eta_fn: ParamsFn,
50 | shape_fn: Optional[ParamsFn] = None,
51 | beta: float = 0.999,
52 | extra_l2: float = 0.,
53 | d_coef: float = 0.25,
54 | c_coef: float = 0.25,
55 | epsilon: float = 1. / (1 << 125),
56 | ) -> optax.GradientTransformation:
57 | """Rescale updates according to the Amos algorithm."""
58 |
59 | def init_fn(params):
60 | flat_v = {}
61 | flat_b = {}
62 | flat_params = _flatten(to_state_dict(params), keep_empty_nodes=True)
63 | for name, theta in flat_params.items():
64 | if theta == empty_node:
65 | flat_v[name] = empty_node
66 | flat_b[name] = empty_node
67 | continue
68 |
69 | if shape_fn is None:
70 | v = jnp.zeros_like(theta)
71 | else:
72 | v = jnp.zeros(shape_fn(name, theta.shape), dtype=theta.dtype)
73 | flat_v[name] = v
74 | flat_b[name] = jnp.zeros_like(v)
75 |
76 | v = from_state_dict(params, _unflatten(flat_v))
77 | b = from_state_dict(params, _unflatten(flat_b))
78 | return ScaleByAmosState(count=jnp.array(0), v=v, b=b)
79 |
80 | def update_fn(updates, state, params):
81 | count = optax.safe_int32_increment(state.count)
82 | if callable(learning_rate):
83 | xi = learning_rate(count)
84 | else:
85 | xi = learning_rate
86 | bias_correction = 1. - beta**count
87 | xi2 = jnp.square(xi)
88 | c_coef_sqrt_xi = c_coef * jnp.sqrt(xi)
89 |
90 | flat_grad = _flatten(to_state_dict(updates), keep_empty_nodes=True)
91 | flat_v = _flatten(to_state_dict(state.v), keep_empty_nodes=True)
92 | flat_b = _flatten(to_state_dict(state.b), keep_empty_nodes=True)
93 | flat_params = _flatten(to_state_dict(params))
94 | for name, theta in flat_params.items():
95 | grad = flat_grad[name]
96 | v = flat_v[name]
97 | if v.shape:
98 | reduced = [i for i, k in enumerate(grad.shape) if v.shape[i] < k]
99 | g2 = jnp.mean(jnp.square(grad), axis=reduced, keepdims=True)
100 | else:
101 | g2 = jnp.mean(jnp.square(grad))
102 | v = v * beta + g2 * (1. - beta)
103 | flat_v[name] = v
104 | rcpl_v_hat = bias_correction / jnp.maximum(v, epsilon)
105 |
106 | b = flat_b[name]
107 | decay_factor_c = jax.lax.rsqrt(1. + c_coef_sqrt_xi * b)
108 | gamma = decay_factor_c * xi2 * rcpl_v_hat * g2
109 |
110 | init_lr = xi * eta_fn(name, theta.shape)
111 | decay_factor_d = jnp.reciprocal(1. + d_coef * jnp.sqrt(init_lr) * b)
112 | l2_regularization = (-0.5 * gamma - extra_l2) * theta
113 | flat_grad[name] = decay_factor_d * (
114 | l2_regularization - init_lr * jnp.sqrt(rcpl_v_hat) * grad)
115 | flat_b[name] = b + gamma * (1. + b)
116 |
117 | updates = from_state_dict(updates, _unflatten(flat_grad))
118 | v = from_state_dict(state.v, _unflatten(flat_v))
119 | b = from_state_dict(state.b, _unflatten(flat_b))
120 | return updates, ScaleByAmosState(count=count, v=v, b=b)
121 |
122 | return optax.GradientTransformation(init_fn, update_fn)
123 |
124 |
125 | def _flatten(x, keep_empty_nodes=False):
126 | if not isinstance(x, dict):
127 | return {(): x}
128 |
129 | return flatten_dict(x, keep_empty_nodes=keep_empty_nodes)
130 |
131 |
132 | def _unflatten(x):
133 | if tuple(x.keys()) == ((),):
134 | return x[()]
135 |
136 | return unflatten_dict(x)
137 |
138 |
139 | def amos(
140 | learning_rate: ScalarOrSchedule,
141 | eta_fn: ParamsFn,
142 | shape_fn: Optional[ParamsFn] = None,
143 | beta: float = 0.999,
144 | momentum: Optional[float] = None,
145 | clip_value: Optional[float] = None,
146 | extra_l2: float = 0.,
147 | d_coef: float = 0.25,
148 | c_coef: float = 0.25,
149 | epsilon: float = 1. / (1 << 125),
150 | ) -> optax.GradientTransformation:
151 | """The full Amos optimizer with optional gradient clipping and momentum.
152 |
153 | References:
154 | [The Amos Paper](https://arxiv.org/abs/2210.11693)
155 |
156 | Args:
157 | learning_rate: A float or callable for learning rate. When it is callable,
158 | the `leaning_rate` takes step count as input and returns a float scalar.
159 | Let N be the number of independent batches in the training data. It is
160 | recommended to set the learning rate to about 1/sqrt(N).
161 | eta_fn: A function that maps a variable name and shape to the variable-
162 | specific hyper-parameter 'eta' indicating the expected scale of entries.
163 | shape_fn: A function that maps a variable name and shape to the shape of the
164 | corresponding slot variables `v` and `b`. The returned shape should be
165 | broadcastable to the varialbe, while some axes might be reduced to 1 to
166 | save memory.
167 | beta: A float slightly < 1. We recommend setting `1 - beta` to the same
168 | order of magnitude as the learning rate. Defaults to 0.999.
169 | momentum: Exponential decay rate for optional moving average of updates.
170 | clip_value: Optional gradient clipping value.
171 | extra_l2: Addional L2 regularization (experimental). Defaults to 0.
172 | d_coef: Coefficient for decay_factor_d. Defaults to 0.25.
173 | c_coef: Coefficient for decay_factor_c. Defaults to 0.25.
174 | epsilon: The smallest positive normal to prevent division by 0.
175 |
176 | Returns:
177 | An (init_fn, update_fn) tuple.
178 | """
179 | tx = []
180 | if clip_value is not None and clip_value > 0.:
181 | tx.append(optax.clip(clip_value))
182 | tx.append(
183 | scale_by_amos(
184 | learning_rate,
185 | eta_fn,
186 | shape_fn=shape_fn,
187 | beta=beta,
188 | extra_l2=extra_l2,
189 | d_coef=d_coef,
190 | c_coef=c_coef,
191 | epsilon=epsilon))
192 | if momentum is not None and momentum > 0.:
193 | tx.append(optax.ema(momentum, debias=False))
194 |
195 | if len(tx) >= 2:
196 | return optax.chain(*tx)
197 | return tx[0]
198 |
--------------------------------------------------------------------------------
/jestimator/amos_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper utilities for the Amos optimizer."""
16 | import ast
17 | import math
18 | import operator as op
19 | import re
20 | from typing import Any, Dict, Tuple
21 |
22 | from absl import logging
23 | import jax
24 | from jax.sharding import PartitionSpec # pylint: disable=g-importing-member
25 | from jestimator.amos import ParamsFn, ScaleByAmosState, Shape # pylint: disable=g-multiple-import,g-importing-member
26 | import numpy
27 |
28 | _BIN_OP_MAP = {
29 | ast.Add: op.add,
30 | ast.Sub: op.sub,
31 | ast.Mult: op.mul,
32 | ast.Div: op.truediv,
33 | ast.FloorDiv: op.floordiv,
34 | ast.Mod: op.mod,
35 | ast.Pow: op.pow,
36 | }
37 |
38 |
39 | def evaluate(s: str, shape: Shape):
40 | """Evaluate simple expression. Allow 'SHAPE' referring to variable shape."""
41 |
42 | def _evaluate(node):
43 | if node is None:
44 | return None
45 |
46 | if isinstance(node, ast.BinOp):
47 | left = _evaluate(node.left)
48 | right = _evaluate(node.right)
49 | return _BIN_OP_MAP[type(node.op)](left, right)
50 |
51 | if isinstance(node, ast.Call):
52 | func_name = node.func
53 | assert isinstance(func_name, ast.Name)
54 | func = getattr(math, func_name.id, None)
55 | if func is None:
56 | func = getattr(numpy, func_name.id)
57 |
58 | assert not node.keywords
59 | args = [_evaluate(x) for x in node.args]
60 | return func(*args)
61 |
62 | if isinstance(node, ast.Constant):
63 | return node.value
64 |
65 | if isinstance(node, ast.Name):
66 | assert node.id == 'SHAPE'
67 | return shape
68 |
69 | if isinstance(node, ast.Num): # Python 3.7 compatibility
70 | return node.n
71 |
72 | if isinstance(node, ast.Index): # Python 3.8 compatibility
73 | return _evaluate(node.value)
74 |
75 | if isinstance(node, ast.Slice):
76 | return slice(
77 | _evaluate(node.lower), _evaluate(node.upper), _evaluate(node.step))
78 |
79 | if isinstance(node, ast.Subscript):
80 | return _evaluate(node.value)[_evaluate(node.slice)]
81 |
82 | if isinstance(node, ast.Tuple):
83 | return tuple([_evaluate(x) for x in node.elts])
84 |
85 | if isinstance(node, ast.UnaryOp):
86 | assert isinstance(node.op, ast.USub)
87 | return -_evaluate(node.operand) # pylint: disable=invalid-unary-operand-type
88 |
89 | raise TypeError(f'Cannot handle node type: {type(node).__name__}')
90 |
91 | node = ast.parse(s, mode='eval').body
92 | return _evaluate(node)
93 |
94 |
95 | def params_fn_from_assign_map(assign_map: Dict[str, Any],
96 | name_sep: str = '/',
97 | eval_str_value: bool = False) -> ParamsFn:
98 | """Creates a params_fn from assign_map.
99 |
100 | A params_fn maps each variable name and shape to some value. The variable name
101 | is a tuple of str, and shape is a tuple of int. An assign_map is a sequence of
102 | rules, where each rule maps a regex of variable names to a value.
103 |
104 | Args:
105 | assign_map: A dictionary mapping 'regex' to 'value'. Given a variable name,
106 | the returned params_fn will find the first matching 'regex' and return the
107 | corresponding 'value'.
108 | name_sep: Join the the variable name (tuple of str) by this separator before
109 | regex matching. Defaults to '/'.
110 | eval_str_value: If True, value can be str of simple expressions, which will
111 | be evaluated.
112 |
113 | Returns:
114 | params_fn: A function that maps each variable name and shape to a value.
115 | """
116 |
117 | def params_fn(name: Tuple[str, ...], shape: Shape):
118 | name_str = name_sep.join(name)
119 | for regex, value in assign_map.items():
120 | if re.match(regex, name_str):
121 | logging.info('Matched rule (%s -> %s) to variable %s of shape %s.',
122 | regex, value, name, shape)
123 | if eval_str_value and isinstance(value, str):
124 | return evaluate(value, shape)
125 | return value
126 | raise ValueError(f'No matching rule for variable {name} of shape {shape}.')
127 |
128 | return params_fn
129 |
130 |
131 | def maybe_reduce_axis_names(var, axes):
132 | """Prepend 'reduced_' to the axis name if a dimension is 1."""
133 | if not var.shape: # Scalar.
134 | return None
135 |
136 | if axes is None: # No axes info.
137 | return None
138 |
139 | assert len(var.shape) == len(axes), f'shape: {var.shape} axis: {axes}'
140 | names = [(f'reduced_{x}' if d == 1 else x) for d, x in zip(var.shape, axes)]
141 | return PartitionSpec(*names)
142 |
143 |
144 | def state_partition_rule(state: ScaleByAmosState, params_axes):
145 | """Creates partition for Amos states from partition of parameters."""
146 | return ScaleByAmosState(
147 | count=None,
148 | v=jax.tree.map(maybe_reduce_axis_names, state.v, params_axes),
149 | b=jax.tree.map(maybe_reduce_axis_names, state.b, params_axes))
150 |
--------------------------------------------------------------------------------
/jestimator/amos_helper_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for amos_helper."""
16 | import math
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from jestimator import amos_helper
21 |
22 |
23 | class AmosHelperTest(absltest.TestCase):
24 |
25 | def test_evaluate(self):
26 | shape = (7, 8, 9)
27 | x = amos_helper.evaluate('sqrt(1 / prod(SHAPE[:-1]))', shape)
28 | y = math.sqrt(1 / (shape[0] * shape[1]))
29 | self.assertEqual(x, y)
30 |
31 | x = amos_helper.evaluate('(1, 1, SHAPE[2])', shape)
32 | y = (1, 1, shape[2])
33 | self.assertSequenceEqual(x, y)
34 |
35 | def test_params_fn_from_assign_map(self):
36 | assign_map = {
37 | 'init_bn/scale': 'sqrt(1 / SHAPE[-1])',
38 | r'.*bn.?/scale$': 1.0,
39 | }
40 | fn = amos_helper.params_fn_from_assign_map(assign_map, eval_str_value=True)
41 | self.assertEqual(fn(('init_bn', 'scale'), (7, 256)), math.sqrt(1 / 256))
42 | self.assertEqual(fn(('decoder', 'layer_0', 'bn1', 'scale'), (256,)), 1.0)
43 |
44 |
45 | if __name__ == '__main__':
46 | absltest.main()
47 |
--------------------------------------------------------------------------------
/jestimator/amos_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for amos."""
16 | from absl.testing import absltest
17 | import jax
18 | import jax.numpy as jnp
19 | from jestimator import amos
20 |
21 |
22 | def _setup_parabola(dtype):
23 | """Quadratic function as an optimization target."""
24 | initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
25 | final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)
26 |
27 | @jax.grad
28 | def get_updates(params):
29 | return jnp.sum(jnp.square(params - final_params))
30 |
31 | return initial_params, final_params, get_updates
32 |
33 |
34 | def _setup_rosenbrock(dtype):
35 | """Rosenbrock function as an optimization target."""
36 | a = 1.0
37 | b = 100.0
38 |
39 | initial_params = jnp.array([0.0, 0.0], dtype=dtype)
40 | final_params = jnp.array([a, a**2], dtype=dtype)
41 |
42 | @jax.grad
43 | def get_updates(params):
44 | return (jnp.square(a - params[0]) +
45 | b * jnp.square(params[1] - params[0]**2))
46 |
47 | return initial_params, final_params, get_updates
48 |
49 |
50 | class AmosTest(absltest.TestCase):
51 |
52 | def test_parabola(self):
53 | opt = amos.amos(
54 | learning_rate=1.0,
55 | eta_fn=lambda name, shape: 1.0,
56 | shape_fn=lambda name, shape: ())
57 | initial_params, final_params, get_updates = _setup_parabola(jnp.float32)
58 |
59 | @jax.jit
60 | def step(params, state):
61 | updates = get_updates(params)
62 | updates, state = opt.update(updates, state, params)
63 | params = jax.tree_util.tree_map(
64 | lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype), params,
65 | updates)
66 | return params, state
67 |
68 | params = initial_params
69 | state = opt.init(params)
70 | for _ in range(10000):
71 | params, state = step(params, state)
72 |
73 | self.assertSequenceAlmostEqual(params, final_params)
74 |
75 | def test_rosenbrock(self):
76 | opt = amos.amos(
77 | learning_rate=0.5,
78 | eta_fn=lambda name, shape: 1.0,
79 | shape_fn=lambda name, shape: (),
80 | beta=0.5,
81 | clip_value=1.0,
82 | momentum=0.9)
83 | initial_params, final_params, get_updates = _setup_rosenbrock(jnp.float32)
84 |
85 | @jax.jit
86 | def step(params, state):
87 | updates = get_updates(params)
88 | updates, state = opt.update(updates, state, params)
89 | params = jax.tree_util.tree_map(
90 | lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype), params,
91 | updates)
92 | return params, state
93 |
94 | params = initial_params
95 | state = opt.init(params)
96 | for _ in range(10000):
97 | params, state = step(params, state)
98 |
99 | self.assertSequenceAlmostEqual(params, final_params)
100 |
101 |
102 | if __name__ == '__main__':
103 | absltest.main()
104 |
--------------------------------------------------------------------------------
/jestimator/checkpoint_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Checkpointing utilities."""
16 | import os
17 | import time
18 | from typing import Any, List, Mapping, Optional, Tuple, Union
19 |
20 | from absl import logging
21 | from flax.core.frozen_dict import freeze
22 | from flax.training import checkpoints
23 | from flax.traverse_util import flatten_dict, unflatten_dict # pylint: disable=g-multiple-import
24 | import jax
25 | from jax.experimental import multihost_utils
26 | from t5x.utils import get_local_data
27 | from tensorflow import errors
28 | from tensorflow.io import gfile
29 |
30 |
31 | def partial_restore(state,
32 | ckpt_dict: Mapping[str, Any],
33 | load_step: bool = False):
34 | """Partially restore from checkpoint."""
35 | flat_x = flatten_dict(state.params)
36 | flat_y = flatten_dict(ckpt_dict['target'])
37 | flat_ret = {}
38 | for k, v in flat_x.items():
39 | u = flat_y.get(k)
40 | if u is None:
41 | logging.warning('Key %s not found in ckpt.', '/'.join(k))
42 | flat_ret[k] = v
43 | elif u.shape == v.shape:
44 | flat_ret[k] = u
45 | else:
46 | logging.warning('Shape %s != in ckpt %s', v.shape, u.shape)
47 | if u.shape[1:] == v.shape[1:]:
48 | if u.shape[0] < v.shape[0]:
49 | flat_ret[k] = v.at[:u.shape[0]].set(u)
50 | else:
51 | flat_ret[k] = u[:v.shape[0]]
52 | else:
53 | logging.warning('Ignored checkpoint due to shape discrepancy.')
54 | flat_ret[k] = v
55 |
56 | params = freeze(unflatten_dict(flat_ret))
57 | state = state.replace(params=params)
58 | if 'flax_mutables' in ckpt_dict:
59 | state = state.replace(_vars=freeze(ckpt_dict['flax_mutables']))
60 | if load_step:
61 | state = state.replace(step=get_local_data(ckpt_dict['state']['step']))
62 | return state
63 |
64 |
65 | def latest_ckpt_path(model_dir: Optional[str] = None,
66 | init_ckpt_path: Optional[str] = None,
67 | prefix: str = 'checkpoint_') -> Tuple[Optional[str], bool]:
68 | """Get path of the latest checkpoint.
69 |
70 | If `init_ckpt_path` comes from `model_dir`, then it overrides other
71 | checkpoints in `model_dir`;
72 | Else, if `model_dir` is not empty, load the latest checkpoint in it;
73 | Otherwise, load the checkpoint specified by `init_ckpt_path`.
74 |
75 | Args:
76 | model_dir: Dir to store model checkpoints.
77 | init_ckpt_path: An optional checkpoint to initialize the model.
78 | prefix: str: name prefix of checkpoint files.
79 |
80 | Returns:
81 | (ckpt_path, same_dir).
82 | ckpt_path: The latest or init checkpoint path.
83 | same_dir: Whether the checkpoint is in the same `model_dir`.
84 | """
85 | if model_dir is not None:
86 | if model_dir.startswith('gs://'):
87 | model_dir = model_dir.rstrip('/') + '/'
88 | else:
89 | model_dir = os.path.abspath(model_dir) + os.sep
90 | if init_ckpt_path is not None:
91 | ckpt_dir = os.path.abspath(os.path.dirname(init_ckpt_path)) + os.sep
92 | if ckpt_dir.startswith(model_dir):
93 | logging.info(
94 | 'Use checkpoint specified by `checkpoint_path` (%s),'
95 | ' since it comes from specified `model_dir` (%s) as well,'
96 | ' and hence overrides other checkpoints in the dir.',
97 | init_ckpt_path, model_dir)
98 | return init_ckpt_path, True
99 |
100 | if jax.process_index() == 0:
101 | for tmp in gfile.glob(os.path.join(model_dir, f'{prefix}*tmp*')):
102 | try:
103 | gfile.rmtree(tmp)
104 | except errors.NotFoundError:
105 | pass
106 | multihost_utils.sync_global_devices(
107 | f'jestimator:latest_ckpt_path:remove_tmp_ckpts:{model_dir}')
108 | latest = checkpoints.latest_checkpoint(model_dir, prefix=prefix)
109 | if latest is not None:
110 | logging.info(
111 | 'Use the latest checkpoint (%s) from `model_dir`,'
112 | ' and ignores `checkpoint_path`.', latest)
113 | return latest, True
114 |
115 | logging.info(
116 | 'Use checkpoint specified by `checkpoint_path` (%s),'
117 | ' since `model_dir` (%s) is empty.', init_ckpt_path, model_dir)
118 | return init_ckpt_path, False
119 |
120 |
121 | def last_evaluated_ckpt(last_eval_path: str) -> Optional[str]:
122 | """Get the last evaluated checkpoint path."""
123 | if gfile.exists(last_eval_path):
124 | logging.info('Reading last_evaluated from: %s', last_eval_path)
125 | with gfile.GFile(last_eval_path, 'rb') as f:
126 | last_evaluated = f.read().decode('utf-8')
127 | logging.info('Found last_evaluated: %s', last_evaluated)
128 | else:
129 | last_evaluated = None
130 | return last_evaluated
131 |
132 |
133 | def sorted_checkpoints(
134 | ckpt_dir: Union[str, os.PathLike], # pylint: disable=g-bare-generic
135 | prefix: str = 'checkpoint_') -> List[str]:
136 | """Retrieve the path of all checkpoints in a directory.
137 |
138 | Args:
139 | ckpt_dir: str: directory of checkpoints to restore from.
140 | prefix: str: name prefix of checkpoint files.
141 |
142 | Returns:
143 | A list of checkpoint paths.
144 | """
145 | ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str
146 | glob_path = gfile.glob(os.path.join(ckpt_dir, f'{prefix}*'))
147 | glob_tmp = frozenset(gfile.glob(os.path.join(ckpt_dir, f'{prefix}*tmp*')))
148 | glob_path = [f for f in glob_path if f not in glob_tmp]
149 | checkpoint_files = checkpoints.natural_sort(glob_path)
150 | return checkpoint_files
151 |
152 |
153 | def checkpoints_iterator_from_oldest(model_dir: str,
154 | last_eval_path: str,
155 | min_interval_secs: float = 0.,
156 | last_evaluated: Optional[str] = None):
157 | """Iterate checkpoints in a dir from the oldest, and wait for new."""
158 | logging.info('Monitoring checkpoints in dir: %s', model_dir)
159 | while True:
160 | ckpts = sorted_checkpoints(model_dir)[1:]
161 | if last_evaluated is not None:
162 | for i, x in enumerate(ckpts):
163 | if x == last_evaluated:
164 | ckpts = ckpts[i + 1:]
165 | break
166 |
167 | for x in ckpts:
168 | if gfile.exists(x):
169 | last_evaluated = x
170 | if jax.process_index() == 0:
171 | with gfile.GFile(last_eval_path, 'w') as f:
172 | f.write(x)
173 | yield x
174 |
175 | if not ckpts:
176 | time.sleep(min_interval_secs)
177 |
178 |
179 | def last_score(last_score_path: str) -> float:
180 | """Get the last evaluated score."""
181 | score = None
182 | if gfile.exists(last_score_path):
183 | logging.info('Reading last score from: %s', last_score_path)
184 | with gfile.GFile(last_score_path, 'rb') as f:
185 | score = f.read().decode('utf-8')
186 | logging.info('Found last score: %s', score)
187 | if not score:
188 | score = '-inf'
189 | return float(score)
190 |
--------------------------------------------------------------------------------
/jestimator/data/pipeline_lm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Data pipeline for language models: fixed-length sequences of token ids."""
16 | from typing import Callable, Optional, Sequence
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def pipeline_from_filenames(
22 | filenames: Sequence[str],
23 | seq_length: int,
24 | allow_remainder: bool = False,
25 | dataset_fn: Optional[Callable[[str], tf.data.Dataset]] = None,
26 | cache: bool = False,
27 | random_skip: bool = False,
28 | feature_fn: Optional[Callable] = None, # pylint: disable=g-bare-generic
29 | interleave: bool = False,
30 | shard_num: int = 1,
31 | shard_index: int = 0,
32 | epochs: Optional[int] = 1,
33 | num_take: int = -1):
34 | r"""Creates a tensorflow dataset from filenames.
35 |
36 | Args:
37 | filenames: A list of file name strings.
38 | seq_length: Length of token-id sequences in the output dtaset.
39 | allow_remainder: bool. Whether to allow the last sequence to be shorter.
40 | dataset_fn: A function that maps a path str to a tf.data.Dataset. If None,
41 | defaults to tf.data.Dataset.load().
42 | cache: Whether to cache the constructed datasets in memory.
43 | random_skip: bool. Whether to randomly skip some tokens in the beginning of
44 | each dataset. This is used to increase randomness in training.
45 | feature_fn: A function that maps a token-id sequence to model-specific
46 | features. This is called before batching.
47 | interleave: bool. Whether to randomly interleave multiple files.
48 | shard_num: int. Number of shards.
49 | shard_index: int. Worker index.
50 | epochs: Number of epochs to repeat. If None, repeat forever.
51 | num_take: int. If not -1, take the first n examples.
52 |
53 | Returns:
54 | An instance of tf.data.Dataset.
55 | """
56 | num_files = len(filenames)
57 |
58 | shard_data = True
59 | if num_files % shard_num == 0 or num_files / shard_num > 9:
60 | filenames = filenames[shard_index::shard_num]
61 | shard_data = False
62 |
63 | ds = []
64 | for path in filenames:
65 | if dataset_fn is None:
66 | d = tf.data.Dataset.load(path)
67 | else:
68 | d = dataset_fn(path)
69 | if cache and num_take == -1:
70 | d = d.cache()
71 | ds.append(d)
72 | fd = tf.data.Dataset.from_tensor_slices(ds)
73 |
74 | if interleave and num_files > 1:
75 | fd = fd.shuffle(num_files, seed=num_files + 37)
76 | fd = fd.repeat(epochs)
77 | if random_skip:
78 | fd = tf.data.Dataset.zip((fd, tf.data.Dataset.random(seed=num_files + 19)))
79 |
80 | def seq_fn(d, *rnd_):
81 | if random_skip:
82 | (rnd,) = rnd_
83 | d = d.skip(rnd % seq_length)
84 | d = d.batch(seq_length, drop_remainder=(not allow_remainder))
85 | if shard_data:
86 | d = d.shard(shard_num, shard_index)
87 | if feature_fn is not None:
88 | d = d.map(feature_fn, num_parallel_calls=tf.data.AUTOTUNE)
89 | return d
90 |
91 | if interleave and num_files > 1:
92 | ret = fd.interleave(
93 | seq_fn, deterministic=False, num_parallel_calls=tf.data.AUTOTUNE)
94 | else:
95 | ret = fd.flat_map(seq_fn)
96 | ret = ret.take(num_take)
97 | if num_take >= 0 and cache:
98 | ret = ret.cache()
99 | return ret
100 |
101 |
102 | def lm_data(
103 | seq_length: int,
104 | allow_remainder: bool = False,
105 | dataset_fn: Optional[Callable[[str], tf.data.Dataset]] = None,
106 | cache: bool = False,
107 | random_skip: bool = False,
108 | feature_fn: Optional[Callable] = None, # pylint: disable=g-bare-generic
109 | interleave: bool = False):
110 | """Builds a data pipeline for language modeling.
111 |
112 | Args:
113 | seq_length: Length of token-id sequences in the output dtaset.
114 | allow_remainder: bool. Whether to allow the last sequence to be shorter.
115 | dataset_fn: A function that maps a path str to a tf.data.Dataset. If None,
116 | defaults to tf.data.Dataset.load().
117 | cache: Whether to cache the constructed datasets in memory.
118 | random_skip: bool. Whether to randomly skip some tokens in the beginning of
119 | each dataset. This is used to increase randomness in training.
120 | feature_fn: A function that maps a token-id sequence to model-specific
121 | features. This is called before batching.
122 | interleave: bool. Whether to randomly interleave multiple files.
123 |
124 | Returns:
125 | A `data_fn` to be used by jestimator.
126 | """
127 |
128 | def data_fn(filenames, **kwargs):
129 | return pipeline_from_filenames(
130 | filenames,
131 | seq_length,
132 | allow_remainder=allow_remainder,
133 | dataset_fn=dataset_fn,
134 | cache=cache,
135 | random_skip=random_skip,
136 | feature_fn=feature_fn,
137 | interleave=interleave,
138 | **kwargs)
139 |
140 | return data_fn
141 |
--------------------------------------------------------------------------------
/jestimator/data/pipeline_rec.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Data pipeline for record datasets."""
16 | from typing import Callable, Optional, Sequence
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def pipeline_from_filenames(
22 | filenames: Sequence[str],
23 | dataset_fn: Optional[Callable[[str], tf.data.Dataset]] = None,
24 | cache: bool = False,
25 | feature_fn: Optional[Callable] = None, # pylint: disable=g-bare-generic
26 | interleave: bool = False,
27 | shard_num: int = 1,
28 | shard_index: int = 0,
29 | epochs: Optional[int] = 1,
30 | num_take: int = -1):
31 | r"""Creates a tensorflow dataset from filenames.
32 |
33 | Args:
34 | filenames: A list of file name strings.
35 | dataset_fn: A function that maps a path str to a tf.data.Dataset. If None,
36 | defaults to tf.data.Dataset.load().
37 | cache: Whether to cache the constructed datasets in memory.
38 | feature_fn: A function that maps a token-id sequence to model-specific
39 | features. This is called before batching.
40 | interleave: bool. Whether to randomly interleave multiple files.
41 | shard_num: int. Number of shards.
42 | shard_index: int. Worker index.
43 | epochs: Number of epochs to repeat. If None, repeat forever.
44 | num_take: int. If not -1, take the first n examples.
45 |
46 | Returns:
47 | An instance of tf.data.Dataset.
48 | """
49 | num_files = len(filenames)
50 |
51 | shard_data = True
52 | if num_files % shard_num == 0 or num_files / shard_num > 9:
53 | filenames = filenames[shard_index::shard_num]
54 | shard_data = False
55 |
56 | ds = []
57 | for path in filenames:
58 | d = dataset_fn(path)
59 | if cache and num_take == -1:
60 | if feature_fn is not None:
61 | d = d.map(feature_fn, num_parallel_calls=tf.data.AUTOTUNE)
62 | d = d.cache()
63 | ds.append(d)
64 |
65 | if interleave and num_files > 1:
66 | rnd = tf.data.Dataset.random(seed=num_files + 11)
67 | if epochs is not None:
68 | rnd = rnd.take(epochs)
69 | ret = rnd.flat_map(
70 | lambda x: tf.data.Dataset.sample_from_datasets(ds, seed=x))
71 | else:
72 | ret = ds[0]
73 | for d in ds[1:]:
74 | ret = ret.concatenate(d)
75 | ret = ret.repeat(epochs)
76 | if shard_data:
77 | ret = ret.shard(shard_num, shard_index)
78 |
79 | ret = ret.take(num_take)
80 | if not cache or num_take != -1:
81 | if feature_fn is not None:
82 | ret = ret.map(feature_fn, num_parallel_calls=tf.data.AUTOTUNE)
83 | if cache and num_take != -1:
84 | ret = ret.cache()
85 | return ret
86 |
87 |
88 | def rec_data(
89 | dataset_fn: Optional[Callable[[str], tf.data.Dataset]] = None,
90 | cache: bool = False,
91 | feature_fn: Optional[Callable] = None, # pylint: disable=g-bare-generic
92 | interleave: bool = False):
93 | """Builds a data pipeline for records.
94 |
95 | Args:
96 | dataset_fn: A function that maps a path str to a tf.data.Dataset. If None,
97 | defaults to tf.data.Dataset.load().
98 | cache: Whether to cache the constructed datasets in memory.
99 | feature_fn: A function that maps a token-id sequence to model-specific
100 | features. This is called before batching.
101 | interleave: bool. Whether to randomly interleave multiple files.
102 |
103 | Returns:
104 | A `data_fn` to be used by jestimator.
105 | """
106 |
107 | def data_fn(filenames, **kwargs):
108 | return pipeline_from_filenames(
109 | filenames,
110 | dataset_fn=dataset_fn,
111 | cache=cache,
112 | feature_fn=feature_fn,
113 | interleave=interleave,
114 | **kwargs)
115 |
116 | return data_fn
117 |
--------------------------------------------------------------------------------
/jestimator/data/pipeline_seqio.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Data pipeline that wraps the seqio lib."""
16 | from typing import Mapping, Optional, Sequence
17 |
18 | import seqio
19 |
20 | SEQIO_PREFIX = "seqio://"
21 |
22 |
23 | def is_seqio(filenames: Optional[Sequence[str]]) -> bool:
24 | return bool(filenames and filenames[0].startswith(SEQIO_PREFIX))
25 |
26 |
27 | def pipeline_from_mixture_or_task_name(
28 | mixture_or_task_name: str,
29 | task_feature_lengths: Mapping[str, int],
30 | feature_converter: seqio.FeatureConverter,
31 | use_cached: bool = False,
32 | shuffle: bool = False,
33 | seed: Optional[int] = None,
34 | shard_num: int = 1,
35 | shard_index: int = 0,
36 | epochs: Optional[int] = 1,
37 | num_take: int = -1):
38 | r"""Creates a tensorflow dataset from filenames.
39 |
40 | Args:
41 | mixture_or_task_name: str. Name of task or mixture.
42 | task_feature_lengths: Dict of sequence lengths of features.
43 | feature_converter: Model-specific feature converter.
44 | use_cached: bool. Whether to use a precomputed version of the dataset from a
45 | cache dir. Defaults to False.
46 | shuffle: bool. Whether to shuffle data. Defaults to False.
47 | seed: int. Random seed. Defaults to None.
48 | shard_num: int. Number of shards.
49 | shard_index: int. Worker index.
50 | epochs: Number of epochs to repeat. If None, repeat forever.
51 | num_take: int. If not -1, take the first n examples.
52 |
53 | Returns:
54 | An instance of tf.data.Dataset.
55 | """
56 | if mixture_or_task_name.startswith(SEQIO_PREFIX):
57 | mixture_or_task_name = mixture_or_task_name[len(SEQIO_PREFIX):]
58 |
59 | sp = mixture_or_task_name.split("/")
60 | if sp[-1].startswith("split="):
61 | mixture_or_task_name = "/".join(sp[:-1])
62 | split = sp[-1][len("split="):]
63 | else:
64 | split = None
65 |
66 | ret = seqio.get_dataset(
67 | mixture_or_task_name=mixture_or_task_name,
68 | task_feature_lengths=task_feature_lengths,
69 | dataset_split=split,
70 | shuffle=shuffle,
71 | num_epochs=epochs,
72 | feature_converter=feature_converter,
73 | shard_info=seqio.ShardInfo(shard_index, shard_num),
74 | use_cached=use_cached,
75 | seed=seed)
76 | ret = ret.take(num_take)
77 | return ret
78 |
79 |
80 | def seqio_data(task_feature_lengths: Mapping[str, int],
81 | feature_converter: seqio.FeatureConverter,
82 | use_cached: bool = False,
83 | shuffle: bool = False,
84 | seed: Optional[int] = None):
85 | """Wraps a seqio data pipeline."""
86 |
87 | def data_fn(filenames, **kwargs):
88 | (mixture_or_task_name,) = filenames
89 | return pipeline_from_mixture_or_task_name(
90 | mixture_or_task_name,
91 | task_feature_lengths,
92 | feature_converter,
93 | use_cached=use_cached,
94 | shuffle=shuffle,
95 | seed=seed,
96 | **kwargs)
97 |
98 | return data_fn
99 |
--------------------------------------------------------------------------------
/jestimator/data/reader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Data file readers for different formats."""
16 | import dataclasses
17 | import enum
18 | from typing import Mapping, Optional, Tuple, Union
19 |
20 | import tensorflow as tf
21 | import tensorflow_datasets as tfds
22 |
23 | TFDS_PREFIX = 'tfds://'
24 |
25 |
26 | # Supported file format for record data.
27 | @enum.unique
28 | class RecordFormat(enum.Enum):
29 | TFRECORD = 'tfrecord'
30 |
31 |
32 | def get_format(path: str) -> RecordFormat:
33 | """Returns the type of record file, 'sstable, 'recordio', or 'tfrecord'."""
34 |
35 | # Try as a TFRecord.
36 | try:
37 | next(tf.data.TFRecordDataset(path).as_numpy_iterator())
38 | return RecordFormat.TFRECORD
39 | except (IOError, StopIteration, tf.errors.DataLossError, AttributeError):
40 | pass
41 |
42 | raise TypeError(f'Invalid file format: {path}')
43 |
44 |
45 | def get_record_dataset(
46 | path: str,
47 | file_format: Optional[Union[str, RecordFormat]] = None) -> tf.data.Dataset:
48 | r"""Creates a tensorflow dataset from path.
49 |
50 | Args:
51 | path: Path to the dataset.
52 | file_format: The format of dataset files.
53 |
54 | Returns:
55 | An instance of tf.data.Dataset.
56 | """
57 | if file_format is None:
58 | file_format = get_format(path)
59 | elif not isinstance(file_format, RecordFormat):
60 | file_format = RecordFormat(file_format)
61 |
62 | if file_format == RecordFormat.TFRECORD:
63 | d = tf.data.TFRecordDataset(path)
64 | return d
65 |
66 | raise TypeError(f'Unknown file format: {path}')
67 |
68 |
69 | def get_tfds_dataset(path: str) -> tf.data.Dataset:
70 | """Creates a tensorflow dataset from path."""
71 | if path.startswith(TFDS_PREFIX):
72 | path = path[len(TFDS_PREFIX):]
73 | sp = path.split(':', 1)
74 | if len(sp) == 2:
75 | data_dir, path = sp
76 | else:
77 | data_dir = None
78 |
79 | sp = path.split('/')
80 | if sp[-1].startswith('split='):
81 | path = '/'.join(sp[:-1])
82 | split = sp[-1][len('split='):]
83 | else:
84 | split = None
85 |
86 | d = tfds.load(
87 | path,
88 | split=split,
89 | as_supervised=False,
90 | shuffle_files=False,
91 | data_dir=data_dir)
92 | return d
93 |
94 |
95 | def serialize_tensor_dict(data: Mapping[str, tf.Tensor]) -> bytes:
96 | """Converts a tensor dict to bytes, via tf.train.Example."""
97 | feature = {}
98 | for k, v in data.items():
99 | bytes_list = tf.train.BytesList(value=[tf.io.serialize_tensor(v).numpy()])
100 | feature[k] = tf.train.Feature(bytes_list=bytes_list)
101 | ex = tf.train.Example(features=tf.train.Features(feature=feature))
102 | return ex.SerializeToString()
103 |
104 |
105 | def parse_tensor_dict(
106 | x, elem_spec: Mapping[str, tf.TensorSpec]) -> Mapping[str, tf.Tensor]:
107 | """Creates a tensor dict from serialized bytes."""
108 | features = {
109 | k: tf.io.FixedLenFeature([], v.dtype) for k, v in elem_spec.items()
110 | }
111 | x = tf.io.parse_single_example(x, features)
112 | x = {k: tf.ensure_shape(v, elem_spec[k].shape) for k, v in x.items()}
113 | return x
114 |
115 |
116 | def to_str(x: tf.Tensor, encoding: str = 'utf-8') -> str:
117 | """Converts an eager tf.string tensor to str. Fail-safe."""
118 | try:
119 | ret = x.numpy().decode(encoding)
120 | except UnicodeDecodeError:
121 | ret = ''
122 | return ret
123 |
124 |
125 | @dataclasses.dataclass
126 | class PyOutSpec:
127 | """Specifies the shape and type of a value returned by a py_func."""
128 | shape: Tuple[int]
129 | dtype: tf.DType
130 |
131 |
132 | def apply_py_fn(py_fn, data, out_spec):
133 | """Applies a python function to graph-mode data.
134 |
135 | Args:
136 | py_fn: A python function.
137 | data: A nested structure of graph-mode data.
138 | out_spec: A nested structure of PyOutSpec of the returned values.
139 |
140 | Returns:
141 | A nested structure of graph-mode values.
142 | """
143 | flat_data = tf.nest.flatten(data, expand_composites=True)
144 |
145 | def fn(*flat):
146 | data_eager = tf.nest.pack_sequence_as(data, flat, expand_composites=True)
147 | ret = py_fn(data_eager)
148 | return tf.nest.flatten(ret)
149 |
150 | flat_out_spec = tf.nest.flatten(out_spec)
151 | ret = tf.py_function(fn, flat_data, [x.dtype for x in flat_out_spec])
152 | ret = [tf.reshape(y, x.shape) for x, y in zip(flat_out_spec, ret)]
153 | return tf.nest.pack_sequence_as(out_spec, ret)
154 |
155 |
156 | def lines_iterator(path: str,
157 | encoding: str = 'utf-8',
158 | split: bool = False,
159 | allow_empty: bool = False):
160 | """Line iterator from a file."""
161 | with tf.io.gfile.GFile(path, 'rb') as f:
162 | for line in f:
163 | try:
164 | line = line.decode(encoding)
165 | except UnicodeDecodeError:
166 | continue
167 | ret = line.split() if split else line.rstrip('\r\n')
168 | if allow_empty or ret:
169 | yield ret
170 |
--------------------------------------------------------------------------------
/jestimator/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility functions for building data pipelines."""
16 | from typing import Callable, List, Sequence, Tuple, Union
17 |
18 | import jax
19 | from jax.experimental.multihost_utils import host_local_array_to_global_array
20 | import tensorflow as tf
21 | from tensorflow.io import gfile
22 |
23 | TFDS_PREFIX = 'tfds://'
24 | SEQIO_PREFIX = 'seqio://'
25 | DUMMY_PREFIX = 'dummy://'
26 |
27 |
28 | class StringIterable(object):
29 | """Converts `x` to iterable of strings.
30 |
31 | `x` can be a string, list of strings, or a tf.data.Dataset of tf.string.
32 | This prevents a Python string to be converted to iterable of characters.
33 | """
34 |
35 | def __init__(self, x: Union[str, Sequence[str], tf.data.Dataset]):
36 | self.x = x
37 |
38 | def __iter__(self):
39 | if isinstance(self.x, str):
40 | return iter([self.x])
41 | if isinstance(self.x, tf.data.Dataset):
42 | return self.x.as_numpy_iterator()
43 | return iter(self.x)
44 |
45 |
46 | def get_dataset_filenames(pattern: Union[str, Sequence[str], tf.data.Dataset],
47 | do_glob: bool = True) -> List[str]:
48 | """Returns a list of file names for a given sharded/glob file pattern.
49 |
50 | Args:
51 | pattern: A string, list of strings, or tf.data.Dataset containing file path,
52 | sharded path, tfds address, or glob_pattern.
53 | do_glob: Whether to do gfile glob.
54 |
55 | Returns:
56 | A list of filenames.
57 |
58 | Raises:
59 | ValueError: When some filename pattern does not match any file.
60 | """
61 | paths = []
62 | for pat in StringIterable(pattern):
63 | if pat.startswith(DUMMY_PREFIX):
64 | continue
65 | if pat.startswith(SEQIO_PREFIX):
66 | paths.append(pat)
67 | continue
68 | if pat.startswith(TFDS_PREFIX):
69 | paths.append(pat)
70 | continue
71 |
72 | if do_glob:
73 | glob = gfile.glob(pat)
74 | if not glob:
75 | raise ValueError(f'File pattern {pat} has no match.') from None
76 | paths.extend(glob)
77 | else:
78 | paths.append(pat)
79 | return paths
80 |
81 |
82 | def count_dataset(d: tf.data.Dataset, batch_size: int) -> Tuple[int, int]:
83 | """Count the iterator length of a dataset, and the last batch size.
84 |
85 | Args:
86 | d: tf.data.Dataset.
87 | batch_size: int. Batch size.
88 |
89 | Returns:
90 | (dataset_length, last_batch_size).
91 | """
92 | d = d.batch(batch_size)
93 | d = d.prefetch(tf.data.AUTOTUNE)
94 | dataset_length = 0
95 | x = None
96 | for x in d:
97 | dataset_length += 1
98 | last_batch_size = tf.shape(tf.nest.flatten(x)[0])[0].numpy()
99 | return dataset_length, last_batch_size
100 |
101 |
102 | def transpose_dataset(d: tf.data.Dataset, size_d: int, size_per_elem: int,
103 | bs: int) -> tf.data.Dataset:
104 | """Transpose a dataset of datasets."""
105 | flat = d.flat_map(lambda x: x)
106 | ret = flat.window(size_d, 1, size_per_elem, drop_remainder=True)
107 | ret = ret.flat_map(lambda x: x.batch(bs, drop_remainder=True))
108 | return ret
109 |
110 |
111 | def create_data_pipeline(filenames: List[str],
112 | data_fn: Callable[..., tf.data.Dataset],
113 | data_layout,
114 | shuffle_buf=None,
115 | consecutive=None,
116 | shard_source=False,
117 | **kwargs):
118 | """Builds a data pipeline with partitioning, batching and shuffle.
119 |
120 | When `consecutive` is not None, this pipeline produces consecutive batches,
121 | which can be used to divide very long sequences into multiple batches.
122 |
123 | Args:
124 | filenames: List of data file names.
125 | data_fn: A function that returns a tf dataset.
126 | data_layout: Partitioning data layout.
127 | shuffle_buf: int. Buffer size for shuffling. Do not shuffle if None.
128 | consecutive: int. If not None, every n batches are consecutive.
129 | shard_source: bool. For multiple workers, whether to shard the data source
130 | instead of sharding at the end of data pipeline. Defaults to False.
131 | **kwargs: Other kwargs passed to `data_fn`.
132 |
133 | Returns:
134 | A tf dataset instance.
135 | """
136 | shard_id = data_layout.shard_id
137 | num_shards = data_layout.num_shards
138 | batch_size = data_layout.batch_size
139 | assert num_shards == 1 or batch_size is not None
140 | if batch_size is not None:
141 | bs, r = divmod(batch_size, num_shards)
142 | assert r == 0, f'{batch_size} % {num_shards} != 0'
143 |
144 | if consecutive is None:
145 | if shard_source:
146 | d = data_fn(
147 | filenames, shard_num=num_shards, shard_index=shard_id, **kwargs)
148 | else:
149 | d = data_fn(filenames, **kwargs)
150 |
151 | if shuffle_buf is not None:
152 | d = d.shuffle(shuffle_buf, seed=shuffle_buf)
153 | if batch_size is not None:
154 | d = d.batch(bs, drop_remainder=True)
155 | if num_shards > 1 and not shard_source:
156 | d = d.shard(num_shards, shard_id)
157 |
158 | else:
159 | epochs = kwargs.pop('epochs', 1)
160 | d = data_fn(filenames, **kwargs)
161 | d = d.window(consecutive, drop_remainder=True).repeat(epochs)
162 | if num_shards > 1:
163 | d = d.shard(num_shards, shard_id)
164 | if shuffle_buf is not None:
165 | d = d.shuffle(shuffle_buf, seed=shuffle_buf)
166 | if batch_size is None:
167 | d = d.flat_map(lambda x: x)
168 | else:
169 | d = d.window(bs, drop_remainder=True)
170 | d = d.flat_map(lambda x: transpose_dataset(x, bs, consecutive, bs))
171 |
172 | d = d.prefetch(tf.data.AUTOTUNE)
173 | return d
174 |
175 |
176 | class DataIterable(object):
177 | """Converts tf.data.Dataset to iterable of numpy arrays."""
178 |
179 | def __init__(self, x: tf.data.Dataset, partitioner):
180 | self.x = x
181 | self.partitioner = partitioner
182 |
183 | def __iter__(self):
184 | ret = self.x.as_numpy_iterator()
185 | mesh = self.partitioner.mesh
186 | spec = self.partitioner.data_partition_spec
187 | it = (host_local_array_to_global_array(batch, mesh, spec) for batch in ret)
188 | return it
189 |
--------------------------------------------------------------------------------
/jestimator/modeling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Modeling utilities."""
16 | import inspect
17 | import threading
18 | from typing import Callable, Optional, Tuple
19 |
20 | from flax import linen as nn
21 | import jax
22 | import jax.numpy as jnp
23 |
24 | from jax.typing import ArrayLike
25 | from flaxformer.types import DType, Initializer, PRNGKey, Shape # pylint: disable=g-multiple-import
26 |
27 |
28 | def sparse_xe_with_logits(labels: ArrayLike,
29 | logits: ArrayLike,
30 | mask: Optional[ArrayLike] = None,
31 | normalized: bool = False,
32 | reduce_all: bool = True):
33 | """Sparse cross entropy from softmax logits.
34 |
35 | Args:
36 | labels: int tensor.
37 | logits: float tensor of shape (labels.shape + [num_labels]).
38 | mask: 0/1 float tensor, the same shape as `labels`.
39 | normalized: Whether `logits` is normalized.
40 | reduce_all: Whether to reduce_sum the loss tensor.
41 |
42 | Returns:
43 | Cross-entropy loss. If `reduce_all` is True, returns a scalar tensor.
44 | Otherwise returns a float tensor of the same shape as `labels`.
45 | """
46 | if not normalized:
47 | logits = jax.nn.log_softmax(logits)
48 |
49 | labels_ = jnp.expand_dims(labels, -1)
50 | llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
51 | llh = jnp.squeeze(llh_, -1)
52 |
53 | if mask is not None:
54 | llh = jnp.where(mask, llh, 0.)
55 | if reduce_all:
56 | loss = -jnp.sum(llh)
57 | else:
58 | loss = -llh
59 | return loss
60 |
61 |
62 | def normalize_loss_by_size(
63 | loss: ArrayLike, size: ArrayLike
64 | ) -> Tuple[ArrayLike, ArrayLike]:
65 | """Normalize a loss value by size of labels."""
66 | loss = jnp.asarray(loss)
67 | size = jnp.asarray(size, loss.dtype)
68 | loss = loss * jax.lax.rsqrt(size)
69 | size = jnp.sqrt(size)
70 | return loss, size
71 |
72 |
73 | def unstack(x, axis):
74 | """Unstack a tensor along axis."""
75 | return [
76 | jax.lax.index_in_dim(x, i, axis, keepdims=False)
77 | for i in range(x.shape[axis])
78 | ]
79 |
80 |
81 | _thread_local = threading.local()
82 |
83 |
84 | def global_kwargs(*inherits: str, pass_down: bool = False):
85 | """Function decorator to use global kwargs.
86 |
87 | A utility for passing keyword arguments down to nested sub-calls.
88 |
89 | # Example
90 |
91 | ```
92 | @global_kwargs('attention_mask', 'training', 'dropout_rate')
93 | def func1(x, attention_mask=None, training=False, dropout_rate=0.1):
94 | # calculation code...
95 | if attention_mask is not None:
96 | att += attention_mask
97 | if training:
98 | attention = tf.nn.dropout(attention, dropout_rate)
99 | # ...
100 |
101 | @global_kwargs(pass_down=True)
102 | def func2(hidden):
103 | # call `func1`
104 | func1(hidden)
105 | # ...
106 |
107 | # Then, one can pass arguments 'attention_mask', 'training', 'dropout_rate'
108 | # from `func2` down to `func1`, without explicitly declaring those arguments
109 | # in `func2`:
110 |
111 | func2(a, attention_mask=b, training=True, dropout_rate=0.5) # It works!
112 | ```
113 |
114 | Args:
115 | *inherits: Keys to be inherited from the global context.
116 | pass_down: bool. If True, unrecognized keys will be passed to sub-routines.
117 |
118 | Returns:
119 | The function wrapper.
120 | """
121 |
122 | def wrap(func: Callable) -> Callable: # pylint: disable=g-bare-generic
123 | func_signature = inspect.signature(func, follow_wrapped=False)
124 | func_params = func_signature.parameters
125 | for v in func_params.values():
126 | assert v.kind != inspect.Parameter.VAR_KEYWORD, (
127 | '`func` should not have VAR_KEYWORD parameter.')
128 | for k in inherits:
129 | assert k in func_params, (
130 | f'The inherit key ({k}) is not an argument of `func`.')
131 |
132 | def wrapped(*args, **kwargs):
133 | current = getattr(_thread_local, 'current_inherit_kwargs', {})
134 | func_kwargs = {k: current[k] for k in inherits if k in current}
135 | if pass_down:
136 | subrout = {**current}
137 |
138 | for k, v in kwargs.items():
139 | if k in func_params:
140 | func_kwargs[k] = v
141 | else:
142 | assert pass_down, f'Unrecognized kwarg ({k}).'
143 | subrout[k] = v
144 |
145 | if pass_down:
146 | _thread_local.current_inherit_kwargs = subrout
147 | ret = func(*args, **func_kwargs)
148 | if pass_down:
149 | _thread_local.current_inherit_kwargs = current
150 | return ret
151 |
152 | return wrapped
153 |
154 | return wrap
155 |
156 |
157 | def truncated_normal_initializer(stddev: ArrayLike) -> Initializer:
158 | """Truncated normal initializer."""
159 |
160 | def init(key: PRNGKey, shape: Shape, dtype: DType) -> ArrayLike:
161 | return jax.random.truncated_normal(
162 | key=key, lower=-2., upper=2., shape=shape, dtype=dtype) * stddev
163 |
164 | return init
165 |
166 |
167 | class Dropout(nn.Module):
168 | """Dropout layer with fast random generator."""
169 | rate: float
170 |
171 | @global_kwargs('enable_dropout')
172 | def __call__(self, inputs: ArrayLike, enable_dropout: bool = False):
173 | """Applies a random dropout mask to the input."""
174 | if not enable_dropout:
175 | return inputs
176 | if self.rate == 0.:
177 | return inputs
178 | # Prevent gradient NaNs in 1.0 edge-case.
179 | if self.rate == 1.0:
180 | return jnp.zeros_like(inputs)
181 |
182 | inputs = jnp.asarray(inputs)
183 | mask = jax.lax.rng_uniform(0., 1., inputs.shape) < self.rate
184 | return jnp.where(mask, 0., inputs / (1. - self.rate))
185 |
--------------------------------------------------------------------------------
/jestimator/models/bert/finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Sequence classification.
16 |
17 | # For debug run locally:
18 |
19 | ## Train:
20 |
21 | ```
22 | PYTHONPATH=. python3 \
23 | jestimator/estimator.py \
24 | --module_imp="jestimator.models.bert.finetune" \
25 | --module_config="jestimator/models/bert/finetune.py" \
26 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
27 | --module_config.segment_names="sentence1,sentence2" \
28 | --module_config.model_config.num_labels=2 \
29 | --train_pattern="tfds://glue/rte/split=train" \
30 | --valid_pattern="tfds://glue/rte/split=validation" \
31 | --model_dir="$HOME/models/bert_rte" \
32 | --checkpoint_path="gs://gresearch/checkpoints_in_amos_paper/\
33 | adamw/bert-base/checkpoint_300000" \
34 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \
35 | --check_every_steps=10 --logtostderr
36 | ```
37 |
38 | ## Eval:
39 |
40 | ```
41 | PYTHONPATH=. python3 \
42 | jestimator/estimator.py \
43 | --module_imp="jestimator.models.bert.finetune" \
44 | --module_config="jestimator/models/bert/finetune.py" \
45 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
46 | --module_config.segment_names="sentence1,sentence2" \
47 | --module_config.model_config.num_labels=2 \
48 | --module_config.eval_metric="accuracy" \
49 | --eval_pattern="tfds://glue/rte/split=validation" \
50 | --model_dir="$HOME/models/bert_rte" \
51 | --eval_batch_size=4 --num_eval_examples=4 \
52 | --logtostderr
53 | ```
54 |
55 | ## Predict:
56 |
57 | ```
58 | PYTHONPATH=. python3 \
59 | jestimator/estimator.py \
60 | --module_imp="jestimator.models.bert.finetune" \
61 | --module_config="jestimator/models/bert/finetune.py" \
62 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
63 | --module_config.segment_names="sentence1,sentence2" \
64 | --module_config.model_config.num_labels=2 \
65 | --module_config.label_names="entailment,not_entailment" \
66 | --pred_pattern="tfds://glue/rte/split=test" \
67 | --model_dir="$HOME/models/bert_rte" \
68 | --pred_batch_size=4 --num_pred_examples=4 \
69 | --logtostderr
70 | ```
71 | """
72 |
73 | import dataclasses
74 |
75 | import jax
76 | import jax.numpy as jnp
77 | from jestimator.data import reader
78 | from jestimator.data.pipeline_rec import rec_data
79 | from jestimator.data.reader import PyOutSpec
80 | from jestimator.models.bert import modeling
81 | from jestimator.states import Evaluator, InferState, MeanMetrics, Predictor, TrainState # pylint: disable=g-multiple-import
82 | import ml_collections
83 | from ml_collections.config_dict import config_dict
84 | import optax
85 | from scipy import stats as scipy_stats
86 | from sklearn import metrics as sklearn_metrics
87 | import tensorflow as tf
88 |
89 | import sentencepiece as spm
90 |
91 |
92 | def get_config():
93 | """Returns a config object for modeling flags."""
94 | module_config = ml_collections.ConfigDict()
95 |
96 | # Model config.
97 | model_config = modeling.ModelConfig()
98 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config))
99 | module_config.model_config = model_config
100 |
101 | # Optimizer config.
102 | opt_config = ml_collections.ConfigDict()
103 | opt_config.optimizer = 'adam'
104 | opt_config.learning_rate = 5e-6
105 | module_config.opt_config = opt_config
106 |
107 | # Other config.
108 | module_config.vocab_path = config_dict.placeholder(str)
109 | module_config.segment_names = config_dict.placeholder(str)
110 | module_config.eval_metric = config_dict.placeholder(str)
111 | module_config.output_path = config_dict.placeholder(str)
112 | module_config.label_names = config_dict.placeholder(str)
113 | module_config.stsb = False
114 | return module_config
115 |
116 |
117 | def load_config(global_flags):
118 | """Init config data from global flags."""
119 | config = ml_collections.ConfigDict()
120 | config.update(global_flags.module_config)
121 |
122 | tokenizer = spm.SentencePieceProcessor()
123 | tokenizer.Load(config.vocab_path)
124 | config.model_config.vocab_size = tokenizer.GetPieceSize()
125 |
126 | segment_names = config.segment_names.split(',')
127 | num_segments = len(segment_names)
128 | config.model_config.num_segments = num_segments + 1
129 |
130 | # Only a frozen config (hashable object) can be passed to jit functions
131 | # (i.e. train_step/valid_step/infer_step).
132 | config.frozen = ml_collections.FrozenConfigDict(config)
133 |
134 | # Construct data pipelines in the following (using TensorFLow):
135 | max_length = config.model_config.max_length
136 | max_len_1 = (max_length - 1) // num_segments
137 | cls_token_id = tokenizer.PieceToId('')
138 | sep_token_id = tokenizer.PieceToId('')
139 | eos_token_id = tokenizer.PieceToId('')
140 | data_keys = ['idx', 'label'] + config.segment_names.split(',')
141 | mode = global_flags.mode
142 |
143 | def tokenize_fn(texts):
144 | ids = []
145 | for s in texts:
146 | s = tf.strings.lower(s).numpy()
147 | ids.append(tf.convert_to_tensor(tokenizer.EncodeAsIds(s), tf.int32))
148 | return ids
149 |
150 | def example_fn(data):
151 | data = {k: data[k] for k in data_keys if k in data}
152 | texts = [data[k] for k in segment_names]
153 | out_spec = [PyOutSpec((-1,), tf.int32)] * num_segments
154 | tokenized = reader.apply_py_fn(tokenize_fn, texts, out_spec)
155 |
156 | max_len_0 = max_length - 1
157 | input_ids = [tf.concat([[cls_token_id], tokenized[0]], 0)]
158 | for x in tokenized[1:]:
159 | x = tf.concat([[sep_token_id], x], 0)[:max_len_1]
160 | input_ids.append(x)
161 | max_len_0 = max_len_0 - tf.shape(x)[0]
162 | input_ids[0] = input_ids[0][:max_len_0]
163 | input_ids.append([eos_token_id])
164 |
165 | segment_ids = [tf.ones_like(x) * i for i, x in enumerate(input_ids)]
166 | input_ids = tf.concat(input_ids, 0)
167 | input_mask = tf.ones_like(input_ids)
168 | segment_ids = tf.concat(segment_ids, 0)
169 |
170 | pad_len = max_length - tf.shape(input_ids)[0]
171 | input_ids = tf.pad(input_ids, [[0, pad_len]])
172 | input_mask = tf.pad(input_mask, [[0, pad_len]])
173 | segment_ids = tf.pad(segment_ids, [[0, pad_len]])
174 |
175 | ret = {
176 | 'input_ids': tf.ensure_shape(input_ids, (max_length,)),
177 | 'input_mask': tf.ensure_shape(input_mask, (max_length,)),
178 | 'segment_ids': tf.ensure_shape(segment_ids, (max_length,)),
179 | }
180 | if mode == 'train':
181 | ret['label'] = data['label']
182 | if config.stsb:
183 | ret['label'] /= 5.0
184 | else:
185 | ret['idx'] = data['idx']
186 | if mode.startswith('eval'):
187 | ret = (data['label'], ret)
188 | return ret
189 |
190 | def dataset_fn(path: str) -> tf.data.Dataset:
191 | d = reader.get_tfds_dataset(path)
192 | d = d.map(example_fn, tf.data.AUTOTUNE)
193 | return d
194 |
195 | config.train_data_fn = rec_data(
196 | dataset_fn=dataset_fn, cache=True, interleave=True)
197 | config.eval_data_fn = config.valid_data_fn = rec_data(
198 | dataset_fn=dataset_fn, cache=True)
199 | config.pred_data_fn = rec_data(dataset_fn=dataset_fn)
200 | return config
201 |
202 |
203 | def get_train_state(config, rng):
204 | """Create train state."""
205 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
206 | model = modeling.ModelForSeqCls(model_config)
207 |
208 | opt_config = config.opt_config
209 | if opt_config.optimizer == 'adam':
210 | optimizer = optax.adam(learning_rate=opt_config.learning_rate)
211 |
212 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss')
213 | return TrainState.create(metrics_mod, optimizer, model, rng, jnp.array([[0]]))
214 |
215 |
216 | def train_step(config, train_batch, state: TrainState, metrics):
217 | """Training step."""
218 | loss_fn = (
219 | modeling.ModelForSeqCls.mse_loss
220 | if config.stsb else modeling.ModelForSeqCls.xe_loss)
221 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)(
222 | state.params,
223 | train_batch['label'],
224 | train_batch['input_ids'],
225 | segment_ids=train_batch['segment_ids'],
226 | input_mask=train_batch['input_mask'],
227 | enable_dropout=True,
228 | method=loss_fn)
229 | _, metrics = state.metrics_mod.apply(
230 | metrics,
231 | 'train_loss',
232 | loss,
233 | size,
234 | method=MeanMetrics.update,
235 | mutable=['metrics'])
236 | return state.apply_gradients(grads=grads), metrics
237 |
238 |
239 | def valid_step(config, valid_batch, state: TrainState, metrics):
240 | """Validation step."""
241 | loss_fn = (
242 | modeling.ModelForSeqCls.mse_loss
243 | if config.stsb else modeling.ModelForSeqCls.xe_loss)
244 | loss, size = state.apply_fn(
245 | state.variables(),
246 | valid_batch['label'],
247 | valid_batch['input_ids'],
248 | segment_ids=valid_batch['segment_ids'],
249 | input_mask=valid_batch['input_mask'],
250 | method=loss_fn)
251 | _, metrics = state.metrics_mod.apply(
252 | metrics,
253 | 'valid_loss',
254 | loss,
255 | size,
256 | method=MeanMetrics.update,
257 | mutable=['metrics'])
258 | return metrics
259 |
260 |
261 | def get_infer_state(config):
262 | """Create infer state."""
263 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
264 | model = modeling.ModelForSeqCls(model_config)
265 | return InferState.create(model, jnp.array([[0]]))
266 |
267 |
268 | def infer_step(config, batch, state: InferState) -> InferState:
269 | """Infer step."""
270 | logits = state.apply_fn(
271 | state.variables(),
272 | batch['input_ids'],
273 | segment_ids=batch['segment_ids'],
274 | input_mask=batch['input_mask'])
275 | if config.stsb:
276 | pred = jax.nn.softmax(logits)[..., 0] * 5.0
277 | else:
278 | pred = jnp.argmax(logits, axis=-1)
279 | return state.replace(ret={
280 | 'idx': batch['idx'],
281 | 'prediction': pred,
282 | })
283 |
284 |
285 | def get_evaluator(config) -> Evaluator:
286 | """Create evaluator."""
287 | eval_fns = {
288 | 'accuracy': sklearn_metrics.accuracy_score,
289 | 'f1': sklearn_metrics.f1_score,
290 | 'spearmanr': lambda x, y: scipy_stats.spearmanr(x, y)[0],
291 | }
292 |
293 | def proc_fn(infer):
294 | return infer['prediction']
295 |
296 | metric = config.eval_metric
297 | return Evaluator({metric: (proc_fn, eval_fns[metric])})
298 |
299 |
300 | def get_predictor(config) -> Predictor:
301 | """Create predictor."""
302 | pre_str = 'index\tprediction'
303 | label_names = (None if config.label_names is None else
304 | config.label_names.split(','))
305 |
306 | def proc_fn(infer):
307 | ret = []
308 | for x, y in zip(infer['idx'], infer['prediction']):
309 | z = y if label_names is None else label_names[y]
310 | ret.append(f'{x}\t{z}')
311 | return ret
312 |
313 | return Predictor(proc_fn, config.output_path, pre_str=pre_str)
314 |
--------------------------------------------------------------------------------
/jestimator/models/bert/pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Pretrain.
16 |
17 | # For debug run locally:
18 |
19 | ```
20 | PYTHONPATH=. python3 \
21 | jestimator/estimator.py \
22 | --module_imp="jestimator.models.bert.pretrain" \
23 | --module_config="jestimator/models/bert/pretrain.py" \
24 | --module_config.model_config.vocab_size=32000 \
25 | --module_config.mask_token_id=4 \
26 | --train_pattern="gs://gresearch/checkpoints_in_amos_paper/data/\
27 | books-00000-of-00500" \
28 | --valid_pattern="gs://gresearch/checkpoints_in_amos_paper/data/ptb" \
29 | --model_dir="$HOME/models/bert_pretrain" \
30 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \
31 | --check_every_steps=10 --logtostderr
32 | ```
33 | """
34 |
35 | import dataclasses
36 | from typing import Mapping
37 |
38 | import jax
39 | import jax.numpy as jnp
40 | from jestimator import amos
41 | from jestimator.data.pipeline_lm import lm_data
42 | from jestimator.models.bert import modeling
43 | from jestimator.states import TrainState, MeanMetrics # pylint: disable=g-multiple-import
44 | import ml_collections
45 | from ml_collections.config_dict import config_dict
46 | import optax
47 | import tensorflow as tf
48 |
49 |
50 | def get_config():
51 | """Returns a config object for modeling flags."""
52 | module_config = ml_collections.ConfigDict()
53 |
54 | # Model config.
55 | model_config = modeling.ModelConfig()
56 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config))
57 | module_config.model_config = model_config
58 |
59 | # Optimizer config.
60 | opt_config = ml_collections.ConfigDict()
61 | opt_config.optimizer = 'adamw'
62 | opt_config.learning_rate = 1e-4
63 | opt_config.warmup_steps = 10000
64 | opt_config.linear_decay_to_step = config_dict.placeholder(int)
65 | opt_config.momentum = 0.9
66 | opt_config.beta = 0.999
67 | opt_config.weight_decay = 0.01
68 | module_config.opt_config = opt_config
69 |
70 | # Other config.
71 | module_config.mask_token_id = config_dict.placeholder(int)
72 | module_config.mask_rate = 0.15
73 | return module_config
74 |
75 |
76 | def load_config(global_flags):
77 | """Init config data from global flags."""
78 | config = ml_collections.ConfigDict()
79 | config.update(global_flags.module_config)
80 |
81 | # Only a frozen config (hashable object) can be passed to jit functions
82 | # (i.e. train_step/valid_step/infer_step).
83 | config.frozen = ml_collections.FrozenConfigDict(config)
84 |
85 | # Construct data pipelines in the following (using TensorFLow):
86 | seq_length = config.model_config.max_length
87 |
88 | def feature_fn(token_ids: tf.Tensor) -> Mapping[str, tf.Tensor]:
89 | """Builds a feature dict to be compatible with seqio."""
90 | return {'targets': tf.ensure_shape(token_ids, (seq_length,))}
91 |
92 | config.train_data_fn = lm_data(
93 | seq_length, random_skip=True, feature_fn=feature_fn, interleave=True)
94 | config.valid_data_fn = lm_data(seq_length, feature_fn=feature_fn)
95 | return config
96 |
97 |
98 | def get_train_state(config, rng) -> TrainState:
99 | """Create train state."""
100 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
101 | model = modeling.ModelForPretrain(model_config)
102 | opt_config = config.opt_config
103 | warmup = opt_config.warmup_steps
104 | decay = opt_config.linear_decay_to_step
105 |
106 | def lr_schedule(step):
107 | lr = opt_config.learning_rate
108 | if warmup is not None:
109 | lr *= jnp.minimum(1., step / warmup)
110 | if decay is not None:
111 | lr *= 1. - jnp.maximum(0., step - warmup) / (decay - warmup)
112 | elif decay is not None:
113 | lr *= 1. - step / decay
114 | return lr
115 |
116 | if opt_config.optimizer == 'adamw':
117 | optimizer = optax.adamw(
118 | learning_rate=lr_schedule,
119 | b1=opt_config.momentum,
120 | b2=opt_config.beta,
121 | weight_decay=opt_config.weight_decay)
122 | elif opt_config.optimizer == 'amos':
123 | optimizer = amos.amos(
124 | lr_schedule,
125 | modeling.get_eta_fn(model_config),
126 | shape_fn=modeling.get_shape_fn(model_config),
127 | beta=opt_config.beta,
128 | momentum=opt_config.momentum,
129 | clip_value=1.)
130 |
131 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss', 'valid_mrr')
132 | return TrainState.create(metrics_mod, optimizer, model, rng, jnp.array([[0]]))
133 |
134 |
135 | def train_step(config, train_batch, state: TrainState, metrics):
136 | """Training step."""
137 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)(
138 | state.params,
139 | train_batch['targets'],
140 | config.mask_token_id,
141 | mask_rate=config.mask_rate,
142 | input_mask=train_batch.get('input_mask'),
143 | enable_dropout=True,
144 | method=modeling.ModelForPretrain.mlm_train_loss)
145 | _, metrics = state.metrics_mod.apply(
146 | metrics,
147 | 'train_loss',
148 | loss,
149 | size,
150 | method=MeanMetrics.update,
151 | mutable=['metrics'])
152 | return state.apply_gradients(grads=grads), metrics
153 |
154 |
155 | def valid_step(config, valid_batch, state: TrainState, metrics):
156 | """Validation step."""
157 |
158 | def body(i, metrics):
159 | del i # Unused.
160 | loss, mrr, size = state.apply_fn(
161 | state.variables(),
162 | valid_batch['targets'],
163 | config.mask_token_id,
164 | mask_rate=config.mask_rate,
165 | input_mask=valid_batch.get('input_mask'),
166 | method=modeling.ModelForPretrain.mlm_valid_metrics)
167 | _, metrics = state.metrics_mod.apply(
168 | metrics,
169 | 'valid_loss',
170 | loss,
171 | size,
172 | method=MeanMetrics.update,
173 | mutable=['metrics'])
174 | _, metrics = state.metrics_mod.apply(
175 | metrics,
176 | 'valid_mrr',
177 | mrr,
178 | size,
179 | method=MeanMetrics.update,
180 | mutable=['metrics'])
181 | return metrics
182 |
183 | return jax.lax.fori_loop(0, 20, body, metrics)
184 |
--------------------------------------------------------------------------------
/jestimator/models/bert_rpe/finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Sequence classification.
16 |
17 | # For debug run locally:
18 |
19 | ## Train:
20 |
21 | ```
22 | PYTHONPATH=. python3 \
23 | jestimator/estimator.py \
24 | --module_imp="jestimator.models.bert_rpe.finetune" \
25 | --module_config="jestimator/models/bert_rpe/finetune.py" \
26 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
27 | --module_config.segment_names="sentence1,sentence2" \
28 | --module_config.model_config.num_labels=2 \
29 | --train_pattern="tfds://glue/rte/split=train" \
30 | --valid_pattern="tfds://glue/rte/split=validation" \
31 | --model_dir="$HOME/models/bert_rte" \
32 | --checkpoint_path="gs://gresearch/checkpoints_in_amos_paper/\
33 | adamw/rpe/checkpoint_300000" \
34 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \
35 | --check_every_steps=10 --logtostderr
36 | ```
37 |
38 | ## Eval:
39 |
40 | ```
41 | PYTHONPATH=. python3 \
42 | jestimator/estimator.py \
43 | --module_imp="jestimator.models.bert_rpe.finetune" \
44 | --module_config="jestimator/models/bert_rpe/finetune.py" \
45 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
46 | --module_config.segment_names="sentence1,sentence2" \
47 | --module_config.model_config.num_labels=2 \
48 | --module_config.eval_metric="accuracy" \
49 | --eval_pattern="tfds://glue/rte/split=validation" \
50 | --model_dir="$HOME/models/bert_rte" \
51 | --eval_batch_size=4 --num_eval_examples=4 \
52 | --logtostderr
53 | ```
54 |
55 | ## Predict:
56 |
57 | ```
58 | PYTHONPATH=. python3 \
59 | jestimator/estimator.py \
60 | --module_imp="jestimator.models.bert_rpe.finetune" \
61 | --module_config="jestimator/models/bert_rpe/finetune.py" \
62 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
63 | --module_config.segment_names="sentence1,sentence2" \
64 | --module_config.model_config.num_labels=2 \
65 | --module_config.label_names="entailment,not_entailment" \
66 | --pred_pattern="tfds://glue/rte/split=test" \
67 | --model_dir="$HOME/models/bert_rte" \
68 | --pred_batch_size=4 --num_pred_examples=4 \
69 | --logtostderr
70 | ```
71 | """
72 |
73 | import dataclasses
74 |
75 | import jax
76 | import jax.numpy as jnp
77 | from jestimator.data import reader
78 | from jestimator.data.pipeline_rec import rec_data
79 | from jestimator.data.reader import PyOutSpec
80 | from jestimator.models.bert_rpe import modeling
81 | from jestimator.states import Evaluator, InferState, MeanMetrics, Predictor, TrainState # pylint: disable=g-multiple-import
82 | import ml_collections
83 | from ml_collections.config_dict import config_dict
84 | import optax
85 | from scipy import stats as scipy_stats
86 | from sklearn import metrics as sklearn_metrics
87 | import tensorflow as tf
88 |
89 | import sentencepiece as spm
90 |
91 |
92 | def get_config():
93 | """Returns a config object for modeling flags."""
94 | module_config = ml_collections.ConfigDict()
95 |
96 | # Model config.
97 | model_config = modeling.ModelConfig()
98 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config))
99 | module_config.model_config = model_config
100 |
101 | # Optimizer config.
102 | opt_config = ml_collections.ConfigDict()
103 | opt_config.optimizer = 'adam'
104 | opt_config.learning_rate = 5e-6
105 | module_config.opt_config = opt_config
106 |
107 | # Other config.
108 | module_config.vocab_path = config_dict.placeholder(str)
109 | module_config.segment_names = config_dict.placeholder(str)
110 | module_config.eval_metric = config_dict.placeholder(str)
111 | module_config.output_path = config_dict.placeholder(str)
112 | module_config.label_names = config_dict.placeholder(str)
113 | module_config.stsb = False
114 | return module_config
115 |
116 |
117 | def load_config(global_flags):
118 | """Init config data from global flags."""
119 | config = ml_collections.ConfigDict()
120 | config.update(global_flags.module_config)
121 |
122 | tokenizer = spm.SentencePieceProcessor()
123 | tokenizer.Load(config.vocab_path)
124 | config.model_config.vocab_size = tokenizer.GetPieceSize()
125 |
126 | segment_names = config.segment_names.split(',')
127 | num_segments = len(segment_names)
128 | config.model_config.num_segments = num_segments + 1
129 |
130 | # Only a frozen config (hashable object) can be passed to jit functions
131 | # (i.e. train_step/valid_step/infer_step).
132 | config.frozen = ml_collections.FrozenConfigDict(config)
133 |
134 | # Construct data pipelines in the following (using TensorFLow):
135 | max_length = config.model_config.max_length
136 | max_len_1 = (max_length - 1) // num_segments
137 | cls_token_id = tokenizer.PieceToId('')
138 | sep_token_id = tokenizer.PieceToId('')
139 | eos_token_id = tokenizer.PieceToId('')
140 | data_keys = ['idx', 'label'] + config.segment_names.split(',')
141 | mode = global_flags.mode
142 |
143 | def tokenize_fn(texts):
144 | ids = []
145 | for s in texts:
146 | s = tf.strings.lower(s).numpy()
147 | ids.append(tf.convert_to_tensor(tokenizer.EncodeAsIds(s), tf.int32))
148 | return ids
149 |
150 | def example_fn(data):
151 | data = {k: data[k] for k in data_keys if k in data}
152 | texts = [data[k] for k in segment_names]
153 | out_spec = [PyOutSpec((-1,), tf.int32)] * num_segments
154 | tokenized = reader.apply_py_fn(tokenize_fn, texts, out_spec)
155 |
156 | max_len_0 = max_length - 1
157 | input_ids = [tf.concat([[cls_token_id], tokenized[0]], 0)]
158 | for x in tokenized[1:]:
159 | x = tf.concat([[sep_token_id], x], 0)[:max_len_1]
160 | input_ids.append(x)
161 | max_len_0 = max_len_0 - tf.shape(x)[0]
162 | input_ids[0] = input_ids[0][:max_len_0]
163 | input_ids.append([eos_token_id])
164 |
165 | segment_ids = [tf.ones_like(x) * i for i, x in enumerate(input_ids)]
166 | input_ids = tf.concat(input_ids, 0)
167 | input_mask = tf.ones_like(input_ids)
168 | segment_ids = tf.concat(segment_ids, 0)
169 |
170 | pad_len = max_length - tf.shape(input_ids)[0]
171 | input_ids = tf.pad(input_ids, [[0, pad_len]])
172 | input_mask = tf.pad(input_mask, [[0, pad_len]])
173 | segment_ids = tf.pad(segment_ids, [[0, pad_len]])
174 |
175 | ret = {
176 | 'input_ids': tf.ensure_shape(input_ids, (max_length,)),
177 | 'input_mask': tf.ensure_shape(input_mask, (max_length,)),
178 | 'segment_ids': tf.ensure_shape(segment_ids, (max_length,)),
179 | }
180 | if mode == 'train':
181 | ret['label'] = data['label']
182 | if config.stsb:
183 | ret['label'] /= 5.0
184 | else:
185 | ret['idx'] = data['idx']
186 | if mode.startswith('eval'):
187 | ret = (data['label'], ret)
188 | return ret
189 |
190 | def dataset_fn(path: str) -> tf.data.Dataset:
191 | d = reader.get_tfds_dataset(path)
192 | d = d.map(example_fn, tf.data.AUTOTUNE)
193 | return d
194 |
195 | config.train_data_fn = rec_data(
196 | dataset_fn=dataset_fn, cache=True, interleave=True)
197 | config.eval_data_fn = config.valid_data_fn = rec_data(
198 | dataset_fn=dataset_fn, cache=True)
199 | config.pred_data_fn = rec_data(dataset_fn=dataset_fn)
200 | return config
201 |
202 |
203 | def get_train_state(config, rng):
204 | """Create train state."""
205 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
206 | model = modeling.ModelForSeqCls(model_config)
207 |
208 | opt_config = config.opt_config
209 | if opt_config.optimizer == 'adam':
210 | optimizer = optax.adam(learning_rate=opt_config.learning_rate)
211 |
212 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss')
213 | dummy_input = jnp.array([[0] * config.model_config.max_length])
214 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy_input)
215 |
216 |
217 | def train_step(config, train_batch, state: TrainState, metrics):
218 | """Training step."""
219 | loss_fn = (
220 | modeling.ModelForSeqCls.mse_loss
221 | if config.stsb else modeling.ModelForSeqCls.xe_loss)
222 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)(
223 | state.params,
224 | train_batch['label'],
225 | train_batch['input_ids'],
226 | segment_ids=train_batch['segment_ids'],
227 | input_mask=train_batch['input_mask'],
228 | enable_dropout=True,
229 | method=loss_fn)
230 | _, metrics = state.metrics_mod.apply(
231 | metrics,
232 | 'train_loss',
233 | loss,
234 | size,
235 | method=MeanMetrics.update,
236 | mutable=['metrics'])
237 | return state.apply_gradients(grads=grads), metrics
238 |
239 |
240 | def valid_step(config, valid_batch, state: TrainState, metrics):
241 | """Validation step."""
242 | loss_fn = (
243 | modeling.ModelForSeqCls.mse_loss
244 | if config.stsb else modeling.ModelForSeqCls.xe_loss)
245 | loss, size = state.apply_fn(
246 | state.variables(),
247 | valid_batch['label'],
248 | valid_batch['input_ids'],
249 | segment_ids=valid_batch['segment_ids'],
250 | input_mask=valid_batch['input_mask'],
251 | method=loss_fn)
252 | _, metrics = state.metrics_mod.apply(
253 | metrics,
254 | 'valid_loss',
255 | loss,
256 | size,
257 | method=MeanMetrics.update,
258 | mutable=['metrics'])
259 | return metrics
260 |
261 |
262 | def get_infer_state(config):
263 | """Create infer state."""
264 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
265 | model = modeling.ModelForSeqCls(model_config)
266 | dummy_input = jnp.array([[0] * config.model_config.max_length])
267 | return InferState.create(model, dummy_input)
268 |
269 |
270 | def infer_step(config, batch, state: InferState) -> InferState:
271 | """Infer step."""
272 | logits = state.apply_fn(
273 | state.variables(),
274 | batch['input_ids'],
275 | segment_ids=batch['segment_ids'],
276 | input_mask=batch['input_mask'])
277 | if config.stsb:
278 | pred = jax.nn.softmax(logits)[..., 0] * 5.0
279 | else:
280 | pred = jnp.argmax(logits, axis=-1)
281 | return state.replace(ret={
282 | 'idx': batch['idx'],
283 | 'prediction': pred,
284 | })
285 |
286 |
287 | def get_evaluator(config) -> Evaluator:
288 | """Create evaluator."""
289 | eval_fns = {
290 | 'accuracy': sklearn_metrics.accuracy_score,
291 | 'f1': sklearn_metrics.f1_score,
292 | 'spearmanr': lambda x, y: scipy_stats.spearmanr(x, y)[0],
293 | }
294 |
295 | def proc_fn(infer):
296 | return infer['prediction']
297 |
298 | metric = config.eval_metric
299 | return Evaluator({metric: (proc_fn, eval_fns[metric])})
300 |
301 |
302 | def get_predictor(config) -> Predictor:
303 | """Create predictor."""
304 | pre_str = 'index\tprediction'
305 | label_names = (None if config.label_names is None else
306 | config.label_names.split(','))
307 |
308 | def proc_fn(infer):
309 | ret = []
310 | for x, y in zip(infer['idx'], infer['prediction']):
311 | z = y if label_names is None else label_names[y]
312 | ret.append(f'{x}\t{z}')
313 | return ret
314 |
315 | return Predictor(proc_fn, config.output_path, pre_str=pre_str)
316 |
--------------------------------------------------------------------------------
/jestimator/models/bert_rpe/pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Pretrain.
16 |
17 | # For debug run locally:
18 |
19 | ```
20 | PYTHONPATH=. python3 \
21 | jestimator/estimator.py \
22 | --module_imp="jestimator.models.bert_rpe.pretrain" \
23 | --module_config="jestimator/models/bert_rpe/pretrain.py" \
24 | --module_config.model_config.vocab_size=32000 \
25 | --module_config.mask_token_id=4 \
26 | --train_pattern="gs://gresearch/checkpoints_in_amos_paper/data/\
27 | books-00000-of-00500" \
28 | --valid_pattern="gs://gresearch/checkpoints_in_amos_paper/data/ptb" \
29 | --model_dir="$HOME/models/bert_rpe_pretrain" \
30 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \
31 | --check_every_steps=10 --logtostderr
32 | ```
33 | """
34 |
35 | import dataclasses
36 | from typing import Mapping
37 |
38 | import jax
39 | import jax.numpy as jnp
40 | from jestimator import amos
41 | from jestimator.data.pipeline_lm import lm_data
42 | from jestimator.models.bert_rpe import modeling
43 | from jestimator.states import TrainState, MeanMetrics # pylint: disable=g-multiple-import
44 | import ml_collections
45 | from ml_collections.config_dict import config_dict
46 | import optax
47 | import tensorflow as tf
48 |
49 |
50 | def get_config():
51 | """Returns a config object for modeling flags."""
52 | module_config = ml_collections.ConfigDict()
53 |
54 | # Model config.
55 | model_config = modeling.ModelConfig()
56 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config))
57 | module_config.model_config = model_config
58 |
59 | # Optimizer config.
60 | opt_config = ml_collections.ConfigDict()
61 | opt_config.optimizer = 'adamw'
62 | opt_config.learning_rate = 1e-4
63 | opt_config.warmup_steps = 10000
64 | opt_config.linear_decay_to_step = config_dict.placeholder(int)
65 | opt_config.momentum = 0.9
66 | opt_config.beta = 0.999
67 | opt_config.weight_decay = 0.01
68 | module_config.opt_config = opt_config
69 |
70 | # Other config.
71 | module_config.mask_token_id = config_dict.placeholder(int)
72 | module_config.mask_rate = 0.15
73 | return module_config
74 |
75 |
76 | def load_config(global_flags):
77 | """Init config data from global flags."""
78 | config = ml_collections.ConfigDict()
79 | config.update(global_flags.module_config)
80 |
81 | # Only a frozen config (hashable object) can be passed to jit functions
82 | # (i.e. train_step/valid_step/infer_step).
83 | config.frozen = ml_collections.FrozenConfigDict(config)
84 |
85 | # Construct data pipelines in the following (using TensorFLow):
86 | seq_length = config.model_config.max_length
87 |
88 | def feature_fn(token_ids: tf.Tensor) -> Mapping[str, tf.Tensor]:
89 | """Builds a feature dict to be compatible with seqio."""
90 | return {'targets': tf.ensure_shape(token_ids, (seq_length,))}
91 |
92 | config.train_data_fn = lm_data(
93 | seq_length, random_skip=True, feature_fn=feature_fn, interleave=True)
94 | config.valid_data_fn = lm_data(seq_length, feature_fn=feature_fn)
95 | return config
96 |
97 |
98 | def get_train_state(config, rng) -> TrainState:
99 | """Create train state."""
100 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
101 | model = modeling.ModelForPretrain(model_config)
102 | opt_config = config.opt_config
103 | warmup = opt_config.warmup_steps
104 | decay = opt_config.linear_decay_to_step
105 |
106 | def lr_schedule(step):
107 | lr = opt_config.learning_rate
108 | if warmup is not None:
109 | lr *= jnp.minimum(1., step / warmup)
110 | if decay is not None:
111 | lr *= 1. - jnp.maximum(0., step - warmup) / (decay - warmup)
112 | elif decay is not None:
113 | lr *= 1. - step / decay
114 | return lr
115 |
116 | if opt_config.optimizer == 'adamw':
117 | optimizer = optax.adamw(
118 | learning_rate=lr_schedule,
119 | b1=opt_config.momentum,
120 | b2=opt_config.beta,
121 | weight_decay=opt_config.weight_decay)
122 | elif opt_config.optimizer == 'amos':
123 | optimizer = amos.amos(
124 | lr_schedule,
125 | modeling.get_eta_fn(model_config),
126 | shape_fn=modeling.get_shape_fn(model_config),
127 | beta=opt_config.beta,
128 | momentum=opt_config.momentum,
129 | clip_value=1.)
130 |
131 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss', 'valid_mrr')
132 | dummy_input = jnp.array([[0] * config.model_config.max_length])
133 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy_input)
134 |
135 |
136 | def train_step(config, train_batch, state: TrainState, metrics):
137 | """Training step."""
138 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)(
139 | state.params,
140 | train_batch['targets'],
141 | config.mask_token_id,
142 | mask_rate=config.mask_rate,
143 | input_mask=train_batch.get('input_mask'),
144 | enable_dropout=True,
145 | method=modeling.ModelForPretrain.mlm_train_loss)
146 | _, metrics = state.metrics_mod.apply(
147 | metrics,
148 | 'train_loss',
149 | loss,
150 | size,
151 | method=MeanMetrics.update,
152 | mutable=['metrics'])
153 | return state.apply_gradients(grads=grads), metrics
154 |
155 |
156 | def valid_step(config, valid_batch, state: TrainState, metrics):
157 | """Validation step."""
158 |
159 | def body(i, metrics):
160 | del i # Unused.
161 | loss, mrr, size = state.apply_fn(
162 | state.variables(),
163 | valid_batch['targets'],
164 | config.mask_token_id,
165 | mask_rate=config.mask_rate,
166 | input_mask=valid_batch.get('input_mask'),
167 | method=modeling.ModelForPretrain.mlm_valid_metrics)
168 | _, metrics = state.metrics_mod.apply(
169 | metrics,
170 | 'valid_loss',
171 | loss,
172 | size,
173 | method=MeanMetrics.update,
174 | mutable=['metrics'])
175 | _, metrics = state.metrics_mod.apply(
176 | metrics,
177 | 'valid_mrr',
178 | mrr,
179 | size,
180 | method=MeanMetrics.update,
181 | mutable=['metrics'])
182 | return metrics
183 |
184 | return jax.lax.fori_loop(0, 20, body, metrics)
185 |
--------------------------------------------------------------------------------
/jestimator/models/linear_regression/linear_regression.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""A toy linear regression model.
16 |
17 | Using jestimator as the entry point.
18 |
19 | # For debug run locally:
20 |
21 | ## Train:
22 |
23 | ```
24 | PYTHONPATH=. python3 \
25 | jestimator/estimator.py \
26 | --module_imp="jestimator.models.linear_regression.linear_regression" \
27 | --module_config="jestimator/models/linear_regression/\
28 | linear_regression.py" \
29 | --train_pattern="dummy://" \
30 | --valid_pattern="dummy://" \
31 | --model_dir="$HOME/experiments/linear_regression/models" \
32 | --train_batch_size=4 --valid_batch_size=4 \
33 | --max_train_steps=200 --train_shuffle_buf=32 \
34 | --check_every_steps=10 --logtostderr
35 | ```
36 |
37 | ## Eval:
38 |
39 | ```
40 | PYTHONPATH=. python3 \
41 | jestimator/estimator.py \
42 | --module_imp="jestimator.models.linear_regression.linear_regression" \
43 | --module_config="jestimator/models/linear_regression/\
44 | linear_regression.py" \
45 | --eval_pattern="dummy://" \
46 | --model_dir="$HOME/experiments/linear_regression/models" \
47 | --eval_batch_size=4 \
48 | --logtostderr
49 | ```
50 |
51 | ## Predict:
52 |
53 | ```
54 | PYTHONPATH=. python3 \
55 | jestimator/estimator.py \
56 | --module_imp="jestimator.models.linear_regression.linear_regression" \
57 | --module_config="jestimator/models/linear_regression/\
58 | linear_regression.py" \
59 | --pred_pattern="dummy://" \
60 | --model_dir="$HOME/experiments/linear_regression/models" \
61 | --pred_batch_size=4 \
62 | --logtostderr
63 | ```
64 | """
65 | from typing import Tuple
66 |
67 | from flax import linen as nn
68 | import jax
69 | import jax.numpy as jnp
70 | from jax.typing import ArrayLike
71 | from jestimator.states import Evaluator, InferState, MeanMetrics, Predictor, TrainState # pylint: disable=g-multiple-import
72 | import ml_collections
73 | import optax
74 | from sklearn.metrics import mean_squared_error
75 | import tensorflow as tf
76 |
77 | from flaxformer.components.dense import DenseGeneral
78 |
79 |
80 | def get_config():
81 | """Returns a config object for modeling flags."""
82 | module_config = ml_collections.ConfigDict()
83 | module_config.num_train = 20
84 | module_config.num_eval = 20
85 | module_config.x_dim = 10
86 | module_config.y_dim = 5
87 | return module_config
88 |
89 |
90 | def load_config(global_flags):
91 | """Init config data from global flags."""
92 | config = ml_collections.ConfigDict()
93 | config.update(global_flags.module_config)
94 |
95 | # Only a frozen config (hashable object) can be passed to jit functions
96 | # (i.e. train_step/valid_step/infer_step).
97 | config.frozen = ml_collections.FrozenConfigDict(config)
98 |
99 | # Construct data pipelines in the following (using TensorFLow):
100 | def get_data_fn(ds):
101 |
102 | def data_fn(filenames, shard_num=1, shard_index=0, epochs=1, num_take=-1):
103 | del filenames # Unused.
104 | return ds.repeat(epochs).shard(shard_num, shard_index).take(num_take)
105 |
106 | return data_fn
107 |
108 | # Generate random ground truth W and b.
109 | W = tf.random.normal((config.x_dim, config.y_dim), seed=11) # pylint: disable=invalid-name
110 | b = tf.random.normal((config.y_dim,), seed=12)
111 |
112 | # Generate samples with additional noise.
113 | x_train = tf.random.normal((config.num_train, config.x_dim), seed=13)
114 | y_train = tf.matmul(x_train, W) + b + 0.1 * tf.random.normal(
115 | (config.num_train, config.y_dim), seed=14)
116 | ds_train = tf.data.Dataset.from_tensor_slices({'x': x_train, 'y': y_train})
117 |
118 | x_eval = tf.random.normal((config.num_eval, config.x_dim), seed=15)
119 | y_eval = tf.matmul(x_eval, W) + b + 0.1 * tf.random.normal(
120 | (config.num_eval, config.y_dim), seed=16)
121 | ds_valid = tf.data.Dataset.from_tensor_slices({'x': x_eval, 'y': y_eval})
122 | ds_eval = tf.data.Dataset.from_tensor_slices((y_eval, x_eval))
123 | ds_pred = tf.data.Dataset.from_tensor_slices(x_eval)
124 |
125 | config.train_data_fn = get_data_fn(ds_train)
126 | config.valid_data_fn = get_data_fn(ds_valid)
127 | config.eval_data_fn = get_data_fn(ds_eval)
128 | config.pred_data_fn = get_data_fn(ds_pred)
129 | return config
130 |
131 |
132 | class LinearRegression(nn.Module):
133 | """A simple linear regression module."""
134 | y_dim: int
135 |
136 | @nn.compact
137 | def __call__(self, x: ArrayLike) -> ArrayLike:
138 | """Applies linear on the input."""
139 | linear = DenseGeneral(
140 | features=self.y_dim,
141 | use_bias=True,
142 | kernel_init=nn.zeros,
143 | kernel_axis_names=('x', 'y'))
144 | return linear(x)
145 |
146 | def mse(self, x: ArrayLike, y: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
147 | """Mean squared error."""
148 | loss = jnp.mean(jnp.square(self(x) - y), axis=-1)
149 | size = jnp.asarray(loss.size, loss.dtype)
150 | num_hosts = jnp.asarray(jax.host_count(), loss.dtype)
151 | loss = jnp.sum(loss) * jax.lax.rsqrt(size * num_hosts)
152 | size = jnp.sqrt(size / num_hosts)
153 | return loss, size
154 |
155 |
156 | def get_train_state(config, rng):
157 | """Create train state."""
158 | model = LinearRegression(y_dim=config.y_dim)
159 |
160 | def lr_schedule(step):
161 | return 0.5 / (1. + 0.1 * step)
162 |
163 | optimizer = optax.sgd(learning_rate=lr_schedule)
164 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss')
165 | dummy_x = jnp.zeros((config.x_dim,), jnp.float32)
166 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy_x)
167 |
168 |
169 | def train_step(config, train_batch, state: TrainState, metrics):
170 | """Training step."""
171 | del config # Unused.
172 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)(
173 | state.params,
174 | train_batch['x'],
175 | train_batch['y'],
176 | method=LinearRegression.mse)
177 | _, metrics = state.metrics_mod.apply(
178 | metrics,
179 | 'train_loss',
180 | loss,
181 | size,
182 | method=MeanMetrics.update,
183 | mutable=['metrics'])
184 | return state.apply_gradients(grads=grads), metrics
185 |
186 |
187 | def valid_step(config, valid_batch, state: TrainState, metrics):
188 | """Validation step."""
189 | del config # Unused.
190 | loss, size = state.apply_fn(
191 | state.variables(),
192 | valid_batch['x'],
193 | valid_batch['y'],
194 | method=LinearRegression.mse)
195 | _, metrics = state.metrics_mod.apply(
196 | metrics,
197 | 'valid_loss',
198 | loss,
199 | size,
200 | method=MeanMetrics.update,
201 | mutable=['metrics'])
202 | return metrics
203 |
204 |
205 | def get_infer_state(config):
206 | """Create infer state."""
207 | model = LinearRegression(y_dim=config.y_dim)
208 | dummy_x = jnp.zeros((config.x_dim,), jnp.float32)
209 | return InferState.create(model, dummy_x)
210 |
211 |
212 | def infer_step(config, batch, state: InferState):
213 | """Infer step."""
214 | del config # Unused.
215 | return state.replace(ret=state.apply_fn(state.variables(), batch))
216 |
217 |
218 | def get_evaluator(config) -> Evaluator:
219 | """Create evaluator."""
220 | del config # Unused.
221 | return Evaluator({'mse': (lambda y: y, mean_squared_error)})
222 |
223 |
224 | def get_predictor(config) -> Predictor:
225 | """Create predictor."""
226 | del config # Unused.
227 |
228 | def proc_fn(y_batched):
229 | return [str(y) for y in y_batched]
230 |
231 | return Predictor(proc_fn)
232 |
--------------------------------------------------------------------------------
/jestimator/models/linear_regression/linear_regression_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for linear_regression."""
16 | import os
17 |
18 | from absl import flags
19 | from absl.testing import absltest
20 | import jax
21 | from jestimator import checkpoint_utils
22 | from jestimator import estimator
23 | from jestimator.models.linear_regression import linear_regression
24 | import tensorflow as tf
25 |
26 | FLAGS = flags.FLAGS
27 |
28 |
29 | class LinearRegressionTest(absltest.TestCase):
30 |
31 | def setUp(self):
32 | super(LinearRegressionTest, self).setUp()
33 | tmp_model_dir = self.create_tempdir('tmp_model_dir')
34 | model_dir = os.fspath(tmp_model_dir)
35 | FLAGS.model_dir = model_dir
36 |
37 | FLAGS.module_config = linear_regression.get_config()
38 | self.config = linear_regression.load_config(FLAGS)
39 | self.partitioner = estimator.get_partitioner(self.config)
40 |
41 | def test_train_eval(self):
42 | FLAGS.train_pattern = 'dummy://'
43 | FLAGS.valid_pattern = 'dummy://'
44 | FLAGS.train_batch_size = 4
45 | FLAGS.valid_batch_size = 4
46 | FLAGS.train_shuffle_buf = 32
47 | FLAGS.check_every_steps = 10
48 |
49 | FLAGS.max_train_steps = 100
50 | seed = 100
51 | tf.random.set_seed(seed)
52 | rng = jax.random.PRNGKey(seed)
53 | estimator.train(None, False, rng, linear_regression, self.config,
54 | self.partitioner)
55 | middle_ckpt_path = os.path.join(FLAGS.model_dir, 'checkpoint_100')
56 | self.assertTrue(os.path.exists(middle_ckpt_path))
57 | ckpt_path, same_dir = checkpoint_utils.latest_ckpt_path(FLAGS.model_dir)
58 | self.assertTrue(os.path.samefile(middle_ckpt_path, ckpt_path))
59 | self.assertTrue(same_dir)
60 |
61 | FLAGS.max_train_steps = 200
62 | final_metrics = estimator.train(ckpt_path, same_dir, rng, linear_regression,
63 | self.config, self.partitioner)
64 | self.assertLess(final_metrics['train_loss'], 0.01)
65 | valid_loss = final_metrics['valid_loss']
66 | self.assertLess(valid_loss, 0.05)
67 | final_ckpt_path = os.path.join(FLAGS.model_dir, 'checkpoint_200')
68 | self.assertTrue(os.path.exists(final_ckpt_path))
69 | ckpt_path, same_dir = checkpoint_utils.latest_ckpt_path(FLAGS.model_dir)
70 | self.assertTrue(os.path.samefile(final_ckpt_path, ckpt_path))
71 | self.assertTrue(same_dir)
72 |
73 | FLAGS.eval_pattern = 'dummy://'
74 | FLAGS.eval_batch_size = 4
75 | mode = estimator.RunMode.EVAL_ONCE
76 | eval_metrics = estimator.eval_or_predict(ckpt_path, mode, linear_regression,
77 | self.config, self.partitioner)
78 | self.assertAlmostEqual(eval_metrics['mse'], valid_loss)
79 |
80 | FLAGS.save_low = ['mse']
81 | mode = estimator.RunMode.EVAL_WAIT
82 | estimator.eval_or_predict(None, mode, linear_regression, self.config,
83 | self.partitioner)
84 | ckpt_dir = os.path.join(FLAGS.model_dir, 'eval', 'low_mse')
85 | ckpt_path, _ = checkpoint_utils.latest_ckpt_path(ckpt_dir)
86 | self.assertTrue(ckpt_path)
87 | mode = estimator.RunMode.EVAL_ONCE
88 | eval_metrics = estimator.eval_or_predict(ckpt_path, mode, linear_regression,
89 | self.config, self.partitioner)
90 | self.assertAlmostEqual(eval_metrics['mse'], valid_loss)
91 |
92 | def test_predict(self):
93 | FLAGS.pred_pattern = 'dummy://'
94 | FLAGS.pred_batch_size = 4
95 | mode = estimator.RunMode.PREDICT
96 | estimator.eval_or_predict(None, mode, linear_regression, self.config,
97 | self.partitioner)
98 |
99 |
100 | if __name__ == '__main__':
101 | absltest.main()
102 |
--------------------------------------------------------------------------------
/jestimator/models/lstm/lm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Language modeling on PTB-like datasets.
16 |
17 | Using jestimator as the entry point.
18 |
19 | # For debug run locally:
20 |
21 | ## Train:
22 |
23 | ```
24 | PYTHONPATH=. python3 jestimator/estimator.py \
25 | --module_imp="jestimator.models.lstm.lm" \
26 | --module_config="jestimator/models/lstm/lm.py" \
27 | --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \
28 | --train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"\
29 | --model_dir="$HOME/experiments/ptb_lstm/models" \
30 | --train_batch_size=64 --train_consecutive=113 \
31 | --check_every_steps=10 --logtostderr
32 | ```
33 |
34 | ## Eval:
35 |
36 | ```
37 | PYTHONPATH=. python3 jestimator/estimator.py \
38 | --module_imp="jestimator.models.lstm.lm" \
39 | --module_config="jestimator/models/lstm/lm.py" \
40 | --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \
41 | --eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt" \
42 | --model_dir="$HOME/experiments/ptb_lstm/models" \
43 | --eval_batch_size=1 --logtostderr
44 | ```
45 | """
46 | import dataclasses
47 | import math
48 |
49 | import jax
50 | import jax.numpy as jnp
51 | from jestimator import amos
52 | from jestimator.data.pipeline_lm import lm_data
53 | from jestimator.data.reader import lines_iterator
54 | from jestimator.models.lstm import modeling
55 | from jestimator.states import TrainState, MeanMetrics, InferState # pylint: disable=g-multiple-import
56 | import ml_collections
57 | from ml_collections.config_dict import config_dict
58 | import optax
59 | import tensorflow as tf
60 |
61 |
62 | def get_config():
63 | """Returns a config object for modeling flags."""
64 | module_config = ml_collections.ConfigDict()
65 |
66 | # Model config.
67 | model_config = modeling.ModelConfig()
68 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config))
69 | module_config.model_config = model_config
70 |
71 | # Optimizer config.
72 | opt_config = ml_collections.ConfigDict()
73 | opt_config.optimizer = 'adam'
74 | opt_config.learning_rate = 5e-4
75 | opt_config.momentum = 0.9
76 | opt_config.beta = 0.99
77 | opt_config.weight_decay = 0.01
78 | module_config.opt_config = opt_config
79 |
80 | # Other config.
81 | module_config.seq_length = 64
82 | module_config.vocab_path = config_dict.placeholder(str)
83 | return module_config
84 |
85 |
86 | def load_config(global_flags):
87 | """Init config data from global flags."""
88 | config = ml_collections.ConfigDict()
89 | config.update(global_flags.module_config)
90 |
91 | mode = global_flags.mode
92 | if mode == 'train':
93 | batch_size = global_flags.train_batch_size
94 | train_consecutive = global_flags.train_consecutive
95 | assert train_consecutive is not None, (
96 | 'Should set --train_consecutive for training LSTM language models.')
97 | config.train_consecutive = train_consecutive
98 | elif mode.startswith('eval'):
99 | batch_size = global_flags.eval_batch_size
100 | assert batch_size == 1, 'Should set --eval_batch_size to 1 for LSTM.'
101 | assert jax.process_count() == 1, 'Should evaluate on single process.'
102 | else:
103 | batch_size = global_flags.pred_batch_size
104 | config.mode = mode
105 | config.batch_size = batch_size
106 |
107 | # Read vocab file.
108 | count = 0
109 | word_dict = {}
110 | for w, _ in lines_iterator(config.vocab_path, split=True):
111 | word_dict[w] = count
112 | count += 1
113 |
114 | config.model_config.vocab_size = count
115 | config.model_config.start_token_id = word_dict['']
116 | eos_token_id = word_dict['']
117 |
118 | # Only a frozen config (hashable object) can be passed to jit functions
119 | # (i.e. train_step/valid_step/infer_step).
120 | config.frozen = ml_collections.FrozenConfigDict(config)
121 |
122 | # Construct data pipelines in the following (using TensorFLow):
123 | def corpus_fn(path: str) -> tf.data.Dataset:
124 |
125 | def gen():
126 | for tokens in lines_iterator(path, split=True):
127 | ids = [word_dict[w] for w in tokens] + [eos_token_id]
128 | for x in ids:
129 | yield x
130 |
131 | return tf.data.Dataset.from_generator(
132 | gen, output_signature=tf.TensorSpec(shape=(), dtype=tf.int32))
133 |
134 | seq_length = config.seq_length
135 |
136 | def eval_feature_fn(x):
137 | length = tf.shape(x)[0] # The last sequence might be shorter.
138 | y = tf.pad(x, [(0, seq_length - length)])
139 |
140 | # Eval dataset requires the gold label to be returned as first arg.
141 | # For language modeling, gold is not used. We return length as gold.
142 | return length, {'y': y, 'length': length}
143 |
144 | config.train_data_fn = lm_data(
145 | seq_length, dataset_fn=corpus_fn, cache=True, random_skip=True)
146 | config.eval_data_fn = lm_data(
147 | seq_length,
148 | allow_remainder=True,
149 | dataset_fn=corpus_fn,
150 | cache=True,
151 | feature_fn=eval_feature_fn)
152 | return config
153 |
154 |
155 | def get_train_state(config, rng) -> TrainState:
156 | """Create train state."""
157 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
158 | model = modeling.SingleLstmLM(model_config, config.batch_size)
159 |
160 | opt_config = config.opt_config
161 | if opt_config.optimizer == 'adam':
162 | optimizer = optax.adam(
163 | learning_rate=opt_config.learning_rate,
164 | b1=opt_config.momentum,
165 | b2=opt_config.beta)
166 | elif opt_config.optimizer == 'adamw':
167 | optimizer = optax.adamw(
168 | learning_rate=opt_config.learning_rate,
169 | b1=opt_config.momentum,
170 | b2=opt_config.beta,
171 | weight_decay=opt_config.weight_decay)
172 | elif opt_config.optimizer == 'amos':
173 | optimizer = amos.amos(
174 | opt_config.learning_rate,
175 | modeling.get_eta_fn(model_config),
176 | shape_fn=modeling.get_shape_fn(model_config),
177 | beta=opt_config.beta,
178 | momentum=opt_config.momentum,
179 | clip_value=1.)
180 |
181 | metrics_mod = MeanMetrics.create('train_loss')
182 | dummy = jnp.zeros((config.batch_size, config.seq_length), jnp.int32)
183 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy, False)
184 |
185 |
186 | def train_step(config, train_batch, state: TrainState, metrics):
187 | """Training step."""
188 | (loss, (size, vars_)), grads = state.value_and_grad_apply_fn(has_aux=True)(
189 | state.params,
190 | train_batch,
191 | state.step % config.train_consecutive != 0,
192 | enable_dropout=True)
193 | _, metrics = state.metrics_mod.apply(
194 | metrics,
195 | 'train_loss',
196 | loss,
197 | size,
198 | method=MeanMetrics.update,
199 | mutable=['metrics'])
200 | state = state.apply_gradients(grads=grads)
201 | state = state.replace(_vars=vars_)
202 | return state, metrics
203 |
204 |
205 | def get_infer_state(config):
206 | """Create infer state."""
207 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
208 | model = modeling.SingleLstmLM(model_config, config.batch_size)
209 | dummy = jnp.zeros((config.batch_size, config.seq_length), jnp.int32)
210 | return InferState.create(model, dummy, True, mode=config.mode)
211 |
212 |
213 | def infer_step(config, batch, state: InferState) -> InferState:
214 | """Infer step."""
215 | if config.mode.startswith('eval'):
216 | (loss, mrr, size), vars_ = state.apply_fn(
217 | state.variables(),
218 | batch['y'],
219 | True,
220 | mode=config.mode,
221 | length=batch['length'],
222 | mutable=state.mutable())
223 |
224 | return state.replace(
225 | _vars=vars_,
226 | ret={
227 | 'loss': jnp.expand_dims(loss, 0),
228 | 'mrr': jnp.expand_dims(mrr, 0),
229 | 'size': jnp.expand_dims(size, 0),
230 | })
231 |
232 | raise NotImplementedError(f'Infer-step for {config.mode} not implemented.')
233 |
234 |
235 | class Evaluator(object):
236 | """Evaluator class for language modeling."""
237 |
238 | def reset_states(self):
239 | self._total_loss = 0.
240 | self._total_mrr = 0.
241 | self._total_size = 0.
242 |
243 | def update_state(self, gold, infer):
244 | del gold # Unused.
245 | self._total_loss += infer['loss'].sum()
246 | self._total_mrr += infer['mrr'].sum()
247 | self._total_size += infer['size'].sum()
248 |
249 | def result(self):
250 | cost = self._total_loss / self._total_size
251 | return {
252 | 'cost': cost,
253 | 'perplexity': math.exp(cost),
254 | 'mrr': self._total_mrr / self._total_size,
255 | }
256 |
257 |
258 | def get_evaluator(config) -> Evaluator:
259 | del config # Unused.
260 | return Evaluator()
261 |
--------------------------------------------------------------------------------
/jestimator/models/lstm/modeling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """LSTM model implemented in Flax."""
16 | import dataclasses
17 | import math
18 | from typing import Optional, Tuple
19 |
20 | from flax import linen as nn
21 | from flax.linen.partitioning import variable_with_axes
22 | import jax
23 | import jax.numpy as jnp
24 | from jax.typing import ArrayLike
25 | from jestimator.modeling import global_kwargs, sparse_xe_with_logits, normalize_loss_by_size, unstack, truncated_normal_initializer, Dropout # pylint: disable=g-multiple-import
26 |
27 | DType = jnp.dtype
28 | Shape = Tuple[int, ...]
29 |
30 |
31 | @dataclasses.dataclass
32 | class ModelConfig:
33 | """Config object."""
34 | hidden_size: int = 256
35 | memory_size: int = 1024
36 | forget_gate_bias: float = 1.0
37 |
38 | hidden_dropout_rate: float = 0.55
39 | memory_dropout_rate: float = 0.1
40 |
41 | vocab_size: int = -1
42 | start_token_id: int = -1
43 |
44 |
45 | class ShiftNorm(nn.Module):
46 | """Shifted normalization."""
47 |
48 | @nn.compact
49 | def __call__(self, x: ArrayLike) -> ArrayLike:
50 | x = jnp.asarray(x)
51 | shift = self.param('shift', nn.zeros, x.shape[-1], x.dtype)
52 | x = x - shift
53 | x = x - jnp.mean(x, axis=-1, keepdims=True)
54 | mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
55 |
56 | # Instead of normalize to 1, we normalize to d**(-0.25).
57 | x = x * jax.lax.rsqrt(jnp.maximum(mean2 * math.sqrt(x.shape[-1]), 1.))
58 | return x
59 |
60 |
61 | class LstmCell(nn.Module):
62 | """LSTM cell with some working modifications."""
63 | config: ModelConfig
64 |
65 | def setup(self):
66 | config = self.config
67 | inputs_size = 2 * config.hidden_size
68 | core_init = truncated_normal_initializer(math.sqrt(1 / inputs_size))
69 | self.core = nn.DenseGeneral(
70 | features=(4, config.memory_size), use_bias=True, kernel_init=core_init)
71 | self.normalize = ShiftNorm()
72 | self.out = nn.Dense(
73 | features=config.hidden_size, use_bias=True, kernel_init=nn.zeros)
74 |
75 | def __call__(self,
76 | inputs: ArrayLike,
77 | memory: ArrayLike,
78 | memory_mask: Optional[ArrayLike] = None):
79 | """Call LSTM cell with some working modifications.
80 |
81 | Args:
82 | inputs: Inputs for the current step.
83 | memory: Long-term memory.
84 | memory_mask: Optional memory mask.
85 |
86 | Returns:
87 | (out, next_memory).
88 | """
89 | xa, xb, xg, xo = unstack(self.core(inputs), -2)
90 | fb = self.config.forget_gate_bias
91 | gate = nn.sigmoid(fb - xg)
92 | xb = jnp.clip(1 / (1 + math.exp(fb)) * jnp.tanh(xb), -gate, gate)
93 | next_memory = memory * (1. - gate) + xb * 2. * nn.silu(xa)
94 | memory_out = self.normalize(next_memory) * 2. * nn.sigmoid(xo)
95 |
96 | if memory_mask is not None:
97 | memory_out *= memory_mask
98 | out = self.out(memory_out)
99 | return out, next_memory
100 |
101 |
102 | class LstmLayer(nn.Module):
103 | """LSTM layer which encodes a sequence."""
104 | config: ModelConfig
105 |
106 | def setup(self):
107 | config = self.config
108 | self.hidden_dropout = Dropout(config.hidden_dropout_rate)
109 | self.memory_dropout = Dropout(config.memory_dropout_rate)
110 | self.normalize = ShiftNorm()
111 | self.cell = LstmCell(config)
112 |
113 | @global_kwargs('enable_dropout')
114 | def __call__(self,
115 | xs: ArrayLike,
116 | seq_axis: int = -2,
117 | init_carry: Optional[Tuple[ArrayLike, ArrayLike]] = None,
118 | enable_dropout: bool = False):
119 | """Encode a sequence with LSTM.
120 |
121 | Args:
122 | xs: Input sequence tensor.
123 | seq_axis: int. The sequence axis.
124 | init_carry: Initial state.
125 | enable_dropout: Whether to enable dropout.
126 |
127 | Returns:
128 | Encoded sequence of the same shape as `xs`.
129 | """
130 | xs = jnp.asarray(xs)
131 | if init_carry is None:
132 | batch_shape = xs.shape[:seq_axis] + xs.shape[seq_axis + 1:-1]
133 | init_carry = self.zero_carry(batch_shape, xs.dtype)
134 |
135 | memory_mask = None
136 | if enable_dropout:
137 | memory_mask = self.memory_dropout(jnp.ones_like(init_carry[1]))
138 |
139 | def body_fn(self, carry: Tuple[ArrayLike, ArrayLike], x: ArrayLike):
140 | recur, memory = carry
141 | inputs = jnp.concatenate((recur, x), -1)
142 | inputs = self.hidden_dropout(self.normalize(inputs))
143 | out, next_memory = self.cell(inputs, memory, memory_mask=memory_mask)
144 | return (out, next_memory), out
145 |
146 | last_carry, outs = nn.scan(
147 | body_fn,
148 | in_axes=seq_axis,
149 | out_axes=seq_axis,
150 | variable_broadcast='params',
151 | split_rngs={'params': False})(self, init_carry, xs)
152 | outs = self.hidden_dropout(outs)
153 | return outs, last_carry
154 |
155 | def zero_carry(self, batch_shape: Shape, dtype: DType = jnp.float32):
156 | """Creates a zero state."""
157 | recur = jnp.zeros(batch_shape + (self.config.hidden_size,), dtype)
158 | memory = jnp.zeros(batch_shape + (self.config.memory_size,), dtype)
159 | return (recur, memory)
160 |
161 |
162 | class SingleLstmLM(nn.Module):
163 | """Single layer LSTM language model."""
164 | config: ModelConfig
165 | batch_size: int
166 |
167 | def setup(self):
168 | config = self.config
169 | embed_init = truncated_normal_initializer(math.sqrt(1 / config.hidden_size))
170 | self.embed = nn.Embed(
171 | config.vocab_size, config.hidden_size, embedding_init=embed_init
172 | )
173 | self.lstm = LstmLayer(config)
174 | self.bias = self.param('bias', nn.zeros, config.vocab_size)
175 |
176 | # Variables to keep context from previous batch.
177 | self.ctx_prev = variable_with_axes(
178 | 'context',
179 | 'prev',
180 | jnp.full,
181 | (self.batch_size,),
182 | config.start_token_id,
183 | axes=('data',),
184 | )
185 | self.ctx_recur = variable_with_axes(
186 | 'context',
187 | 'recur',
188 | jnp.zeros,
189 | (self.batch_size, config.hidden_size),
190 | axes=('data', 'model'),
191 | )
192 | self.ctx_memory = variable_with_axes(
193 | 'context',
194 | 'memory',
195 | jnp.zeros,
196 | (self.batch_size, config.memory_size),
197 | axes=('data', 'model'),
198 | )
199 |
200 | @global_kwargs(pass_down=True)
201 | def __call__(self,
202 | y: ArrayLike,
203 | carry_mask: ArrayLike,
204 | mode: str = 'train',
205 | length: Optional[ArrayLike] = None):
206 | """Generation logits/loss for batch-major sequence `y`."""
207 | y = jnp.asarray(y)
208 | _, seq_length = y.shape
209 | ty = jnp.transpose(y) # `ty` is time-major.
210 |
211 | if mode == 'predict':
212 | x = ty
213 | else: # `y` is label. Shift one position to create input ids.
214 | s = self.config.start_token_id
215 | prev_ids = jnp.expand_dims(
216 | jnp.where(carry_mask, self.ctx_prev.value, s), 0)
217 | x = jnp.concatenate((prev_ids, ty[:-1]), 0)
218 | self.ctx_prev.value = y[:, -1]
219 |
220 | x = self.embed(x)
221 | carry_mask = jnp.asarray(carry_mask, x.dtype)
222 | carry = (self.ctx_recur.value * carry_mask,
223 | self.ctx_memory.value * carry_mask)
224 | x, (last_recur, last_memory) = self.lstm(x, seq_axis=0, init_carry=carry)
225 | self.ctx_recur.value = last_recur
226 | self.ctx_memory.value = last_memory
227 |
228 | if mode == 'predict':
229 | if length is None:
230 | x = x[-1]
231 | else:
232 | x = jnp.swapaxes(x, 0, 1)
233 | x = jnp.take_along_axis(x, jnp.expand_dims(length - 1, 1), 1)
234 |
235 | logits = self.embed.attend(x) + self.bias
236 | if mode == 'predict':
237 | return logits
238 |
239 | if length is None:
240 | size = jnp.asarray(y.size, x.dtype)
241 | mask = None
242 | else:
243 | size = jnp.sum(jnp.asarray(length, x.dtype))
244 | mask = (jnp.expand_dims(jnp.arange(seq_length), 1) < length)
245 |
246 | if mode == 'train':
247 | loss = sparse_xe_with_logits(ty, logits, mask=mask)
248 | return normalize_loss_by_size(loss, size)
249 |
250 | # Evaluation with loss and Mean Reciprocal Rank (MRR).
251 | logits = nn.log_softmax(logits)
252 | gold = sparse_xe_with_logits(
253 | ty, logits, mask=mask, normalized=True, reduce_all=False)
254 | loss = jnp.sum(gold)
255 | higher = (logits + jnp.expand_dims(gold, -1) >= 0)
256 | ranks = jnp.sum(jnp.asarray(higher, x.dtype), axis=-1)
257 | rcpl_ranks = jnp.reciprocal(ranks)
258 | if mask is not None:
259 | rcpl_ranks = jnp.where(mask, rcpl_ranks, 0.)
260 | mrr = jnp.sum(rcpl_ranks)
261 | return loss, mrr, size
262 |
263 |
264 | def get_eta_fn(config: ModelConfig):
265 | """Get the `eta_fn` function for Amos optimizer."""
266 | hidden_size = config.hidden_size
267 | memory_size = config.memory_size
268 |
269 | def eta_fn(name: Tuple[str, ...], shape: Shape) -> ArrayLike:
270 | del shape # Unused.
271 | if name[-4:] == ('lstm', 'cell', 'core', 'kernel'):
272 | return math.pow(2 * hidden_size, -0.25)
273 |
274 | if name[-4:] == ('lstm', 'cell', 'normalize', 'shift'):
275 | return 0.5
276 |
277 | if name[-4:] == ('lstm', 'cell', 'out', 'kernel'):
278 | return math.pow(memory_size * hidden_size, -0.25)
279 |
280 | if name[-4:] == ('lstm', 'cell', 'out', 'bias'):
281 | return 0.5 * math.pow(hidden_size, -0.25)
282 |
283 | if name[-3:] == ('lstm', 'normalize', 'shift'):
284 | return 0.5 * math.pow(hidden_size, -0.25)
285 |
286 | if name[-2:] == ('embed', 'embedding'):
287 | return math.pow(hidden_size, -0.25)
288 |
289 | if name[-1] == 'bias':
290 | return 0.5
291 |
292 | raise ValueError(f'`eta_fn` for {name} not defined.')
293 |
294 | return eta_fn
295 |
296 |
297 | def get_shape_fn(config):
298 | """Get the `shape_fn` function for Amos optimizer."""
299 | del config # Unused.
300 |
301 | def shape_fn(name: Tuple[str, ...], shape: Shape) -> Shape:
302 | if name[-1] == 'kernel':
303 | assert len(shape) >= 2
304 | return (1,) + shape[1:]
305 |
306 | if name[-1] == 'embedding':
307 | assert len(shape) == 2
308 | return (shape[0], 1)
309 |
310 | return ()
311 |
312 | return shape_fn
313 |
--------------------------------------------------------------------------------
/jestimator/models/mnist/mnist.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b841868d-4cf1-4fcf-9e93-9ddee582b82e",
6 | "metadata": {},
7 | "source": [
8 | "### 1. Imports"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "d3218139-816a-412d-8a19-1bbfb219ad40",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import jax\n",
19 | "import jax.numpy as jnp # JAX NumPy\n",
20 | "from jestimator import amos # The Amos optimizer implementation\n",
21 | "from jestimator import amos_helper # Helper module for Amos\n",
22 | "\n",
23 | "from flax import linen as nn # The Linen API\n",
24 | "from flax.training import train_state # Useful dataclass to keep train state\n",
25 | "\n",
26 | "import math\n",
27 | "import tensorflow_datasets as tfds # TFDS for MNIST\n",
28 | "from sklearn.metrics import accuracy_score"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "id": "ba3ea98c-080c-4ca8-bbc5-276c95d8196e",
34 | "metadata": {},
35 | "source": [
36 | "### 2. Load data"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 2,
42 | "id": "f11a52ce-e9e4-490a-94d6-d76758305729",
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "def get_datasets():\n",
47 | " \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n",
48 | "\n",
49 | " ds_builder = tfds.builder('mnist')\n",
50 | " ds_builder.download_and_prepare()\n",
51 | " train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n",
52 | " test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n",
53 | " train_ds['image'] = jnp.float32(train_ds['image']) / 255.\n",
54 | " test_ds['image'] = jnp.float32(test_ds['image']) / 255.\n",
55 | " return train_ds, test_ds"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "id": "9996a480-b5dc-48b1-8fc1-d65c38aa9100",
61 | "metadata": {},
62 | "source": [
63 | "### 3. Build model"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "id": "0e81e1c5-30db-4ad3-a67d-ebab67a5c27f",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "class CNN(nn.Module):\n",
74 | " \"\"\"A simple CNN model.\"\"\"\n",
75 | "\n",
76 | " @nn.compact\n",
77 | " def __call__(self, x):\n",
78 | " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
79 | " x = nn.relu(x)\n",
80 | " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
81 | " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
82 | " x = nn.relu(x)\n",
83 | " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
84 | " x = x.reshape((x.shape[0], -1)) # flatten\n",
85 | " x = nn.Dense(features=256)(x)\n",
86 | " x = nn.relu(x)\n",
87 | " x = nn.Dense(features=10)(x)\n",
88 | " return x\n",
89 | "\n",
90 | " def classify_xe_loss(self, x, labels):\n",
91 | " # Labels read from the tfds MNIST are integers from 0 to 9. \n",
92 | " # Logits are arrays of size 10.\n",
93 | " logits = self(x)\n",
94 | " logits = jax.nn.log_softmax(logits)\n",
95 | " labels_ = jnp.expand_dims(labels, -1)\n",
96 | " llh_ = jnp.take_along_axis(logits, labels_, axis=-1)\n",
97 | " loss = -jnp.sum(llh_)\n",
98 | " return loss"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "id": "71146ddd-6e10-4bb8-9e7f-4fc8f952a6e9",
104 | "metadata": {},
105 | "source": [
106 | "### 4. Create train state\n",
107 | "\n",
108 | "A `TrainState` object keeps the model parameters and optimizer states, and can be checkpointed into files.\n",
109 | "\n",
110 | "We create the model and optimizer in this function.\n",
111 | "\n",
112 | "**For the optimizer, we use Amos here.** The following hyper-parameters are set:\n",
113 | "\n",
114 | " * *learning_rate*: The global learning rate.\n",
115 | " * *eta_fn*: The model-specific 'eta'.\n",
116 | " * *shape_fn*: Memory reduction setting.\n",
117 | " * *beta*: Rate for running average of gradient squares.\n",
118 | " * *clip_value*: Gradient clipping for stable training.\n",
119 | "\n",
120 | "The global learning rate is usually set to the 1/sqrt(N), where N is the number of batches in the training data. For MNIST, we have 60k training examples and batch size is 32. So learning_rate=1/sqrt(60000/32).\n",
121 | "\n",
122 | "The model-specific 'eta_fn' requires a function that, given a variable name and shape, returns a float indicating the expected scale of that variable. Hopefully in the near future we will have libraries that can automatically calculate this 'eta_fn' from the modeling code; but for now we have to specify it manually.\n",
123 | "\n",
124 | "One can use the amos_helper.params_fn_from_assign_map() helper function to create 'eta_fn' from an assign_map. An assign_map is a dict which maps regex rules to a value or simple Python expressions. It will find the first regex rule which matches the name of a variable, and evaluate the Python expression if necessary to return the value. See our example below.\n",
125 | "\n",
126 | "The 'shape_fn' similarly requires a function that, given a variable name and shape, returns a reduced shape for the corresponding slot variables. We can use the amos_helper.params_fn_from_assign_map() helper function to create 'shape_fn' from an assign_map as well.\n",
127 | "\n",
128 | "'beta' is the exponential decay rate for running average of gradient squares. We set it to 0.98 here.\n",
129 | "\n",
130 | "'clip_value' is the gradient clipping value, which should match the magnitude of the loss function. If the loss function is a sum of cross-entropy, then we should set 'clip_value' to the sqrt of the number of labels.\n",
131 | "\n",
132 | "Please refer to our [paper](https://arxiv.org/abs/2210.11693) for more details of the hyper-parameters."
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 4,
138 | "id": "eb049df8-70dc-447c-9a11-7166feb12d25",
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "def get_train_state(rng):\n",
143 | " model = CNN()\n",
144 | " dummy_x = jnp.ones([1, 28, 28, 1])\n",
145 | " params = model.init(rng, dummy_x)\n",
146 | "\n",
147 | " eta_fn = amos_helper.params_fn_from_assign_map(\n",
148 | " {\n",
149 | " '.*/bias': 0.5,\n",
150 | " '.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',\n",
151 | " '.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',\n",
152 | " '.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',\n",
153 | " '.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',\n",
154 | " },\n",
155 | " eval_str_value=True,\n",
156 | " )\n",
157 | " shape_fn = amos_helper.params_fn_from_assign_map(\n",
158 | " {\n",
159 | " '.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',\n",
160 | " '.*Dense_0/kernel': '(1, SHAPE[1])',\n",
161 | " '.*': (),\n",
162 | " },\n",
163 | " eval_str_value=True,\n",
164 | " )\n",
165 | " optimizer = amos.amos(\n",
166 | " learning_rate=1/math.sqrt(60000/32),\n",
167 | " eta_fn=eta_fn,\n",
168 | " shape_fn=shape_fn,\n",
169 | " beta=0.98,\n",
170 | " clip_value=math.sqrt(32),\n",
171 | " )\n",
172 | " return train_state.TrainState.create(\n",
173 | " apply_fn=model.apply, params=params, tx=optimizer)"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "id": "2b00564e-504e-4275-b80a-8deed0fde177",
179 | "metadata": {},
180 | "source": [
181 | "### 5. Training step\n",
182 | "\n",
183 | "Use JAX’s @jit decorator to just-in-time compile the function for better performance."
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 5,
189 | "id": "ca15ee35-eb1d-4685-8e98-568c1cafc08c",
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "@jax.jit\n",
194 | "def train_step(batch, state):\n",
195 | " grad_fn = jax.grad(state.apply_fn)\n",
196 | " grads = grad_fn(\n",
197 | " state.params,\n",
198 | " batch['image'],\n",
199 | " batch['label'],\n",
200 | " method=CNN.classify_xe_loss)\n",
201 | " return state.apply_gradients(grads=grads)"
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "id": "7f0a64e7-5507-49aa-8399-c3f20aa72a0e",
207 | "metadata": {},
208 | "source": [
209 | "### 6. Infer step\n",
210 | "\n",
211 | "Use JAX’s @jit decorator to just-in-time compile the function for better performance."
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 6,
217 | "id": "ad5ef40d-7fac-479f-ad5f-8bda7f839b66",
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "@jax.jit\n",
222 | "def infer_step(batch, state):\n",
223 | " logits = state.apply_fn(state.params, batch['image'])\n",
224 | " return jnp.argmax(logits, -1)"
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "id": "5429829e-2f7a-4fb6-acff-d4caaa3a20f6",
230 | "metadata": {},
231 | "source": [
232 | "### 7. Main\n",
233 | "\n",
234 | "Run the training loop and evaluate on test set."
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 7,
240 | "id": "2b181696-9268-4f8e-b359-249a544c53b1",
241 | "metadata": {},
242 | "outputs": [
243 | {
244 | "name": "stderr",
245 | "output_type": "stream",
246 | "text": [
247 | "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
248 | ]
249 | },
250 | {
251 | "name": "stdout",
252 | "output_type": "stream",
253 | "text": [
254 | "epoch: 1, test accuracy: 97.28\n",
255 | "epoch: 2, test accuracy: 98.46\n",
256 | "epoch: 3, test accuracy: 98.63\n",
257 | "epoch: 4, test accuracy: 97.91\n",
258 | "epoch: 5, test accuracy: 98.59\n",
259 | "epoch: 6, test accuracy: 99.05\n",
260 | "epoch: 7, test accuracy: 99.15\n",
261 | "epoch: 8, test accuracy: 99.21\n",
262 | "epoch: 9, test accuracy: 99.26\n"
263 | ]
264 | }
265 | ],
266 | "source": [
267 | "train_ds, test_ds = get_datasets()\n",
268 | "\n",
269 | "rng = jax.random.PRNGKey(0)\n",
270 | "rng, init_rng = jax.random.split(rng)\n",
271 | "state = get_train_state(init_rng)\n",
272 | "del init_rng # Must not be used anymore.\n",
273 | "\n",
274 | "num_epochs = 9\n",
275 | "for epoch in range(1, num_epochs + 1):\n",
276 | " # Use a separate PRNG key to permute image data during shuffling\n",
277 | " rng, input_rng = jax.random.split(rng)\n",
278 | " perms = jax.random.permutation(input_rng, 60000)\n",
279 | " del input_rng\n",
280 | " perms = perms.reshape((60000 // 32, 32))\n",
281 | " for perm in perms:\n",
282 | " batch = {k: v[perm, ...] for k, v in train_ds.items()}\n",
283 | " state = train_step(batch, state)\n",
284 | "\n",
285 | " pred = jax.device_get(infer_step(test_ds, state))\n",
286 | " accuracy = accuracy_score(test_ds['label'], pred)\n",
287 | " print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "id": "9236e19d-3fa1-43b0-90d4-a29df8613deb",
294 | "metadata": {},
295 | "outputs": [],
296 | "source": []
297 | }
298 | ],
299 | "metadata": {
300 | "kernelspec": {
301 | "display_name": "Python 3 (ipykernel)",
302 | "language": "python",
303 | "name": "python3"
304 | },
305 | "language_info": {
306 | "codemirror_mode": {
307 | "name": "ipython",
308 | "version": 3
309 | },
310 | "file_extension": ".py",
311 | "mimetype": "text/x-python",
312 | "name": "python",
313 | "nbconvert_exporter": "python",
314 | "pygments_lexer": "ipython3",
315 | "version": "3.10.7"
316 | }
317 | },
318 | "nbformat": 4,
319 | "nbformat_minor": 5
320 | }
321 |
--------------------------------------------------------------------------------
/jestimator/models/mnist/mnist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""MNIST Example using JEstimator.
16 |
17 | # For debug run locally:
18 |
19 | ## Train:
20 |
21 | ```
22 | PYTHONPATH=. python3 jestimator/estimator.py \
23 | --module_imp="jestimator.models.mnist.mnist" \
24 | --module_config="jestimator/models/mnist/mnist.py" \
25 | --train_pattern="tfds://mnist/split=train" \
26 | --model_dir="$HOME/experiments/mnist/models" \
27 | --train_batch_size=32 \
28 | --train_shuffle_buf=4096 \
29 | --train_epochs=9 \
30 | --check_every_steps=100 \
31 | --logtostderr
32 | ```
33 |
34 | ## Eval continuously:
35 |
36 | ```
37 | PYTHONPATH=. python3 jestimator/estimator.py \
38 | --module_imp="jestimator.models.mnist.mnist" \
39 | --module_config="jestimator/models/mnist/mnist.py" \
40 | --eval_pattern="tfds://mnist/split=test" \
41 | --model_dir="$HOME/experiments/mnist/models" \
42 | --eval_batch_size=32 \
43 | --mode="eval_wait" \
44 | --check_ckpt_every_secs=1 \
45 | --save_high="test_accuracy" \
46 | --logtostderr
47 | ```
48 | """
49 | import math
50 | import re
51 |
52 | from flax import linen as nn
53 | import jax
54 | import jax.numpy as jnp
55 | from jestimator import amos
56 | from jestimator import amos_helper
57 | from jestimator.data.pipeline_rec import rec_data
58 | from jestimator.states import Evaluator, InferState, MeanMetrics, TrainState # pylint: disable=g-multiple-import
59 | import ml_collections
60 | from sklearn.metrics import accuracy_score
61 | import tensorflow as tf
62 | import tensorflow_datasets as tfds
63 |
64 |
65 | def get_config():
66 | """Returns a config object for modeling flags."""
67 | module_config = ml_collections.ConfigDict()
68 | module_config.warmup = 2000
69 | module_config.amos_beta = 0.98
70 | return module_config
71 |
72 |
73 | def load_config(global_flags):
74 | """Init config data from global flags."""
75 | config = ml_collections.ConfigDict()
76 | config.update(global_flags.module_config)
77 | config.train_batch_size = global_flags.train_batch_size
78 |
79 | # Only a frozen config (hashable object) can be passed to jit functions
80 | # (i.e. train_step/valid_step/infer_step).
81 | config.frozen = ml_collections.FrozenConfigDict(config)
82 |
83 | # Construct data pipelines in the following (using TensorFLow):
84 | def dataset_fn(path: str) -> tf.data.Dataset:
85 | # Assum `path` is of the form 'tfds://{name}/split={split}'
86 | m = re.match('tfds://(.*)/split=(.*)', path)
87 | assert m is not None, (f'Cannot parse "{path}" (should be of the form '
88 | '"tfds://{name}/split={split}").')
89 | name = m.group(1)
90 | split = m.group(2)
91 | builder = tfds.builder(name)
92 |
93 | # Use dataset info to setup model.
94 | info = builder.info
95 | config.image_shape = info.features['image'].shape
96 | config.num_classes = info.features['label'].num_classes
97 | config.num_examples = info.splits[split].num_examples
98 |
99 | builder.download_and_prepare()
100 | return builder.as_dataset(split=split)
101 |
102 | def eval_feature_fn(x):
103 | # For evaluation, we should return a (gold, data) tuple.
104 | label = x.pop('label')
105 | return label, x
106 |
107 | # `pipeline_rec.rec_data` wraps dataset of record type
108 | # (i.e. each record is a single data point)
109 | config.train_data_fn = rec_data(dataset_fn, interleave=True)
110 | config.eval_data_fn = rec_data(dataset_fn, feature_fn=eval_feature_fn)
111 | return config
112 |
113 |
114 | class CNN(nn.Module):
115 | """A simple CNN model."""
116 | num_classes: int
117 |
118 | @nn.compact
119 | def __call__(self, x):
120 | x /= 255. # Each value (pixel) in input image is 0~255.
121 | x = nn.Conv(features=32, kernel_size=(3, 3))(x)
122 | x = nn.relu(x)
123 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
124 | x = nn.Conv(features=64, kernel_size=(3, 3))(x)
125 | x = nn.relu(x)
126 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
127 | x = x.reshape((x.shape[0], -1)) # flatten
128 | x = nn.Dense(features=256)(x)
129 | x = nn.relu(x)
130 | x = nn.Dense(features=self.num_classes)(x)
131 | return x
132 |
133 | def classify_xe_loss(self, x, labels):
134 | # Labels read from the tfds MNIST are integers from 0 to 9.
135 | # Logits are arrays of size 10.
136 | logits = self(x)
137 | logits = jax.nn.log_softmax(logits)
138 | labels_ = jnp.expand_dims(labels, -1)
139 | llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
140 | loss = -jnp.sum(llh_)
141 | return loss
142 |
143 |
144 | def get_train_state(config, rng):
145 | """Create train state."""
146 | model = CNN(num_classes=config.num_classes)
147 | # Shape of input is (batch_size,) + image_shape.
148 | dummy_x = jnp.zeros((1,) + config.image_shape)
149 | # Use a separate module to record training time metrics.
150 | metrics_mod = MeanMetrics.create('train_loss')
151 |
152 | def lr_schedule(step):
153 | # Set up a warm-up schedule.
154 | lr = math.sqrt(config.train_batch_size / config.num_examples)
155 | lr *= jnp.minimum(1., step / config.warmup)
156 | return lr
157 |
158 | eta_fn = amos_helper.params_fn_from_assign_map(
159 | {
160 | '.*/bias': 0.5,
161 | '.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-2]))',
162 | '.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
163 | '.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
164 | '.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
165 | },
166 | eval_str_value=True,
167 | )
168 | shape_fn = amos_helper.params_fn_from_assign_map(
169 | {
170 | '.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
171 | '.*Dense_0/kernel': '(1, SHAPE[1])',
172 | '.*': (),
173 | },
174 | eval_str_value=True,
175 | )
176 | optimizer = amos.amos(
177 | learning_rate=lr_schedule,
178 | eta_fn=eta_fn,
179 | shape_fn=shape_fn,
180 | beta=config.amos_beta,
181 | clip_value=math.sqrt(config.train_batch_size),
182 | )
183 | return TrainState.create(metrics_mod, optimizer, model, rng, dummy_x)
184 |
185 |
186 | def train_step(config, train_batch, state: TrainState, metrics):
187 | """Training step."""
188 | loss, grads = state.value_and_grad_apply_fn()(
189 | state.params,
190 | train_batch['image'],
191 | train_batch['label'],
192 | method=CNN.classify_xe_loss)
193 | _, metrics = state.metrics_mod.apply(
194 | metrics,
195 | 'train_loss',
196 | loss,
197 | config.train_batch_size,
198 | method=MeanMetrics.update,
199 | mutable=['metrics'])
200 | return state.apply_gradients(grads=grads), metrics
201 |
202 |
203 | def get_infer_state(config):
204 | """Create infer state."""
205 | model = CNN(num_classes=config.num_classes)
206 | dummy_x = jnp.zeros((1,) + config.image_shape)
207 | return InferState.create(model, dummy_x)
208 |
209 |
210 | def infer_step(config, batch, state: InferState):
211 | """Infer step."""
212 | del config # Unused.
213 | logits = state.apply_fn(state.variables(), batch['image'])
214 | return state.replace(ret=jnp.argmax(logits, -1))
215 |
216 |
217 | def get_evaluator(config) -> Evaluator:
218 | """Create evaluator."""
219 | del config # Unused.
220 | return Evaluator({'test_accuracy': (lambda y: y, accuracy_score)})
221 |
--------------------------------------------------------------------------------
/jestimator/models/rope/finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Sequence classification.
16 |
17 | # For debug run locally:
18 |
19 | ## Train:
20 |
21 | ```
22 | PYTHONPATH=. python3 \
23 | jestimator/estimator.py \
24 | --module_imp="jestimator.models.rope.finetune" \
25 | --module_config="jestimator/models/rope/finetune.py" \
26 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
27 | --module_config.segment_names="sentence1,sentence2" \
28 | --module_config.model_config.num_labels=2 \
29 | --train_pattern="tfds://glue/rte/split=train" \
30 | --valid_pattern="tfds://glue/rte/split=validation" \
31 | --model_dir="$HOME/models/rope_rte" \
32 | --checkpoint_path="gs://gresearch/checkpoints_in_amos_paper/\
33 | adamw/rope-base/checkpoint_300000" \
34 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \
35 | --check_every_steps=10 --logtostderr
36 | ```
37 |
38 | ## Eval:
39 |
40 | ```
41 | PYTHONPATH=. python3 \
42 | jestimator/estimator.py \
43 | --module_imp="jestimator.models.rope.finetune" \
44 | --module_config="jestimator/models/rope/finetune.py" \
45 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
46 | --module_config.segment_names="sentence1,sentence2" \
47 | --module_config.model_config.num_labels=2 \
48 | --module_config.eval_metric="accuracy" \
49 | --eval_pattern="tfds://glue/rte/split=validation" \
50 | --model_dir="$HOME/models/rope_rte" \
51 | --eval_batch_size=4 --num_eval_examples=4 \
52 | --logtostderr
53 | ```
54 |
55 | ## Predict:
56 |
57 | ```
58 | PYTHONPATH=. python3 \
59 | jestimator/estimator.py \
60 | --module_imp="jestimator.models.rope.finetune" \
61 | --module_config="jestimator/models/rope/finetune.py" \
62 | --module_config.vocab_path="$HOME/data/sentence_piece/sp.model" \
63 | --module_config.segment_names="sentence1,sentence2" \
64 | --module_config.model_config.num_labels=2 \
65 | --module_config.label_names="entailment,not_entailment" \
66 | --pred_pattern="tfds://glue/rte/split=test" \
67 | --model_dir="$HOME/models/rope_rte" \
68 | --pred_batch_size=4 --num_pred_examples=4 \
69 | --logtostderr
70 | ```
71 | """
72 |
73 | import dataclasses
74 |
75 | import jax
76 | import jax.numpy as jnp
77 | from jestimator.data import reader
78 | from jestimator.data.pipeline_rec import rec_data
79 | from jestimator.data.reader import PyOutSpec
80 | from jestimator.models.rope import modeling
81 | from jestimator.states import Evaluator, InferState, MeanMetrics, Predictor, TrainState # pylint: disable=g-multiple-import
82 | import ml_collections
83 | from ml_collections.config_dict import config_dict
84 | import optax
85 | from scipy import stats as scipy_stats
86 | from sklearn import metrics as sklearn_metrics
87 | import tensorflow as tf
88 |
89 | import sentencepiece as spm
90 |
91 |
92 | def get_config():
93 | """Returns a config object for modeling flags."""
94 | module_config = ml_collections.ConfigDict()
95 |
96 | # Model config.
97 | model_config = modeling.ModelConfig()
98 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config))
99 | module_config.model_config = model_config
100 |
101 | # Optimizer config.
102 | opt_config = ml_collections.ConfigDict()
103 | opt_config.optimizer = 'adam'
104 | opt_config.learning_rate = 5e-6
105 | module_config.opt_config = opt_config
106 |
107 | # Other config.
108 | module_config.vocab_path = config_dict.placeholder(str)
109 | module_config.segment_names = config_dict.placeholder(str)
110 | module_config.eval_metric = config_dict.placeholder(str)
111 | module_config.output_path = config_dict.placeholder(str)
112 | module_config.label_names = config_dict.placeholder(str)
113 | module_config.stsb = False
114 | return module_config
115 |
116 |
117 | def load_config(global_flags):
118 | """Init config data from global flags."""
119 | config = ml_collections.ConfigDict()
120 | config.update(global_flags.module_config)
121 |
122 | tokenizer = spm.SentencePieceProcessor()
123 | tokenizer.Load(config.vocab_path)
124 | config.model_config.vocab_size = tokenizer.GetPieceSize()
125 |
126 | segment_names = config.segment_names.split(',')
127 | num_segments = len(segment_names)
128 | config.model_config.num_segments = num_segments + 1
129 |
130 | # Only a frozen config (hashable object) can be passed to jit functions
131 | # (i.e. train_step/valid_step/infer_step).
132 | config.frozen = ml_collections.FrozenConfigDict(config)
133 |
134 | # Construct data pipelines in the following (using TensorFLow):
135 | max_length = config.model_config.max_length
136 | max_len_1 = (max_length - 1) // num_segments
137 | cls_token_id = tokenizer.PieceToId('')
138 | sep_token_id = tokenizer.PieceToId('')
139 | eos_token_id = tokenizer.PieceToId('')
140 | data_keys = ['idx', 'label'] + config.segment_names.split(',')
141 | mode = global_flags.mode
142 |
143 | def tokenize_fn(texts):
144 | ids = []
145 | for s in texts:
146 | s = tf.strings.lower(s).numpy()
147 | ids.append(tf.convert_to_tensor(tokenizer.EncodeAsIds(s), tf.int32))
148 | return ids
149 |
150 | def example_fn(data):
151 | data = {k: data[k] for k in data_keys if k in data}
152 | texts = [data[k] for k in segment_names]
153 | out_spec = [PyOutSpec((-1,), tf.int32)] * num_segments
154 | tokenized = reader.apply_py_fn(tokenize_fn, texts, out_spec)
155 |
156 | max_len_0 = max_length - 1
157 | input_ids = [tf.concat([[cls_token_id], tokenized[0]], 0)]
158 | for x in tokenized[1:]:
159 | x = tf.concat([[sep_token_id], x], 0)[:max_len_1]
160 | input_ids.append(x)
161 | max_len_0 = max_len_0 - tf.shape(x)[0]
162 | input_ids[0] = input_ids[0][:max_len_0]
163 | input_ids.append([eos_token_id])
164 |
165 | segment_ids = [tf.ones_like(x) * i for i, x in enumerate(input_ids)]
166 | input_ids = tf.concat(input_ids, 0)
167 | input_mask = tf.ones_like(input_ids)
168 | segment_ids = tf.concat(segment_ids, 0)
169 |
170 | pad_len = max_length - tf.shape(input_ids)[0]
171 | input_ids = tf.pad(input_ids, [[0, pad_len]])
172 | input_mask = tf.pad(input_mask, [[0, pad_len]])
173 | segment_ids = tf.pad(segment_ids, [[0, pad_len]])
174 |
175 | ret = {
176 | 'input_ids': tf.ensure_shape(input_ids, (max_length,)),
177 | 'input_mask': tf.ensure_shape(input_mask, (max_length,)),
178 | 'segment_ids': tf.ensure_shape(segment_ids, (max_length,)),
179 | }
180 | if mode == 'train':
181 | ret['label'] = data['label']
182 | if config.stsb:
183 | ret['label'] /= 5.0
184 | else:
185 | ret['idx'] = data['idx']
186 | if mode.startswith('eval'):
187 | ret = (data['label'], ret)
188 | return ret
189 |
190 | def dataset_fn(path: str) -> tf.data.Dataset:
191 | d = reader.get_tfds_dataset(path)
192 | d = d.map(example_fn, tf.data.AUTOTUNE)
193 | return d
194 |
195 | config.train_data_fn = rec_data(
196 | dataset_fn=dataset_fn, cache=True, interleave=True)
197 | config.eval_data_fn = config.valid_data_fn = rec_data(
198 | dataset_fn=dataset_fn, cache=True)
199 | config.pred_data_fn = rec_data(dataset_fn=dataset_fn)
200 | return config
201 |
202 |
203 | def get_train_state(config, rng):
204 | """Create train state."""
205 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
206 | model = modeling.ModelForSeqCls(model_config)
207 |
208 | opt_config = config.opt_config
209 | if opt_config.optimizer == 'adam':
210 | optimizer = optax.adam(learning_rate=opt_config.learning_rate)
211 |
212 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss')
213 | return TrainState.create(metrics_mod, optimizer, model, rng, jnp.array([[0]]))
214 |
215 |
216 | def train_step(config, train_batch, state: TrainState, metrics):
217 | """Training step."""
218 | loss_fn = (
219 | modeling.ModelForSeqCls.mse_loss
220 | if config.stsb else modeling.ModelForSeqCls.xe_loss)
221 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)(
222 | state.params,
223 | train_batch['label'],
224 | train_batch['input_ids'],
225 | segment_ids=train_batch['segment_ids'],
226 | input_mask=train_batch['input_mask'],
227 | enable_dropout=True,
228 | method=loss_fn)
229 | _, metrics = state.metrics_mod.apply(
230 | metrics,
231 | 'train_loss',
232 | loss,
233 | size,
234 | method=MeanMetrics.update,
235 | mutable=['metrics'])
236 | return state.apply_gradients(grads=grads), metrics
237 |
238 |
239 | def valid_step(config, valid_batch, state: TrainState, metrics):
240 | """Validation step."""
241 | loss_fn = (
242 | modeling.ModelForSeqCls.mse_loss
243 | if config.stsb else modeling.ModelForSeqCls.xe_loss)
244 | loss, size = state.apply_fn(
245 | state.variables(),
246 | valid_batch['label'],
247 | valid_batch['input_ids'],
248 | segment_ids=valid_batch['segment_ids'],
249 | input_mask=valid_batch['input_mask'],
250 | method=loss_fn)
251 | _, metrics = state.metrics_mod.apply(
252 | metrics,
253 | 'valid_loss',
254 | loss,
255 | size,
256 | method=MeanMetrics.update,
257 | mutable=['metrics'])
258 | return metrics
259 |
260 |
261 | def get_infer_state(config):
262 | """Create infer state."""
263 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
264 | model = modeling.ModelForSeqCls(model_config)
265 | return InferState.create(model, jnp.array([[0]]))
266 |
267 |
268 | def infer_step(config, batch, state: InferState) -> InferState:
269 | """Infer step."""
270 | logits = state.apply_fn(
271 | state.variables(),
272 | batch['input_ids'],
273 | segment_ids=batch['segment_ids'],
274 | input_mask=batch['input_mask'])
275 | if config.stsb:
276 | pred = jax.nn.softmax(logits)[..., 0] * 5.0
277 | else:
278 | pred = jnp.argmax(logits, axis=-1)
279 | return state.replace(ret={
280 | 'idx': batch['idx'],
281 | 'prediction': pred,
282 | })
283 |
284 |
285 | def get_evaluator(config) -> Evaluator:
286 | """Create evaluator."""
287 | eval_fns = {
288 | 'accuracy': sklearn_metrics.accuracy_score,
289 | 'f1': sklearn_metrics.f1_score,
290 | 'spearmanr': lambda x, y: scipy_stats.spearmanr(x, y)[0],
291 | }
292 |
293 | def proc_fn(infer):
294 | return infer['prediction']
295 |
296 | metric = config.eval_metric
297 | return Evaluator({metric: (proc_fn, eval_fns[metric])})
298 |
299 |
300 | def get_predictor(config) -> Predictor:
301 | """Create predictor."""
302 | pre_str = 'index\tprediction'
303 | label_names = (None if config.label_names is None else
304 | config.label_names.split(','))
305 |
306 | def proc_fn(infer):
307 | ret = []
308 | for x, y in zip(infer['idx'], infer['prediction']):
309 | z = y if label_names is None else label_names[y]
310 | ret.append(f'{x}\t{z}')
311 | return ret
312 |
313 | return Predictor(proc_fn, config.output_path, pre_str=pre_str)
314 |
--------------------------------------------------------------------------------
/jestimator/models/rope/pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The jestimator Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Pretrain.
16 |
17 | # For debug run locally:
18 |
19 | ```
20 | PYTHONPATH=. python3 \
21 | jestimator/estimator.py \
22 | --module_imp="jestimator.models.rope.pretrain" \
23 | --module_config="jestimator/models/rope/pretrain.py" \
24 | --module_config.model_config.vocab_size=32000 \
25 | --module_config.mask_token_id=4 \
26 | --train_pattern="gs://gresearch/checkpoints_in_amos_paper/data/\
27 | books-00000-of-00500" \
28 | --valid_pattern="gs://gresearch/checkpoints_in_amos_paper/data/ptb" \
29 | --model_dir="$HOME/models/rope_pretrain" \
30 | --train_batch_size=4 --valid_batch_size=4 --num_valid_examples=4 \
31 | --check_every_steps=10 --logtostderr
32 | ```
33 | """
34 |
35 | import dataclasses
36 | from typing import Mapping
37 |
38 | import jax
39 | import jax.numpy as jnp
40 | from jestimator import amos
41 | from jestimator.data.pipeline_lm import lm_data
42 | from jestimator.models.rope import modeling
43 | from jestimator.states import TrainState, MeanMetrics # pylint: disable=g-multiple-import
44 | import ml_collections
45 | from ml_collections.config_dict import config_dict
46 | import optax
47 | import tensorflow as tf
48 |
49 |
50 | def get_config():
51 | """Returns a config object for modeling flags."""
52 | module_config = ml_collections.ConfigDict()
53 |
54 | # Model config.
55 | model_config = modeling.ModelConfig()
56 | model_config = ml_collections.ConfigDict(dataclasses.asdict(model_config))
57 | module_config.model_config = model_config
58 |
59 | # Optimizer config.
60 | opt_config = ml_collections.ConfigDict()
61 | opt_config.optimizer = 'adamw'
62 | opt_config.learning_rate = 1e-4
63 | opt_config.warmup_steps = 10000
64 | opt_config.linear_decay_to_step = config_dict.placeholder(int)
65 | opt_config.momentum = 0.9
66 | opt_config.beta = 0.999
67 | opt_config.weight_decay = 0.01
68 | module_config.opt_config = opt_config
69 |
70 | # Other config.
71 | module_config.mask_token_id = config_dict.placeholder(int)
72 | module_config.mask_rate = 0.15
73 | return module_config
74 |
75 |
76 | def load_config(global_flags):
77 | """Init config data from global flags."""
78 | config = ml_collections.ConfigDict()
79 | config.update(global_flags.module_config)
80 |
81 | # Only a frozen config (hashable object) can be passed to jit functions
82 | # (i.e. train_step/valid_step/infer_step).
83 | config.frozen = ml_collections.FrozenConfigDict(config)
84 |
85 | # Construct data pipelines in the following (using TensorFLow):
86 | seq_length = config.model_config.max_length
87 |
88 | def feature_fn(token_ids: tf.Tensor) -> Mapping[str, tf.Tensor]:
89 | """Builds a feature dict to be compatible with seqio."""
90 | return {'targets': tf.ensure_shape(token_ids, (seq_length,))}
91 |
92 | config.train_data_fn = lm_data(
93 | seq_length, random_skip=True, feature_fn=feature_fn, interleave=True)
94 | config.valid_data_fn = lm_data(seq_length, feature_fn=feature_fn)
95 | return config
96 |
97 |
98 | def get_train_state(config, rng) -> TrainState:
99 | """Create train state."""
100 | model_config = modeling.ModelConfig(**config.model_config.to_dict())
101 | model = modeling.ModelForPretrain(model_config)
102 | opt_config = config.opt_config
103 | warmup = opt_config.warmup_steps
104 | decay = opt_config.linear_decay_to_step
105 |
106 | def lr_schedule(step):
107 | lr = opt_config.learning_rate
108 | if warmup is not None:
109 | lr *= jnp.minimum(1., step / warmup)
110 | if decay is not None:
111 | lr *= 1. - jnp.maximum(0., step - warmup) / (decay - warmup)
112 | elif decay is not None:
113 | lr *= 1. - step / decay
114 | return lr
115 |
116 | if opt_config.optimizer == 'adamw':
117 | optimizer = optax.adamw(
118 | learning_rate=lr_schedule,
119 | b1=opt_config.momentum,
120 | b2=opt_config.beta,
121 | weight_decay=opt_config.weight_decay)
122 | elif opt_config.optimizer == 'amos':
123 | optimizer = amos.amos(
124 | lr_schedule,
125 | modeling.get_eta_fn(model_config),
126 | shape_fn=modeling.get_shape_fn(model_config),
127 | beta=opt_config.beta,
128 | momentum=opt_config.momentum,
129 | clip_value=1.)
130 |
131 | metrics_mod = MeanMetrics.create('train_loss', 'valid_loss', 'valid_mrr')
132 | return TrainState.create(metrics_mod, optimizer, model, rng, jnp.array([[0]]))
133 |
134 |
135 | def train_step(config, train_batch, state: TrainState, metrics):
136 | """Training step."""
137 | (loss, size), grads = state.value_and_grad_apply_fn(has_aux=True)(
138 | state.params,
139 | train_batch['targets'],
140 | config.mask_token_id,
141 | mask_rate=config.mask_rate,
142 | input_mask=train_batch.get('input_mask'),
143 | enable_dropout=True,
144 | method=modeling.ModelForPretrain.mlm_train_loss)
145 | _, metrics = state.metrics_mod.apply(
146 | metrics,
147 | 'train_loss',
148 | loss,
149 | size,
150 | method=MeanMetrics.update,
151 | mutable=['metrics'])
152 | return state.apply_gradients(grads=grads), metrics
153 |
154 |
155 | def valid_step(config, valid_batch, state: TrainState, metrics):
156 | """Validation step."""
157 |
158 | def body(i, metrics):
159 | del i # Unused.
160 | loss, mrr, size = state.apply_fn(
161 | state.variables(),
162 | valid_batch['targets'],
163 | config.mask_token_id,
164 | mask_rate=config.mask_rate,
165 | input_mask=valid_batch.get('input_mask'),
166 | method=modeling.ModelForPretrain.mlm_valid_metrics)
167 | _, metrics = state.metrics_mod.apply(
168 | metrics,
169 | 'valid_loss',
170 | loss,
171 | size,
172 | method=MeanMetrics.update,
173 | mutable=['metrics'])
174 | _, metrics = state.metrics_mod.apply(
175 | metrics,
176 | 'valid_mrr',
177 | mrr,
178 | size,
179 | method=MeanMetrics.update,
180 | mutable=['metrics'])
181 | return metrics
182 |
183 | return jax.lax.fori_loop(0, 20, body, metrics)
184 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "jestimator"
3 | description = "Implementation of the Amos optimizer from the JEstimator lib."
4 | readme = "README.md"
5 | requires-python = ">=3.8"
6 | license = {file = "LICENSE"}
7 | authors = [{name = "jestimator authors", email="no-reply@google.com"}]
8 | classifiers = [
9 | "Programming Language :: Python :: 3",
10 | "Programming Language :: Python :: 3 :: Only",
11 | "License :: OSI Approved :: Apache Software License",
12 | "Intended Audience :: Science/Research",
13 | ]
14 | keywords = []
15 |
16 | # pip dependencies of the project
17 | dependencies = [
18 | "jax>=0.4.3",
19 | "flax",
20 | ]
21 |
22 | # This is set automatically by flit using `jestimator.__version__`
23 | dynamic = ["version"]
24 |
25 | [project.urls]
26 | homepage = "https://github.com/google-research/jestimator"
27 | repository = "https://github.com/google-research/jestimator"
28 | # Other: `documentation`, `changelog`
29 |
30 | [project.optional-dependencies]
31 | # Installed through `pip install .[test]`
32 | test = [
33 | "absl-py",
34 | ]
35 |
36 | [build-system]
37 | requires = ["flit_core >=3.5,<4"]
38 | build-backend = "flit_core.buildapi"
39 |
40 | # Only publish the Amos optimizer to pip.
41 | [tool.flit.sdist]
42 | exclude = [
43 | "jestimator/amos_helper_test.py",
44 | "jestimator/amos_test.py",
45 | "jestimator/checkpoint_utils.py",
46 | "jestimator/data/",
47 | "jestimator/data_utils.py",
48 | "jestimator/estimator.py",
49 | "jestimator/modeling.py",
50 | "jestimator/models/",
51 | "jestimator/states.py",
52 | ]
53 |
--------------------------------------------------------------------------------