├── .github
└── workflows
│ └── linux.yml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── nanodo
├── __init__.py
├── configs
│ └── default.py
├── data.py
├── evaluate.py
├── fsdp.py
├── loss.py
├── main.py
├── metrics.py
├── model.py
├── model_factory.py
├── optimizer.py
└── train.py
├── pyproject.toml
└── tests
├── data_test.py
├── evaluate_test.py
├── metrics_test.py
├── model_factory_test.py
├── model_test.py
├── optimizer_test.py
├── testdata
└── sentencepiece_cc_all.32000.100extra-sentencepiece.model
└── train_test.py
/.github/workflows/linux.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # Copied from github.com/google/jax-md and github.com/google/jax
3 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
4 |
5 | name: linux
6 |
7 | on:
8 | push:
9 | branches:
10 | - main
11 | - 'test_*'
12 |
13 | pull_request:
14 | branches:
15 | - main
16 |
17 | jobs:
18 | Linux:
19 |
20 | timeout-minutes: 120
21 |
22 | strategy:
23 | matrix:
24 | python-version: ['3.10', 3.11]
25 | JAX_ENABLE_X64: [0, 1]
26 |
27 | runs-on: ubuntu-latest
28 |
29 | steps:
30 |
31 | - name: Cancel previous
32 | uses: styfle/cancel-workflow-action@0.12.1
33 | with:
34 | access_token: ${{ github.token }}
35 | - uses: actions/checkout@v4.1.1
36 |
37 | - name: Set up Python ${{ matrix.python-version }}
38 | uses: actions/setup-python@v5.0.0
39 | with:
40 | python-version: ${{ matrix.python-version }}
41 |
42 | - name: Install dependencies
43 | run: |
44 | pip install --upgrade pip
45 | pip install -e .
46 | pip install -e .[test]
47 | pip install pytest
48 | pip install pytest-xdist
49 | pip install pytest-cov
50 |
51 | - name: Test with pytest and generate coverage report (Ubuntu)
52 | run: |
53 | JAX_ENABLE_X64=${{ matrix.JAX_ENABLE_X64 }} PYTHONHASHSEED=0 pytest -n auto --cov=nanodo --cov-report=xml --cov-report=term
54 |
55 | - name: Upload coverage to Codecov
56 | uses: codecov/codecov-action@v4.0.1
57 | with:
58 | file: ./coverage.xml
59 |
60 | # The below step just reports the success or failure of tests as a "commit status".
61 | # This is needed for copybara integration.
62 | - name: Report success or failure as github status
63 | if: always()
64 | shell: bash
65 | run: |
66 | status="${{ job.status }}"
67 | lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
68 | curl -sS --request POST \
69 | --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
70 | --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
71 | --header 'content-type: application/json' \
72 | --data '{
73 | "state": "'$lowercase_status'",
74 | "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
75 | "description": "'$status'",
76 | "context": "github-actions/linux"
77 | }'
78 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NanoDO: A minimal ("nano-sized") Transformer decoder-only language model implementation in JAX.
2 | Inspired by minGPT/nanoGPT and flax/examples we provide a minimal
3 | implementation of a Transformer decoder-only language model in Jax.
4 |
5 | The purpose is to be maximally hackable, forkable, and readable for researchers,
6 | to enable highly exploratory research. Magic is great for products, but it is
7 | harmful in many cases for research and so we minimize abstraction as a design
8 | goal.
9 |
10 | Currently we use:
11 |
12 | * [flax](https://github.com/google/flax) for modules
13 | * [optax](https://github.com/google-deepmind/optax) for optimization
14 | * [orbax](https://github.com/google/orbax) for checkpointing
15 | * [tfds](https://github.com/tensorflow/datasets) for data
16 | * [pygrain](https://github.com/google/grain) for data loading
17 | * [ConfigDict](https://github.com/google/ml_collections) for hyper-parameters.
18 |
19 |
20 | Design opinions:
21 |
22 | * Tensors have short names similar to math and have shapes in their names.
23 | No more shapes in comments. This violates the
24 | python style guide, but that was written for non-ML code.
25 | * We avoid long docstrings and let code self-document when possible. In
26 | particular, type hints makes a lot of python documentation redundant.
27 |
28 |
29 | Current model and training:
30 |
31 | * gelu activation function
32 | * learned position embedding
33 | * adamw optimizer
34 | * shared input and output embedding
35 | * Use both BOS and EOS
36 | * No biases on layernorm or weight parameters, which PaLM found to improve
37 | stability and speed
38 |
39 | Current parallelism:
40 |
41 | We use Fully Sharded Data Parallel (FSDP) for parallelism. Model parameters
42 | and the optimizer state are sharded among the devices. These shardings are
43 | passed to jit, which is responsible for determining how to all-gather weights
44 | when necessary.
45 |
46 | ## Setup (open-source, Linux/CPU)
47 |
48 | ```
49 | python3.11 -m venv /tmp/nanodo_test_env
50 | source /tmp/nanodo_test_env/bin/activate
51 | cd [path_to_repo]
52 | pip install -e .
53 |
54 | # Run tests
55 | pip install pytest pytest-xdist
56 | PYTHONHASHSEED=0 pytest -n auto -rA
57 |
58 | # Run training example:
59 | python nanodo/main.py \
60 | --config=nanodo/configs/default.py \
61 | --config.workdir=/tmp/nanodo_workdir \
62 | --config.vocab_path=tests/testdata/sentencepiece_cc_all.32000.100extra-sentencepiece.model \
63 | --config.model.L=128 \
64 | --config.batch_size=2 \
65 | --config.pygrain_worker_count=0 \
66 | 2> stderr.log
67 | ```
68 |
69 | Then point your [Tensorboard](https://github.com/tensorflow/tensorboard) to the workdir:
70 |
71 | ```
72 | tensorboard --logdir /tmp/nanodo_workdir
73 | ```
74 |
75 | If input-bound, try adjusting `config=pygrain_worker_count` to enable pygrain multi-processing.
76 |
77 | To use accelerators, ensure the appropriate JAX package is installed by following these [instructions](https://jax.readthedocs.io/en/latest/installation.html).
78 |
79 | ## Maintenance
80 |
81 | There are no guarantees that the software will be maintained going forward. The software is designed to be easily forked and modified.
82 |
83 | ## Citing NanoDO
84 |
85 | To cite this repository:
86 |
87 | ```
88 | @software{nanodo,
89 | author = {Peter J. Liu and Roman Novak and Jaehoon Lee and Mitchell Wortsman and Lechao Xiao and Katie Everett and Alexander A. Alemi and Mark Kurzeja and Pierre Marcenac and Izzeddin Gur and Simon Kornblith and Kelvin Xu and Gamaleldin Elsayed and Ian Fischer and Jeffrey Pennington and Ben Adlam and Jascha-Sohl Dickstein},
90 | title = {NanoDO: A minimal Transformer decoder-only language model implementation in {JAX}.},
91 | url = {http://github.com/google-deepmind/nanodo},
92 | version = {0.1.0},
93 | year = {2024},
94 | }
95 | ```
96 |
97 |
98 | Authors all performed work while at Google Brain / DeepMind. We also acknowledge the help of Anselm Levskaya, Gellért Weisz, Xinyang Geng, Yotam Doron, and Noah Fiedel.
99 |
100 | The first published paper to use (a fork of) the library was:
101 |
102 | [Wortsman et al. "Small-scale proxies for large-scale Transformer training instabilities." *ICLR 2024*.](https://openreview.net/forum?id=d8w0pmvXbZ)
--------------------------------------------------------------------------------
/nanodo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Nanodo public API."""
15 |
16 |
17 | __version__ = "0.0.1"
18 |
--------------------------------------------------------------------------------
/nanodo/configs/default.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Default Hyperparameter configuration.
15 |
16 | Usage:
17 | /bin/bash third_party/py/nanodo/run.sh --config=default
18 | """
19 |
20 | import ml_collections
21 |
22 |
23 | def get_config() -> ml_collections.ConfigDict:
24 | """Get the default hyperparameter configuration."""
25 | cfg = ml_collections.ConfigDict()
26 | cfg.seed = 42
27 |
28 | # Data
29 | cfg.batch_size = 256 # Global batch size. Must be divisible by the #devices.
30 | cfg.train_epochs = None # None=>infinite
31 | cfg.ds_name = "lm1b:1.1.0"
32 | cfg.vocab_path = "" # set to local-path
33 |
34 | # Transformer
35 | cfg.model = ml_collections.config_dict.create(
36 | D=512, # model/embed dim = qkv dim
37 | H=8, # num attention heads
38 | L=512, # max context/sequence length (move out of config?)
39 | N=6, # number of transformer block layers
40 | F=2048, # FF inner dimension
41 | dtype="bfloat16", # computation dtype.
42 | fsdp_enabled=True, # True to shard the model.
43 | remat=False, # Transformer block gradient checkpointing to save memory.
44 | )
45 |
46 | # Optimizer
47 | cfg.opt = ml_collections.config_dict.create(
48 | num_train_steps=100_000, # Note: lm1b has 30,301,028 training examples
49 | peak_learning_rate=0.0016,
50 | init_learning_rate=0.00016,
51 | final_learning_rate=0.00016,
52 | warmup_steps=1000,
53 | decay_type="cosine",
54 | weight_decay=0.1,
55 | clip_by_global_norm=None, # 1.0 is common for many well-known LLMs.
56 | optimizer="adamw",
57 | )
58 |
59 | # Checkpointing
60 | cfg.workdir = ""
61 | cfg.checkpoint = True
62 | cfg.checkpoint_every_steps = 2000
63 | # Path to the checkpoint to be restored. Note than new checkpoints will be
64 | # saved to the new workdir.
65 | cfg.checkpoint_restore_dir = None
66 | cfg.max_to_keep = 100
67 |
68 | # Eval
69 | cfg.eval_every_steps = 100
70 | cfg.eval_split = "test" # 306,688 examples
71 | cfg.eval_steps = 100 # less if this exceeds 1 epoch
72 | cfg.eval_max_target_length = 512
73 |
74 | # Logging
75 | cfg.write_train_metrics_every_steps = 1 # train loss, gradient norms, etc.
76 | cfg.write_perf_metrics_every_steps = 100 # steps_per_sec, uptime.
77 | # For Vizier interface, we currently require write_to_xm_measurements=True
78 | cfg.write_to_xm_measurements = True
79 | # Option to turn on internal statistics: rms_norm, mean, std of per-layer,
80 | # module-wise statistics. Due to high-load, when setting this to True consider
81 | # turning off writing to XM measurements and rely on Datatables.
82 | cfg.log_internal_metrics = True
83 |
84 | # pygrain
85 | cfg.pygrain_worker_count = 16 # might increase this if input-bound
86 | # Buffer size (in unit of batches) for the data loader. Default to 2 so we
87 | # always prefetch another batch
88 | cfg.pygrain_worker_buffer_size = 2
89 |
90 | return cfg
91 |
--------------------------------------------------------------------------------
/nanodo/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data pipeline."""
15 |
16 | from collections.abc import Mapping, Sequence
17 | import dataclasses
18 | import enum
19 | import functools
20 | from typing import Iterator
21 |
22 | import grain.python as grain
23 | import jax
24 | import jax.numpy as jnp
25 | import numpy as np
26 | import tensorflow_datasets as tfds
27 |
28 | import sentencepiece as spm
29 |
30 | PAD_ID = 0
31 | ### pure python helpers for use with grain ###
32 |
33 |
34 | class Preprocess(enum.Enum):
35 | NOAM_PACKED = 1
36 | PADDED = 2
37 |
38 |
39 | def py_batched_tfds(
40 | *,
41 | tfds_name: str,
42 | split: str,
43 | context_size: int,
44 | worker_count: int,
45 | vocab_path: str,
46 | batch_size: int,
47 | seed: int | None = 1234,
48 | num_epochs: int | None = None,
49 | num_records: int | None = None,
50 | preprocessing: Preprocess = Preprocess.NOAM_PACKED,
51 | worker_buffer_size: int = 2,
52 | shuffle: bool = True,
53 | ) -> grain.DataLoader:
54 | """Returns iterator for regularly batched text examples."""
55 | datasource = tfds.data_source(tfds_name, split=split)
56 | index_sampler = grain.IndexSampler(
57 | num_records=num_records if num_records is not None else len(datasource),
58 | num_epochs=num_epochs,
59 | shard_options=grain.NoSharding(),
60 | shuffle=shuffle,
61 | seed=seed,
62 | )
63 | spt = _SPTokenizer(vocab_path)
64 |
65 | pad_len = None if preprocessing == Preprocess.NOAM_PACKED else context_size
66 | pygrain_ops = [
67 | grain.MapOperation(
68 | map_function=functools.partial(
69 | _py_tokenize,
70 | spt=spt,
71 | pad_len=pad_len,
72 | )
73 | )
74 | ]
75 | if preprocessing == Preprocess.NOAM_PACKED:
76 | pygrain_ops.append(_NoamPack(context_size=context_size))
77 | elif preprocessing == Preprocess.PADDED:
78 | pygrain_ops.append(grain.MapOperation(map_function=np.array))
79 | else:
80 | raise ValueError(f'Unknown preprocessing: {preprocessing}')
81 | pygrain_ops.append(grain.Batch(batch_size=batch_size, drop_remainder=True))
82 | batched_dataloader = grain.DataLoader(
83 | data_source=datasource,
84 | operations=pygrain_ops,
85 | sampler=index_sampler,
86 | worker_count=worker_count,
87 | worker_buffer_size=worker_buffer_size,
88 | )
89 | return batched_dataloader
90 |
91 |
92 | def get_py_tokenizer(path: str) -> spm.SentencePieceProcessor:
93 | sp = spm.SentencePieceProcessor()
94 | sp.Load(path)
95 | assert sp.pad_id() == PAD_ID
96 | assert sp.eos_id() != -1
97 | assert sp.bos_id() != -1
98 | return sp
99 |
100 |
101 | # Need this because we can't pickle SentencePieceProcessor object
102 | class _SPTokenizer:
103 | """Wrapper class for SentencePiece tokenizer."""
104 |
105 | def __init__(self, vocab_path):
106 | self._tokenizer = None
107 | self._vocab_path = vocab_path
108 |
109 | def get_tokenizer(self) -> spm.SentencePieceProcessor:
110 | if not self._tokenizer:
111 | self._tokenizer = get_py_tokenizer(self._vocab_path)
112 | return self._tokenizer
113 |
114 |
115 | def _py_tokenize(
116 | features: Mapping[str, str],
117 | spt: _SPTokenizer,
118 | pad_len: int | None = None,
119 | pad_id: int = PAD_ID,
120 | ) -> Sequence[int]:
121 | """Tokenizes text into ids, optionally pads or truncates to pad_len."""
122 | text = features['text']
123 | tokenizer = spt.get_tokenizer()
124 | bos_id = tokenizer.bos_id()
125 | eos_id = tokenizer.eos_id()
126 | ids = tokenizer.EncodeAsIds(text)
127 |
128 | ids.insert(0, bos_id)
129 | ids.append(eos_id)
130 | if pad_len is not None:
131 | if len(ids) < pad_len:
132 | ids.extend([pad_id] * (pad_len - len(ids)))
133 | elif len(ids) > pad_len:
134 | ids = ids[:pad_len]
135 | return ids
136 |
137 |
138 | @dataclasses.dataclass
139 | class _NoamPack:
140 | """Pygrain operation for tokenizing and Noam packing text."""
141 |
142 | context_size: int
143 |
144 | def __call__(
145 | self, idseq_iterator: Iterator[grain.Record]
146 | ) -> Iterator[grain.Record]:
147 | packed_ids = []
148 | for input_record in idseq_iterator:
149 | start = 0
150 | while start < len(input_record.data):
151 | rem_data = input_record.data[start:]
152 | if len(packed_ids) + len(rem_data) < self.context_size:
153 | packed_ids.extend(rem_data) # use rest of example, move-on
154 | break
155 | else:
156 | take = self.context_size - len(packed_ids)
157 | packed_ids.extend(rem_data[:take])
158 | last_record_key = input_record.metadata.remove_record_key()
159 | yield grain.Record(
160 | last_record_key, np.array(packed_ids, dtype=np.int32)
161 | )
162 | start += take
163 | packed_ids = []
164 | # Drop remainder for simplicity.
165 | # We lose the rest of the example on restore.
166 |
167 |
168 | # pylint: disable=invalid-name
169 |
170 |
171 | def get_in_out(
172 | in_BxL: jax.Array,
173 | pad_id: int = PAD_ID,
174 | ) -> tuple[jax.Array, jax.Array, jax.Array]:
175 | """Returns input, output, and weights for a batch of examples."""
176 | # Assumes input of the form for eval.
177 | x_BxL = in_BxL
178 | y_BxL = jnp.pad(
179 | in_BxL[:, 1:],
180 | ((0, 0), (0, 1)),
181 | mode='constant',
182 | constant_values=pad_id,
183 | )
184 | weights_BxL = jnp.where(y_BxL != pad_id, 1, 0).astype(jnp.float32)
185 |
186 | return x_BxL, y_BxL, weights_BxL
187 |
--------------------------------------------------------------------------------
/nanodo/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Functions to evaluate nanodo runs."""
15 |
16 | # pylint: disable=invalid-name,g-importing-member,g-import-not-at-top
17 |
18 | import functools
19 | import math
20 | import os
21 | from typing import Any, Iterator, TYPE_CHECKING
22 |
23 | from absl import logging
24 | import jax
25 | from jax.sharding import Mesh
26 | from jax.sharding import NamedSharding
27 | from jax.sharding import PartitionSpec as P
28 | from nanodo import data
29 | from nanodo import metrics as metrics_lib
30 | import numpy as np
31 | from optax import losses
32 |
33 |
34 | if TYPE_CHECKING:
35 | from flax import linen as nn
36 | import ml_collections
37 |
38 |
39 | PyTree = Any
40 |
41 |
42 | # Conversion factor to bits per Byte from nats per tokens.
43 | _BPN = 1.0 / math.log(2)
44 |
45 | # (tfds_name, vocab_path) -> bits per Bytes.
46 | _TO_BPB = {
47 | (
48 | "lm1b:1.1.0",
49 | "cc_all.32000.100extra.bos.model",
50 | ): _BPN * (
51 | 10_449_751 / 41_715_169.0
52 | ), # 0.36139860649310773
53 | (
54 | "c4:3.1.0",
55 | "cc_all.32000.100extra.bos.model",
56 | ): _BPN * (
57 | 183_808_378 / 789_615_977.0
58 | ), # 0.3358334217374176
59 | (
60 | "huggingface:cerebras__slimpajama_627b", # validation
61 | "cc_all.32000.100extra.bos.model",
62 | ): _BPN * (
63 | 560_013_105 / 2_174_889_064.0
64 | ), # 0.3714801562937696
65 | }
66 |
67 |
68 | class Evaluator:
69 | """Executes eval."""
70 |
71 | def __init__(
72 | self,
73 | c: "ml_collections.ConfigDict",
74 | model: "nn.Module",
75 | eval_ds: "Iterator[np.ndarray]",
76 | mesh: Mesh,
77 | shardings: PyTree,
78 | ):
79 | self.step_fn = jax.jit(
80 | functools.partial(_eval_step, model=model, mesh=mesh),
81 | in_shardings=(
82 | shardings.params,
83 | NamedSharding(mesh, P()),
84 | ),
85 | out_shardings=(NamedSharding(mesh, P())),
86 | donate_argnames=("params", "in_BxL"),
87 | )
88 | self.c = c
89 | self.ds = eval_ds
90 | # Conversion factor to bits per Byte from nats per tokens.
91 | self.bpB = _TO_BPB.get((c.ds_name, os.path.basename(c.vocab_path)), None)
92 |
93 | def eval(self, params: PyTree) -> dict[str, float]:
94 | """Run eval with at most one epoch."""
95 | metrics = metrics_lib.Average()
96 | pending_metrics = metrics_lib.Average()
97 | i = 0
98 | for i, batch in enumerate(iter(self.ds)):
99 | new_metrics = self.step_fn(params, batch) # Async dispatch new step.
100 | # Get previous step's results and merge with metrics.
101 | metrics = metrics.merge(jax.device_get(pending_metrics))
102 | pending_metrics = new_metrics
103 | if i == self.c.eval_steps:
104 | logging.info("Ended eval at step %d (batch size %d)", i, batch.shape[0])
105 | break
106 | if i < self.c.eval_steps:
107 | logging.warning("Ran out of data at step %d. Stopping.", i)
108 | # Get the last step's results and merge with metrics.
109 | metrics = metrics.merge(jax.device_get(pending_metrics))
110 | output = {
111 | "loss": metrics.mean,
112 | "loss_std": metrics.sem,
113 | "loss_uc": metrics.mean + 3 * metrics.sem,
114 | }
115 | if self.bpB:
116 | output |= {
117 | "loss_bpB": output["loss"] * self.bpB,
118 | "loss_std_bpB": output["loss_std"] * self.bpB,
119 | "loss_uc_bpB": output["loss_uc"] * self.bpB,
120 | }
121 |
122 | output = {"eval_" + k: v for k, v in output.items()}
123 | # Dummy scalar to show high up in XM measurements.
124 | output["_eval_loss"] = output["eval_loss"]
125 | return output
126 |
127 |
128 | def _eval_step(
129 | params: PyTree,
130 | in_BxL: jax.Array,
131 | model: "nn.Module",
132 | mesh: Mesh | None = None,
133 | ) -> metrics_lib.Average:
134 | """Return evaluation metrics on a single batch of data."""
135 | if mesh is not None:
136 | in_BxL = jax.lax.with_sharding_constraint(
137 | in_BxL, NamedSharding(mesh, P("data"))
138 | )
139 | x_BxL, y_BxL, weights_BxL = data.get_in_out(in_BxL)
140 | logits_BxLxV = model.apply({"params": params}, x_BxL)
141 | return _compute_unnormed_metrics(logits_BxLxV, y_BxL, weights_BxL)
142 |
143 |
144 | def _compute_unnormed_metrics(
145 | logits_BxLxV: jax.Array,
146 | labels_BxL: jax.Array,
147 | weights_BxL: jax.Array,
148 | ) -> metrics_lib.Average:
149 | """Compute unnormalized summary metrics."""
150 | losses_BxL = losses.softmax_cross_entropy_with_integer_labels(
151 | logits_BxLxV, labels_BxL
152 | )
153 | return metrics_lib.Average.from_array(losses_BxL, mask=weights_BxL)
154 |
--------------------------------------------------------------------------------
/nanodo/fsdp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils to assist with FSDP."""
15 |
16 | # pylint: disable=g-importing-member
17 |
18 | from typing import Any
19 | from flax import linen as nn
20 |
21 |
22 | DoConfig = Any # model.DoConfig; model.py imports this module.
23 |
24 |
25 | # For a tensor with dims (n1, n2, ..., nk) a partitioning must be specified of
26 | # size (p1, p2, ..., pk).
27 | # Here we partition over one dim only, so exactly one pi = "data" and the rest
28 | # should be None. This means, partition the tensor on dim i over the "data" axis
29 | # and not on the rest. Note that the "data" axis is the axis used for data
30 | # parallel, and corresponds to the number of devices.
31 | # The condition is that ni must be divisible by number of devices, so this
32 | # partitioning therefore chooses the partitioning axis to be the model dim
33 | # as this is usually divisible by number of devices.
34 | def init(layer_type: str, docfg: DoConfig) -> nn.initializers.Initializer:
35 | """This function specifies the partitioning of various transformer layers."""
36 | partition_fn = nn.with_partitioning if docfg.fsdp_enabled else lambda x, y: x
37 | if layer_type == "embedding": # [V, D]
38 | return partition_fn(docfg.embed_init, (None, "data"))
39 | elif layer_type == "attn_in_proj": # [D, H, Dh]
40 | return partition_fn(docfg.kernel_init, ("data", None, None))
41 | elif layer_type == "attn_out_proj": # [H, Dh, D]
42 | return partition_fn(docfg.kernel_init, (None, None, "data"))
43 | elif layer_type == "mlp_kernel": # [D, F]
44 | return partition_fn(docfg.kernel_init, ("data", None))
45 | elif layer_type == "head": # [D, V]
46 | if hasattr(docfg, "head_init"):
47 | return partition_fn(docfg.head_init, ("data", None))
48 | else:
49 | return partition_fn(docfg.kernel_init, ("data", None))
50 | else:
51 | raise ValueError(f"unrecognized layer type: {layer_type}")
52 |
--------------------------------------------------------------------------------
/nanodo/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Loss functions."""
15 |
16 | # pylint: disable=invalid-name,g-import-not-at-top,g-bare-generic
17 |
18 | from typing import Any, Callable, TYPE_CHECKING
19 |
20 | from flax.struct import dataclass
21 | import jax
22 | import jax.numpy as jnp
23 | from nanodo import data
24 | from optax import losses
25 |
26 | if TYPE_CHECKING:
27 | import ml_collections
28 |
29 |
30 | PyTree = Any
31 |
32 |
33 | @dataclass
34 | class LossAuxData:
35 | ntokens: jax.Array
36 | state: PyTree
37 | log_perplexity: jax.Array
38 |
39 | # loss(params) function to be used in `jax.value_and_grad`.
40 | LossFn = Callable[[PyTree], tuple[jax.Array, LossAuxData]]
41 |
42 | LossFnFactory = Callable[
43 | [jax.Array, Callable, "ml_collections.ConfigDict"],
44 | LossFn,
45 | ]
46 |
47 |
48 | def get_default_loss_fn(
49 | in_BxL: jax.Array,
50 | apply_fn: Callable,
51 | c: "ml_collections.ConfigDict",
52 | ) -> LossFn:
53 | """Standard next-token-prediction language modeling loss."""
54 | def loss_fn(params: PyTree) -> tuple[jax.Array, LossAuxData]:
55 | x_BxL, y_BxL, weights_BxL = data.get_in_out(in_BxL)
56 |
57 | mutable = (
58 | "intermediate_acts",) if c.get("log_internal_metrics", False) else ()
59 | logits_BxLxV, state = apply_fn(
60 | {"params": params},
61 | x_BxL,
62 | mutable=mutable,
63 | )
64 |
65 | losses_BxL = losses.softmax_cross_entropy_with_integer_labels(
66 | logits_BxLxV, y_BxL
67 | )
68 | ntokens = weights_BxL.sum()
69 | mean_loss = jnp.sum(losses_BxL * weights_BxL) / ntokens
70 | return mean_loss, LossAuxData(
71 | ntokens=ntokens, state=state, log_perplexity=mean_loss)
72 |
73 | return loss_fn
74 |
--------------------------------------------------------------------------------
/nanodo/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Main file for running the Language Modelling example with nanodo.
15 |
16 | This file is intentionally kept short. The majority for logic is in libraries
17 | that can be easily tested and imported in Colab.
18 | """
19 |
20 | from absl import app
21 | from absl import flags
22 | from absl import logging
23 | from clu import platform
24 | import jax
25 | from ml_collections import config_flags
26 | from nanodo import train
27 |
28 | FLAGS = flags.FLAGS
29 |
30 | config_flags.DEFINE_config_file(
31 | 'config',
32 | 'configs/default.py',
33 | 'File path to the training hyperparameter configuration.',
34 | lock_config=True,
35 | )
36 | flags.mark_flags_as_required(['config'])
37 |
38 |
39 | def main(argv):
40 | if len(argv) > 1:
41 | raise app.UsageError('Too many command-line arguments.')
42 |
43 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
44 | logging.info('JAX local devices: %r', jax.local_devices())
45 |
46 | # Add a note so that we can tell which task is which JAX host.
47 | # (Depending on the platform task 0 is not guaranteed to be host 0)
48 | platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '
49 | f'process_count: {jax.process_count()}')
50 | train.train_and_evaluate(FLAGS.config)
51 |
52 |
53 | if __name__ == '__main__':
54 | jax.config.config_with_absl()
55 | app.run(main)
56 |
--------------------------------------------------------------------------------
/nanodo/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Computing metrics tracked during training and evaluation."""
15 |
16 | # pylint: disable=invalid-name,g-importing-member,g-import-not-at-top
17 |
18 | from typing import Any, TYPE_CHECKING
19 |
20 | from flax.struct import dataclass
21 | import jax
22 | import jax.numpy as jnp
23 | from nanodo import optimizer
24 | import numpy as np
25 |
26 |
27 | if TYPE_CHECKING:
28 | from flax.training.train_state import TrainState
29 | import ml_collections
30 | from nanodo import loss as loss_lib
31 |
32 |
33 | PyTree = Any
34 |
35 |
36 | def get_init_metrics(
37 | state: "TrainState",
38 | ) -> dict[str, float | int]:
39 | """Compute metrics only at init, as they are constant throughout training."""
40 | metrics = {}
41 |
42 | n_params_all = _size(state.params)
43 |
44 | n_params_embedding = 0
45 | if "embed" in state.params:
46 | n_params_embedding = _size(state.params["embed"])
47 |
48 | if "pos_embed" in state.params:
49 | n_params_embedding += _size(state.params["pos_embed"])
50 | n_params_non_embedding = n_params_all - n_params_embedding
51 |
52 | metrics |= {
53 | "n_params/all": n_params_all,
54 | "n_params/embedding": n_params_embedding,
55 | "n_params/non_embedding": n_params_non_embedding,
56 | }
57 | metrics |= _counts_from_tree(state.params)
58 |
59 | if "head" in state.params:
60 | n_params_head = _size(state.params["head"])
61 | n_params_non_embedding_head = (
62 | n_params_all - n_params_embedding - n_params_head
63 | )
64 | metrics |= {
65 | "n_params/head": n_params_head,
66 | "n_params/non_embedding_head": n_params_non_embedding_head,
67 | }
68 | return metrics
69 |
70 |
71 | def get_metrics(
72 | aux_data: "loss_lib.LossAuxData",
73 | c: "ml_collections.ConfigDict",
74 | loss: float,
75 | state: "TrainState",
76 | grads: PyTree,
77 | updates: PyTree,
78 | ) -> dict[str, float | jax.Array]:
79 | """Compute metrics tracked at every training step."""
80 | # Access final gradient through opt_state.acc_grad
81 | step = state.opt_state.gradient_step # pytype: disable=attribute-error
82 | acc_grads = state.opt_state.acc_grads # pytype: disable=attribute-error
83 | # Use Welford algorithm for numerically stable aggregation of mean.
84 | # TODO: Consider computing Welford var/std as accumulated stats.
85 | acc_grads = jax.tree.map(
86 | lambda acc_grads, grads: acc_grads
87 | + (grads - acc_grads) / (state.opt_state.mini_step + 1), # pytype: disable=attribute-error
88 | acc_grads,
89 | grads,
90 | )
91 |
92 | lr = optimizer.get_learning_rate_schedule(c.opt)(step)
93 | # Normalized update scale (w/o global learning rate factor).
94 | updates = jax.tree.map(lambda x: x / (lr + 1e-20), updates)
95 | metrics = {
96 | "__train_loss": loss, # dummy scalar to be first alphabetically in XM.
97 | "train_loss": loss,
98 | "log_perplexity": aux_data.log_perplexity,
99 | "train_ntokens": aux_data.ntokens,
100 | "learning_rate": jnp.array(lr),
101 | "train_fraction": step / c.opt.num_train_steps,
102 | "train_tokens_seen": aux_data.ntokens * step,
103 |
104 | **_global_stats_from_tree("grads/all/", acc_grads),
105 | **_global_stats_from_tree("params/all/", state.params),
106 | **_global_stats_from_tree("updates/all/", updates),
107 | }
108 | if c.get("log_internal_metrics", False):
109 | metrics |= {
110 | **_stats_from_state(aux_data.state),
111 | **_stats_from_tree("grads/", acc_grads),
112 | **_stats_from_tree("params/", state.params),
113 | **_stats_from_tree("updates/", updates),
114 | }
115 | return metrics
116 |
117 |
118 | def aggregate_microbatch_metrics(
119 | microbatch_metrics: list[dict[str, int | float | jax.Array]],
120 | ) -> dict[str, int | float | jax.Array]:
121 | """Accumulate train metrics weighted by `train_ntokens`.
122 |
123 | Accumulates train metrics with micro-batching logic. The logic assumes the
124 | default metrics are averaging metrics. `train_ntokens` is the only summed
125 | metrics and metrics including norm-based metrics are correctly computed
126 | after actual updates.
127 |
128 | Args:
129 | microbatch_metrics: a list of metric dictionaries, one for each microbatch.
130 |
131 | Returns:
132 | a single metric dictionary for the entire batch.
133 | """
134 | def _is_non_accumulating_metric(k):
135 | return (
136 | k.startswith("grads/") or
137 | k.startswith("params/") or
138 | k.startswith("updates/")
139 | )
140 |
141 | # Accumulate
142 | metrics = {}
143 | for m in microbatch_metrics:
144 | train_ntokens = float(m["train_ntokens"])
145 | for k, v in m.items():
146 | multiplier = train_ntokens if k != "train_ntokens" else 1.0
147 | if _is_non_accumulating_metric(k):
148 | metrics[k] = v
149 | elif k in metrics:
150 | metrics[k] += multiplier * v
151 | else:
152 | metrics[k] = multiplier * v
153 |
154 | # Normalize
155 | train_ntokens = metrics["train_ntokens"]
156 | for k, v in metrics.items():
157 | if _is_non_accumulating_metric(k):
158 | continue
159 | elif k != "train_ntokens":
160 | metrics[k] = v / train_ntokens
161 |
162 | # Perplexity is exponential of average, so compute after accumulation.
163 | metrics["train_perplexity"] = np.minimum(
164 | np.exp(metrics["log_perplexity"]),
165 | 1.0e4,
166 | )
167 | return metrics
168 |
169 |
170 | def _stats_from_state(state: dict[str, dict[str, float]]) -> dict[str, float]:
171 | """Convert the intermediates returned by the model into dict."""
172 | stats = {}
173 | for k, v in state.items():
174 | stats |= _tree_to_dict(k + "/", v)
175 | return stats
176 |
177 |
178 | def _stats_from_tree(prefix: str, g: PyTree) -> dict[str, float]:
179 | return _tree_to_dict(prefix, jax.tree.map(_get_stats, g))
180 |
181 |
182 | def _global_stats_from_tree(prefix: str, g: PyTree) -> dict[str, float]:
183 | return _tree_to_dict(prefix, _get_stats(g))
184 |
185 |
186 | def _welford_mean(g: PyTree) -> float:
187 |
188 | def step(mean_and_size, x):
189 | mean, size = mean_and_size
190 | new_size = size + x.size
191 | new_mean = mean * (size / new_size) + jnp.sum(x) / new_size
192 | return new_mean, new_size
193 |
194 | mean, _ = jax.tree.reduce(step, g, (0., 0.))
195 | return mean
196 |
197 |
198 | def _get_stats(g: PyTree) -> dict[str, float]:
199 | mean = _welford_mean(g)
200 | ms = _welford_mean(jax.tree.map(jnp.square, g))
201 | stats = {
202 | "rms": jnp.sqrt(ms),
203 | "std": jnp.sqrt(jnp.maximum(ms - mean**2, 0.)),
204 | "mean": mean,
205 | }
206 | stats: dict[str, float]
207 | return stats
208 |
209 |
210 | def _counts_from_tree(g: PyTree) -> dict[str, int]:
211 | g = jax.tree.map(jnp.size, g)
212 | return _tree_to_dict("n_params/", g)
213 |
214 |
215 | def _tree_to_dict(prefix: str, g: PyTree) -> dict[str, Any]:
216 | return {prefix + "_".join(z.key for z in k if hasattr(z, "key")): v
217 | for k, v in jax.tree_util.tree_leaves_with_path(g)}
218 |
219 |
220 | def _size(g: PyTree) -> int:
221 | return jax.tree_util.tree_reduce(lambda x, y: x + jnp.size(y), g, 0)
222 |
223 |
224 | # A dataclass version of the welford metric.
225 | #
226 | # Computes a running mean and standard deviation for a set of measurements.
227 | #
228 | # For more details see:
229 | #
230 | # https://www.johndcook.com/blog/standard_deviation/
231 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
232 | # Chan, Tony F.; Golub, Gene H.; LeVeque, Randall J. (1983). "Algorithms for
233 | # computing the sample variance: Analysis and recommendations" (PDF). The
234 | # American Statistician. 37 (3): 242–247. doi:10.1080/00031305.1983.10483115.
235 | # JSTOR 2683386. Archived (PDF) from the original on 9 October 2022.
236 | # Schubert, Erich; Gertz, Michael (9 July 2018). Numerically stable parallel
237 | # computation of (co-)variance. ACM. p. 10. doi:10.1145/3221269.3223036. ISBN
238 | # 9781450365055. S2CID 49665540.
239 | #
240 | # In particular, what is implemented here is a version of the parallel algorithm
241 | # from Chan et al. This should be more numerically stable than the naive
242 | # sum of squares minus square of sum method (which loses a lot of precision).
243 | #
244 | # As an example of the usage:
245 | #
246 | # average = Average()
247 | # for x in values:
248 | # update = Average.from_array(x)
249 | # average = average.merge(update)
250 | # print(average)
251 |
252 |
253 | @dataclass
254 | class Average:
255 | """Computes a running mean and standard deviation from a set of measurements.
256 |
257 | Assumes the resulting value is a scalar but will count all values
258 | fed in, so will average across all dimensions by default.
259 | """
260 |
261 | count: int = 0
262 | mean: float = 0
263 | m2: float = 0
264 | variance: float = 0
265 | sem: float = 0
266 |
267 | @classmethod
268 | def from_array(
269 | cls,
270 | x: np.ndarray | jax.Array,
271 | mask: np.ndarray | jax.Array | None = None,
272 | ) -> "Average":
273 | """Compute the mean/std of a numpy array.
274 |
275 | Args:
276 | x: array of values.
277 | mask: optional mask.
278 |
279 | Returns:
280 | An `Average` instance from the array values.
281 | """
282 | if mask is None:
283 | count = x.size
284 | else:
285 | nnz = np.count_nonzero if isinstance(x, np.ndarray) else jnp.count_nonzero
286 | count = nnz(mask)
287 |
288 | total = x.sum(where=mask)
289 | mean = total / count
290 | delta2 = (x - mean)**2
291 | m2 = delta2.sum(where=mask)
292 | variance = m2 / count
293 | sem = (variance / count)**0.5
294 | return Average(count=count, mean=mean, m2=m2, variance=variance, sem=sem)
295 |
296 | def merge(self, other: "Average") -> "Average":
297 | """Compute the average statistics given two averages.
298 |
299 | Args:
300 | other: `Average` statistics of another collection.
301 |
302 | Returns:
303 | `Average` of `self` and `other`.
304 | """
305 | count = other.count + self.count
306 |
307 | if count == 0:
308 | return self
309 |
310 | delta = other.mean - self.mean
311 | # TODO: in cases where na ~ nb >> 1, instead use
312 | # mean = (self.count * self.mean + other.count * other.mean) / count
313 | mean = self.mean + delta * other.count / count
314 | m2 = self.m2 + other.m2 + delta * delta * self.count * other.count / count
315 | variance = m2 / count
316 | sem = (variance / count)**0.5
317 | return Average(count=count, mean=mean, m2=m2, variance=variance, sem=sem)
318 |
--------------------------------------------------------------------------------
/nanodo/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Transformer Decoder-only model."""
15 |
16 | # pylint: disable=g-importing-member
17 | # pylint: disable=invalid-name
18 |
19 | import dataclasses
20 | from functools import partial
21 |
22 | from flax import linen as nn
23 | import jax
24 | import jax.numpy as jnp
25 |
26 | from nanodo import fsdp
27 |
28 |
29 | @dataclasses.dataclass
30 | class DoConfig:
31 | """Hyper-parameters for Transformer decoder-only."""
32 | D: int # model/embed dim = qkv dim
33 | H: int # num attention heads
34 | L: int # max context/sequence length (move out of config?)
35 | N: int # number of transformer block layers
36 | V: int # vocab size
37 | F: int # FF inner dimension
38 | kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform()
39 | embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling(
40 | 1.0, 'fan_in', 'normal', out_axis=0)
41 | dtype: jnp.dtype = jnp.float32
42 | fsdp_enabled: bool = True
43 |
44 | # Transformer block rematerialization / gradient checkpointing to save memory.
45 | remat: bool = False
46 |
47 |
48 | class TransformerDo(nn.Module):
49 | """Transformer decoder-only."""
50 | docfg: DoConfig
51 |
52 | def setup(self):
53 | cfg = self.docfg
54 | self.embed = nn.Embed(
55 | num_embeddings=cfg.V,
56 | features=cfg.D,
57 | embedding_init=fsdp.init('embedding', cfg),
58 | )
59 | self.pos_embed = nn.Embed(
60 | num_embeddings=cfg.L,
61 | features=cfg.D,
62 | embedding_init=fsdp.init('embedding', cfg),
63 | )
64 |
65 | block = nn.remat(TBlock) if cfg.remat else TBlock
66 | self.blocks = [block(cfg) for _ in range(cfg.N)]
67 | self.out_ln = nn.LayerNorm(dtype=cfg.dtype, use_bias=False)
68 |
69 | def __call__(self, y_BxL: jax.Array):
70 | # For training on concatenated examples.
71 | y_BxLxD = self.embed(y_BxL)
72 | y_BxLxD += self.pos_embed(jnp.arange(0, y_BxL.shape[1])[None, ...])
73 | for block in self.blocks:
74 | y_BxLxD = block(y_BxLxD)
75 | y_BxLxD = self.out_ln(y_BxLxD)
76 | logits_BxLxV = self.embed.attend(y_BxLxD.astype(jnp.float32))
77 | return logits_BxLxV
78 |
79 |
80 | class Mlp(nn.Module):
81 | """Multilayer perceptron."""
82 | cfg: DoConfig
83 |
84 | @nn.compact
85 | def __call__(self, x_BxLxD: jax.Array):
86 | cfg = self.cfg
87 | linear = partial(
88 | nn.Dense, kernel_init=fsdp.init('mlp_kernel', cfg), use_bias=False,
89 | dtype=cfg.dtype
90 | )
91 | x_BxLxF = linear(cfg.F)(x_BxLxD)
92 | x_BxLxF = jax.nn.gelu(x_BxLxF)
93 | x_BxLxD = linear(cfg.D)(x_BxLxF)
94 | return x_BxLxD
95 |
96 |
97 | class TBlock(nn.Module):
98 | """Transformer Block."""
99 | docfg: DoConfig
100 |
101 | @nn.compact
102 | def __call__(self, in_BxLxD: jax.Array):
103 | cfg = self.docfg
104 |
105 | # "pre-layernorm"
106 | x_BxLxD = nn.LayerNorm(dtype=cfg.dtype, use_bias=False)(in_BxLxD)
107 | x_BxLxD = CausalAttn(cfg)(x_BxLxD)
108 | x_BxLxD += in_BxLxD
109 |
110 | z_BxLxD = nn.LayerNorm(dtype=cfg.dtype, use_bias=False)(x_BxLxD)
111 | z_BxLxD = Mlp(cfg)(z_BxLxD)
112 |
113 | return x_BxLxD + z_BxLxD
114 |
115 |
116 | class CausalAttn(nn.Module):
117 | """Causal attention layer."""
118 | cfg: DoConfig
119 |
120 | @nn.compact
121 | def __call__(self, x_BxLxD: jax.Array):
122 | cfg = self.cfg
123 |
124 | assert cfg.D % cfg.H == 0, f'D {cfg.D} not divisible by H {cfg.H}'
125 | Dh = cfg.D // cfg.H
126 |
127 | # Maps D -> (H, Dh)
128 | multilinear = partial(
129 | nn.DenseGeneral,
130 | axis=-1,
131 | features=(cfg.H, Dh),
132 | kernel_init=fsdp.init('attn_in_proj', cfg),
133 | use_bias=False,
134 | dtype=cfg.dtype,
135 | )
136 |
137 | q_BxLxHxDh, k_BxLxHxDh, v_BxLxHxDh = (
138 | multilinear(name='query')(x_BxLxD),
139 | multilinear(name='key')(x_BxLxD),
140 | multilinear(name='value')(x_BxLxD),
141 | )
142 | q_BxLxHxDh /= Dh**0.5
143 | att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh)
144 | # cast to fp32 for softmax
145 | att_BxHxLxL = att_BxHxLxL.astype(jnp.float32)
146 |
147 | # causal attention mask
148 | L = x_BxLxD.shape[1]
149 | mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_))
150 |
151 | _NEG_INF = jnp.finfo(cfg.dtype).min
152 | att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF)
153 | att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1)
154 | att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype)
155 | out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh)
156 | # Output projection followed by contraction back to original dims
157 | out_BxLxD = nn.DenseGeneral(
158 | features=cfg.D,
159 | name='attn_out_proj',
160 | axis=(-2, -1),
161 | kernel_init=fsdp.init('attn_out_proj', cfg),
162 | use_bias=False,
163 | dtype=cfg.dtype,
164 | )(out_BxLxHxDh)
165 | return out_BxLxD
166 |
--------------------------------------------------------------------------------
/nanodo/model_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Factory for producing experimental models."""
15 |
16 | # pylint: disable=invalid-name,g-import-not-at-top,unused-import
17 |
18 | from typing import TYPE_CHECKING
19 |
20 | from flax import linen as nn
21 | from nanodo import loss as loss_lib
22 | from nanodo import model
23 |
24 | if TYPE_CHECKING:
25 | import ml_collections
26 |
27 |
28 | def get_model_and_loss(
29 | c: "ml_collections.ConfigDict",
30 | vocab_size: int,
31 | ) -> tuple[nn.Module, loss_lib.LossFnFactory]:
32 | """Returns an instantiated (potentially experimental) model."""
33 |
34 | # default model and configs
35 | m = model
36 | get_loss_fn = loss_lib.get_default_loss_fn
37 |
38 | cfg = m.DoConfig(**c.model, V=vocab_size) # pytype:disable=attribute-error
39 | module = m.TransformerDo(cfg) # pytype:disable=attribute-error
40 | return module, get_loss_fn
41 |
--------------------------------------------------------------------------------
/nanodo/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Optimizer."""
15 |
16 | # pylint: disable=g-import-not-at-top
17 |
18 | import functools
19 | from typing import Iterable, TYPE_CHECKING
20 |
21 | import jax
22 | import optax
23 |
24 | if TYPE_CHECKING:
25 | import ml_collections
26 |
27 |
28 | def get_optimizer(c: "ml_collections.ConfigDict") -> optax.MultiSteps:
29 | """Get optimizer."""
30 | optimizer = _get_base_optimizer(c)
31 |
32 | if c.get("layerwise_lr_multiplier", None) is not None:
33 | scale_dict = dict(c.layerwise_lr_multiplier)
34 | optimizer = optax.chain(optimizer, _scale_by_dict(scale_dict))
35 |
36 | clip_by_global_norm = c.get("clip_by_global_norm", None)
37 | if clip_by_global_norm:
38 | optimizer = optax.chain(
39 | optax.clip_by_global_norm(clip_by_global_norm), optimizer)
40 |
41 | # Multistep gradient accumulation
42 | optimizer = optax.MultiSteps(optimizer, c.get("grad_accumulation_steps", 1))
43 |
44 | return optimizer
45 |
46 |
47 | def get_learning_rate_schedule(
48 | c: "ml_collections.ConfigDict",
49 | ) -> optax.Schedule:
50 | """Creates a learning rate schedule based on the config."""
51 |
52 | schedules = [
53 | optax.linear_schedule(
54 | init_value=c.init_learning_rate,
55 | end_value=c.peak_learning_rate,
56 | transition_steps=c.warmup_steps,
57 | )
58 | ]
59 |
60 | decay_type = c.get("decay_type", "cosine")
61 |
62 | if decay_type == "rsqrt":
63 | schedules.append(
64 | _rsqrt_schedule(
65 | init_value=c.peak_learning_rate,
66 | shift=1 + c.warmup_steps,
67 | )
68 | )
69 |
70 | elif decay_type == "cosine":
71 | decay_steps = c.get("decay_steps", c.num_train_steps - c.warmup_steps)
72 | schedules.append(
73 | optax.cosine_decay_schedule(
74 | init_value=c.peak_learning_rate,
75 | decay_steps=decay_steps,
76 | alpha=c.final_learning_rate / c.peak_learning_rate,
77 | exponent=1.0,
78 | )
79 | )
80 |
81 | elif decay_type == "linear":
82 | schedules.append(
83 | optax.linear_schedule(
84 | init_value=c.peak_learning_rate,
85 | end_value=c.final_learning_rate,
86 | transition_steps=c.num_train_steps - c.warmup_steps,
87 | )
88 | )
89 |
90 | elif decay_type == "constant_without_warmup":
91 | return optax.constant_schedule(value=c.peak_learning_rate)
92 |
93 | elif decay_type == "constant":
94 | schedules.append(optax.constant_schedule(value=c.peak_learning_rate))
95 |
96 | elif decay_type.startswith("constant_linear_decay_"):
97 | if decay_type.endswith("p"):
98 | percent_decay = float(decay_type.split("_")[-1].split("p")[0]) / 100
99 | if percent_decay < 0 or percent_decay > 1:
100 | raise ValueError(f"Invalid decay % provided in {decay_type}")
101 | transition_steps = int(c.num_train_steps * percent_decay)
102 | else:
103 | decay_steps = int(decay_type.split("_")[-1])
104 | if decay_steps < 0 or decay_steps > c.num_train_steps:
105 | raise ValueError(f"Invalid decay steps provided in {decay_type}")
106 | transition_steps = decay_steps
107 | schedules += [
108 | optax.constant_schedule(value=c.peak_learning_rate),
109 | optax.linear_schedule(
110 | init_value=c.peak_learning_rate,
111 | end_value=c.final_learning_rate,
112 | transition_steps=transition_steps,
113 | )
114 | ]
115 | return optax.join_schedules(schedules, boundaries=[
116 | c.warmup_steps, c.num_train_steps - transition_steps])
117 |
118 | else:
119 | raise NotImplementedError(f"Unsupported decay type: {c.decay_type}")
120 |
121 | return optax.join_schedules(schedules, boundaries=[c.warmup_steps])
122 |
123 |
124 | def _rsqrt_schedule(*, init_value: float, shift: int) -> optax.Schedule:
125 | """Constructs a schedule with reciprocal sqrt decay."""
126 |
127 | def schedule(count):
128 | return init_value * (count + shift) ** -0.5 * shift**0.5
129 |
130 | return schedule
131 |
132 |
133 | def _params_mask(
134 | params: optax.Params, exclude_names: Iterable[str] = ("bias", "scale")
135 | ) -> optax.Params:
136 | """Generate boolean mask for params PyTree with `exclude_names` parameters."""
137 | def _check_key_contain_exclude_names(key_path):
138 | return any([
139 | x in "/".join([k.key for k in key_path if hasattr(k, "key")])
140 | for x in exclude_names
141 | ])
142 |
143 | # Mask should return True for parameters that does not match patterns inside
144 | # `exclude_names`.
145 | return jax.tree_util.tree_map_with_path(
146 | lambda key_path, _: not _check_key_contain_exclude_names(key_path), params
147 | )
148 |
149 |
150 | def _get_base_optimizer(
151 | c: "ml_collections.ConfigDict",
152 | ) -> optax.GradientTransformation:
153 | """Get base optimizer."""
154 | learning_rate_fn = get_learning_rate_schedule(c)
155 | optimizer_type = c.optimizer
156 | weight_decay_exclusion_names = c.get("weight_decay_exclusion_names", [])
157 | if c.get("independent_weight_decay", False):
158 | weight_decay = c.weight_decay / c.peak_learning_rate
159 | else:
160 | weight_decay = c.weight_decay
161 |
162 | if optimizer_type == "adafactor":
163 | base_optimizer = optax.adafactor(
164 | learning_rate_fn,
165 | multiply_by_parameter_scale=c.get(
166 | "multiply_by_parameter_scale", True),
167 | decay_rate=c.get("decay_rate", 0.8),
168 | momentum=c.get("momentum", None),
169 | factored=c.get("factored", True),
170 | eps=c.get("eps", 1e-30),
171 | weight_decay_rate=c.weight_decay,
172 | weight_decay_mask=functools.partial(
173 | _params_mask, exclude_names=weight_decay_exclusion_names))
174 |
175 | elif optimizer_type == "adamw":
176 | base_optimizer = optax.adamw(
177 | learning_rate_fn,
178 | b1=c.get("b1", 0.9),
179 | b2=c.get("b2", 0.98),
180 | eps=c.get("eps", 1e-9),
181 | weight_decay=weight_decay,
182 | mask=functools.partial(
183 | _params_mask, exclude_names=weight_decay_exclusion_names),
184 | )
185 |
186 | elif optimizer_type == "lion":
187 | base_optimizer = optax.lion(
188 | learning_rate_fn,
189 | b1=c.get("b1", 0.9),
190 | b2=c.get("b2", 0.98),
191 | weight_decay=weight_decay,
192 | mask=functools.partial(
193 | _params_mask, exclude_names=weight_decay_exclusion_names),
194 | )
195 |
196 | else:
197 | raise ValueError(optimizer_type)
198 |
199 | return base_optimizer
200 |
201 |
202 | def _scale_by_dict(
203 | scale_dict: dict[str, float]) -> optax.GradientTransformation:
204 | """Optax transform for performing layerwise learning rate rescaling.
205 |
206 | Args:
207 | scale_dict: a dictionary that determines which parameters to apply
208 | learning rate rescaling, e.g., {"kernel": 3.} means using a 3X learning rate
209 | for all parameters whose name contain "kernel".
210 |
211 | Returns:
212 | An Optax transform suitable for chaining (should be applied after the
213 | optimizer).
214 | """
215 |
216 | def init_fn(_):
217 | return optax.EmptyState()
218 |
219 | def update_fn(updates, state, params=None):
220 | del params
221 |
222 | def scale(keys, x):
223 | # Convert to str "module_name_1/module_name_2/.../kernel"
224 | str_keys = "/".join([k.key for k in keys if hasattr(k, "key")])
225 | for which_to_rescale, multiplier in scale_dict.items():
226 | if which_to_rescale in str_keys:
227 | return x * multiplier
228 | return x
229 |
230 | updates = jax.tree_util.tree_map_with_path(scale, updates)
231 | return updates, state
232 |
233 | return optax.GradientTransformation(init_fn, update_fn)
234 |
--------------------------------------------------------------------------------
/nanodo/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Training loop."""
15 |
16 | # pylint: disable=invalid-name,g-importing-member,g-import-not-at-top,unused-import
17 |
18 | import functools
19 | import time
20 | from typing import Any, Iterator, TYPE_CHECKING
21 |
22 | from absl import logging
23 | from clu import metric_writers
24 | from clu import periodic_actions
25 | from flax import linen as nn
26 | from flax.training.train_state import TrainState
27 | import grain.python as grain
28 | import jax
29 | from jax.experimental import mesh_utils
30 | import jax.numpy as jnp
31 | from jax.sharding import Mesh
32 | from jax.sharding import NamedSharding
33 | from jax.sharding import PartitionSpec as P
34 | from nanodo import data
35 | from nanodo import evaluate
36 | from nanodo import loss as loss_lib
37 | from nanodo import metrics as metrics_lib
38 | from nanodo import model_factory
39 | from nanodo import optimizer
40 | import numpy as np
41 | import optax
42 | import orbax.checkpoint as ocp
43 |
44 | import os
45 |
46 |
47 | if TYPE_CHECKING:
48 | import ml_collections # pylint: disable=g-bad-import-order
49 |
50 |
51 | PyTree = Any
52 |
53 |
54 | def train_and_evaluate(c: "ml_collections.ConfigDict"):
55 | """Train loop."""
56 |
57 | mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), ("data",))
58 | # For multistep gradient accumulator to simulate large batch sizes.
59 | grad_accumulation_steps = c.get("grad_accumulation_steps", 1)
60 | micro_batch_size, r = divmod(c.batch_size, grad_accumulation_steps)
61 | if grad_accumulation_steps > 1:
62 | logging.info("Gradient accumulation steps: %d", grad_accumulation_steps)
63 | logging.info(
64 | "Using total batch size = %d, micro batch size = %d",
65 | c.batch_size, micro_batch_size
66 | )
67 | if r:
68 | raise ValueError(
69 | "Batch size must be divisible by the gradient accumulation steps."
70 | )
71 | if micro_batch_size % jax.device_count() != 0:
72 | raise ValueError("Batch size must be divisible by the number of devices.")
73 |
74 | os.makedirs(c.workdir, exist_ok=True)
75 | rng = jax.random.PRNGKey(c.seed)
76 |
77 | tokenizer = data.get_py_tokenizer(c.vocab_path)
78 | vocab_size = tokenizer.GetPieceSize()
79 |
80 | model, get_loss_fn = model_factory.get_model_and_loss(c, vocab_size)
81 |
82 | tic = time.time()
83 | shardings, state = _init_train_state(c, model, rng=rng, mesh=mesh)
84 | init_time = time.time() - tic
85 | logging.info("[TIMING]: get_new_state (jit init) time: %.2fs", init_time)
86 |
87 | train_ds = data.py_batched_tfds(
88 | tfds_name=c.ds_name,
89 | split="train",
90 | context_size=c.model.L,
91 | worker_count=c.pygrain_worker_count,
92 | vocab_path=c.vocab_path,
93 | batch_size=micro_batch_size,
94 | num_epochs=c.train_epochs,
95 | preprocessing=data.Preprocess.NOAM_PACKED,
96 | worker_buffer_size=c.pygrain_worker_buffer_size,
97 | )
98 | train_iter = iter(train_ds)
99 |
100 | if c.checkpoint:
101 | ckpt_mngr = _get_ckpt_manager(c.workdir, c)
102 | if c.checkpoint_restore_dir is not None:
103 | logging.info("Restoring checkpoint from %s", c.checkpoint_restore_dir)
104 | ex_ckpt_mngr = _get_ckpt_manager(c.checkpoint_restore_dir, c)
105 | state, train_iter = _restore_ckpt(ex_ckpt_mngr, state, train_iter)
106 |
107 | elif ckpt_mngr.latest_step() is not None:
108 | latest_step = ckpt_mngr.latest_step()
109 | logging.info("Restoring checkpoint %d from %s", latest_step, c.workdir)
110 | state, train_iter = _restore_ckpt(ckpt_mngr, state, train_iter)
111 |
112 | trainer = Trainer(
113 | c=c,
114 | state=state,
115 | mesh=mesh,
116 | shardings=shardings,
117 | get_loss_fn=get_loss_fn,
118 | )
119 |
120 | # We may evaluate on larger context length than training to measure length
121 | # generalization.
122 | if c.model.L < c.eval_max_target_length:
123 | logging.warning(
124 | "L (context length) %d is smaller than eval_max_target_length %d",
125 | c.model.L,
126 | c.eval_max_target_length,
127 | )
128 | eval_batch_size = c.get("eval_batch_size", micro_batch_size)
129 | if eval_batch_size % jax.device_count() != 0:
130 | raise ValueError(
131 | "Eval Batch size must be divisible by the number of devices.")
132 |
133 | eval_ds = data.py_batched_tfds(
134 | tfds_name=c.ds_name,
135 | split=c.eval_split,
136 | context_size=c.model.L,
137 | worker_count=c.pygrain_worker_count,
138 | vocab_path=c.vocab_path,
139 | batch_size=eval_batch_size,
140 | num_epochs=1,
141 | num_records=None,
142 | preprocessing=data.Preprocess.PADDED,
143 | shuffle=False,
144 | )
145 | evaluator = evaluate.Evaluator(c, model, eval_ds, mesh, shardings)
146 |
147 | writer = metric_writers.create_default_writer(
148 | c.workdir,
149 | just_logging=jax.process_index() > 0,
150 | )
151 | if trainer.step == 0:
152 | writer.write_hparams(dict(c))
153 | writer.write_scalars(trainer.step, {"jit_compilation_time": init_time})
154 |
155 | report_progress = periodic_actions.ReportProgress(
156 | num_train_steps=c.opt.num_train_steps,
157 | writer=writer,
158 | every_steps=c.write_perf_metrics_every_steps,
159 | every_secs=None,
160 | )
161 |
162 | if jax.process_index() == 0:
163 | hooks = [
164 | report_progress,
165 | periodic_actions.Profile(logdir=c.workdir, num_profile_steps=5),
166 | ]
167 | else:
168 | hooks = []
169 |
170 | with metric_writers.ensure_flushes(writer):
171 | def _eval():
172 | with report_progress.timed("eval"):
173 | step = trainer.step
174 | eval_metrics = evaluator.eval(trainer.state.params)
175 | writer.write_scalars(step, eval_metrics)
176 |
177 | def _checkpoint():
178 | if c.checkpoint:
179 | step = trainer.step
180 | logging.info("Saving last checkpoint step %d", step)
181 | ckpt_mngr.save(step, {"state": trainer.state, "data": train_iter}) # pylint: disable=undefined-variable
182 |
183 | def _process_metrics(step, microbatch_metrics):
184 | if microbatch_metrics and step % c.write_train_metrics_every_steps == 0:
185 | microbatch_metrics = [trainer.get_metrics(step, m)
186 | for m in microbatch_metrics]
187 | metrics = metrics_lib.aggregate_microbatch_metrics(microbatch_metrics)
188 | writer.write_scalars(step, metrics)
189 | # Simple check for NaN/Inf for early termination.
190 | loss = metrics["train_loss"]
191 | if np.isnan(loss) or np.isinf(loss):
192 | # Terminate training. The next step has already been dispatched.
193 | logging.error(
194 | "[TRAINING ERROR] Nan/Inf encountered in training loop.\n "
195 | "Terminating training loop at step: %d", step + 1
196 | )
197 | _eval()
198 | raise FloatingPointError(step + 1, loss)
199 |
200 | pending_microbatch_metrics = []
201 | for step in range(trainer.step, c.opt.num_train_steps + 1):
202 | is_final_step = step == c.opt.num_train_steps
203 | if step % c.eval_every_steps == 0 or is_final_step:
204 | _eval()
205 | if step % c.checkpoint_every_steps == 0 or is_final_step:
206 | _checkpoint()
207 |
208 | for h in hooks:
209 | h(step)
210 |
211 | # Schedule this step's tasks.
212 | # Initialize metrics for microbatch accumulation.
213 | new_microbatch_metrics = []
214 | for _ in range(grad_accumulation_steps):
215 | try:
216 | in_BxL = next(train_iter)
217 | except StopIteration:
218 | logging.warning("Ran out of data at step %d. Stopping.", step)
219 | break
220 | # Async dispatch next step.
221 | new_microbatch_metrics.append(trainer.do_step(step, in_BxL))
222 |
223 | # Download to host and process the previous step's metrics after having
224 | # asynchronously dispatched the new step.
225 | _process_metrics(step - 1, pending_microbatch_metrics)
226 | pending_microbatch_metrics = new_microbatch_metrics
227 | logging.log_first_n(
228 | logging.INFO, "Finished training step %d.", 5, step - 1)
229 | # Download to host and process the final step's metrics.
230 | _process_metrics(c.opt.num_train_steps, pending_microbatch_metrics)
231 |
232 | if c.checkpoint:
233 | ckpt_mngr.close() # pylint: disable=undefined-variable
234 |
235 |
236 | class Trainer:
237 | """Executes training step."""
238 |
239 | def __init__(
240 | self,
241 | c: "ml_collections.ConfigDict",
242 | state: TrainState,
243 | mesh: Mesh,
244 | shardings: PyTree,
245 | get_loss_fn: loss_lib.LossFnFactory = loss_lib.get_default_loss_fn,
246 | ):
247 | self.state = state
248 | self.init_metrics = None
249 |
250 | # In the jit call below, in_shardings and out_shardings specify the
251 | # shardings of the input and output of the jitted function.
252 | # There is just as many in_shardings as input arguments, and likewise for
253 | # outputs. "shardings" is the shardings of the state, P("data") denotes
254 | # that the argument is split along the data axis (in this case the
255 | # input data), and P() denotes that the result is replicated on each
256 | # device (in this case the train metrics).
257 | self.step_fn = jax.jit(
258 | functools.partial(
259 | _train_step,
260 | c=c,
261 | get_loss_fn=get_loss_fn,
262 | mesh=mesh,
263 | ),
264 | in_shardings=(
265 | shardings,
266 | NamedSharding(mesh, P()),
267 | ),
268 | out_shardings=(shardings, NamedSharding(mesh, P())),
269 | donate_argnames=("state", "in_BxL"),
270 | )
271 |
272 | @property
273 | def step(self) -> int:
274 | return int(self.state.step)
275 |
276 | def get_metrics(
277 | self, step: int, metrics: dict[str, float]
278 | ) -> dict[str, float]:
279 | # Grab the (possibly previous step's) metrics from device.
280 | metrics = jax.device_get(metrics)
281 | if step == 0:
282 | metrics |= self.init_metrics
283 | return metrics
284 |
285 | def do_step(self, step: int, in_BxL: jax.Array) -> dict[str, float]:
286 | """Async dispatch one training step and return metrics."""
287 | # Note that the device may be busy with the previous step.
288 | # Avoid calling self.step as that would block until the device is ready.
289 | if step == 0 or self.init_metrics is None:
290 | self.init_metrics = metrics_lib.get_init_metrics(self.state)
291 |
292 | self.state, metrics = self.step_fn(self.state, in_BxL)
293 | return metrics
294 |
295 |
296 | def _train_step(
297 | state: TrainState,
298 | in_BxL: jax.Array,
299 | c: "ml_collections.ConfigDict",
300 | get_loss_fn: loss_lib.LossFnFactory = loss_lib.get_default_loss_fn,
301 | mesh: Mesh | None = None,
302 | ) -> tuple[TrainState, dict[str, float | jax.Array]]:
303 | """One forward/backward pass."""
304 | if mesh is not None:
305 | in_BxL = jax.lax.with_sharding_constraint(
306 | in_BxL, NamedSharding(mesh, P("data"))
307 | )
308 | grad_fn = jax.value_and_grad(
309 | get_loss_fn(in_BxL, state.apply_fn, c), has_aux=True
310 | )
311 | (loss, aux_data), grads = grad_fn(state.params)
312 |
313 | # Access to optax updates.
314 | updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
315 | new_params = optax.apply_updates(state.params, updates)
316 | new_state = state.replace(
317 | # Keep gradient_step as Trainer's step.
318 | step=state.opt_state.gradient_step + 1, # pytype: disable=attribute-error
319 | params=new_params,
320 | opt_state=new_opt_state,
321 | )
322 |
323 | metrics = metrics_lib.get_metrics(aux_data, c, loss, state, grads, updates)
324 | return new_state, metrics
325 |
326 |
327 | def _init_train_state(
328 | c: "ml_collections.ConfigDict",
329 | module: nn.Module,
330 | rng: jax.Array,
331 | mesh: Mesh,
332 | ) -> tuple[PyTree, TrainState]:
333 | """Creates a sharding and model state."""
334 | inputs = jax.ShapeDtypeStruct(shape=(1, c.model.L), dtype=jnp.int32)
335 |
336 | def init(rng, inputs):
337 | params = module.init(rng, inputs)
338 | return TrainState.create(
339 | apply_fn=module.apply,
340 | params=params["params"],
341 | tx=optimizer.get_optimizer(c.opt),
342 | )
343 |
344 | params = jax.eval_shape(init, rng, inputs)
345 | shardings = nn.get_sharding(params, mesh)
346 | state = jax.jit(init, out_shardings=shardings)(rng, inputs)
347 | return shardings, state
348 |
349 |
350 | def _get_ckpt_manager(
351 | ckpt_dir: str,
352 | c: "ml_collections.ConfigDict",
353 | ) -> ocp.CheckpointManager:
354 | options = ocp.CheckpointManagerOptions(max_to_keep=c.max_to_keep)
355 | checkpointers = dict(
356 | state=ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler()),
357 | data=ocp.Checkpointer(grain.PyGrainCheckpointHandler()), # pytype:disable=wrong-arg-types
358 | )
359 | return ocp.CheckpointManager(ckpt_dir, checkpointers, options)
360 |
361 |
362 | def _restore_ckpt(
363 | ckpt_mngr: ocp.CheckpointManager,
364 | state: TrainState,
365 | train_iter: Iterator[jax.Array],
366 | step: int | None = None,
367 | ) -> tuple[TrainState, Iterator[jax.Array]]:
368 | """Restore a checkpoint."""
369 | restore_args = ocp.checkpoint_utils.construct_restore_args(state)
370 | restore_kwargs = {"state": {"restore_args": restore_args}}
371 | restored = ckpt_mngr.restore(
372 | ckpt_mngr.latest_step() if step is None else step,
373 | items={"state": state, "data": train_iter},
374 | restore_kwargs=restore_kwargs,
375 | )
376 | return restored["state"], restored["data"]
377 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["flit_core>=3.9.0"]
3 | build-backend = "flit_core.buildapi"
4 |
5 | [project]
6 | name = "nanodo"
7 | dynamic = ["version"]
8 | description = "A minimal ('nano') Transformer decoder-only ('do') library in JAX."
9 | readme = "README.md"
10 | license = { file = "LICENSE" }
11 | requires-python = ">=3.10"
12 | authors = [
13 | {name = "Google DeepMind", email = "nanodo-team@google.com"},
14 | ]
15 | keywords = [
16 | "python",
17 | "machine learning",
18 | "llm",
19 | "jax",
20 | "flax",
21 | "decoder-only",
22 | "large language model",
23 | "language modelling",
24 | "artificial intelligence",
25 | ]
26 | classifiers = [
27 | "Environment :: Console",
28 | "Development Status :: 4 - Beta",
29 | "License :: OSI Approved :: Apache Software License",
30 | # copybara:strip_begin(internal)
31 | # TODO: add support for python 3.12 after
32 | # https://github.com/tensorflow/datasets/issues/4666
33 | # https://github.com/google/array_record/issues/94
34 | # are fixed.
35 | # copybara:strip_end
36 | 'Programming Language :: Python :: 3.10',
37 | 'Programming Language :: Python :: 3.11',
38 | 'Operating System :: MacOS',
39 | 'Operating System :: POSIX :: Linux',
40 | 'Topic :: Software Development',
41 | 'Topic :: Software Development :: Libraries',
42 | 'Topic :: Software Development :: Libraries :: Python Modules',
43 | 'Topic :: Scientific/Engineering',
44 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
45 | 'Topic :: Scientific/Engineering :: Mathematics',
46 | 'Intended Audience :: Science/Research',
47 | 'Intended Audience :: Developers',
48 | 'Intended Audience :: Education',
49 | ]
50 |
51 | dependencies = [
52 | "absl-py>=2.1.0",
53 | "clu>=0.0.12",
54 | "flax>=0.8.2",
55 | "grain>=0.1.0",
56 | "jax>=0.4.26",
57 | "jaxlib>=0.4.26",
58 | "ml-collections>=0.1.1",
59 | "numpy>=1.26.0",
60 | "optax>=0.2.2",
61 | "orbax>=0.1.7",
62 | "sentencepiece>=0.2.0",
63 | "tensorflow_datasets>=4.9.5",
64 | "tensorflow>=2.16.1",
65 | ]
66 |
67 | [project.urls]
68 | homepage = "https://github.com/google-deepmind/nanodo"
69 | repository = "https://github.com/google-deepmind/nanodo"
70 | # documentation = "https://nanodo.readthedocs.io/"
71 |
72 | [project.optional-dependencies]
73 | test = [
74 | "chex>=0.1.86",
75 | ]
76 |
77 | [tool.setuptools.packages.find]
78 | where = ["nanodo"]
79 | include = ["README.md", "LICENSE"]
80 | exclude = ["*_test.py"]
81 |
--------------------------------------------------------------------------------
/tests/data_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for `../data.py`."""
15 |
16 | # pylint: disable=invalid-name
17 |
18 | import os
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | import chex
23 | import grain.python as grain
24 | import jax
25 | import jax.numpy as jnp
26 | from nanodo import data
27 | import numpy as np
28 | import tensorflow_datasets as tfds
29 |
30 |
31 | jax.config.parse_flags_with_absl()
32 | jax.config.update("jax_numpy_rank_promotion", "raise")
33 |
34 |
35 | def _get_vocab_path():
36 | return os.path.join(
37 | os.path.dirname(__file__),
38 | "testdata/sentencepiece_cc_all.32000.100extra-sentencepiece.model",
39 | )
40 |
41 |
42 | def _assert_grain_records(records: list[grain.Record], expected: np.ndarray):
43 | actual = [r.data for r in records]
44 | np.testing.assert_equal(actual, expected)
45 |
46 |
47 | class DataTest(parameterized.TestCase):
48 |
49 | def test_py_batched_tfds(self):
50 | num_examples = 100
51 | with tfds.testing.mock_data(num_examples=num_examples):
52 | context_size = 512
53 | batch_size = 2
54 | ds = data.py_batched_tfds(
55 | tfds_name="lm1b",
56 | split="train",
57 | context_size=context_size,
58 | worker_count=0,
59 | vocab_path=_get_vocab_path(),
60 | batch_size=batch_size,
61 | )
62 | self.assertEqual((batch_size, context_size), next(iter(ds)).shape)
63 |
64 | def test_py_noam_pack(self):
65 | records = [[2, 3, 4, 1], [5, 6, 7, 8, 9, 10, 11, 1]]
66 | pyg_records = [
67 | grain.Record(metadata=grain.RecordMetadata(index=i), data=records[i])
68 | for i in range(len(records))
69 | ]
70 | npack = data._NoamPack(4)
71 | _assert_grain_records(
72 | list(npack(iter(pyg_records))),
73 | np.array([[2, 3, 4, 1], [5, 6, 7, 8], [9, 10, 11, 1]]),
74 | )
75 |
76 | def test_py_batched_tfds_noam_packed(self):
77 | with tfds.testing.mock_data():
78 | ds = data.py_batched_tfds(
79 | tfds_name="lm1b",
80 | split="train",
81 | context_size=1024,
82 | batch_size=2,
83 | worker_count=0,
84 | vocab_path=_get_vocab_path(),
85 | num_records=10,
86 | preprocessing=data.Preprocess.NOAM_PACKED,
87 | )
88 | it = iter(ds)
89 | b = next(it)
90 | self.assertEqual(b.shape, (2, 1024))
91 | self.assertEqual(np.sum(b == data.PAD_ID), 0)
92 | b = next(it)
93 | self.assertEqual(b.shape, (2, 1024))
94 | self.assertEqual(np.sum(b == data.PAD_ID), 0)
95 |
96 | def test_py_batched_tfds_padded(self):
97 | with tfds.testing.mock_data():
98 | ds = data.py_batched_tfds(
99 | tfds_name="lm1b",
100 | split="train",
101 | context_size=1024,
102 | batch_size=2,
103 | worker_count=0,
104 | vocab_path=_get_vocab_path(),
105 | num_records=10,
106 | preprocessing=data.Preprocess.PADDED,
107 | )
108 | it = iter(ds)
109 | b = next(it)
110 | self.assertEqual(b.shape, (2, 1024))
111 | self.assertGreater(np.sum(b == data.PAD_ID), 0) # sanity check
112 |
113 | def test_py_tokenize(self):
114 | tok = data._SPTokenizer(_get_vocab_path())
115 | ids = data._py_tokenize({"text": "some text"}, spt=tok)
116 | self.assertNotEmpty(ids)
117 | ids = data._py_tokenize(
118 | {"text": "some text"}, spt=tok, pad_len=128, pad_id=0
119 | )
120 | self.assertLen(ids, 128)
121 |
122 | def test_get_in_out(self):
123 | rng = jax.random.PRNGKey(42)
124 | length = 256
125 | batch_size = 8
126 | in_BxL = jax.random.randint(
127 | rng, shape=(batch_size, length), minval=1, maxval=256
128 | )
129 | x_BxL, y_BxL, weights_BxL = data.get_in_out(in_BxL)
130 | self.assertEqual(x_BxL.shape, in_BxL.shape)
131 | self.assertEqual(y_BxL.shape, in_BxL.shape)
132 | self.assertEqual(weights_BxL.shape, in_BxL.shape)
133 | chex.assert_trees_all_equal(x_BxL, in_BxL)
134 | chex.assert_trees_all_equal(y_BxL[:, : length - 1], in_BxL[:, 1:length])
135 | chex.assert_trees_all_equal(
136 | y_BxL[:, length - 1],
137 | jnp.ones_like(y_BxL[:, length - 1]) * data.PAD_ID,
138 | )
139 | chex.assert_trees_all_equal(
140 | weights_BxL[:, : length - 1],
141 | jnp.ones_like(weights_BxL[:, : length - 1]),
142 | )
143 | chex.assert_trees_all_equal(
144 | weights_BxL[:, length - 1], jnp.zeros_like(weights_BxL[:, length - 1])
145 | )
146 |
147 |
148 | if __name__ == "__main__":
149 | absltest.main()
150 |
--------------------------------------------------------------------------------
/tests/evaluate_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for `../evaluate.py`."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import jax
19 | import jax.numpy as jnp
20 | from nanodo import evaluate
21 | from nanodo import metrics as metrics_lib
22 | from nanodo import model
23 |
24 |
25 | jax.config.parse_flags_with_absl()
26 | jax.config.update("jax_numpy_rank_promotion", "raise")
27 |
28 |
29 | class EvalTest(parameterized.TestCase):
30 |
31 | def test_eval_step(self):
32 | docfg = model.DoConfig(D=128, H=16, L=256, N=4, V=1024, F=4 * 4)
33 | m = model.TransformerDo(docfg)
34 | rng = jax.random.PRNGKey(42)
35 | _, init_rng = jax.random.split(rng)
36 | input_shape = (2, 256)
37 | x = jnp.ones(input_shape, dtype=jnp.int32)
38 | initial_variables = jax.jit(m.init)(init_rng, x)
39 | metrics = metrics_lib.Average()
40 | for _ in range(3):
41 | step_metrics = evaluate._eval_step(initial_variables["params"], x, m)
42 | metrics = metrics.merge(step_metrics)
43 |
44 | self.assertGreater(metrics.mean, 0)
45 | self.assertGreater(metrics.sem, 0)
46 | self.assertGreater(metrics.variance, 0)
47 | self.assertGreater(metrics.count, 0)
48 |
49 |
50 | if __name__ == "__main__":
51 | absltest.main()
52 |
--------------------------------------------------------------------------------
/tests/metrics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for `../metrics.py`."""
15 |
16 | # pylint: disable=invalid-name,g-importing-member,g-import-not-at-top
17 |
18 | from typing import TYPE_CHECKING
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | import chex
23 | from flax.training.train_state import TrainState
24 | import jax
25 | from jax import random
26 | import jax.numpy as jnp
27 | from nanodo import metrics as metrics_lib
28 | from nanodo import model
29 | from nanodo import optimizer as opt
30 | from nanodo import train
31 | from nanodo.configs import default
32 |
33 | if TYPE_CHECKING:
34 | import ml_collections
35 |
36 |
37 | jax.config.parse_flags_with_absl()
38 | jax.config.update("jax_numpy_rank_promotion", "raise")
39 |
40 |
41 | def _get_config() -> "ml_collections.ConfigDict":
42 | """Get the default hyperparameter configuration."""
43 | c = default.get_config()
44 |
45 | c.batch_size = 2
46 | c.eval_steps = 1
47 | c.V = 32
48 |
49 | c.model.L = 256
50 | c.model.D = 32
51 | c.model.F = 128
52 | c.model.N = 2
53 | c.model.H = 4
54 |
55 | return c
56 |
57 |
58 | class MetricsTest(parameterized.TestCase):
59 |
60 | def test_welford_mean_large_array(self):
61 | if jax.default_backend() != "gpu":
62 | self.skipTest("Not enough RAM on TPU/CPU to generate a contiguous array.")
63 |
64 | dtype = jnp.bfloat16
65 | ref_mean = 5.
66 | x = random.normal(random.PRNGKey(1), (2**10, 2**(31 - 10)), dtype)
67 | x += ref_mean
68 |
69 | # Array size > int32 limit.
70 | self.assertGreater(x.size, jnp.iinfo(jnp.int32).max)
71 |
72 | # Mean matches the reference.
73 | mean = metrics_lib._welford_mean(x)
74 | self.assertEqual(dtype, jnp.dtype(mean))
75 | self.assertEqual(ref_mean, mean)
76 |
77 | def test_welford_mean_large_pytree(self):
78 | if jax.default_backend() == "cpu":
79 | self.skipTest("Test too slow on CPU.")
80 |
81 | dtype = jnp.bfloat16
82 | n = 2**4
83 | ref_means = range(n)
84 | keys = random.split(random.PRNGKey(1), n)
85 | x = [
86 | ref_mean + random.normal(key, (2**10, 2**(31 - 10 - 4)), dtype)
87 | for ref_mean, key in zip(ref_means, keys)
88 | ]
89 |
90 | # Total tree size > int32 limit.
91 | self.assertGreater(metrics_lib._size(x), jnp.iinfo(jnp.int32).max)
92 |
93 | # Mean matches the reference.
94 | mean = metrics_lib._welford_mean(x)
95 | self.assertEqual(dtype, jnp.dtype(mean))
96 | self.assertEqual(sum(ref_means) / n, mean)
97 |
98 | def test_aggregate_microbatch_metrics(self):
99 | c = _get_config()
100 | docfg = model.DoConfig(**c.model, V=c.V)
101 | m = model.TransformerDo(docfg)
102 | init_rng = jax.random.PRNGKey(42)
103 | in_BxL = jax.random.categorical(init_rng, jnp.ones((16, c.model.L, c.V)))
104 |
105 | initial_variables = jax.jit(m.init)(
106 | init_rng,
107 | in_BxL,
108 | )
109 |
110 | state_single = TrainState.create(
111 | apply_fn=m.apply, params=initial_variables["params"],
112 | tx=opt.get_optimizer(c.opt),
113 | )
114 |
115 | state_single, metrics_single = train._train_step(state_single, in_BxL, c)
116 | metrics_single = metrics_lib.aggregate_microbatch_metrics([metrics_single])
117 |
118 | grad_accumulation_steps = 4
119 | c.opt.grad_accumulation_steps = grad_accumulation_steps
120 |
121 | state_multistep = TrainState.create(
122 | apply_fn=m.apply, params=initial_variables["params"],
123 | tx=opt.get_optimizer(c.opt),
124 | )
125 |
126 | microbatch_train_metrics = []
127 | for sub_in_BxL in jnp.array_split(in_BxL, grad_accumulation_steps, axis=0):
128 | state_multistep, metrics = train._train_step(
129 | state_multistep, sub_in_BxL, c)
130 | microbatch_train_metrics.append(metrics)
131 | metrics_multistep = metrics_lib.aggregate_microbatch_metrics(
132 | microbatch_train_metrics)
133 |
134 | self.assertEqual(state_single.step, state_multistep.step)
135 | # Check metrics agreement
136 | chex.assert_trees_all_close(
137 | metrics_single, metrics_multistep, rtol=1e-2, atol=1e-1)
138 | # Check updated params agreement
139 | chex.assert_trees_all_close(
140 | state_single.params, state_multistep.params, rtol=1e-2, atol=1e-1)
141 | # Check optimizer state agreement
142 | chex.assert_trees_all_close(
143 | state_single.opt_state, state_multistep.opt_state, rtol=1e-2, atol=1e-1)
144 |
145 | def test_gaussian(self):
146 | rng = jax.random.PRNGKey(0)
147 | data = jax.random.normal(rng, (100,))
148 | average = None
149 |
150 | for x in data:
151 | update = metrics_lib.Average.from_array(x)
152 | average = update if average is None else average.merge(update)
153 |
154 | self.assertIsNotNone(average)
155 |
156 | self.assertAlmostEqual(
157 | average.mean,
158 | 0.0,
159 | delta=3 * average.sem,
160 | )
161 |
162 | full_average = metrics_lib.Average.from_array(data)
163 |
164 | self.assertAlmostEqual(
165 | full_average.mean,
166 | 0.0,
167 | delta=3 * full_average.sem,
168 | )
169 |
170 | # agreement
171 | self.assertAlmostEqual(
172 | average.mean,
173 | full_average.mean,
174 | delta=(average.sem ** 2 + full_average.sem ** 2) ** 0.5,
175 | )
176 |
177 |
178 | if __name__ == "__main__":
179 | absltest.main()
180 |
--------------------------------------------------------------------------------
/tests/model_factory_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Test experimental models and model factory."""
15 |
16 | # pylint: disable=invalid-name
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import chex
21 | import flax
22 | import jax
23 | import jax.numpy as jnp
24 | from nanodo import model as default_model
25 | from nanodo import model_factory
26 | from nanodo.configs import default
27 |
28 |
29 | jax.config.parse_flags_with_absl()
30 | jax.config.update('jax_numpy_rank_promotion', 'raise')
31 |
32 |
33 | class ModelTest(parameterized.TestCase):
34 |
35 | def _default_output(self, rng):
36 | """Set up an example input, output, params and config."""
37 | B, L = (2, 128)
38 | # default model
39 | cfg = default_model.DoConfig(D=16, H=4, L=L, N=4, V=256, F=4 * 4)
40 | m = default_model.TransformerDo(cfg)
41 | rng, spl = jax.random.split(rng)
42 | x_BxL = jax.random.randint(
43 | rng, minval=0, maxval=cfg.V, dtype=jnp.int32, shape=(B, L)
44 | )
45 | params = m.init(spl, x_BxL)
46 | default_model_out = m.apply(params, x_BxL)
47 |
48 | c = default.get_config()
49 | c.model.D = cfg.D
50 | c.model.H = cfg.H
51 | c.model.L = cfg.L
52 | c.model.N = cfg.N
53 | c.V = cfg.V
54 | c.model.F = cfg.F
55 | c.model.dtype = 'float32'
56 |
57 | return default_model_out, params, x_BxL, c
58 |
59 | def test_default_model(self):
60 | rng = jax.random.PRNGKey(42)
61 | default_model_out, params, x_BxL, c = self._default_output(rng)
62 | m, _ = model_factory.get_model_and_loss(c, c.V)
63 | self.assertTrue(jnp.allclose(m.apply(params, x_BxL), default_model_out))
64 |
65 |
66 | if __name__ == '__main__':
67 | absltest.main()
68 |
--------------------------------------------------------------------------------
/tests/model_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests `../model.py`."""
15 |
16 | # pylint: disable=invalid-name
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import chex
21 | import jax
22 | import jax.numpy as jnp
23 | from nanodo import model
24 | from optax import losses
25 |
26 |
27 | jax.config.parse_flags_with_absl()
28 | jax.config.update("jax_numpy_rank_promotion", "raise")
29 |
30 |
31 | class ModelTest(chex.TestCase):
32 |
33 | @chex.variants(with_jit=True, without_jit=True)
34 | def test_full_model(self):
35 | B, L = (2, 128)
36 | cfg = model.DoConfig(D=16, H=4, L=L, N=4, V=256, F=4 * 4)
37 |
38 | k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)
39 | x_BxL = jax.random.randint(k1, (B, L), 0, cfg.V, jnp.int32)
40 | m = model.TransformerDo(cfg)
41 | params = m.init(k2, x_BxL)
42 | y_BxLxV = self.variant(m.apply)(params, x_BxL)
43 |
44 | chex.assert_tree_all_finite(y_BxLxV)
45 | chex.assert_shape(y_BxLxV, (B, L, cfg.V))
46 |
47 | @chex.variants(with_jit=True, without_jit=True)
48 | @parameterized.named_parameters(
49 | ["beginning", 0],
50 | ["near-beginning", 5],
51 | ["near-end", 200],
52 | )
53 | def test_causality(self, token_loc: int):
54 | """Tests model's prediction is causal.
55 |
56 | Ensures the TransformerDo block maintains causality for autoregressive
57 | modeling. Causality is asserted by ensuring the gradients of the learned
58 | positional encoding satisfies three criteria.
59 |
60 | If the loss at token offset X is being examined:
61 | 1. The gradients, w.r.t. the loss of the positional encoding _before_ this
62 | token should be non-zero. This ensures we are looking at previous tokens
63 | and their positions to infer this token.
64 | 2. The gradients at X should be non-zero: for next-token prediction, we
65 | slide our window one step at a time. Thus, the input at X is predicting
66 | the output at X.
67 | 3. The gradients after X should be strictly zero: the output at X should not
68 | use information from the future to predict itself.
69 |
70 | The positional encoding, since it is directly learned, maintains
71 | independence cross-token which makes its gradients reliable proxies for
72 | changes made on the token level.
73 |
74 | Args:
75 | token_loc: location of the current 'token' under investigation.
76 | """
77 | cfg = model.DoConfig(D=16, H=4, L=256, N=4, V=512, F=16)
78 | m = model.TransformerDo(cfg)
79 |
80 | def loss(params, x: chex.Array, y: chex.Array) -> chex.Array:
81 | # Simple cross-entropy loss function. This loss function emulates the loss
82 | # used in nanodo/train.py.
83 | loss = losses.softmax_cross_entropy_with_integer_labels(
84 | m.apply(params, x),
85 | y,
86 | )
87 | # Jax only computes the derivatives of scalar-valued functions.
88 | return loss[0][token_loc]
89 |
90 | x_1xL = jnp.arange(cfg.L)[None, :]
91 | params = m.init(jax.random.PRNGKey(42), x_1xL)
92 | grads = self.variant(jax.grad(loss))(params, x=x_1xL, y=x_1xL + 1)
93 | # The learned positional embedding is token-wise independent. Taking the
94 | # gradient, with respect to the loss, gives us a proxy for perturbing
95 | # a single token.
96 | pos_grads = grads.get("params").get("pos_embed").get("embedding").value
97 |
98 | # Ensure the computation succeeded.
99 | chex.assert_tree_all_finite(pos_grads)
100 |
101 | # Before the current token, the gradient should be non-zero somewhere.
102 | # If the current token is zero, this is a noop.
103 | if token_loc:
104 | chex.assert_scalar_in(
105 | float(jnp.sum(jnp.square(pos_grads[0:token_loc]))),
106 | 1e-2,
107 | 1000,
108 | )
109 |
110 | # At the current token, the gradient should be non-zero somewhere.
111 | chex.assert_scalar_in(
112 | float(jnp.sum(jnp.square(pos_grads[token_loc]))),
113 | 1e-2,
114 | 1000,
115 | )
116 |
117 | # After the current token, the gradient of the loss is precisely zero
118 | # everywhere.
119 | after_token = pos_grads[token_loc + 1 :]
120 | chex.assert_trees_all_close(
121 | after_token,
122 | jnp.zeros_like(after_token),
123 | atol=1e-5,
124 | )
125 |
126 | def test_heads_divides_dimension(self):
127 | cfg = model.DoConfig(D=16, H=3, L=256, N=4, V=256, F=4 * 4)
128 | m = model.TransformerDo(cfg)
129 |
130 | x_BxL = jnp.ones((2, cfg.L), dtype=jnp.int32)
131 | with self.assertRaises(AssertionError):
132 | m.init(jax.random.PRNGKey(42), x_BxL)
133 |
134 | @chex.variants(with_jit=True, without_jit=True)
135 | def test_mlp(self):
136 | B = 3
137 | L = 4
138 | D = 16
139 | dtype = jnp.bfloat16
140 | cfg = model.DoConfig(D=D, H=4, L=L, N=4, V=256, F=4 * 4, dtype=dtype)
141 | m = model.Mlp(cfg)
142 | x_BxLxD = jnp.ones((B, L, D), dtype=dtype)
143 | params = m.init(jax.random.PRNGKey(42), x_BxLxD)
144 | out_BxLxD = self.variant(m.apply)(params, x_BxLxD)
145 |
146 | chex.assert_shape(out_BxLxD, (B, L, D))
147 | chex.assert_type(out_BxLxD, dtype)
148 | chex.assert_tree_all_finite(out_BxLxD)
149 |
150 | @chex.variants(with_jit=True, without_jit=True)
151 | def test_remat_forward(self):
152 | B, L = (2, 128)
153 | cfg_base = model.DoConfig(D=16, H=4, L=L, N=4, V=256, F=4 * 4, remat=False)
154 | cfg_remat = model.DoConfig(D=16, H=4, L=L, N=4, V=256, F=4 * 4, remat=True)
155 |
156 | k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)
157 | x_BxL = jax.random.randint(k1, (B, L), 0, cfg_base.V, jnp.int32)
158 | m = model.TransformerDo(cfg_base)
159 | m_remat = model.TransformerDo(cfg_remat)
160 | params = m.init(k2, x_BxL)
161 | params_remat = m_remat.init(k2, x_BxL)
162 |
163 | chex.assert_trees_all_equal(params, params_remat)
164 |
165 | y_BxLxV = self.variant(m.apply)(params, x_BxL)
166 | y_remat_BxLxV = self.variant(m_remat.apply)(params_remat, x_BxL)
167 |
168 | chex.assert_tree_all_finite(y_BxLxV)
169 | chex.assert_shape(y_BxLxV, (B, L, cfg_base.V))
170 | error = jnp.linalg.norm(y_BxLxV - y_remat_BxLxV) / y_BxLxV.size
171 | self.assertLess(error, 1e-8)
172 |
173 |
174 | if __name__ == "__main__":
175 | absltest.main()
176 |
--------------------------------------------------------------------------------
/tests/optimizer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for `../optimizer.py`."""
15 |
16 | # pylint: disable=invalid-name,g-importing-member
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | import jax.numpy as jnp
22 | import ml_collections
23 | from nanodo import model
24 | from nanodo import optimizer
25 |
26 |
27 | jax.config.parse_flags_with_absl()
28 | jax.config.update("jax_numpy_rank_promotion", "raise")
29 |
30 |
31 | def _get_test_opt_config() -> ml_collections.ConfigDict:
32 | c = ml_collections.config_dict.create(
33 | num_train_steps=10_000,
34 | peak_learning_rate=0.01,
35 | init_learning_rate=0.001,
36 | final_learning_rate=0.0001,
37 | warmup_steps=10,
38 | decay_steps=100,
39 | weight_decay=0.1,
40 | )
41 | return c
42 |
43 |
44 | class OptimizerTest(parameterized.TestCase):
45 |
46 | @parameterized.parameters("cosine", "rsqrt")
47 | def test_create_lr(self, decay_type: str):
48 | c = _get_test_opt_config()
49 | c.decay_type = decay_type
50 | lr_fn = optimizer.get_learning_rate_schedule(c)
51 |
52 | self.assertGreater(lr_fn(0), 0)
53 | self.assertLess(jnp.abs(lr_fn(0) - c.init_learning_rate), 1e-9)
54 | self.assertGreater(lr_fn(1), lr_fn(0))
55 | self.assertEqual(lr_fn(c.warmup_steps), c.peak_learning_rate)
56 | if decay_type == "rsqrt":
57 | self.assertEqual(
58 | lr_fn(c.warmup_steps + 1),
59 | optimizer._rsqrt_schedule(
60 | init_value=lr_fn(c.warmup_steps), shift=1 + c.warmup_steps
61 | )(1),
62 | )
63 | else:
64 | self.assertEqual(lr_fn(c.num_train_steps), c.final_learning_rate)
65 |
66 | def test_create_lr_no_warmup(self):
67 | c = _get_test_opt_config()
68 | c.warmup_steps = 0
69 | lr_fn = optimizer.get_learning_rate_schedule(c)
70 | self.assertGreater(lr_fn(0), 0)
71 | self.assertLess(jnp.abs(lr_fn(0) - c.peak_learning_rate), 1e-9)
72 | self.assertGreater(lr_fn(1), 0)
73 | self.assertGreater(lr_fn(0), lr_fn(1))
74 | self.assertEqual(lr_fn(c.num_train_steps), c.final_learning_rate)
75 |
76 | def test_scale_by_dict(self):
77 | docfg = model.DoConfig(
78 | D=128, H=16, L=256, N=4, V=1024, F=4 * 4, fsdp_enabled=False)
79 | m = model.TransformerDo(docfg)
80 | init_rng = jax.random.PRNGKey(42)
81 | in_BxL = jnp.ones((2, 256), dtype=jnp.int32)
82 | initial_variables = jax.jit(m.init)(
83 | init_rng,
84 | in_BxL,
85 | )
86 | multiplier = 10
87 | residual = 1. - multiplier
88 | opt = optimizer._scale_by_dict({"kernel": multiplier})
89 | params = jax.tree_util.tree_map(jnp.zeros_like, initial_variables)
90 | grads = jax.tree_util.tree_map(jnp.ones_like, initial_variables)
91 | opt_state = opt.init(params)
92 | updates, _ = opt.update(grads, opt_state)
93 | delta = jax.tree_util.tree_map(
94 | lambda u, v: u - multiplier * v, updates, grads)
95 |
96 | def _assert_close(x, scalar=0.):
97 | error = jnp.linalg.norm(x - scalar) / x.size
98 | self.assertLess(error, 1e-8)
99 |
100 | for i in range(docfg.N):
101 | for name in ["key", "value", "query", "attn_out_proj"]:
102 | x = delta["params"][f"blocks_{i}"]["CausalAttn_0"][name]["kernel"]
103 | _assert_close(x)
104 | for name in ["Dense_0", "Dense_1"]:
105 | x = delta["params"][f"blocks_{i}"]["Mlp_0"][name]["kernel"]
106 | _assert_close(x)
107 | for name in ["LayerNorm_0", "LayerNorm_1"]:
108 | x = delta["params"][f"blocks_{i}"][name]["scale"]
109 | _assert_close(x, scalar=residual)
110 | _assert_close(delta["params"]["embed"]["embedding"], scalar=residual)
111 | _assert_close(delta["params"]["pos_embed"]["embedding"], scalar=residual)
112 | _assert_close(delta["params"]["out_ln"]["scale"], scalar=residual)
113 |
114 |
115 | if __name__ == "__main__":
116 | absltest.main()
117 |
--------------------------------------------------------------------------------
/tests/testdata/sentencepiece_cc_all.32000.100extra-sentencepiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/nanodo/10aefdeed40a63293daf112b91a5538cd24fa3a4/tests/testdata/sentencepiece_cc_all.32000.100extra-sentencepiece.model
--------------------------------------------------------------------------------
/tests/train_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for `../train.py`."""
15 |
16 | # pylint: disable=invalid-name,g-importing-member,g-import-not-at-top
17 |
18 | import os
19 |
20 | from typing import TYPE_CHECKING
21 |
22 | from absl import logging
23 | from absl.testing import parameterized
24 | import chex
25 | from flax.training.train_state import TrainState
26 | import jax
27 | from jax.experimental import mesh_utils
28 | import jax.numpy as jnp
29 | from jax.sharding import Mesh
30 | from nanodo import data
31 | from nanodo import model
32 | from nanodo import optimizer as opt
33 | from nanodo import train
34 | from nanodo.configs import default
35 | import tensorflow_datasets as tfds
36 |
37 | from absl.testing import absltest
38 |
39 | if TYPE_CHECKING:
40 | import ml_collections
41 |
42 |
43 | jax.config.parse_flags_with_absl()
44 | jax.config.update("jax_numpy_rank_promotion", "raise")
45 |
46 |
47 | _VOCAB_PATH = "testdata/sentencepiece_cc_all.32000.100extra-sentencepiece.model"
48 |
49 |
50 | def _get_config(self: parameterized.TestCase) -> "ml_collections.ConfigDict":
51 | """Get the default hyperparameter configuration."""
52 | c = default.get_config()
53 | c.vocab_path = os.path.join(os.path.dirname(__file__), _VOCAB_PATH)
54 |
55 | c.opt.peak_learning_rate = 0.01
56 | c.opt.init_learning_rate = 0.001
57 | c.opt.final_learning_rate = 0.0001
58 | c.opt.num_train_steps = 1
59 | c.opt.warmup_steps = 10
60 | c.opt.decay_steps = 100
61 |
62 | c.opt.b1 = 0.9
63 | c.opt.b2 = 0.98
64 | c.opt.eps = 1e-9
65 | c.opt.weight_decay = 0.1
66 |
67 | c.batch_size = 2
68 | c.eval_steps = 1
69 | c.checkpoint_every_steps = 1
70 | c.pygrain_worker_count = 2
71 | c.V = 32
72 |
73 | c.model.L = 64
74 | c.model.D = 32
75 | c.model.F = 128
76 | c.model.N = 2
77 | c.model.H = 4
78 |
79 | c.workdir = self.create_tempdir().full_path
80 | return c
81 |
82 |
83 | class TrainTest(parameterized.TestCase):
84 |
85 | @parameterized.parameters(True, False)
86 | def test_trainer(self, fsdp_enabled: bool = False):
87 | c = _get_config(self)
88 | cfg = model.DoConfig(**c.model, V=c.V)
89 | cfg.fsdp_enabled = fsdp_enabled
90 | m = model.TransformerDo(cfg)
91 | rng = jax.random.PRNGKey(42)
92 | mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), ("data",))
93 | shardings, state = train._init_train_state(c, m, rng, mesh=mesh)
94 | t = train.Trainer(c, state, mesh, shardings)
95 |
96 | self.assertEqual(t.step, 0)
97 |
98 | def test_train_step(self):
99 | c = _get_config(self)
100 | docfg = model.DoConfig(**c.model, V=c.V)
101 | m = model.TransformerDo(docfg)
102 | init_rng, data_rng = jax.random.split(jax.random.PRNGKey(42))
103 | in_BxL = jax.random.randint(
104 | data_rng,
105 | (2, c.model.L),
106 | 0,
107 | c.V,
108 | jnp.int32,
109 | )
110 | params = jax.jit(m.init)(init_rng, in_BxL)
111 | optimizer = opt.get_optimizer(c.opt)
112 | state = TrainState.create(
113 | apply_fn=m.apply, params=params["params"],
114 | tx=optimizer,
115 | )
116 |
117 | self.assertEqual(state.step, 0)
118 | state, metrics = train._train_step(state, in_BxL, c)
119 | self.assertEqual(state.step, 1)
120 |
121 | reference = {
122 | "__train_loss": 3.945808,
123 | "train_loss": 3.945808,
124 | "train_ntokens": 124,
125 |
126 | "grads/all/rms": 0.01341779,
127 | "grads/all/mean": 3.026796e-05,
128 | "grads/all/std": 0.01341776,
129 |
130 | "updates/all/rms": 0.99979043,
131 | "updates/all/mean": -0.00446726,
132 | "updates/all/std": 0.9997804,
133 |
134 | "params/all/rms": 0.16059065,
135 | "params/all/mean": 0.00510256,
136 | "params/all/std": 0.16050959,
137 |
138 | "learning_rate": c.opt.init_learning_rate,
139 |
140 | "train_fraction": 0,
141 | "train_tokens_seen": 0,
142 | }
143 | metrics_subset = {k: v for k, v in metrics.items() if k in reference}
144 | print(metrics_subset)
145 | warning = (
146 | " metric after doing a single gradient step of the default model "
147 | "have changed. If you did not intend to change the model's behavior "
148 | "(e.g. refactoring), this may indicate a bug. If the change is "
149 | "expected (e.g. change in parameterization, default hyperparameters, "
150 | "random seed, renaming or removing metrics, etc.), then please update "
151 | "the `reference` dictionary above with the new expected values."
152 | )
153 |
154 | jax.tree_util.tree_map_with_path(
155 | lambda k, x, y: self.assertAlmostEqual(
156 | x,
157 | y,
158 | places=2,
159 | msg=jax.tree_util.keystr(k) + warning),
160 | reference,
161 | metrics_subset,
162 | )
163 |
164 | @parameterized.parameters(data.Preprocess.NOAM_PACKED, data.Preprocess.PADDED)
165 | def test_train_and_evaluate(self, preprocessing):
166 |
167 | c = _get_config(self)
168 | c.checkpoint = True
169 |
170 | cfg = model.DoConfig(**c.model, V=c.V)
171 | m = model.TransformerDo(cfg)
172 | rng = jax.random.PRNGKey(42)
173 | mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), ("data",))
174 | _, state = train._init_train_state(c, m, rng, mesh)
175 | ckpt_dir = c.workdir
176 | with tfds.testing.mock_data(num_examples=100):
177 | train_ds = data.py_batched_tfds(
178 | tfds_name=c.ds_name,
179 | split="train",
180 | context_size=c.model.L,
181 | worker_count=c.pygrain_worker_count,
182 | vocab_path=c.vocab_path,
183 | batch_size=c.batch_size,
184 | num_epochs=c.train_epochs,
185 | preprocessing=preprocessing,
186 | )
187 | train_iter = iter(train_ds)
188 | train.train_and_evaluate(c)
189 |
190 | ckpt_mngr = train._get_ckpt_manager(ckpt_dir, c)
191 |
192 | self.assertEqual(ckpt_mngr.latest_step(), 1)
193 | restored_state, _ = train._restore_ckpt(ckpt_mngr, state, train_iter)
194 | self.assertEqual(restored_state.step, 1)
195 |
196 | logging.info("Trigger restore, check step is updated.")
197 | c.opt.num_train_steps = 2
198 | train.train_and_evaluate(c)
199 | ckpt_mngr = train._get_ckpt_manager(ckpt_dir, c)
200 | self.assertEqual(ckpt_mngr.latest_step(), 2)
201 | restored_state, _ = train._restore_ckpt(ckpt_mngr, state, train_iter)
202 | self.assertEqual(restored_state.step, 2)
203 |
204 | def test_train_step_remat(self):
205 | c = _get_config(self)
206 |
207 | docfg = model.DoConfig(**c.model, V=c.V)
208 | docfg.remat = False
209 | m = model.TransformerDo(docfg)
210 |
211 | docfg_remat = model.DoConfig(**c.model, V=c.V)
212 | docfg_remat.remat = False
213 | m_remat = model.TransformerDo(docfg_remat)
214 |
215 | init_rng = jax.random.PRNGKey(42)
216 | in_BxL = jax.random.categorical(init_rng, jnp.ones((16, c.model.L, c.V)))
217 | initial_variables = jax.jit(m.init)(
218 | init_rng,
219 | in_BxL,
220 | )
221 | optimizer = opt.get_optimizer(c.opt)
222 | state = TrainState.create(
223 | apply_fn=m.apply, params=initial_variables["params"],
224 | tx=optimizer,
225 | )
226 |
227 | state_remat = TrainState.create(
228 | apply_fn=m_remat.apply, params=initial_variables["params"],
229 | tx=optimizer,
230 | )
231 | new_state, metrics = train._train_step(state, in_BxL, c)
232 | new_state_remat, metrics_remat = train._train_step(state_remat, in_BxL, c)
233 |
234 | # Check metrics agreement
235 | chex.assert_trees_all_close(metrics, metrics_remat, rtol=1e-2, atol=1e-1)
236 | # Check updated params agreement
237 | chex.assert_trees_all_close(
238 | new_state.params, new_state_remat.params, rtol=1e-2, atol=1e-1
239 | )
240 | # Check optimizer state agreement
241 | chex.assert_trees_all_close(
242 | new_state_remat.opt_state,
243 | new_state_remat.opt_state,
244 | rtol=1e-2,
245 | atol=1e-1,
246 | )
247 |
248 |
249 | if __name__ == "__main__":
250 | absltest.main()
251 |
--------------------------------------------------------------------------------