├── .gitignore
├── LICENSE
├── README.md
├── configs
├── base.yaml
├── c4_a100x8_1b.yaml
├── c4_a100x8_270m.yaml
├── c4_a100x8_2b.yaml
├── c4_a100x8_540m.yaml
├── c4_a100x8_84m.yaml
├── c4_a100x8_base.yaml
├── c4_a100x8x4_1b.yaml
├── c4_a100x8x4_2b.yaml
├── c4_a100x8x4_540m.yaml
├── flat_tokens_c4_a100x1_84m.yaml
├── huggingface_c4_a100x1_84m.yaml
└── local_test_synthetic.yaml
├── docs
├── flat-tokens.md
├── matx.svg
└── pytree-zarr-checkpoint.md
├── env.py
├── input_loader.py
├── jax_extra.py
├── requirements-cpu.txt
├── shardlib
├── shardops.py
└── shardtypes.py
├── synthetic_dataset.zarr
├── .zgroup
├── train
│ ├── .zattrs
│ ├── .zgroup
│ ├── encoded_tokens
│ │ ├── 0
│ │ └── .zarray
│ └── seq_starts
│ │ ├── 0
│ │ └── .zarray
└── validation
│ ├── .zattrs
│ ├── .zgroup
│ ├── encoded_tokens
│ ├── 0
│ └── .zarray
│ └── seq_starts
│ ├── 0
│ └── .zarray
├── tools
├── configs
│ ├── c4_en.yaml
│ ├── synthetic_dataset.yaml
│ └── tinystories.yaml
├── flat_tokens.py
├── huggingface_to_flat_tokens.py
├── requirements.txt
└── write_synthetic_dataset.py
├── train.py
└── training_io.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Hydra
2 | outputs/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2024 MatX Inc
2 |
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 |
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 |
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 |
9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10 |
11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # seqax = sequence modeling + JAX
6 |
7 | seqax is a codebase for small-to-medium-scale LLM pretraining research. The entire training program---including the model implementation; optimizer; multihost FSDP and tensor parallel partitioning---is [500 lines of code](/train.py), which scales well up to ~100 GPUs or TPUs[^1] and [typically achieves good MFUs of 30-50%](#performance).
8 |
9 | [^1]: Achieving good performance at larger scale requires pipeline parallelism (which we have not yet implemented). At that scale, you may also care more about using custom kernels to further improve performance at the cost of code simplicity.
10 |
11 | seqax is written in a style that makes the important information visible, rather than being hidden behind abstractions and indirections or being inferred automatically and unpredictably. This shows up in:
12 |
13 | * **Math**. seqax implements all of the training step's math, rather than calling into external libraries. If you want to understand or change the math, it's right there!
14 |
15 | * **Memory**. All tensors that go into a model checkpoint on disk are explicits. All tensors that occupy a lot of memory, including activations saved for the backwards pass, are explicit. You can straightforwardly read the memory footprint from the source code.
16 |
17 | * **Partitioning and communication**. The partitioned layout of all tensors and operations is explicit. All interchip communication is explicit.
18 |
19 | ## Getting started
20 |
21 | ### Installation
22 |
23 | 1. Install `graphviz` from your system package manager: e.g. `brew install graphviz` or `apt install graphviz`.
24 | 2. Install Python dependencies, typically inside a virtualenv: `python -m pip install -r requirements-cpu.txt`.
25 |
26 | NOTE: the `requirements-cpu.txt` is configured for CPU-based installation. For GPU or TPU installation, you may need a different install of JAX and jaxlib. Consult the [JAX install documentation](https://jax.readthedocs.io/en/latest/installation.html). If your GPU environment has a Torch-GPU installation, you may need to switch it to a Torch-CPU installation to avoid conflicts with JAX-GPU.
27 |
28 | ### Run on CPU for local development
29 |
30 | For development and testing you can run on CPU. Typically you'd use our synthetic dataset (which is [checked into this repository](/synthetic_dataset.zarr)) or the [Huggingface data loader](#data-loaders) and you'd set XLA flags to simulate multiple devices so as to test that parallelism is working as intended:
31 |
32 | ```
33 | XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m train --config-name=local_test_synthetic +paths.model_name=synthetic_000
34 | ```
35 |
36 | The `paths.model_name` flag specifies which subdirectory on disk (inside `/tmp`) to write model checkpoints to. You'll typically want to change this when starting a new model run.
37 |
38 | ### Run on GPUs
39 |
40 | We have configured a range of model sizes, to be trained on the C4 dataset with the Llama tokenizer. Browse the `configs/` directory to select your preferred configuration file. Each configuration file lists how to run it at the top.
41 |
42 | You typically want to set `paths.model_name` to a unique name for each distinct training run. This path specifies which subdirectory on disk to write model checkpoints to.
43 |
44 | ## Performance
45 |
46 | Recent benchmark results on A100 clusters:
47 |
48 | Single-host A100x8
49 | | Model Size | MFU |
50 | |------------|-------|
51 | | 84m | 14 |
52 | | 270m | 24 |
53 | | 540m | 35 |
54 | | 1b | 41.6 |
55 | | 2b | 50.66 |
56 |
57 | On 4 A100x8 hosts connected with infiniband
58 | | Model Size | MFU |
59 | |------------|-------|
60 | | 1b | 32.4 |
61 | | 2b | 39.0 |
62 |
63 | ## Data loaders
64 |
65 | seqax can stream training data directly from Huggingface (see [example config](/configs/huggingface_c4_a100x1_84m.yaml)), or can first convert the training data to a pre-tokenized format on disk which we call [flat-tokens](/docs/flat-tokens.md) (see [example config](/configs/flat_tokens_c4_a100x1_84m.yaml)). Streaming from Huggingface allows you to quickly experiment with different datasets, but it doesn't offer an efficient way to resume training from a checkpoint after a job is aborted, and it wastes some tokens from the dataset at batch boundaries. The flat-tokens format supports efficiently resuming training from a checkpoint, uses 100% of tokens for training, and also consumes less CPU time during training.
66 |
67 | To pre-tokenize the training data, you can run [huggingface_to_flat_tokens.py](/tools/huggingface_to_flat_tokens.py). You'll need to first install the requirements in [/tools/requirements.txt](/tools/requirements.txt), and then you can invoke the command listed at the top of [/tools/configs/c4_en.yaml](/tools/configs/c4_en.yaml). On modern CPUs this script processes about 100M tokens per minute. You can limit the number of output tokens it processes with a configuration flag.
68 |
69 | ## Expressing partitioning and communication with `shardlib`
70 |
71 | seqax ships with a new library called [shardlib](/shardlib) for expressing partitioning and communication with JAX, building on the ideas and style of [jaxtyping](https://docs.kidger.site/jaxtyping/), [einops](https://einops.rocks/), [equinox](https://docs.kidger.site/equinox/), and [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Here we demonstrate its core ideas, to implement fully sharded data parallelism (FSDP) for a simple fully connected neural network.
72 |
73 |
74 | ```python
75 | # XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m shardlib_example
76 | from shardlib.shardtypes import bf16, bool_, f32, pytree_dataclass, typed_shard_map, u32, make_shardings
77 | from shardlib import shardtypes
78 | shardtypes.register_with_typeguard()
79 | import shardlib.shardops as shardops
80 | from jax.sharding import Mesh
81 | from jax.experimental import mesh_utils
82 | import jax
83 | import jax.numpy as jnp
84 |
85 | # We set up a device mesh where 'd' refers to the "data parallel" axis.
86 | MESH = Mesh(mesh_utils.create_device_mesh([8], jax.devices()), ('d'))
87 |
88 | # At rest, weights are all sharded over the data parallel axis, making them fully sharded.
89 | #
90 | # The `hidden1/d` syntax means that second axis has size `hidden1` and is sharded over device axis `d`.
91 | # Equivalently, you can view this as saying that the per-device shape is `(in, hidden1/d)`, where `/`
92 | # indicates division.
93 | @pytree_dataclass
94 | class Weights:
95 | w1: f32['in hidden1/d']
96 | w2: f32['hidden1 hidden2/d']
97 | w3: f32['hidden2/d']
98 |
99 | with MESH:
100 | # Create dummy weights.
101 | w = Weights(
102 | w1=jnp.zeros((8, 8), dtype=jnp.float32),
103 | w2=jnp.zeros((8, 8), dtype=jnp.float32),
104 | w3=jnp.zeros((8,), dtype=jnp.float32),
105 | )
106 | # Apply sharding to weights. The sharding specs are inferred from the type annotations on the Weights class.
107 | w = jax.tree.map(jax.device_put, w, make_shardings(Weights))
108 |
109 | # We use `typed_shard_map` to allow us to write per-device code with explicit communication.
110 | #
111 | # Compared to untyped `jax.shard_map`, the `in_specs` and `out_specs` do not need to be specified:
112 | # they're inferred from the sharding on the function's signature.
113 | @typed_shard_map
114 | def forward_pass(x: f32[b'batch/d in'], w: Weights) -> f32[b'batch/d']:
115 | # Weights are all-gathered just prior to their use. (This is the core idea of fully-sharded data parallelism.)
116 | # The `in hidden1/d -> in hidden1` syntax expresses what this all-gather operation should do: it removes the
117 | # `d` sharding on the `hidden1` axis, resulting in a fully replicated output.
118 | w1 = shardops.all_gather('in hidden1/d -> in hidden1', w.w1)
119 | # The `einsum_unreduced` operation is a chip-local einsum. Unlike `jnp.einsum`, it supports sharding syntax,
120 | # and it performs shape checking using the current typing environment, so it will raise an error if for example
121 | # you use `batch` in two different ways within a function.
122 | #
123 | # We call this einsum "unreduced", because it does not do any cross-chip reductions, even if they are necessary.
124 | # For example, in an `a b/d, b/d c -> a c` einsum, a cross-chip reduction over the `d` sharding axis is required,
125 | # and it is the caller's responsibility to perform this reduction.
126 | y = jax.nn.relu(shardops.einsum_unreduced('batch/d in, in hidden1 -> batch/d hidden1', x, w1))
127 | w2 = shardops.all_gather('hidden1 hidden2/d -> hidden1 hidden2', w.w2)
128 | z = jax.nn.relu(shardops.einsum_unreduced('batch/d hidden1, hidden1 hidden2 -> batch/d hidden2', y, w2))
129 | w3 = shardops.all_gather('hidden2/d -> hidden2', w.w3)
130 | return shardops.einsum_unreduced('batch/d hidden2, hidden2 -> batch/d', z, w3)
131 |
132 | x = forward_pass(jnp.zeros((32, 8), dtype=jnp.float32), w)
133 | assert(x.shape == (32,))
134 | ```
135 |
136 | There are several other APIs exported by shardlib in addition to the ones demonstrated here. [Browse the code](/shardlib/) to see the full list.
137 |
138 | ## Expressing activation checkpointing using `save_for_backward`
139 |
140 | Which intermediate computations in the forwards pass are saved to HBM for later use in the backwards pass? [The default answer](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html) is: JAX saves _all_ intermediates for use in the backwards pass, but in JIT mode the XLA compiler optimizes many of these away so as to save memory.
141 |
142 | While JAX provides many sophisticated policies for making these choices, we offer a very simple one: calling `save_for_backward` causes its argument to be saved for the backwards pass. Here is an example:
143 |
144 | ```python
145 | from jax_extra import explicit_activation_checkpointing, save_for_backward
146 |
147 | # The @explicit_activation_checkpointing switches JAX from its default
148 | # policy of saving all intermediates, and instead only saves the
149 | # arguments to the annotated function, plus any intermediates marked
150 | # with `save_for_backward`.
151 | @explicit_activation_checkpointing
152 | def forward_pass(x, w1, w2):
153 | # save_for_backward marks `y` as being saved.
154 | y = save_for_backward(x @ w1)
155 | # `z` is not saved for the backwards pass.
156 | z = jax.nn.relu(z)
157 | return z @ w2
158 | ```
159 |
160 | ## Profiling
161 |
162 | Every training run gathers and reports performance information:
163 | * the time for two training steps (including data fetching in between them). This is written to stdout.
164 | * model FLOPS utilization (MFU) efficiency for these steps. This is written to stdout.
165 | * an XLA performance profile. This is written into the model directory at `/plugins/profile//perfetto_trace.json.gz`
166 | * an rendered SVG of the optimized XLA computation graph. This is written into the model directory at `/training_step_optimized_hlo_.svg`.
167 |
168 |
169 | ## File formats
170 |
171 | We write checkpoints and datasets in simple file formats based on [zarr](https://zarr.dev/). See our file format specifications:
172 | * [our checkpoint format](/docs/pytree-zarr-checkpoint.md)
173 | * [our dataset format](/docs/flat-tokens.md)
174 |
175 | ## Contact
176 |
177 | `seqax` is developed by the [MatX team](https://matx.com). If you're interested in working with us, you can reach us at [founders@matx.com](mailto:founders@matx.com).
178 |
179 |
180 | ## Citing seqax
181 |
182 | To cite this repository:
183 |
184 | ```
185 | @software{seqax2024github,
186 | author = {Reiner Pope and Vaclav Cvicek and Daniel Heinlein and Akshay Mishra and Mahdi Nazemi and Sanjit Neelam and Rachit Tibrewal},
187 | title = {seqax = sequence modeling + {JAX}},
188 | url = {https://github.com/MatX-inc/seqax},
189 | year = {2024},
190 | }
191 | ```
192 |
193 |
194 | ## Acknowledgements
195 |
196 | seqax's implementation style was substantially inspired by [jaxtyping](https://docs.kidger.site/jaxtyping/), [einops](https://einops.rocks/), [equinox](https://docs.kidger.site/equinox/), and [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).
197 |
198 | Thanks to [MaxText](https://github.com/google/maxtext) for demonstrating good practices for production LLM use of JAX.
199 |
200 | Thanks to the [JAX](https://github.com/google/jax) team for ongoing support and advice.
201 |
202 | Thanks to the [Google TPU Research Cloud](https://sites.research.google/trc/about/), which partially supported this work.
203 |
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | training:
2 | # AdamW optimizer parameters
3 | # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
4 | adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
5 | adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients.
6 | adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
7 | adam_eps_root: 0. # A small constant applied to denominator inside the square root.
8 | weight_decay: 0.1 # AdamW Weight decay
9 | # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
10 | # Learning rate schedule has two parts:
11 | # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
12 | # 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
13 | learning_rate: 3.e-5
14 | cosine_learning_rate_final_fraction: 0.1
15 | seed: 0
16 |
17 | io:
18 | max_io_threads: 1024
--------------------------------------------------------------------------------
/configs/c4_a100x8_1b.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=c4_a100x8_1b +paths.model_name=1b
2 | defaults:
3 | - c4_a100x8_base
4 | - _self_
5 |
6 | training:
7 | warmup_steps: 37000
8 | steps: 370000
9 | steps_for_lr: 370000
10 | learning_rate: 1.0e-5
11 | tokens:
12 | batch: 64
13 |
14 | training_data:
15 | streams: 1
16 |
17 | model:
18 | d_model: 2048
19 | n_q_per_kv: 1
20 | n_kv: 16
21 | d_head: 128
22 | layers: 8
23 | d_ff: 16384
24 | vocab: 32768
25 | rope_max_timescale: 10000
26 |
27 | checkpoint_interval: 10000
--------------------------------------------------------------------------------
/configs/c4_a100x8_270m.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=c4_a100x8_270m +paths.model_name=270m
2 | defaults:
3 | - c4_a100x8_base
4 | - _self_
5 |
6 | training:
7 | warmup_steps: 9200
8 | steps: 92000
9 | steps_for_lr: 92000
10 | learning_rate: 3.0e-4
11 |
12 | model:
13 | d_model: 1024
14 | n_q_per_kv: 1
15 | n_kv: 16
16 | d_head: 128
17 | layers: 8
18 | d_ff: 8192
19 | vocab: 32768
20 | rope_max_timescale: 10000
21 |
22 | checkpoint_interval: 9200
23 |
--------------------------------------------------------------------------------
/configs/c4_a100x8_2b.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=c4_a100x8_2b +paths.model_name=2b
2 | defaults:
3 | - c4_a100x8_base
4 | - _self_
5 |
6 |
7 | training:
8 | warmup_steps: 74000
9 | steps: 740000
10 | steps_for_lr: 740000
11 | learning_rate: 1.0e-5
12 | tokens:
13 | batch: 64
14 |
15 | training_data:
16 | streams: 4
17 |
18 | model:
19 | d_model: 4096
20 | n_q_per_kv: 1
21 | n_kv: 16
22 | d_head: 128
23 | layers: 8
24 | d_ff: 16384
25 | vocab: 32768
26 | rope_max_timescale: 10000
27 |
28 | checkpoint_interval: 10000
--------------------------------------------------------------------------------
/configs/c4_a100x8_540m.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=c4_a100x8_540m +paths.model_name=540m
2 | defaults:
3 | - c4_a100x8_base
4 | - _self_
5 |
6 | training:
7 | warmup_steps: 18500
8 | steps: 185000
9 | steps_for_lr: 185000
10 | learning_rate: 3.0e-4
11 |
12 | model:
13 | d_model: 2048
14 | n_q_per_kv: 1
15 | n_kv: 16
16 | d_head: 128
17 | layers: 8
18 | d_ff: 8192
19 | vocab: 32768
20 | rope_max_timescale: 10000
21 |
22 | checkpoint_interval: 4000
23 |
--------------------------------------------------------------------------------
/configs/c4_a100x8_84m.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=c4_a100x8_84m +paths.model_name=84m
2 | defaults:
3 | - c4_a100x8_base
4 | - _self_
5 |
6 | training:
7 | warmup_steps: 2600
8 | steps: 26000
9 | steps_for_lr: 26000
10 | learning_rate: 3.0e-3
11 |
12 | model:
13 | d_model: 512
14 | n_q_per_kv: 1
15 | n_kv: 8
16 | d_head: 128
17 | layers: 8
18 | d_ff: 4096
19 | vocab: 32768
20 | rope_max_timescale: 10000
21 |
22 | checkpoint_interval: 2600
23 |
--------------------------------------------------------------------------------
/configs/c4_a100x8_base.yaml:
--------------------------------------------------------------------------------
1 | training:
2 | seed: 0
3 | tokens:
4 | batch: 64
5 | len: 1024
6 |
7 | # AdamW optimizer parameters
8 | # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
9 | adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
10 | adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients.
11 | adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
12 | adam_eps_root: 0. # A small constant applied to denominator inside the square root.
13 | weight_decay: 0.1 # AdamW Weight decay
14 | # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
15 | # Learning rate schedule has two parts:
16 | # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
17 | # 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
18 |
19 | # Learning rate is not yet tuned.
20 | learning_rate: 3.e-4
21 | cosine_learning_rate_final_fraction: 0.1
22 |
23 | flat_tokens:
24 | # filespec can also be a path on a local filesystem
25 | filespec: 'gcs://path/to/your/dataset' # Fill in
26 | streams: 1
27 | read_blocks_per_shuffle_buffer: 128
28 | sequences_per_read_block: 1024
29 | seed: 0
30 | sequence_packing: true
31 |
32 | paths:
33 | # root_working_dir can also be a path on a local filesystem
34 | root_working_dir: 'gcs://path/to/your/outputs'
35 |
36 | num_hosts: 1
37 |
38 | mesh:
39 | d: 8
40 | t: 1
41 |
42 | io:
43 | max_io_threads: 1024
--------------------------------------------------------------------------------
/configs/c4_a100x8x4_1b.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=c4_a100x8x4_1b +paths.model_name=1b
2 | defaults:
3 | - c4_a100x8_base
4 | - _self_
5 |
6 | num_hosts: 4
7 |
8 | mesh:
9 | d: 32
10 | t: 1
11 |
12 | training:
13 | warmup_steps: 9250
14 | steps: 92500
15 | steps_for_lr: 92500
16 | learning_rate: 1.0e-5
17 | tokens:
18 | batch: 256
19 |
20 | training_data:
21 | streams: 4
22 |
23 | model:
24 | d_model: 2048
25 | n_q_per_kv: 1
26 | n_kv: 16
27 | d_head: 128
28 | layers: 8
29 | d_ff: 16384
30 | vocab: 32768
31 | rope_max_timescale: 10000
32 |
33 | checkpoint_interval: 2500
--------------------------------------------------------------------------------
/configs/c4_a100x8x4_2b.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=c4_a100x8x4_2b +paths.model_name=2b
2 | defaults:
3 | - c4_a100x8_base
4 | - _self_
5 |
6 | num_hosts: 4
7 |
8 | mesh:
9 | d: 32
10 | t: 1
11 |
12 | training:
13 | warmup_steps: 18500
14 | steps: 185000
15 | steps_for_lr: 185000
16 | learning_rate: 1.0e-5
17 | tokens:
18 | batch: 256
19 |
20 | model:
21 | d_model: 4096
22 | n_q_per_kv: 1
23 | n_kv: 16
24 | d_head: 128
25 | layers: 8
26 | d_ff: 16384
27 | vocab: 32768
28 | rope_max_timescale: 10000
29 |
30 | checkpoint_interval: 2500
--------------------------------------------------------------------------------
/configs/c4_a100x8x4_540m.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=c4_a100x8x4_540m +paths.model_name=540m
2 | defaults:
3 | - c4_a100x8_base
4 | - _self_
5 |
6 | num_hosts: 4
7 |
8 | mesh:
9 | d: 32
10 | t: 1
11 |
12 | training:
13 | warmup_steps: 4625
14 | steps: 46250
15 | steps_for_lr: 46250
16 | learning_rate: 3.0e-4
17 | tokens:
18 | batch: 256
19 |
20 | model:
21 | d_model: 2048
22 | n_q_per_kv: 1
23 | n_kv: 16
24 | d_head: 128
25 | layers: 16
26 | d_ff: 8192
27 | vocab: 32768
28 | rope_max_timescale: 10000
29 |
30 | checkpoint_interval: 2500
--------------------------------------------------------------------------------
/configs/flat_tokens_c4_a100x1_84m.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=huggingface_c4_a100x1_84m +paths.model_name=flat_token_84m
2 | training:
3 | seed: 0
4 | tokens:
5 | batch: 64
6 | len: 1024
7 |
8 | # AdamW optimizer parameters
9 | # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
10 | adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
11 | adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients.
12 | adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
13 | adam_eps_root: 0. # A small constant applied to denominator inside the square root.
14 | weight_decay: 0.1 # AdamW Weight decay
15 | # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
16 | # Learning rate schedule has two parts:
17 | # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
18 | # 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
19 | warmup_steps: 2600
20 | steps: 26000
21 | steps_for_lr: 26000
22 | learning_rate: 3.0e-4
23 |
24 | cosine_learning_rate_final_fraction: 0.1
25 |
26 |
27 | model:
28 | d_model: 512
29 | n_q_per_kv: 1
30 | n_kv: 8
31 | d_head: 128
32 | layers: 8
33 | d_ff: 4096
34 | vocab: 32768
35 | rope_max_timescale: 10000
36 |
37 | paths:
38 | # can also be a path to GCS. IE 'gcs://your_bucket/your_output_path'
39 | root_working_dir: '~/seqax_outputs'
40 |
41 | num_hosts: 1
42 |
43 | io:
44 | max_io_threads: 1024
45 |
46 | # Define either hf_dataset or flat_tokens. Do not use both.
47 | # flat_tokens requires more setup, but is better tested and doesn't waste tokens.
48 | # Using flat_tokens requires setting up a flat_tokens dataset using the script in tools/huggingface_to_flat_tokens.py
49 | flat_tokens:
50 | filespec: 'gcs://path/to/your/dataset' # can be a path to a gcs directory, or local copy of dataset.
51 | streams: 1
52 | read_blocks_per_shuffle_buffer: 128
53 | sequences_per_read_block: 1024
54 | seed: 0
55 | sequence_packing: true
56 |
57 |
58 | mesh:
59 | d: 1
60 | t: 1
61 |
62 | checkpoint_interval: 100
--------------------------------------------------------------------------------
/configs/huggingface_c4_a100x1_84m.yaml:
--------------------------------------------------------------------------------
1 | # python -m train --config-name=huggingface_c4_a100x1_84m +paths.model_name=hf_84m
2 | training:
3 | seed: 0
4 | tokens:
5 | batch: 64
6 | len: 1024
7 |
8 | # AdamW optimizer parameters
9 | # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
10 | adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
11 | adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients.
12 | adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
13 | adam_eps_root: 0. # A small constant applied to denominator inside the square root.
14 | weight_decay: 0.1 # AdamW Weight decay
15 | # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
16 | # Learning rate schedule has two parts:
17 | # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
18 | # 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
19 | warmup_steps: 2600
20 | steps: 26000
21 | steps_for_lr: 26000
22 | learning_rate: 3.0e-4
23 |
24 | cosine_learning_rate_final_fraction: 0.1
25 |
26 |
27 | model:
28 | d_model: 512
29 | n_q_per_kv: 1
30 | n_kv: 8
31 | d_head: 128
32 | layers: 8
33 | d_ff: 4096
34 | vocab: 32768
35 | rope_max_timescale: 10000
36 |
37 | paths:
38 | # can also be a path to GCS. IE 'gcs://your_bucket/your_output_path'
39 | root_working_dir: '~/seqax_outputs'
40 |
41 | num_hosts: 1
42 |
43 | io:
44 | max_io_threads: 1024
45 |
46 | # Define either hf_dataset or flat_tokens. Do not use both.
47 | # hf_dataset should work as long as you have have the necessary permissions to access
48 | # the dataset and tokenizer.
49 | hf_dataset:
50 | path: allenai/c4
51 | name: en
52 | num_workers: 64
53 | tokenizer: mistralai/Mistral-7B-v0.1 # may require huggingface-cli login
54 | sequences_packed_per_batch: 120
55 |
56 |
57 |
58 | mesh:
59 | d: 1
60 | t: 1
61 |
62 | checkpoint_interval: 100
--------------------------------------------------------------------------------
/configs/local_test_synthetic.yaml:
--------------------------------------------------------------------------------
1 | # Command to run on your CPU:
2 | # XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m train --config-name=local_test_synthetic +paths.model_name=synthetic_000
3 |
4 | defaults:
5 | - base
6 | - _self_
7 |
8 | training:
9 | warmup_steps: 10
10 | steps: 50
11 | steps_for_lr: 100
12 | tokens:
13 | batch: 64
14 | len: 64
15 |
16 | model:
17 | d_model: 256
18 | n_q_per_kv: 2
19 | n_kv: 2
20 | d_head: 32
21 | layers: 2
22 | vocab: 1280
23 | d_ff: 1024
24 | rope_max_timescale: 256
25 |
26 | paths:
27 | root_working_dir: '/tmp'
28 |
29 | checkpoint_interval: 10
30 | num_hosts: 1
31 |
32 | mesh:
33 | d: 4
34 | t: 2
35 |
36 | flat_tokens:
37 | filespec: 'synthetic_dataset.zarr'
38 | streams: 2
39 | read_blocks_per_shuffle_buffer: 8
40 | sequences_per_read_block: 16
41 | seed: 0
42 | sequence_packing: true
43 |
--------------------------------------------------------------------------------
/docs/flat-tokens.md:
--------------------------------------------------------------------------------
1 | # `flat-tokens` data format
2 |
3 | ## Introduction
4 |
5 | The `flat-tokens` data format is a very simple data format for storing language model training data.
6 | Unlike some other dataset libraries, it supports efficient seeking after job restarts. It also
7 | supports batch size, sequence length, and "sequence packing vs not" being selected at training
8 | time.
9 |
10 | It is based on the simplest possible design: a concatenation of all tokens in the dataset, together
11 | with start indices of each sequence.
12 |
13 | ## Specification
14 |
15 | ### Flat-tokens array
16 |
17 | A *flat-tokens array* is a [`zarr` Group](https://zarr.readthedocs.io/en/stable/) of the following format:
18 |
19 | ```
20 | arrays: {
21 | "encoded_tokens": uint32[token_count],
22 | "seq_starts": uint64[seq_count + 1],
23 | }
24 | attributes: {
25 | "max_token_id": int32
26 | }
27 | ```
28 |
29 | That is, it has two arrays, named `encoded_tokens`, `seq_starts`.
30 |
31 | 1. The `encoded_tokens` array is a concatenation of all sequences in the dataset into a long array of tokens.
32 | There are no padding, beginning-of-sequence, or end-of-sequence tokens included. Tokens are encoded
33 | as `token_id*2+1` if they are the start of a new sequence, or `token_id*2` if not. The maximum supported `token_id` is `2^31-1`.
34 | 2. The `seq_starts` array lists (in increasing order) the indices of the `tokens` array where each
35 | sequence starts, plus one final index which equals `token_count`, indicating the end of the final
36 | sequence.
37 |
38 | Additionally, it has one attribute, named `max_token_id`. All decoded `token_id` values in `encoded_tokens`
39 | must be `<= max_token_id`. (This is intended to allow readers to quickly check that their vocabulary size is
40 | large enough for the dataset.)
41 |
42 | ### Flat-tokens dataset
43 |
44 | A *flat-tokens dataset* is a `zarr` Group with entries "train", "validation", each of which are flat-tokens arrays.
45 |
46 | ## Example
47 |
48 | The token sequences `[[1, 2], [3, 4, 5], [6, 7, 8]]` are represented in a flat-tokens array as:
49 |
50 | ```
51 | arrays: {
52 | "tokens": [3, 4, 7, 8, 10, 13, 14, 16],
53 | "seq_starts": [0, 2, 5, 8],
54 | }
55 | attributes: {
56 | "max_token_id": 8
57 | }
58 | ```
59 |
60 | ## Discussion
61 |
62 | This is the simplest possible format supporting the following features:
63 | * Batch size and sequence length can be chosen at training time. They are not "baked into" the format.
64 | * Data loading can be done with or without sequence packing:
65 | * Without sequence packing, we consult `seq_starts` to locate the tokens of a particular sequence, e.g. `tokens[seq_starts[1]:seq_starts[2]]` is `[7, 8, 10]`, corresponding to the tokens of sequence 1.
66 | * With sequence packing, we bypass `seq_starts` and directly consult `tokens`, e.g. for packed sequence length 4, sequence 1 is `tokens[4:8]`, i.e. `[10, 13, 14, 16]`.
67 | * O(1) random access to any sequence, packed or not.
68 | * This allows you to restart your training job and continue where you left off in the dataset, without retaining any state except for the step or sequence index where you left off.
69 | * This allows arbitrary shuffling at runtime.
70 | * Minimal disk seeks ("IO operations" on public clouds) per random access: just one disk seek for sequence-packed random access; just two disk seeks for non-packed random access.
71 |
72 | The sequence packing is designed such that no loss masking is required: every single token can be used as a target token. In the above example, if we used packed sequence length 8 (i.e. the whole dataset as one packed sequence),
73 | at training time we'd expand the tokens into the following input and target tokens:
74 |
75 | ```
76 | {
77 | "inputs": [0, 1, 0, 3, 4, 0, 6, 7],
78 | "targets": [1, 2, 3, 4, 5, 6, 7, 8],
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/docs/matx.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/pytree-zarr-checkpoint.md:
--------------------------------------------------------------------------------
1 | # PyTree-zarr checkpoint format
2 |
3 | For `seqax` we write checkpoints of JAX PyTrees, in a simple format documented here.
4 |
5 | ## Specification
6 |
7 | The *zarr of a PyTree* is a a [zarr Group](https://zarr.readthedocs.io/en/stable/api/hierarchy.html) with the following elements:
8 | * for each `path, array` in the [flattened PyTree](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_flatten_with_path.html#jax.tree_util.tree_flatten_with_path), the zarr Group contains `array` as a child array, with path equal to `jax.tree_util.keystr(path)`
9 | * additionally there is a zarr [attribute](https://zarr.readthedocs.io/en/stable/api/attrs.html) by name `write_completed` and value `True`.
10 |
11 | The zarr of a PyTree may be written to disk with any compression and chunk size settings.
12 |
13 | ## Discussion
14 |
15 | We use `zarr` to support parallel writers from different hosts in a fully-sharded training setup. (Parallel writers in this scenario must choose a chunk size that divides the data size per host, so as to avoid zarr race conditions during writing.) Readers of the checkpoint format do not need to be aware that it was written in parallel, as this is hidden by the zarr abstraction.
16 |
17 | We use the `write_completed` attribute to allow parallel writers to support a "two phase commit" protocol: all writers write their data chunks, then wait for a global barrier, then the "leader" writer sets the `write_completed` attribute. This protects readers from reading partially-written checkpoints.
18 |
--------------------------------------------------------------------------------
/env.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def set_variables():
4 | os.environ['XLA_FLAGS'] = (
5 | os.environ.get('XLA_FLAGS', '') + ' '
6 | '--xla_gpu_enable_async_collectives=true '
7 | '--xla_gpu_enable_latency_hiding_scheduler=true '
8 | )
9 | os.environ.update({
10 | "NCCL_LL128_BUFFSIZE": "-2",
11 | "NCCL_LL_BUFFSIZE": "-2",
12 | "NCCL_PROTO": "SIMPLE,LL,LL128",
13 | })
14 | os.environ["LIBTPU_INIT_ARGS"] = "--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
15 |
--------------------------------------------------------------------------------
/input_loader.py:
--------------------------------------------------------------------------------
1 | """Input data loading from `flat-tokens` data format.
2 |
3 | See `docs/flat-tokens.md` for details on the format.
4 |
5 | We support shuffling of the input data, by the following algorithm:
6 | * there are N independent "streams" of data, each of which has disjoint data and is
7 | shuffled independently.
8 | * within each stream, we fetch a "shuffle buffer" consisting of many "read blocks" of
9 | data. We shuffle the entire buffer in memory.
10 | * the "read blocks" attached to each shuffle buffer are themselves selected randomly.
11 |
12 | This is the standard shuffling used by e.g. Huggingface Datasets. Unlike them, we run
13 | this algorithm _after_ tokenization, so we know exactly at which step number each new
14 | shuffle buffer starts at, allowing us to do instant resumes after job restarts. In our
15 | default recommended configuration, we also recommend a much larger shuffle buffer size
16 | than Huggingface Datasets, which allows for more thorough shuffling, taking advantage
17 | of the fact that a single sequence of tokens uses very little memory compared to e.g.
18 | a single image.
19 |
20 | Mosaic's StreamingDatasets library uses a similar algorithm as us, which they call py1b:
21 | https://docs.mosaicml.com/projects/streaming/en/stable/fundamentals/shuffling.html.
22 | """
23 |
24 | from concurrent.futures import ThreadPoolExecutor
25 | import functools
26 | from typing import Tuple, Union, Optional
27 |
28 | from typeguard import typechecked
29 | from shardlib.shardtypes import bool_, pytree_dataclass, u32
30 | import shardlib.shardtypes as shardtypes
31 | import zarr
32 | from dataclasses import dataclass
33 | import jax
34 | import numpy as np
35 | from jax.sharding import PartitionSpec as P
36 | import datetime
37 | import jax
38 |
39 | # imports for hf dataloader
40 | import numpy as onp
41 | from transformers import AutoTokenizer
42 | from torch.utils.data import DataLoader
43 | from datasets import load_dataset
44 |
45 | @dataclass(frozen=True)
46 | class TokenBatchParams:
47 | """The shape of a token batch."""
48 | len: int
49 | batch: int
50 |
51 |
52 | @pytree_dataclass
53 | class TokenBatch:
54 | """A batch of tokens, which are typically the input to training."""
55 | targets: u32['batch/d len']
56 | is_seq_start: bool_['batch/d len']
57 |
58 |
59 |
60 |
61 | @dataclass(frozen=True)
62 | class FlatTokensParams:
63 | filespec: str
64 |
65 | # A "stream" is what's attached to one independent shuffle buffer. There may be multiple
66 | # independent shuffle buffers, allowing parallelism.
67 | #
68 | # A "minipoch" (mini-epoch) is the set of sequences visited by one global refill of shuffle
69 | # buffers. The last minipoch may be shorter than others, but each stream in the last minipoch
70 | # must have the same number of read blocks, which must also be an integer.
71 | #
72 | # (To minimize discarded data on very small training sets, set streams=1 and make
73 | # sequences_per_read_block small.)
74 | #
75 | # Shuffling transforms the uint32[num_tokens] into uint32[streams, sequences, len], the
76 | # "shuffled tokens". We then form batches by a transformation on [streams, sequences].
77 |
78 | streams: int # Recommended: maximum number of hosts you expect to use.
79 | read_blocks_per_shuffle_buffer: int # Recommended: 1 << 10. 4GiB (uncompressed) shuffle buffer.
80 | sequences_per_read_block: int # Recommended: (1 << 20) / len. 1MiB (compressed) read block.
81 | seed: int
82 | sequence_packing: bool
83 |
84 |
85 | @dataclass
86 | class _ShuffleBuffer:
87 | minipoch: int
88 | buffer: u32['Buflen len']
89 |
90 |
91 | class ShufflingLoader:
92 | def __init__(self, split: str, params: FlatTokensParams, token_batch_params: TokenBatchParams):
93 | self.params = params
94 | self.token_batch_params = token_batch_params
95 | self.root = zarr.open_group(params.filespec, mode="r")
96 | assert split in ["train", "validation"], "Invalid split"
97 | self.encoded_tokens = self.root[split]["encoded_tokens"]
98 | self.seq_starts = self.root[split]["seq_starts"]
99 | self.max_token_id = self.root[split].attrs["max_token_id"]
100 | assert len(self.encoded_tokens.shape) == 1, "Expected 1D zarr"
101 | assert self.encoded_tokens.dtype == np.uint32, "Expected uint32 zarr"
102 | assert len(self.seq_starts.shape) == 1, "Expected 1D zarr"
103 | assert self.seq_starts.dtype == np.uint64, "Expected uint64 zarr"
104 |
105 | token_count = self.encoded_tokens.shape[0]
106 | if params.sequence_packing:
107 | self.seq_count = token_count // token_batch_params.len
108 | else:
109 | self.seq_count = self.seq_starts.shape[0] - 1
110 |
111 | # Count read blocks. Round it down to a multiple of streams
112 | read_block_count = self.seq_count // params.sequences_per_read_block
113 | read_block_count = (read_block_count // params.streams) * params.streams
114 | self.read_block_count = read_block_count
115 | assert read_block_count > 0, "Must have at least one read block per stream. Try shrinking streams and sequences_per_read_block."
116 | self.step_count = (read_block_count * params.sequences_per_read_block) // token_batch_params.batch
117 | # Count minipochs
118 | self.minipoch_count = _div_up(read_block_count, params.streams * params.read_blocks_per_shuffle_buffer)
119 | self.seq_indices_per_shuffle_buffer = params.read_blocks_per_shuffle_buffer * params.sequences_per_read_block
120 | # Calculate batch->stream mapping.
121 | self.batch_indices_per_stream = _div_exact(token_batch_params.batch, params.streams)
122 | # Calculate which streams and which batch indices this host is responsible for, based on the sharding.
123 | self.sharding = shardtypes.make_shardings(TokenBatch).targets
124 | streams = set()
125 | batch_indices = set()
126 | for batch_slices, _ in self.sharding.addressable_devices_indices_map((token_batch_params.batch, token_batch_params.len)).values():
127 | batch_lo, batch_hi, batch_step = batch_slices.indices(token_batch_params.batch)
128 | for b in range(batch_lo, batch_hi, batch_step):
129 | batch_indices.add(b)
130 | streams.add(b // self.batch_indices_per_stream)
131 | self.shuffle_buffers_by_stream = {stream_index: None for stream_index in streams}
132 | self.batch_indices = sorted(batch_indices)
133 | # Shuffle read blocks
134 | assert read_block_count < 1 << 32, "Too many read blocks. Try growing sequences_per_read_block."
135 | self.read_block_ordering = _random_permutation(params.seed, read_block_count)
136 |
137 |
138 | def load(self, step: int) -> TokenBatch:
139 | assert step < self.step_count, f"Requested step {step} but dataset only supports {self.step_count} steps at batch size {self.token_batch_params.batch}."
140 | # Conceptually, we remap IDs as follows:
141 | # 1. (step, batch_index) -> (stream, seq_index_in_stream)
142 | # 2. seq_index_in_stream -> (minipoch, seq_index_in_shuffle_buffer)
143 | #
144 | # We visit all batch_indices in increasing order. Since the map batch_index->(stream, minipoch)
145 | # is monotonic (non-decreasing), we can reload the shuffle buffer for a stream whenever
146 | # we cross to a new minipoch without thrashing back and forth between adjacent minipochs.
147 | seq_by_batch_index = {}
148 | for batch_index in self.batch_indices:
149 | # 1. (step, batch_index) -> (stream, seq_index_in_stream)
150 | stream = batch_index // self.batch_indices_per_stream
151 | seq_index_in_stream = step * self.batch_indices_per_stream + (batch_index % self.batch_indices_per_stream)
152 | # 2. seq_index_in_stream -> (minipoch, seq_index_in_shuffle_buffer)
153 | minipoch = seq_index_in_stream // self.seq_indices_per_shuffle_buffer
154 | seq_index_in_shuffle_buffer = seq_index_in_stream % self.seq_indices_per_shuffle_buffer
155 | shuffle_buffer = self._get_shuffle_buffer(stream, minipoch)
156 | seq_by_batch_index[batch_index] = shuffle_buffer[seq_index_in_shuffle_buffer]
157 |
158 | def get_shard(indexing: Tuple[slice]) -> jax.Array:
159 | seqlen_slice = indexing[1]
160 | examples = []
161 | for batch_index in range(*indexing[0].indices(self.token_batch_params.batch)):
162 | examples.append(seq_by_batch_index[batch_index][seqlen_slice])
163 | return np.stack(examples)
164 |
165 | shape = (self.token_batch_params.batch, self.token_batch_params.len)
166 | encoded_tokens = jax.make_array_from_callback(shape, self.sharding, get_shard)
167 | return _decode(encoded_tokens)
168 |
169 |
170 | def _get_shuffle_buffer(self, stream: int, minipoch: int) -> _ShuffleBuffer:
171 | if self.shuffle_buffers_by_stream[stream] is None or self.shuffle_buffers_by_stream[stream].minipoch != minipoch:
172 | self.shuffle_buffers_by_stream[stream] = None # Free the underlying memory
173 | blocks_in_shuffle_buffer = self.params.read_blocks_per_shuffle_buffer
174 | if minipoch == self.minipoch_count - 1:
175 | blocks_in_shuffle_buffer = (self.read_block_count // self.params.streams) - self.params.read_blocks_per_shuffle_buffer * minipoch
176 | # We form a mapping:
177 | # (stream, minipoch, read_block_in_minipoch) -> sequential_read_block
178 | # then we map
179 | # sequential_read_block -> shuffled_read_block
180 | # using self.shuffled_read_blocks.
181 | shuffled_read_block_indices = []
182 | for read_block_in_minipoch in range(blocks_in_shuffle_buffer):
183 | sequential_read_block = (minipoch * self.params.read_blocks_per_shuffle_buffer + read_block_in_minipoch) * self.params.streams + stream
184 | shuffled_read_block = self.read_block_ordering[sequential_read_block]
185 | shuffled_read_block_indices.append(shuffled_read_block)
186 |
187 | # Now load all of the read blocks in parallel.
188 | def load_read_block(read_block_index: int) -> u32['Buflen len']:
189 | start_seq = read_block_index * self.params.sequences_per_read_block
190 | end_seq = start_seq + self.params.sequences_per_read_block
191 | block_shape = (self.params.sequences_per_read_block, self.token_batch_params.len)
192 | if self.params.sequence_packing:
193 | flat_tokens = self.encoded_tokens[start_seq * self.token_batch_params.len : end_seq * self.token_batch_params.len]
194 | return flat_tokens.reshape(block_shape)
195 | else:
196 | seq_starts = self.seq_starts[start_seq : end_seq + 1]
197 | flat_tokens = self.encoded_tokens[seq_starts[0] : seq_starts[-1]]
198 | # Read the ragged array into a (padded) dense array.
199 | #
200 | # We pad with 1s, which decode to (0, new_sequence=true).
201 | result = np.ones(block_shape, dtype=np.uint32)
202 | for i in range(self.params.sequences_per_read_block):
203 | start = seq_starts[i]
204 | end = seq_starts[i + 1]
205 | result[i, :end - start] = flat_tokens[start:end]
206 | return result
207 |
208 | print(f'[{datetime.datetime.now()}] Loading shuffle buffer')
209 | # Loading a read block is IO-dominated work, with very little CPU time involved, so we can afford
210 | # to run a huge number of these in parallel with little concern about thrashing the CPU by having
211 | # excessively many threads doing CPU-intensive work. At the recommended read block sizing of 1MiB,
212 | # the memory footprint of a read block is typically bigger than the memory footprint of a CPU thread,
213 | # so we're also unlikely to waste a significant fraction of memory by having too many threads. In
214 | # net, allow a lot of threads, potentially way more than we have CPUs! Other overheads will
215 | # bite us before thread overheads do.
216 | with ThreadPoolExecutor(max_workers=len(shuffled_read_block_indices)) as executor:
217 | shuffled_read_blocks = list(executor.map(load_read_block, shuffled_read_block_indices))
218 | shuffle_buffer = np.concatenate(shuffled_read_blocks, axis=0)
219 | print(f'[{datetime.datetime.now()}] Finished loading shuffle buffer, {shuffle_buffer.size * 4:_} bytes')
220 |
221 | # Actually shuffle it.
222 | sequences_in_shuffle_buffer = blocks_in_shuffle_buffer * self.params.sequences_per_read_block
223 | assert shuffle_buffer.shape == (sequences_in_shuffle_buffer, self.token_batch_params.len)
224 | shuffle_seed = self.params.seed + 1 + minipoch * self.params.streams + stream
225 | permutation = _random_permutation(shuffle_seed, sequences_in_shuffle_buffer)
226 | shuffle_buffer = shuffle_buffer[permutation, :]
227 | self.shuffle_buffers_by_stream[stream] = _ShuffleBuffer(minipoch, shuffle_buffer)
228 |
229 | return self.shuffle_buffers_by_stream[stream].buffer
230 |
231 | def _div_up(a: int, b: int) -> int:
232 | return (a + b - 1) // b
233 |
234 | def _div_exact(a: int, b: int) -> int:
235 | assert a % b == 0
236 | return a // b
237 |
238 | @functools.partial(jax.jit, donate_argnums=(0,))
239 | @typechecked
240 | def _decode(encoded_tokens: u32[b'batch/d len']) -> TokenBatch:
241 | # encoded_tokens encoding:
242 | # 2*id+1 for the first token in a sequence
243 | # 2*id for other tokens in the sequence
244 | return TokenBatch(
245 | targets = encoded_tokens >> 1,
246 | is_seq_start = (encoded_tokens & 1) == 1,
247 | )
248 |
249 | def _random_permutation(seed: int, n: int) -> u32['N']:
250 | """Same as `np.random.Generator.permutation`, but with a guarantee that it will always produce the same results for a given seed."""
251 | assert n < 1 << 32
252 | # We do a Fisher-Yates shuffle using the Philox BitGenerator. Unlike the rest of np.random,
253 | # which is documented as potentially changing between numpy versions or even platforms on
254 | # the same version, the Philox BitGenerator is documented as stable. Likewise, we also promise
255 | # not to change the following implementation of the Fisher-Yates shuffle.
256 | #
257 | # We calculate the random numbers using `random_uint64() % n` rather than using rejection
258 | # sampling to generate numbers in range `[0, n)`. (Rejection sampling is more complicated,
259 | # because we don't know up front how many random numbers we'll need.) Our approach
260 | # introduces some bias, but it's small: since n<2^32, the bias is at most 2^-32 for each
261 | # random number generated. We're fine with this.
262 | randoms = np.random.Philox(seed).random_raw(n) % (np.arange(n, dtype=np.uint64) + 1)
263 | result = np.arange(n, dtype=np.uint32)
264 | for i in reversed(range(n)):
265 | j = randoms[i]
266 | tmp = result[i]
267 | result[i] = result[j]
268 | result[j] = tmp
269 | return result
270 |
271 |
272 | @dataclass(frozen=True)
273 | class HuggingFaceDataParams:
274 | path: str
275 | tokenizer: str
276 | num_workers: int
277 | sequences_packed_per_batch: int
278 | name: Optional[str] = None
279 |
280 | class HuggingFaceDataLoader:
281 | """
282 | The HuggingFaceDataLoader is provided for convenience and ease of setup,
283 | but the flat tokens dataloader is recommended for production use.
284 | This dataset does not require running the tools/huggingface_to_flat_tokens.py
285 | to create a flat tokens dataset, and instead streams directly from huggingface.
286 |
287 | This datalaoder will waste tokens if you pack too many sequences into a batch,
288 | and does not support instant resume to an arbitrary step.
289 | """
290 | def __init__(self, split, config: HuggingFaceDataParams, token_batch_params: TokenBatchParams):
291 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
292 | self.batch_size = token_batch_params.batch
293 | self.max_seq_len = token_batch_params.len
294 | self.sharding = shardtypes.make_shardings(TokenBatch).targets
295 | self.max_token_id = self.tokenizer.vocab_size-1
296 | assert 0 in self.tokenizer.all_special_ids, "Tokenizer must have a special 0 token"
297 |
298 | # setup an iterator over the dataset
299 | tokenize = functools.partial(self.tokenizer, padding=False, truncation=False, max_length=None, add_special_tokens=False, return_token_type_ids=False, return_attention_mask=False, return_tensors="np")
300 | dataset = load_dataset(config.path, config.name, streaming=True, split=split)
301 | tokenized = dataset.select_columns(["text"]).map(tokenize, input_columns=['text'], remove_columns=["text"])
302 | dataloader = DataLoader(tokenized, num_workers=config.num_workers, collate_fn=self.collate, drop_last=True, batch_size=config.sequences_packed_per_batch)
303 | self.iterator = iter(dataloader)
304 |
305 | def collate(self, sequences):
306 | flat_batch = onp.zeros(self.batch_size * self.max_seq_len, onp.uint32)
307 | flat_is_start = onp.zeros(self.batch_size * self.max_seq_len, onp.bool_)
308 | start = 0
309 | for seq in sequences:
310 | seq = seq['input_ids'][0]
311 | end = min(start + len(seq), len(flat_batch))
312 | flat_is_start[start] = True
313 | flat_batch[start:end] = seq[:end-start]
314 | start += len(seq)
315 | if start >= len(flat_batch):
316 | break
317 | shape = (self.batch_size, self.max_seq_len)
318 | return flat_batch.reshape(shape), flat_is_start.reshape(shape)
319 |
320 | def load(self, step):
321 | shape = (self.batch_size, self.max_seq_len)
322 | batch, is_start = next(self.iterator)
323 | def get_shard(x: jax.Array, indexing: Tuple[slice]) -> jax.Array:
324 | shard = x[indexing]
325 | return shard
326 | tokens = jax.make_array_from_callback(shape, self.sharding, functools.partial(get_shard, batch))
327 | is_start = jax.make_array_from_callback(shape, self.sharding, functools.partial(get_shard, is_start))
328 | return TokenBatch(tokens, is_start)
329 |
330 | def get_loader(split: str, config: Union[FlatTokensParams, HuggingFaceDataParams], token_batch_params: TokenBatchParams):
331 | if isinstance(config, FlatTokensParams):
332 | return ShufflingLoader(split, config, token_batch_params)
333 | elif isinstance(config, HuggingFaceDataParams):
334 | return HuggingFaceDataLoader(split, config, token_batch_params)
335 | else:
336 | raise ValueError(f"Unknown config type {type(config)}")
--------------------------------------------------------------------------------
/jax_extra.py:
--------------------------------------------------------------------------------
1 | """Extra utilities for JAX and Python."""
2 | import jax
3 | import hashlib
4 | import jax
5 | import jax.ad_checkpoint
6 | import dataclasses
7 | from typing import Union, get_args
8 | from dataclasses import fields, is_dataclass
9 |
10 |
11 | def fold_in_str(key: jax.Array, string: str) -> jax.Array:
12 | """Returns a PRNG key derived from an initial PRNG key and a string input.
13 |
14 | Args:
15 | key: The initial PRNG key.
16 | string: The string input (e.g., 'pretrain', 'query', etc.).
17 |
18 | Returns:
19 | A PRNG key derived from the initial PRNG key and the string input.
20 | """
21 | return jax.random.fold_in(
22 | key, int(hashlib.md5(string.encode()).hexdigest()[:8], base=16)
23 | )
24 |
25 | def _convert(value, target_type):
26 | if value is None and target_type is not type(None):
27 | raise ValueError(f"Cannot convert None to {target_type}")
28 | elif value is None and target_type is type(None):
29 | return None
30 | elif is_dataclass(target_type):
31 | return make_dataclass_from_dict(target_type, value)
32 | else:
33 | return target_type(value)
34 |
35 | def _handle_union(name, field_value, union_types):
36 | for type_option in union_types:
37 | try:
38 | return _convert(field_value, type_option)
39 | except (TypeError, ValueError, AssertionError):
40 | continue
41 | raise ValueError(f'could not convert Union type {name} to any of {union_types}.')
42 |
43 | def make_dataclass_from_dict(cls, data):
44 | """Recursively instantiate a dataclass from a dictionary."""
45 | if data is None:
46 | raise ValueError(f'Expected a {cls.__name__}, got None instead.')
47 | field_data = {}
48 | for field in fields(cls):
49 | field_value = data.get(field.name)
50 | if hasattr(field.type, '__origin__') and field.type.__origin__ is Union:
51 | field_data[field.name] = _handle_union(field.name, field_value, get_args(field.type))
52 | else:
53 | try:
54 | field_data[field.name] = _convert(field_value, field.type)
55 | except (TypeError, ValueError, AssertionError):
56 | raise ValueError(f'Expected {field.type} for {cls.__name__}.{field.name}, got {type(field_value)} instead.')
57 | return cls(**field_data)
58 |
59 | def explicit_activation_checkpointing(f):
60 | """Annotates a function f to be used with save_for_backward().
61 |
62 | Example:
63 |
64 | ```
65 | @explicit_activation_checkpointing
66 | def foo(W1, W2, W3, x):
67 | x = jax.nn.relu(save_for_backward(W1 @ x))
68 | x = jax.nn.relu(save_for_backward(W2 @ x))
69 | x = W3 @ x
70 | ```
71 |
72 | This causes the pre-ReLU activations to be saved for the backwards pass.
73 | """
74 | # We save everything that is named.
75 | return jax.ad_checkpoint.checkpoint(f, policy=jax.checkpoint_policies.save_any_names_but_these())
76 |
77 | def save_for_backward(x):
78 | """Saves a value for the backwards pass in a function annotated with explicit_activation_checkpointing()."""
79 | # The actual name isn't important, just the fact that it _is_ named, so that
80 | # the save_any_names_but_these() policy causes it to be saved.
81 | return jax.ad_checkpoint.checkpoint_name(x, name='seqax_save_for_backward')
--------------------------------------------------------------------------------
/requirements-cpu.txt:
--------------------------------------------------------------------------------
1 | zarr
2 | fsspec[gcs]
3 | jax[cpu]==0.4.26
4 | einops
5 | hydra-core
6 | clearml
7 | clearml-agent
8 | typeguard==4.1.5
9 | transformers # For Huggingface data loader
10 | datasets # For Huggingface data loader
11 | torch[cpu] # For Huggingface data loader
--------------------------------------------------------------------------------
/shardlib/shardops.py:
--------------------------------------------------------------------------------
1 | import shardlib.shardtypes as shardtypes
2 | from jax import lax
3 | import jax.numpy as jnp
4 | import jax
5 |
6 | def all_gather(spec: str, x):
7 | """String-specified all-gather operation.
8 |
9 | For example:
10 | all_gather('A/x/y B/z C/w -> A B C/w', x)
11 | """
12 | before, after = spec.split('->')
13 | before = shardtypes.ShapeSpec.parse(before)
14 | after = shardtypes.ShapeSpec.parse(after)
15 | shardtypes.check(x.dtype, before, x)
16 | for i, (before_dim, after_dim) in enumerate(zip(before.dims, after.dims)):
17 | # Check that after_dim.sharding is a prefix of before_dim.sharding
18 | after_n = len(after_dim.sharding)
19 | if before_dim.shape != after_dim.shape or before_dim.sharding[:after_n] != after_dim.sharding:
20 | raise ValueError(f'Cannot all-gather {before_dim} into {after_dim}')
21 | if len(before_dim.sharding) == after_n:
22 | continue
23 | x = lax.all_gather(x, tuple(before_dim.sharding[after_n:]), axis=i, tiled=True)
24 | shardtypes.check(x.dtype, after, x)
25 | return x
26 |
27 | def psum_scatter(spec: str, x):
28 | """String-specified reduce-scatter operation.
29 |
30 | For example:
31 | psum_scatter('A B C/w -> A/x/y B/z C/w', x)
32 | """
33 | before, after = spec.split('->')
34 | before = shardtypes.ShapeSpec.parse(before)
35 | after = shardtypes.ShapeSpec.parse(after)
36 | shardtypes.check(x.dtype, before, x)
37 | for i, (before_dim, after_dim) in enumerate(zip(before.dims, after.dims)):
38 | # Check that before_dim.sharding is a prefix of after_dim.sharding
39 | before_n = len(before_dim.sharding)
40 | if before_dim.shape != after_dim.shape or after_dim.sharding[:before_n] != before_dim.sharding:
41 | raise ValueError(f'Cannot reduce-scatter {before_dim} into {after_dim}')
42 | if len(after_dim.sharding) == before_n:
43 | continue
44 | x = lax.psum_scatter(x, tuple(after_dim.sharding[before_n:]), scatter_dimension=i, tiled=True)
45 | shardtypes.check(x.dtype, after, x)
46 | return x
47 |
48 | def einsum_unreduced(spec: str, x, y, **kwargs):
49 | """Ordinary chip-local einsum, but with sharding-aware typechecking.
50 |
51 | Note that this function does not do any chip-to-chip communication. If the inputs are
52 | sharded over the contraction dimensions, the caller is responsible for reducing the result
53 | over those dimensions. For example:
54 |
55 | c = einsum_unreduced('A/x B/y, B/y C/z -> A/x/z', a, b)
56 | # c still needs to be reduced over the y axis.
57 | d = psum_scatter('A/x/z -> A/x/z/y', c)
58 | # Now the post-einsum reduction is complete.
59 | """
60 | tmp, result = spec.split('->')
61 | lhs, rhs = tmp.split(',')
62 | lhs = shardtypes.ShapeSpec.parse(lhs)
63 | rhs = shardtypes.ShapeSpec.parse(rhs)
64 | result = shardtypes.ShapeSpec.parse(result)
65 | shardtypes.check(x.dtype, lhs, x)
66 | shardtypes.check(y.dtype, rhs, y)
67 | # Convert to jax einsum syntax, with single-letter variables.
68 | jaxspec = ''
69 |
70 | vars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
71 | var_i = 0
72 | dim_table = {}
73 | def map_var(dim):
74 | if dim in dim_table:
75 | return dim_table[dim]
76 | nonlocal var_i
77 | if var_i >= len(vars):
78 | raise ValueError('Too many dimensions in einsum, we ran out of variables')
79 | var = vars[var_i]
80 | var_i += 1
81 | dim_table[dim] = var
82 | return var
83 |
84 | for dim in lhs.dims:
85 | jaxspec += map_var(dim)
86 | jaxspec += ','
87 | for dim in rhs.dims:
88 | jaxspec += map_var(dim)
89 | jaxspec += '->'
90 | for dim in result.dims:
91 | jaxspec += map_var(dim)
92 | r = jnp.einsum(jaxspec, x, y, **kwargs)
93 | shardtypes.check(r.dtype, result, r)
94 | return r
95 |
96 | def index_unreduced(spec: str, table, indices):
97 | """String-specified sharded table lookup operation.
98 |
99 | For example:
100 | index_unreduced(table, indices, 'A [B/x/y] C/z, D/w A -> C/z A D/w')
101 |
102 | In this example, the integers in `indices` are used as lookup addresses into the
103 | `B` dimension of `table`, and all other dimensions (`A`, `C`, `D`) are vmapped over.
104 |
105 | This operation does not do any chip-to-chip communication, even though the table
106 | may be sharded. If the axis inside square brackets is sharded, corresponding to
107 | different table indices on different shards, a table lookup will be performed on each
108 | shard, but only one shard will return a nonzero result: the other shards, where the
109 | index is out of bounds, will return zero. The caller is required to reduce the output
110 | over the axes specified by the square brackets: in the above example, the caller must
111 | reduce over `x` and `y` axes.
112 | """
113 | tmp, result = spec.split('->')
114 | lhs, rhs = tmp.split(',')
115 | lhs_dims = lhs.split()
116 | index_axis = None
117 | for i, dim in enumerate(lhs_dims):
118 | if dim.startswith('['):
119 | index_axis = i
120 | if not dim.endswith(']'):
121 | raise ValueError(f'Expected closing bracket in {dim}')
122 | lhs_dims[i] = dim[1:-1]
123 | break
124 | if index_axis is None:
125 | raise ValueError(f'Expected an index axis in {lhs}')
126 |
127 | lhs_dims = [shardtypes.DimSpec.parse(dim) for dim in lhs_dims]
128 | lhs_spec = shardtypes.ShapeSpec(lhs_dims)
129 | rhs_spec = shardtypes.ShapeSpec.parse(rhs)
130 | result_spec = shardtypes.ShapeSpec.parse(result)
131 | shardtypes.check(table.dtype, lhs_spec, table)
132 | shardtypes.check(indices.dtype, rhs_spec, indices)
133 |
134 | # Do the base operation on scalars, then do a sequence of vmap operations to bring it up
135 | # to the desired shape.
136 | def base_op(table, index):
137 | len_per_chip = table.shape[0]
138 | lower_bound = len_per_chip * lax.axis_index(lhs_dims[index_axis].sharding)
139 | upper_bound = lower_bound + len_per_chip
140 | in_bounds = (lower_bound <= index) & (index < upper_bound)
141 | return jnp.where(in_bounds, table[jnp.where(in_bounds, index - lower_bound, 0)], 0)
142 |
143 | op = base_op
144 |
145 | lhs_dims_handled = [False] * len(lhs_dims)
146 | lhs_dims_handled[index_axis] = True
147 | rhs_dims_handled = [False] * len(rhs_spec.dims)
148 | for dim in reversed(result_spec.dims):
149 | try:
150 | lhs_index = lhs_dims.index(dim)
151 | lhs_vmap_axis = sum(lhs_dims_handled[:lhs_index])
152 | assert not lhs_dims_handled[lhs_index]
153 | lhs_dims_handled[lhs_index] = True
154 | except ValueError:
155 | lhs_index = None
156 | lhs_vmap_axis = None
157 |
158 | try:
159 | rhs_index = rhs_spec.dims.index(dim)
160 | rhs_vmap_axis = sum(rhs_dims_handled[:rhs_index])
161 | assert not rhs_dims_handled[rhs_index]
162 | rhs_dims_handled[rhs_index] = True
163 | except ValueError:
164 | rhs_index = None
165 | rhs_vmap_axis = None
166 |
167 | op = jax.vmap(op, in_axes=(lhs_vmap_axis, rhs_vmap_axis), out_axes=0)
168 |
169 | assert all(lhs_dims_handled)
170 | assert all(rhs_dims_handled)
171 |
172 | result = op(table, indices)
173 | shardtypes.check(result.dtype, result_spec, result)
174 | return result
175 |
176 | def axis_size(name: str) -> int:
177 | """Return the size of the axis with the given name."""
178 | return jax.lax.psum(1, name)
--------------------------------------------------------------------------------
/shardlib/shardtypes.py:
--------------------------------------------------------------------------------
1 | """Type annotations for JAX arrays with sharding information.
2 |
3 | # Shape checking
4 |
5 | Example:
6 |
7 | ```
8 | import jax
9 | shardtypes.register_with_typeguard()
10 | from shardlib.shardtypes import f32
11 | from typeguard import typechecked
12 |
13 | @typechecked
14 | def center_channels(x: f32[b'batch/d channels']) -> f32[b'batch/d channels']:
15 | return x - jax.numpy.mean(x, axis=-1, keepdims=True)
16 | ```
17 |
18 | The type syntax is `[]`, where `dtype` is imported from `shardlib.shardtypes`,
19 | and `` is a space-separated list of dimensions. Each dimension consists of a dimension
20 | name (e.g. `batch`), optionally followed by slashes and sharding axis names, e.g. `batch/d` indicates
21 | that the `batch` tensor dimension is sharded over the `d` device axis. Sharding over multiple axes
22 | is indicated by multiple axis names, e.g. `batch/d/e`.
23 |
24 | The shape string may be either a string ('foo') or a bytes object (b'foo'). Strings have special
25 | meaning in Python type annotations (they are used for forward references, and are eval'ed by typeguard),
26 | so the bytes object b'foo' is a workaround to prevent this eval'ing.
27 |
28 | Shape checking proceeds by maintaining a table of the sizes of all dimension names in a context
29 | variable, known as the shape checking scope. The first time a dimension name is encountered,
30 | its size is recorded in the current scope. Subsequent uses of the same dimension name must have
31 | the same size. Device axes (e.g. `/d`) are looked up in the currently configured JAX device mesh,
32 | to determine the size of the axis.
33 |
34 | For calls into functions or libraries, it can be useful to clear the shape checking scope, so caller
35 | and callee can use the same variable name to mean different things. This can be done with the `@scope`
36 | function decorator or the `with Scope():` context manager.
37 |
38 | # Using type annotations
39 |
40 | In addition to driving shape checking, type annotations can be used to drive sharding in JAX functions.
41 | See for example `typed_shard_map`, which is a simplification of JAX's `shard_map` by taking advantage
42 | of sharding in type signatures.
43 | """
44 | import inspect
45 | import typing
46 | from collections.abc import Sequence
47 | from contextvars import ContextVar
48 | from enum import IntEnum
49 | from typing import Any, Union
50 | from typing import get_args, get_origin
51 | from typeguard import check_type_internal, typechecked
52 | import jax
53 | import jax.numpy as jnp
54 | from types import GenericAlias
55 | from typeguard import TypeCheckError, TypeCheckerCallable
56 | import dataclasses
57 | from dataclasses import dataclass, make_dataclass
58 | from typeguard import checker_lookup_functions
59 |
60 |
61 | #### State
62 | # ContextVar(dict[str, int])
63 | _VARS = ContextVar('shardtypes._VARS', default={})
64 |
65 | class Scope:
66 | """Context manager that clears the shape checking scope."""
67 | def __enter__(self):
68 | self.token = _VARS.set({})
69 |
70 | def __exit__(self, type, value, traceback):
71 | _VARS.reset(self.token)
72 |
73 | def scope(f):
74 | """Function decorator that clears the shape checking scope."""
75 | def wrapper(*args, **kwargs):
76 | with Scope():
77 | return f(*args, **kwargs)
78 | return wrapper
79 |
80 |
81 | def check_size(name: str, size: int):
82 | """Checks that a dimension has the expected size."""
83 | try:
84 | value = int(name)
85 | if value != size:
86 | raise TypeCheckError(f'explicit dimension {value}: actually was {size}')
87 | except ValueError:
88 | v = _VARS.get()
89 | if name in v:
90 | if v[name] != size:
91 | raise TypeCheckError(f'dimension {name}: expected {v[name]}, got {size}')
92 | else:
93 | v[name] = size
94 |
95 |
96 | #### Shape specs
97 | @dataclass(frozen=True)
98 | class DimSpec:
99 | """Parsed result of a dimension in a shape string."""
100 | shape: str
101 | sharding: Sequence[str]
102 |
103 | @staticmethod
104 | def parse(spec: str) -> 'DimSpec':
105 | pieces = spec.split('/')
106 | shape = pieces[0]
107 | sharding = tuple(pieces[1:])
108 | return DimSpec(shape, sharding)
109 |
110 | def __str__(self):
111 | return '/'.join([self.shape] + list(self.sharding))
112 |
113 | @dataclass
114 | class ShapeSpec:
115 | """Parsed result of a shape string."""
116 | dims: Sequence[DimSpec]
117 |
118 | @staticmethod
119 | def parse(spec: Union[bytes, str]) -> 'ShapeSpec':
120 | if isinstance(spec, bytes):
121 | spec = spec.decode('utf-8')
122 | if not isinstance(spec, str):
123 | print(spec)
124 | raise ValueError('Expected a string')
125 | dims = spec.split() # Split on spaces, trimming excess space
126 | result = []
127 | for dim in dims:
128 | result.append(DimSpec.parse(dim))
129 | return ShapeSpec(result)
130 |
131 | def partition_spec(self) -> jax.sharding.PartitionSpec:
132 | result = []
133 | for dim_spec in self.dims:
134 | if len(dim_spec.sharding) == 0:
135 | result.append(None)
136 | elif len(dim_spec.sharding) == 1:
137 | result.append(dim_spec.sharding[0])
138 | else:
139 | result.append(tuple(dim_spec.sharding))
140 | return jax.sharding.PartitionSpec(*result)
141 |
142 | def __str__(self):
143 | return ' '.join(str(dim) for dim in self.dims)
144 |
145 | #### Shape checking
146 | def _partition_spec_equiv(lhs: jax.sharding.PartitionSpec, rhs: jax.sharding.PartitionSpec) -> bool:
147 | if len(lhs) < len(rhs):
148 | lhs, rhs = rhs, lhs
149 | if any(l is not None for l in lhs[len(rhs):]):
150 | return False
151 | return lhs[:len(rhs)] == rhs[:]
152 |
153 |
154 | def check(dtype, shape_spec: ShapeSpec, value):
155 | """Checks that a value has the expected dtype and shape."""
156 | if not isinstance(value, jax.Array):
157 | raise TypeCheckError('is not a jax.Array')
158 | if value.dtype != dtype:
159 | raise TypeCheckError(f'is {value.dtype}, but expected {dtype}')
160 | shape = value.shape
161 | if len(shape) != len(shape_spec.dims):
162 | raise TypeCheckError(f'has shape {shape}, but expected shape {str(shape_spec)}')
163 | mesh = None
164 |
165 | axis_env = jax._src.core.thread_local_state.trace_state.axis_env
166 | if axis_env:
167 | # We're in a shard_map/pmap/xmap context. Multiply sizes by sharding, then check sizes.
168 | # We don't actually check the sharding, because that information is lost inside a
169 | # shard_map/pmap/xmap context, but we do check the unsharded sizes are correct.
170 | mesh = {axis.name: axis.size for axis in axis_env}
171 | for orig_dim, dim_spec in zip(shape, shape_spec.dims):
172 | dim = orig_dim
173 | for axis in dim_spec.sharding:
174 | if axis not in mesh:
175 | raise TypeCheckError(f'has unknown mesh axis {axis}')
176 | axis_size = mesh[axis]
177 | dim *= axis_size
178 | check_size(dim_spec.shape, dim)
179 | else:
180 | # Check sizes
181 | for dim, dim_spec in zip(shape, shape_spec.dims):
182 | check_size(dim_spec.shape, dim)
183 |
184 | # Check sharding
185 | expected_spec = shape_spec.partition_spec()
186 | def cb(actual):
187 | if isinstance(actual, jax.sharding.SingleDeviceSharding):
188 | if any(dim_spec.sharding for dim_spec in shape_spec.dims):
189 | raise TypeCheckError(f'is fully replicated, but expected {expected_spec} is not')
190 | elif not isinstance(actual, jax.sharding.NamedSharding):
191 | if isinstance(actual, jax.sharding.Sharding):
192 | raise TypeCheckError(f'is SPMD-sharded but no axis names are available. Use `with Mesh(...):` to provide axis names for type checking.')
193 | else:
194 | raise TypeCheckError(f': unexpected object when checking sharding: {actual}')
195 | elif not _partition_spec_equiv(actual.spec, expected_spec):
196 | # TODO: when an axis size is None, recovering the NamedSharding from the PositionalSharding
197 | # is ambiguous, and JAX often takes a different approach than the user does.
198 | #
199 | # We could fix this with a more precise _partition_spec_equiv, but for now we'll just ignore it.
200 | # raise TypeCheckError(f'has sharding spec {actual.spec}, but expected {expected_spec} from {str(shape_spec)}')
201 | pass
202 | # Use tracing as a proxy for whether we're in a jit context
203 | is_tracing = jax._src.core.thread_local_state.trace_state.trace_stack
204 | if is_tracing:
205 | jax.debug.inspect_array_sharding(value, callback=cb)
206 | else:
207 | cb(value.sharding)
208 |
209 |
210 |
211 |
212 | #### Typeguard
213 | def register_with_typeguard():
214 | """Registers the shardtypes module with typeguard. Call this at the beginning of your program."""
215 | def check_array(value, origin, args, memo):
216 | if len(args) != 1 or (type(args[0]) is not str and type(args[0]) is not bytes):
217 | raise TypeCheckError(f'has bad type signature; expected {origin.__name__}[], got {origin.__name__}{args}')
218 | check(origin.dtype, ShapeSpec.parse(args[0]), value)
219 |
220 | def check_pytree_dataclass(value, origin, args, memo):
221 | if not isinstance(value, origin):
222 | raise TypeCheckError(f'is not an instance of {origin}')
223 | for field in dataclasses.fields(origin):
224 | check_type_internal(getattr(value, field.name), field.type, memo)
225 |
226 | def lookup(
227 | origin, args, extras
228 | ) -> TypeCheckerCallable | None:
229 | if isinstance(origin, type) and issubclass(origin, number):
230 | return check_array
231 | if origin in _PYTREE_DATACLASSES:
232 | return check_pytree_dataclass
233 | return None
234 |
235 | checker_lookup_functions.append(lookup)
236 |
237 | #### Array types
238 | class number:
239 | def __class_getitem__(cls, x):
240 | if isinstance(x, str):
241 | x = x.encode('utf-8')
242 | return GenericAlias(cls, x)
243 |
244 | class bool_(number):
245 | dtype = jnp.bool_
246 | pass
247 |
248 | class bf16(number):
249 | dtype = jnp.bfloat16
250 | pass
251 |
252 | class f32(number):
253 | dtype = jnp.float32
254 | pass
255 |
256 | class i32(number):
257 | dtype = jnp.int32
258 | pass
259 |
260 | class u32(number):
261 | dtype = jnp.uint32
262 | pass
263 |
264 | class i8(number):
265 | dtype = jnp.int8
266 | pass
267 |
268 | class u8(number):
269 | dtype = jnp.uint8
270 | pass
271 |
272 |
273 | _PYTREE_DATACLASSES = set()
274 |
275 |
276 | def pytree_dataclass(cls):
277 | """Decorator that declares a dataclass that JAX recognizes as a PyTree."""
278 | cls = dataclass(cls)
279 |
280 | def flatten_with_keys(value):
281 | return [(k.name, getattr(value, k.name)) for k in dataclasses.fields(cls)], ()
282 |
283 | def unflatten(_aux, fields):
284 | return cls(*fields)
285 |
286 | jax.tree_util.register_pytree_with_keys(cls, flatten_with_keys, unflatten)
287 | _PYTREE_DATACLASSES.add(cls)
288 | return cls
289 |
290 | class Array:
291 | """If `cls` is an array type or a `pytree_dataclass` of array types,
292 | `Array[axes, cls]` will extend `cls` with leading axes `axes`.
293 | For example, `Array['layers', f32['batch d_model']] returns f32['layers batch d_model`]`.
294 | """
295 | def __class_getitem__(cls, x):
296 | axes, input_cls = x
297 | if isinstance(axes, str):
298 | axes = axes.encode('utf-8')
299 | elif isinstance(axes, bytes):
300 | pass
301 | else:
302 | raise ValueError(f"input axes to {cls} must be Union[bytes, str]")
303 |
304 | if dataclasses.is_dataclass(input_cls):
305 | extended_fields = []
306 | for fld in dataclasses.fields(input_cls):
307 | extended_type = Array[axes, fld.type]
308 | extended_fields.append((fld.name, extended_type))
309 |
310 | extended_cls = make_dataclass(input_cls.__name__, extended_fields, bases=(input_cls,))
311 | pytree_dataclass(extended_cls)
312 | return extended_cls
313 | else:
314 | number_type, shape = get_origin(input_cls), get_args(input_cls)
315 | extended_shape = (axes + b' ' + shape[0],)
316 | return GenericAlias(number_type, extended_shape)
317 |
318 | def make_partition_specs(cls):
319 | """Instantiates a pytree dataclass with a PartitionSpec at array type."""
320 | # Check for a tuple type:
321 | origin = typing.get_origin(cls)
322 | args = typing.get_args(cls)
323 | if origin is tuple:
324 | return tuple(make_partition_specs(arg) for arg in args)
325 | elif origin is not None and issubclass(origin, number):
326 | if len(args) != 1 or (type(args[0]) is not str and type(args[0]) is not bytes):
327 | raise ValueError(f'Type annotation {cls} should be [], got {cls}')
328 | spec = ShapeSpec.parse(args[0])
329 | return spec.partition_spec()
330 | elif dataclasses.is_dataclass(cls):
331 | values = []
332 | for field in dataclasses.fields(cls):
333 | values.append(make_partition_specs(field.type))
334 | return cls(*values)
335 |
336 | raise ValueError(f'Unsupported type {cls} is not a array, dataclass, or tuple type')
337 |
338 |
339 | def make_shardings(cls):
340 | """Instantiates a pytree dataclass with NamedSharding at array type."""
341 | mesh = jax._src.mesh.thread_resources.env.physical_mesh
342 | return jax.tree_map(lambda spec: jax.sharding.NamedSharding(mesh, spec), make_partition_specs(cls))
343 |
344 |
345 | def typed_shard_map(f, **kwargs):
346 | """jax.shard_map, but which does not require specifying in_specs and out_specs.
347 |
348 | Instead, the function signature is used to infer the partitioning of the inputs and outputs.
349 |
350 | For example:
351 | @typed_shard_map
352 | def f(x: f32[b'batch/d len'], y: f32[b'e/d f/t']) -> f32[b'batch/d f/t']:
353 | ...
354 |
355 | """
356 | sig = inspect.signature(f)
357 |
358 | def wrapped(*args):
359 | mesh = jax._src.mesh.thread_resources.env.physical_mesh
360 | in_specs = tuple(make_partition_specs(param.annotation) for param in sig.parameters.values())
361 | out_specs = make_partition_specs(sig.return_annotation)
362 | return jax.experimental.shard_map.shard_map(typechecked(f), in_specs=in_specs, out_specs=out_specs, mesh=mesh, **kwargs)(*args)
363 |
364 | return wrapped
365 |
366 | def is_fully_sharded(spec: jax.sharding.PartitionSpec):
367 | """Returns True if the spec is fully sharded, i.e. every device axis is used in the partition spec."""
368 | axis_count = 0
369 | for axis in spec:
370 | if axis is None:
371 | continue
372 | elif isinstance(axis, str):
373 | axis_count += 1
374 | elif isinstance(axis, tuple):
375 | axis_count += len(axis)
376 | else:
377 | raise ValueError(f'Unknown axis type {axis}')
378 | return axis_count == len(jax._src.core.thread_local_state.trace_state.axis_env)
379 |
--------------------------------------------------------------------------------
/synthetic_dataset.zarr/.zgroup:
--------------------------------------------------------------------------------
1 | {
2 | "zarr_format": 2
3 | }
--------------------------------------------------------------------------------
/synthetic_dataset.zarr/train/.zattrs:
--------------------------------------------------------------------------------
1 | {
2 | "max_token_id": 1017
3 | }
--------------------------------------------------------------------------------
/synthetic_dataset.zarr/train/.zgroup:
--------------------------------------------------------------------------------
1 | {
2 | "zarr_format": 2
3 | }
--------------------------------------------------------------------------------
/synthetic_dataset.zarr/train/encoded_tokens/.zarray:
--------------------------------------------------------------------------------
1 | {
2 | "chunks": [
3 | 4194304
4 | ],
5 | "compressor": {
6 | "blocksize": 0,
7 | "clevel": 5,
8 | "cname": "lz4",
9 | "id": "blosc",
10 | "shuffle": 2
11 | },
12 | "dtype": " self.group.attrs["max_token_id"]:
78 | self.group.attrs["max_token_id"] = chunk.max_token_id
79 | # In parallel:
80 | with concurrent.futures.ThreadPoolExecutor() as executor:
81 | executor.submit(lambda: self.encoded_tokens.append(chunk.encoded_tokens))
82 | executor.submit(lambda: self.seq_starts.append(num_tokens + chunk.seq_starts[1:]))
83 |
84 |
85 |
--------------------------------------------------------------------------------
/tools/huggingface_to_flat_tokens.py:
--------------------------------------------------------------------------------
1 | """Tokenizes a Huggingface dataset and writes it to `flat-tokens` format.
2 |
3 | See `docs/flat-tokens.md` for details on the format.
4 | See `configs/c4_en.yaml` for an instructions on running.
5 |
6 | TODO: we could make this much faster by sharding over multiple CPUs. Rough approach:
7 | 1) Make this script read from a shard of the Huggingface dataset.
8 | 2) At the end of this script, wait for all shards to complete, and then concatenate the zarr data.
9 | """
10 |
11 | import hydra
12 |
13 | from typing import Optional, Dict
14 | import time
15 | import numpy as np
16 | from dataclasses import dataclass
17 | from transformers import AutoTokenizer
18 | from hydra.core.config_store import ConfigStore
19 | from datasets import load_dataset
20 | from concurrent.futures import ThreadPoolExecutor
21 | import flat_tokens
22 |
23 | @dataclass
24 | class Config:
25 | output: str
26 | tokenizer: str
27 | dataset: str
28 | variant: Optional[str]
29 | max_tokens: Optional[int]
30 | write_buffer_size_in_sequences: int
31 | flat_tokens_config: flat_tokens.Config
32 | _target_: str = __name__ + ".Config"
33 |
34 |
35 | # Registering the Config class with the name 'config'.
36 | ConfigStore.instance().store(name="config_schema", node=Config)
37 |
38 |
39 | @hydra.main(config_path="configs", version_base=None)
40 | def main(config):
41 | # Create tokenizer
42 | if config.tokenizer == "bytes_utf8":
43 | def tokenize(texts):
44 | return [np.uint32(np.frombuffer(text.encode('utf-8'), np.uint8)) + 1 for text in texts]
45 | else:
46 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
47 | assert 0 in tokenizer.all_special_ids, "Tokenizer must have 0 as a special id"
48 | assert tokenizer.vocab_size < 1 << 31, "Tokenizer vocab size too large for uint31"
49 | def tokenize(texts):
50 | return tokenizer(texts, add_special_tokens=False)["input_ids"]
51 |
52 | def tokenize_and_concat(batch):
53 | chunk = flat_tokens.Chunk.from_ragged(tokenize(batch["text"]))
54 | # dataset.map() API requires us to return numpy tensors of the appropriate shape...
55 | return {
56 | "encoded_tokens": chunk.encoded_tokens[np.newaxis, :],
57 | "seq_starts": chunk.seq_starts[np.newaxis, :],
58 | "max_token_id": np.array(chunk.max_token_id, np.uint32)[np.newaxis],
59 | }
60 |
61 | executor = ThreadPoolExecutor()
62 |
63 | for split, mode in [("validation", "w-"), ("train", "r+")]:
64 | dataset = load_dataset(
65 | config.dataset,
66 | config.variant,
67 | streaming=True,
68 | split=split,
69 | )
70 | dataset = dataset.select_columns(["text"])
71 | dataset = dataset.map(tokenize_and_concat, batched=True, batch_size=config.write_buffer_size_in_sequences, remove_columns=["text"])
72 |
73 | # Open output
74 | dst = flat_tokens.Writer(config.output, flat_tokens.Split(split), mode, config.flat_tokens_config)
75 | dst_flush = executor.submit(lambda: None)
76 |
77 | # Write in batches
78 | flush_elapsed = 0
79 | start_time = time.time()
80 | next_update = 0
81 | seq_count = 0
82 | token_count = 0
83 | for batch in dataset:
84 | chunk = flat_tokens.Chunk(encoded_tokens=batch["encoded_tokens"], seq_starts=batch["seq_starts"], max_token_id=batch["max_token_id"])
85 | seq_count += len(chunk.seq_starts) - 1
86 | token_count += len(chunk.encoded_tokens)
87 |
88 | flush_start = time.time()
89 | dst_flush.result()
90 | dst_flush = executor.submit(dst.write, chunk)
91 | flush_elapsed += time.time() - flush_start
92 | elapsed = time.time() - start_time
93 | if elapsed > next_update:
94 | total_mib = token_count * 4 // (1024 * 1024)
95 | speed_mib_per_s = total_mib / elapsed
96 | print(f"[{int(elapsed):_}s] Processed {seq_count:_} examples, {token_count:_} tokens, {total_mib:_} MiB, {speed_mib_per_s:.2f} MiB/s. Flush time: {flush_elapsed:.2f}s")
97 | next_update = elapsed + 60
98 |
99 | if token_count >= config.max_tokens:
100 | break
101 |
102 | # Final flush
103 | dst_flush.result()
104 |
105 | print(f"Done with split '{split}': {seq_count:_} examples, {token_count:_} tokens")
106 |
107 |
108 | if __name__ == "__main__":
109 | main()
110 |
--------------------------------------------------------------------------------
/tools/requirements.txt:
--------------------------------------------------------------------------------
1 | zarr
2 | fsspec[gcs]
3 | transformers
4 | datasets
5 | hydra-core
6 |
--------------------------------------------------------------------------------
/tools/write_synthetic_dataset.py:
--------------------------------------------------------------------------------
1 | # To run:
2 | #
3 | # ```
4 | # python write_synthetic_dataset.py --config-name=synthetic_dataset +output=synthetic_dataset.zarr
5 | # ```
6 | #
7 | # Synthetic tasks: see Section 4 of https://arxiv.org/abs/2002.09402 for some ideas.
8 | #
9 | # We do:
10 | # * Task 1: [theirs] fixed-distance copy. Requires attention position queries.
11 | # * Task 2: [theirs] fixed-distance reverse. Requires attention position queries.
12 | # * Task 3: [ours] random-distance (specified) copy. Requires variable-length position queries.
13 | # * Task 4: [ours] random-distance (not specified) copy. Requires equality matching for a prefix.
14 | # * Task 5: [ours] gaussians sampling. Requires model to learn which IDs are close to each other (in numerical order).
15 | #
16 | # Sequences begin with task ID, then have task-specific data. We avoid index 0, which indicates padding.
17 |
18 | from functools import partial
19 | import hydra
20 | from jaxtyping import Float, Int, jaxtyped, UInt32
21 | import numpy as np
22 | from typeguard import typechecked as typechecker
23 | from dataclasses import dataclass
24 | from hydra.core.config_store import ConfigStore
25 | import flat_tokens
26 |
27 | @dataclass
28 | class Config:
29 | output: str
30 | seed: int
31 | seq_len: int
32 | examples: int
33 | flat_tokens_config: flat_tokens.Config
34 | _target_: str = __name__ + '.Config'
35 |
36 |
37 |
38 | @jaxtyped(typechecker=typechecker)
39 | def copy(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']:
40 | seq = gen.integers(1, 11, (examples, (seq_len + 1) // 2), dtype=np.uint32)
41 | return np.append(seq, seq, axis=1)[:, :seq_len]
42 |
43 | @jaxtyped(typechecker=typechecker)
44 | def reverse(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']:
45 | seq = gen.integers(1, 11, (examples, (seq_len + 1) // 2), dtype=np.uint32)
46 | return np.append(seq, np.flip(seq, axis=1), axis=1)[:, :seq_len]
47 |
48 | @jaxtyped(typechecker=typechecker)
49 | def random_known_distance_copy(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']:
50 | distance = gen.integers(max(1, seq_len // 4), seq_len, (examples,), dtype=np.uint32)
51 | seq = gen.integers(1, 11, (examples, seq_len), dtype=np.uint32)
52 | indices = np.arange(seq_len - 1)[np.newaxis, :] % distance[:, np.newaxis]
53 | full_seq = seq[np.arange(examples)[:, np.newaxis], indices]
54 | assert full_seq.shape == (examples, seq_len - 1)
55 | return np.append(distance[:, np.newaxis], full_seq, axis=1)
56 |
57 | @jaxtyped(typechecker=typechecker)
58 | def random_unknown_distance_copy(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']:
59 | return random_known_distance_copy(seq_len + 1, examples, gen)[:, 1:]
60 |
61 | @jaxtyped(typechecker=typechecker)
62 | def mixture_of_gaussians(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']:
63 | centers = gen.uniform(0, 100, (examples, 3)).astype(np.float32)
64 | stddevs = gen.uniform(1, 4, (examples, 3)).astype(np.float32)
65 | sample_cluster_ids = gen.integers(0, 3, (examples, seq_len), dtype=np.uint32)
66 | batch_ids = np.arange(examples)[:, np.newaxis]
67 | sample_centers = centers[batch_ids, sample_cluster_ids]
68 | sample_stddevs = stddevs[batch_ids, sample_cluster_ids]
69 | floats = gen.normal(0, 1, (examples, seq_len)).astype(np.float32) * sample_stddevs + sample_centers
70 | return np.clip(np.round(floats).astype(np.uint32), 1, 100)
71 |
72 | @jaxtyped(typechecker=typechecker)
73 | def synthetic_task(config: Config, gen: np.random.Generator) -> list[UInt32[np.ndarray, '...']]:
74 | task_seq_len = config.seq_len - 1
75 | examples = config.examples
76 | copy_data = copy(task_seq_len, examples, gen)
77 | reverse_data = reverse(task_seq_len, examples, gen)
78 | random_known_distance_copy_data = random_known_distance_copy(task_seq_len, examples, gen)
79 | random_unknown_distance_copy_data = random_unknown_distance_copy(task_seq_len, examples, gen)
80 | mixture_of_gaussians_data = mixture_of_gaussians(task_seq_len, examples, gen)
81 | tasks = np.asarray([copy_data, reverse_data, random_known_distance_copy_data, random_unknown_distance_copy_data, mixture_of_gaussians_data])
82 | task_id = gen.integers(1, 6, (examples,), dtype=np.uint32)
83 | targets = np.append(task_id[:, np.newaxis], tasks[task_id - 1, np.arange(examples)], axis=1)
84 | lengths = gen.integers(1, config.seq_len + 1, (examples,), dtype=np.uint32)
85 | ragged_targets = [targets[i, :lengths[i]] for i in range(examples)]
86 | return ragged_targets
87 |
88 |
89 | # Registering the Config class with the name 'config'.
90 | ConfigStore.instance().store(name="config_schema", node=Config)
91 |
92 |
93 | @hydra.main(config_path="configs", version_base=None)
94 | def main(config):
95 | config = hydra.utils.instantiate(config)
96 | gen = np.random.Generator(np.random.PCG64(config.seed))
97 |
98 | for split, mode in [(flat_tokens.Split.VALIDATION, "w-"), (flat_tokens.Split.TRAIN, "r+")]:
99 | dst = flat_tokens.Writer(config.output, split, mode, config.flat_tokens_config)
100 | examples = synthetic_task(config, gen)
101 | dst.write(flat_tokens.Chunk.from_ragged(examples))
102 |
103 | if __name__ == "__main__":
104 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Main training loop, including the model, loss function, and optimizer."""
2 | import operator
3 | import os
4 | import time
5 |
6 | import env
7 | env.set_variables()
8 | import shardlib.shardtypes as shardtypes
9 | shardtypes.register_with_typeguard()
10 | import gcsfs # Needed for clearml setup
11 |
12 | import datetime
13 | from functools import cached_property, partial
14 | from typing import Any, Optional, Tuple, Union
15 | import hydra
16 | from typeguard import typechecked
17 | from dataclasses import dataclass
18 | import jax
19 | from jax import lax
20 | from jax.sharding import PartitionSpec
21 | import jax.numpy as jnp
22 | import math
23 | from input_loader import FlatTokensParams, HuggingFaceDataParams, TokenBatch, TokenBatchParams, get_loader
24 | from shardlib.shardtypes import bf16, bool_, f32, pytree_dataclass, u32, make_shardings, Array
25 | import shardlib.shardops as shardops
26 | P = PartitionSpec
27 | import einops
28 | import jax_extra
29 | from jax_extra import fold_in_str, explicit_activation_checkpointing, save_for_backward
30 | import os
31 | import training_io
32 | from clearml import Task
33 | from jax.experimental import mesh_utils
34 | from jax.sharding import Mesh
35 | from jax.tree_util import tree_leaves
36 |
37 | PRNGKey = Any
38 |
39 | @dataclass(frozen=True)
40 | class Hparams:
41 | d_model: int
42 | n_q_per_kv: int
43 | n_kv: int
44 | d_head: int
45 | layers: int
46 | vocab: int
47 | d_ff: int
48 | rope_max_timescale: int
49 |
50 | @pytree_dataclass
51 | class TransformerLayer:
52 | ln1: f32['d_model/t/d']
53 | ln2: f32['d_model/t/d']
54 | w_q: f32['d_model/d n_q_per_kv n_kv/t d_head']
55 | w_kv: f32['2 d_model/d n_kv/t d_head']
56 | w_o: f32['d_model/d n_q_per_kv n_kv/t d_head']
57 | w_gate: f32['d_model/d d_ff/t']
58 | w_up: f32['d_model/d d_ff/t']
59 | w_down: f32['d_model/d d_ff/t']
60 |
61 | Transformer = Array['layers', TransformerLayer]
62 |
63 | @pytree_dataclass
64 | class Model:
65 | embed: f32['vocab/t d_model/d']
66 | unembed: f32['vocab/t d_model/d']
67 | transformer: Transformer
68 | final_layer_norm: f32['d_model/d/t']
69 |
70 | @staticmethod
71 | @typechecked
72 | def init(h: Hparams, rng: PRNGKey) -> 'Model':
73 | embed = jax.random.normal(jax_extra.fold_in_str(rng, 'embed'), (h.vocab, h.d_model), dtype=jnp.float32)
74 | # https://github.com/google/jax/issues/20390 for ones_like with sharding.
75 | ln1 = jnp.ones((h.layers, h.d_model), dtype=jnp.float32)
76 | ln2 = jnp.ones((h.layers, h.d_model), dtype=jnp.float32)
77 | final_layer_norm = jnp.ones((h.d_model,), dtype=jnp.float32)
78 |
79 | # All of wi/wq/wo/wo/w_kv use truncated_normal initializers with 'fan_in' scaling,
80 | # i.e. variance set to 1.0/fan_in.
81 | # The constant is stddev of standard normal truncated to (-2, 2)
82 | truncated_normal_stddev = .87962566103423978
83 |
84 | # scale for tensors with d_model fan_in and truncated normal truncated to (-2, 2)
85 | d_model_scale = 1 / (math.sqrt(h.d_model) * truncated_normal_stddev)
86 |
87 | w_kv_scale = d_model_scale
88 | w_q_scale = d_model_scale / math.sqrt(h.d_head)
89 | total_head_dim = h.n_q_per_kv * h.n_kv * h.d_head
90 | w_o_scale = 1 / (math.sqrt(total_head_dim) * truncated_normal_stddev)
91 | w_up_scale = d_model_scale
92 | w_down_scale = 1 / (math.sqrt(h.d_ff) * truncated_normal_stddev)
93 | unembed_scale = d_model_scale
94 |
95 | w_q_shape = (h.layers, h.d_model, h.n_q_per_kv, h.n_kv, h.d_head)
96 | w_q = w_q_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_q'), -2, 2, w_q_shape, dtype=jnp.float32)
97 | w_kv_shape = (h.layers, 2, h.d_model, h.n_kv, h.d_head)
98 | w_kv = w_kv_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_kv'), -2, 2, w_kv_shape, dtype=jnp.float32)
99 | w_o_shape = w_q_shape
100 | w_o = w_o_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_o'), -2, 2, w_o_shape, dtype=jnp.float32)
101 |
102 | ff_shape = (h.layers, h.d_model, h.d_ff)
103 | w_gate = w_up_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_gate'), -2, 2, ff_shape, dtype=jnp.float32)
104 | w_up = w_up_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_up'), -2, 2, ff_shape, dtype=jnp.float32)
105 | w_down = w_down_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_down'), -2, 2, ff_shape, dtype=jnp.float32)
106 |
107 | unembed = unembed_scale * jax.random.truncated_normal(fold_in_str(rng, 'unembed'), -2, 2, (h.vocab, h.d_model), dtype=jnp.float32)
108 | arrays = Model(
109 | embed=embed,
110 | unembed=unembed,
111 | transformer=Transformer(
112 | ln1=ln1,
113 | ln2=ln2,
114 | w_q=w_q,
115 | w_kv=w_kv,
116 | w_o=w_o,
117 | w_gate=w_gate,
118 | w_up=w_up,
119 | w_down=w_down,
120 | ),
121 | final_layer_norm=final_layer_norm,
122 | )
123 | shardings = make_shardings(Model)
124 | return jax.tree.map(lax.with_sharding_constraint, arrays, shardings)
125 |
126 |
127 | @typechecked
128 | def forward_pass(self, h: Hparams, ids: u32[b'B/d L'], is_seq_start: bool_[b'B/d L']) -> f32[b'B/d L V/t']:
129 | ##### Initial embedding lookup.
130 | embed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(self.embed))
131 | x = shardops.index_unreduced('[V/t] M, B/d L -> B/d L M', embed, ids)
132 | x = shardops.psum_scatter('B/d L M -> B/d L M/t', x)
133 |
134 | L = ids.shape[1]
135 | segment_ids = jnp.cumsum(is_seq_start, axis=1)
136 | segment_mask: bool_[b'B/d L L'] = segment_ids[:, :, jnp.newaxis] == segment_ids[:, jnp.newaxis, :]
137 | segment_mask: bool_[b'B/d L L 1 1'] = segment_mask[..., jnp.newaxis, jnp.newaxis] # add axes for q_per_k, num_kv_heads dimensions
138 | causal_mask: bool_[b'1 L L 1 1'] = jnp.tril(jnp.ones((L, L), dtype=jnp.bool_), 0)[jnp.newaxis, ..., jnp.newaxis, jnp.newaxis]
139 | causal_mask: bool_[b'B/d L L 1 1'] = jnp.logical_and(segment_mask, causal_mask)
140 |
141 | rope_table = RopeTable.create(L, h)
142 |
143 | ##### Transformer blocks.
144 | @explicit_activation_checkpointing
145 | @typechecked
146 | def loop_body(x: bf16[b'B/d L M/t'], layer_weights: TransformerLayer) -> Tuple[bf16[b'B/d L M/t'], Tuple[()]]:
147 | # Pre-attention RMSNorm
148 | ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln1))
149 | gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
150 | nx = jnp.bfloat16(rms_norm(gx) * ln1)
151 |
152 | # Attention, using Grouped Query Attention and RoPE position embeddings.
153 | w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_q))
154 | q = save_for_backward(shardops.einsum_unreduced('B/d L M, M Q K/t D -> B/d L Q K/t D', nx, w_q))
155 | q = rope_table.apply('L D -> 1 L 1 1 D', q)
156 | w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(layer_weights.w_kv))
157 | k, v = shardops.einsum_unreduced('B/d L M, k_v M K/t D -> k_v B/d L K/t D', nx, w_kv)
158 | k = save_for_backward(k)
159 | v = save_for_backward(v)
160 | k = rope_table.apply('L d -> 1 L 1 d', k)
161 | logits = shardops.einsum_unreduced(
162 | 'B/d Qlen Q K/t D, B/d Klen K/t D -> B/d Qlen Klen Q K/t', q, k, preferred_element_type=jnp.float32)
163 | logits = jnp.where(causal_mask, logits, -1e10)
164 | probs = jnp.bfloat16(jax.nn.softmax(logits, axis=2))
165 | attn_out = shardops.einsum_unreduced(
166 | 'B/d Qlen Klen Q K/t, B/d Klen K/t D -> B/d Qlen Q K/t D', probs, v)
167 | w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_o))
168 | attn_out = shardops.einsum_unreduced('B/d Qlen Q K/t D, M Q K/t D -> B/d Qlen M', attn_out, w_o)
169 | attn_out = shardops.psum_scatter('B/d Qlen M -> B/d Qlen M/t', attn_out)
170 | x = save_for_backward(x + attn_out)
171 |
172 | # Pre-FFN RMSNorm
173 | ln2 = save_for_backward(shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln2)))
174 | gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
175 | nx = jnp.bfloat16(rms_norm(gx) * ln2)
176 |
177 | # FFN, using SwiGLU
178 | w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_gate))
179 | gate_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_gate))
180 | w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_up))
181 | up_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_up))
182 | y = jax.nn.swish(gate_proj) * up_proj
183 | w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_down))
184 | ffn_out = shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', y, w_down)
185 | ffn_out = shardops.psum_scatter('B/d L M -> B/d L M/t', ffn_out)
186 |
187 | return jnp.bfloat16(x + ffn_out), ()
188 |
189 | x, () = jax.lax.scan(loop_body, jnp.bfloat16(x), self.transformer)
190 |
191 | ##### Final layernorm and output projection.
192 | x = shardops.all_gather('B/d L M/t -> B/d L M', x)
193 | ln = shardops.all_gather('M/t/d -> M', jnp.float32(self.final_layer_norm))
194 | x = jnp.bfloat16(rms_norm(x) * ln)
195 | unembed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(self.unembed))
196 | logits = shardops.einsum_unreduced('B/d L M, V/t M -> B/d L V/t', x, unembed, preferred_element_type=jnp.float32)
197 |
198 | return logits
199 |
200 |
201 | @typechecked
202 | def loss(self, h: Hparams, batch: TokenBatch) -> f32[b'']:
203 | # Given sequence-packed targets:
204 | # [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
205 | # we want inputs:
206 | # [[0, 1], [0, 3, 4], [0, 6, 7, 8]]
207 | # which we get by shifting the targets right by 1 and
208 | # masking sequence-start tokens to 0.
209 | inputs = jnp.pad(batch.targets[:, :-1], pad_width=((0, 0), (1, 0)))
210 | is_seq_start: bool_[b'batch/d len'] = batch.is_seq_start
211 | inputs: u32[b'batch/d len'] = jnp.where(is_seq_start, 0, inputs)
212 |
213 | logits: f32[b'batch/d len V/t'] = self.forward_pass(h, inputs, is_seq_start)
214 | max_logits: f32[b'batch/d len 1'] = lax.pmax(jnp.max(lax.stop_gradient(logits), axis=-1, keepdims=True), 't')
215 | logits = logits - max_logits
216 | sum_logits = lax.psum(jnp.sum(jnp.exp(logits), axis=-1, keepdims=True), 't')
217 | logsumexp = jnp.log(sum_logits)
218 | logprobs: f32[b'batch/d len V/t'] = logits - logsumexp
219 | logprobs_at_targets = shardops.index_unreduced('batch/d len [V/t], batch/d len -> batch/d len', logprobs, batch.targets)
220 | logprobs_at_targets = shardops.psum_scatter('batch/d len -> batch/d len/t', logprobs_at_targets)
221 | tokens_in_global_batch = logprobs_at_targets.size * jax.lax.psum(1, ('d', 't'))
222 | return -jnp.sum(logprobs_at_targets) / jnp.float32(tokens_in_global_batch)
223 |
224 |
225 | @pytree_dataclass
226 | class RopeTable:
227 | sin: f32['len d_head2']
228 | cos: f32['len d_head2']
229 |
230 | @staticmethod
231 | def create(max_len: int, hparams: Hparams) -> 'RopeTable':
232 | rope_max_timescale = hparams.rope_max_timescale
233 | d_head = hparams.d_head
234 | d = d_head // 2
235 | # endpoint=False is equivalent to what MaxText does. endpoint=True would be more natural, though.
236 | timescale = jnp.logspace(0, jnp.log10(jnp.float32(rope_max_timescale)), d, endpoint=False)
237 | position = jnp.arange(max_len, dtype=jnp.int32)
238 | sinusoid_inp = jnp.float32(position[:, jnp.newaxis]) / timescale[jnp.newaxis, :]
239 | sin = jnp.sin(sinusoid_inp)
240 | cos = jnp.cos(sinusoid_inp)
241 | return RopeTable(sin=sin, cos=cos)
242 |
243 | def apply(self, rearrange_spec, x):
244 | x1, x2 = jnp.split(x, 2, axis=-1)
245 | sin = einops.rearrange(self.sin, rearrange_spec)
246 | cos = einops.rearrange(self.cos, rearrange_spec)
247 | r1 = x1 * cos - x2 * sin
248 | r2 = x2 * cos + x1 * sin
249 | return jnp.append(r1, r2, axis=-1)
250 |
251 |
252 | @typechecked
253 | def rms_norm(x: bf16[b'batch/d len M']) -> bf16[b'batch/d len M']:
254 | mean2 = save_for_backward(jnp.mean(jax.lax.square(jnp.float32(x)), axis=-1, keepdims=True))
255 | return jnp.bfloat16(x * jax.lax.rsqrt(mean2 + 1e-6))
256 |
257 |
258 | @pytree_dataclass
259 | class Metrics:
260 | loss: f32[b'']
261 | learning_rate: f32[b'']
262 | grad_norm: f32[b'']
263 | raw_grad_norm: f32[b'']
264 |
265 |
266 | @dataclass(frozen=True)
267 | class TrainingHparams:
268 | adam_b1: float
269 | adam_b2: float
270 | adam_eps: float
271 | adam_eps_root: float
272 | weight_decay: float
273 | warmup_steps: int
274 | steps: int
275 | steps_for_lr: int
276 | cosine_learning_rate_final_fraction: float
277 | learning_rate: float
278 | tokens: TokenBatchParams
279 | seed: int
280 | queue: Optional[str] = None
281 |
282 | @pytree_dataclass
283 | class State:
284 | weights: Model
285 | adam_mu: Model
286 | adam_nu: Model
287 |
288 | @staticmethod
289 | def init(hparams: Hparams, rng: PRNGKey) -> 'State':
290 | weights = Model.init(hparams, rng)
291 | adam_mu = jax.tree.map(lambda p: p * 0.0, weights)
292 | adam_nu = jax.tree.map(lambda p: p * 0.0, weights)
293 | return State(weights=weights, adam_mu=adam_mu, adam_nu=adam_nu)
294 |
295 | @partial(jax.jit, static_argnums=(2, 3), donate_argnums=(0,))
296 | def training_step(state: State, step: u32[b''], h: Hparams, hparams: TrainingHparams, batch: TokenBatch) -> Tuple[Any, Metrics]:
297 | @partial(shardtypes.typed_shard_map, check_rep=False) # check_rep=False for https://github.com/google/jax/issues/20335
298 | def sharded_step(state: State, step: u32[b''], batch: TokenBatch) -> Tuple[State, Metrics]:
299 | loss, grad = jax.value_and_grad(lambda weights: weights.loss(h, batch))(state.weights)
300 | # Gradients have already been reduced across chips because the gradient of the weight `all_gather`
301 | # is weight-gradient `psum_scatter`. Loss, on the other hand, hasn't been reduced across chips: if we
302 | # did that inside the autodiff, we'd be double-reducing the loss, effectively multiplying it by the
303 | # amount of data parallelism.
304 | #
305 | # So we reduce the loss across chips _outside_ the autodiff.
306 | loss = jax.lax.psum(loss, ('d', 't'))
307 |
308 | # Other than global-norm of gradients, no other communication is needed during the weight update,
309 | # because weights and grads are already fully sharded, as checked below.
310 |
311 | # Calculate learning rate from step number.
312 | # We use linear warmup then cosine decay. See https://arxiv.org/pdf/2307.09288.pdf section 2.2
313 | warmup_lr = (jnp.float32(step) / jnp.float32(hparams.warmup_steps)) * hparams.learning_rate
314 | cosine = jnp.cos(jnp.pi * (jnp.float32(step - hparams.warmup_steps) / jnp.float32(hparams.steps_for_lr - hparams.warmup_steps)))
315 | cosine_lr = hparams.learning_rate * (hparams.cosine_learning_rate_final_fraction + (1 - hparams.cosine_learning_rate_final_fraction) * (cosine * .5 + .5))
316 | lr = jnp.where(step < hparams.warmup_steps, warmup_lr, cosine_lr)
317 |
318 | # AdamW optimizer with global gradient clipping.
319 | grad_leaves, grad_treedef = jax.tree_util.tree_flatten(grad)
320 | global_norm_square = jnp.float32(0.0)
321 | for g in grad_leaves:
322 | assert g.dtype == jnp.float32
323 | global_norm_square += jnp.sum(jax.lax.square(g))
324 | global_norm_square = jax.lax.psum(global_norm_square, ('d', 't'))
325 | global_norm = jnp.sqrt(global_norm_square)
326 | rescale = jnp.minimum(1.0, 1.0 / global_norm)
327 |
328 | new_ps = []
329 | new_mus = []
330 | new_nus = []
331 | for p, g, mu, nu, spec in zip(tree_leaves(state.weights), grad_leaves, tree_leaves(state.adam_mu), tree_leaves(state.adam_nu), tree_leaves(shardtypes.make_partition_specs(State))):
332 | assert shardtypes.is_fully_sharded(spec), 'Weight update is only correctly scaled for fully sharded weights.'
333 | # Gradient clipping
334 | g = g * rescale
335 | # Adam scaling
336 | mu = (1 - hparams.adam_b1) * g + hparams.adam_b1 * mu
337 | nu = (1 - hparams.adam_b2) * jax.lax.square(g) + hparams.adam_b2 * nu
338 | # We need step numbers to start at 1, not 0. Otherwise the bias correction produces NaN.
339 | completed_steps = step + 1
340 | mu_hat = mu / (1 - jnp.float32(hparams.adam_b1)**completed_steps)
341 | nu_hat = nu / (1 - jnp.float32(hparams.adam_b2)**completed_steps)
342 | g = mu_hat / (jnp.sqrt(nu_hat + hparams.adam_eps_root) + hparams.adam_eps)
343 | # Weight decay
344 | g += hparams.weight_decay * p
345 | # Learning rate
346 | g *= lr
347 |
348 | # Apply update
349 | new_ps.append(p - g)
350 | new_mus.append(mu)
351 | new_nus.append(nu)
352 |
353 | new_state = State(
354 | weights=jax.tree_util.tree_unflatten(grad_treedef, new_ps),
355 | adam_mu=jax.tree_util.tree_unflatten(grad_treedef, new_mus),
356 | adam_nu=jax.tree_util.tree_unflatten(grad_treedef, new_nus),
357 | )
358 | metrics = Metrics(
359 | loss=loss,
360 | learning_rate=lr,
361 | grad_norm=global_norm * rescale,
362 | raw_grad_norm=global_norm,
363 | )
364 | return new_state, metrics
365 |
366 | return sharded_step(state, step, batch)
367 |
368 |
369 | @dataclass(frozen=True)
370 | class Paths:
371 | root_working_dir: str
372 | model_name: str
373 |
374 | @dataclass(frozen=True)
375 | class MeshConfig:
376 | d: int
377 | t: int
378 |
379 |
380 | @dataclass(frozen=True)
381 | class Config:
382 | model: Hparams
383 | training: TrainingHparams
384 | paths: Paths
385 | num_hosts: int
386 | checkpoint_interval: int
387 | mesh: MeshConfig
388 | io: training_io.IOConfig
389 | flat_tokens: Optional[FlatTokensParams] = None
390 | hf_dataset: Optional[HuggingFaceDataParams] = None
391 |
392 | def __post_init__(self):
393 | assert self.flat_tokens is not None or self.hf_dataset is not None, 'Must provide either flat_tokens or hf_dataset.'
394 | assert not (self.flat_tokens is not None and self.hf_dataset is not None), 'Should not specify both flat_tokens and hf_dataset.'
395 |
396 | @cached_property
397 | def training_data(self) -> Union[FlatTokensParams, HuggingFaceDataParams]:
398 | return self.flat_tokens or self.hf_dataset
399 |
400 | def main_contained(config, logger):
401 | """Main program, which does not access external services except as specified by config.paths or logger."""
402 | # Use partitionable (and hopefully fusable!) RNG.
403 | #
404 | # This is slower in compute time than 'unsafe_rbg' with flag '--xla_tpu_spmd_rng_bit_generator_unsafe=true',
405 | # but hopefully faster in memory time because it's fusable.
406 | # TODO: check this is true and if not, provide our own that actually is fusable.
407 | jax.config.update('jax_threefry_partitionable', True)
408 | with Mesh(mesh_utils.create_device_mesh([config.mesh.d, config.mesh.t], jax.devices()), ('d', 't')):
409 | root_rng = jax.random.PRNGKey(config.training.seed)
410 |
411 | loader = get_loader('train', config.training_data, config.training.tokens)
412 | assert config.model.vocab > loader.max_token_id, f"{config.model.vocab} vs {loader.max_token_id}"
413 |
414 | model_dir = os.path.join(config.paths.root_working_dir, config.paths.model_name)
415 | training_io.mkdir(model_dir)
416 | state = jax.jit(partial(State.init, config.model))(fold_in_str(root_rng, 'init'))
417 | state, start_step = training_io.load_checkpoint_if_it_exists(model_dir, state, config.io)
418 |
419 | # Explicitly compile training step, to record XLA HLO graph.
420 | # See https://bnikolic.co.uk/blog/python/jax/2022/02/22/jax-outputgraph-rev
421 | c_training_step = training_step.lower(state, jnp.uint32(0), config.model, config.training, loader.load(0)).compile()
422 | date = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
423 | training_io.save_hlo_svg(os.path.join(model_dir, f'training_step_optimized_hlo_{date}.svg'), c_training_step)
424 |
425 | for step in range(start_step, config.training.steps):
426 | if step % config.checkpoint_interval == 0 and step > start_step:
427 | training_io.save_checkpoint(model_dir, step, state, config.io)
428 |
429 | # We profile on the second step, because the first step has a long pause for XLA
430 | # compilation and initial shuffle buffer loading.
431 | if jax.process_index() == 0 and step == start_step + 1:
432 | jax.block_until_ready(state)
433 | training_io.start_profile()
434 | profile_start = time.time()
435 |
436 | state, output = c_training_step(state, jnp.uint32(step), loader.load(step))
437 |
438 | # Run profile for two steps, to include data loading time in between them.
439 | if jax.process_index() == 0 and step == start_step + 2:
440 | jax.block_until_ready(state)
441 | profile_duration = time.time() - profile_start
442 | training_io.stop_profile(model_dir)
443 |
444 | # Print MFU, including (one step of) data loading time.
445 | print(f"Profile time: {profile_duration}s for 2 steps.")
446 | model_params = jax.tree.reduce(operator.add, jax.tree.map(lambda w: w.size, state.weights))
447 | tokens = loader.load(step).targets.size
448 | print(f'Model params: {model_params:_}')
449 | print(f'Tokens: {tokens:_}')
450 | device_flops = training_io.get_flops_per_device()
451 | num_devices = jax.device_count()
452 | print(f'MFU (projections only): {100 * (2 * 6 * model_params * tokens / (num_devices * profile_duration)) / device_flops:.2f}% MFU')
453 |
454 | training_io.log(step, logger, output)
455 |
456 |
457 | @hydra.main(config_path='configs', version_base=None)
458 | def main(config):
459 | config = jax_extra.make_dataclass_from_dict(Config, config)
460 | if config.training.queue:
461 | task = Task.init(project_name='testing', task_name=config.paths.model_name)
462 | logger = task.get_logger()
463 | task.execute_remotely(queue_name=config.training.queue)
464 | task.launch_multi_node(config.num_hosts, wait=True)
465 | if int(os.environ['RANK']) > 0:
466 | task.set_system_tags((task.get_system_tags() or []) + ['hidden'])
467 | jax.distributed.initialize(os.environ['MASTER_ADDR'] + ':' + os.environ['MASTER_PORT'],
468 | num_processes=int(os.environ['WORLD_SIZE']),
469 | process_id=int(os.environ['RANK']))
470 | else:
471 | logger = None
472 | main_contained(config, logger)
473 |
474 |
475 | if __name__ == "__main__":
476 | main()
477 |
--------------------------------------------------------------------------------
/training_io.py:
--------------------------------------------------------------------------------
1 | """Provides IO support for training:
2 | * checkpoint save and load
3 | * metrics logging
4 | * profiling of XLA computations
5 | * reporting FLOPs per device
6 | """
7 | import jax
8 | import jax.numpy as jnp
9 | from jax.experimental import multihost_utils
10 | from typing import Tuple, Any
11 | from dataclasses import dataclass
12 | import os
13 | import fsspec
14 | import zarr
15 | from numcodecs import blosc
16 | from clearml import Logger
17 | import numpy as np
18 | import datetime
19 | import concurrent
20 | import jax.profiler
21 | import tempfile
22 | import shutil
23 | from jax.lib import xla_client
24 |
25 | PyTree = Any
26 |
27 | @dataclass
28 | class IOConfig:
29 | # Max number of threads to use for IO-bound tasks like saving and loading checkpoints.
30 | # Recommendation: about 1MiB/thread is typical, so 1024 thread is reasonable for 1GiB of overhead.
31 | # Since this work is IO-bound rather than CPU-bound, it is fine to have many more threads than
32 | # CPU cores.
33 | max_io_threads: int
34 |
35 | def log(step: int, logger: Logger, output: PyTree):
36 | """Logs the output of a training step. The output must be a PyTree of f32 arrays."""
37 | if jax.process_index() == 0:
38 | metrics_dict = {}
39 | for path, arr in jax.tree_util.tree_leaves_with_path(output):
40 | path = jax.tree_util.keystr(path)
41 | arr = jax.device_get(arr)
42 | if arr.shape == () and arr.dtype == jnp.float32:
43 | if logger:
44 | logger.report_scalar(
45 | title=path, series=path, value=arr, iteration=step)
46 | metrics_dict[path] = float(arr)
47 | elif arr.dtype == jnp.float32:
48 | if logger:
49 | logger.report_histogram(
50 | title=path, series=path, values=arr, iteration=step)
51 | else:
52 | raise ValueError(f"Output {path} has unsupported shape {arr.shape} and dtype {arr.dtype}.")
53 | now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
54 | print(f"[{now}] Step {step}: {metrics_dict}")
55 |
56 |
57 | def load_checkpoint_if_it_exists(checkpoint_dir: str, state: PyTree, config: IOConfig) -> Tuple[PyTree, int]:
58 | """Loads the latest checkpoint if it exists, otherwise return the initial state.
59 |
60 | In either case, uses the sharding and PyTree structure of `state` to produce the output.
61 |
62 | Since the state may occupy a large amount of memory, this function makes sure to delete `state`
63 | before loading the checkpoint. To facilitate this, callers should ensure not to hold on to any
64 | additional references to `state` when calling this function.
65 |
66 | Returns state and step number. Step 0 is the initial state, which may or may not have been loaded
67 | from a checkpoint.
68 | """
69 | blosc.use_threads = False # Blindly following recommendation from https://zarr.readthedocs.io/en/stable/tutorial.html#parallel-computing-and-synchronization
70 | checkpoint_dir_pseudofile = fsspec.open(checkpoint_dir)
71 | fs = checkpoint_dir_pseudofile.fs
72 | checkpoint_dir_path = checkpoint_dir_pseudofile.path
73 | del checkpoint_dir_pseudofile
74 |
75 | # Check working_dir for checkpoint files.
76 | # Process index 0 selects the checkpoint, then broadcasts it to everyone else.
77 | selected_checkpoint = -1
78 | if jax.process_index() == 0:
79 | if fs.exists(checkpoint_dir_path):
80 | # fs.mkdir(checkpoint_dir, create_parents=False)
81 | checkpoint_dirs = fs.ls(checkpoint_dir_path)
82 | for c in reversed(sorted(checkpoint_dirs)):
83 | try:
84 | checkpoint_number = int(os.path.basename(c))
85 | except ValueError:
86 | continue
87 | root = zarr.open_group(zarr.storage.FSStore(c, fs=fs))
88 | if "write_completed" not in root.attrs:
89 | print(f"zarr 'write_completed' marker is missing in checkpoint {c}; skipping.")
90 | continue
91 | selected_checkpoint = checkpoint_number
92 | break
93 | selected_checkpoint = multihost_utils.broadcast_one_to_all(jnp.int32(selected_checkpoint))
94 |
95 | if selected_checkpoint == -1:
96 | print(f"No checkpoints found in {checkpoint_dir_path}, starting from initial state.")
97 | return state, 0
98 |
99 | print(f'Found checkpoint {selected_checkpoint} in {checkpoint_dir_path}, starting from there.')
100 | return load_zarr(os.path.join(checkpoint_dir, step_to_str(selected_checkpoint)), state, config), selected_checkpoint
101 |
102 |
103 | def save_checkpoint(checkpoint_dir: str, step: int, state: PyTree, config: IOConfig):
104 | """Saves a checkpoint for the specified step number.
105 |
106 | See docs/pytree-zarr-checkpoint.md for the checkpoint format.
107 | """
108 | blosc.use_threads = False
109 | checkpoint_file = os.path.join(checkpoint_dir, step_to_str(step))
110 | if jax.process_index() == 0:
111 | # If there's already a checkpoint at this step, delete it. It might have been a partially
112 | # written checkpoint from a previous run.
113 | f = fsspec.open(checkpoint_dir)
114 | checkpoint_path = os.path.join(f.path, step_to_str(step))
115 | if f.fs.exists(checkpoint_path):
116 | f.fs.rm(checkpoint_path, recursive=True)
117 |
118 | print(f"[{datetime.datetime.now()}] Saving checkpoint {step} to {checkpoint_file}.")
119 | save_zarr(checkpoint_file, state, config)
120 | print(f"[{datetime.datetime.now()}] Finished saving checkpoint {step} to {checkpoint_file}.")
121 |
122 |
123 | def load_zarr(filename: str, state: PyTree, config: IOConfig) -> PyTree:
124 | """Loads a zarr checkpoint from disk.
125 |
126 | See docs/pytree-zarr-checkpoint.md for the checkpoint format.
127 | """
128 | root = zarr.open_group(filename, mode="r")
129 | if "write_completed" not in root.attrs:
130 | raise ValueError(f"zarr 'write_completed' marker is missing. Should not have selected this checkpoint to load from.")
131 |
132 | def load_one(path: Tuple, prev: jax.Array) -> jax.Array:
133 | path = jax.tree_util.keystr(path)
134 | shape = prev.shape
135 | sharding = prev.sharding
136 | arr = root[path]
137 | assert arr.shape == shape, f'Expected shape {shape} but got {arr.shape} for {path} in {filename}'
138 | assert arr.dtype == prev.dtype, f'Expected dtype {prev.dtype} but got {arr.dtype} for {path} in {filename}'
139 | del prev # Deallocate memory before loading its replacement!
140 | return jax.make_array_from_callback(shape, sharding, lambda shard_index: arr[shard_index])
141 |
142 | state, treedef = jax.tree_util.tree_flatten_with_path(state)
143 | with concurrent.futures.ThreadPoolExecutor(max_workers=config.max_io_threads) as executor:
144 | state_futures = [executor.submit(load_one, path, shape) for (path, shape) in state]
145 | states = [f.result() for f in state_futures]
146 | return jax.tree_util.tree_unflatten(treedef, states)
147 |
148 |
149 | def save_zarr(filename: str, state: PyTree, config: IOConfig):
150 | """Saves a zarr checkpoint to disk.
151 |
152 | See docs/pytree-zarr-checkpoint.md for the checkpoint format.
153 | """
154 | state, _treedef = jax.tree_util.tree_flatten_with_path(state)
155 |
156 | if jax.process_index() == 0:
157 | # Create the zarr file and all the arrays.
158 | try:
159 | root = zarr.open_group(filename, mode='w-')
160 | except zarr.errors.ContainsGroupError:
161 | raise ValueError(f"Checkpoint {filename} already exists.")
162 | for path, arr in state:
163 | path = jax.tree_util.keystr(path)
164 | chunk_shape = arr.sharding.shard_shape(arr.shape)
165 | root.empty(path, shape=arr.shape, chunks=chunk_shape, dtype=arr.dtype)
166 | multihost_utils.sync_global_devices("save_zarr_begin")
167 |
168 | root = zarr.open_group(filename, mode='r+')
169 |
170 | def save_shard(dst: zarr.Array, shard: jax.Array, index: Tuple[int, ...]):
171 | dst[index] = np.asarray(shard)
172 |
173 | with concurrent.futures.ThreadPoolExecutor(max_workers=config.max_io_threads) as executor:
174 | for path, arr in state:
175 | path = jax.tree_util.keystr(path)
176 | dst = root[path]
177 | assert dst.chunks == arr.sharding.shard_shape(arr.shape)
178 | for shard in arr.addressable_shards:
179 | if shard.replica_id == 0:
180 | executor.submit(save_shard, dst, shard.data, shard.index)
181 |
182 | multihost_utils.sync_global_devices("save_zarr_end")
183 | if jax.process_index() == 0:
184 | root.attrs["write_completed"] = True
185 | multihost_utils.sync_global_devices("save_zarr_committed")
186 |
187 | def step_to_str(step: int) -> str:
188 | """Converts a step number to a string with leading zeros.
189 |
190 | We pad up to 10 digits so that lexicographic order matches numerical. 1e10 training steps
191 | should be enough for anyone: the biggest runs as of 2024 are probably around 1e7 tokens/batch,
192 | 1e13 tokens total, so 1e6 training steps total.
193 | """
194 | return str(step).zfill(10)
195 |
196 | _PROFILE_DIR = None
197 |
198 | def start_profile():
199 | """Starts gathering a JAX profile."""
200 | # Get fresh temporary directory
201 | global _PROFILE_DIR
202 | _PROFILE_DIR = tempfile.mkdtemp()
203 | print(f'[{datetime.datetime.now()}] Starting profile, saving to {_PROFILE_DIR}')
204 | jax.profiler.start_trace(_PROFILE_DIR, create_perfetto_trace=True)
205 |
206 | def stop_profile(working_dir: str):
207 | """Stops gathering the JAX profile and saves it to a file."""
208 | global _PROFILE_DIR
209 | jax.profiler.stop_trace()
210 | print(f'[{datetime.datetime.now()}] Finished profile, copying to {working_dir}')
211 | fsspec_put(_PROFILE_DIR + '/', working_dir + '/')
212 | shutil.rmtree(_PROFILE_DIR)
213 | print(f'[{datetime.datetime.now()}] Finished copying profile to {working_dir}')
214 | _PROFILE_DIR = None
215 |
216 |
217 | def fsspec_put(local_src: str, remote_dst: str):
218 | """Copies a file from local disk to a remote location specified by a fsspec path."""
219 | f = fsspec.open(remote_dst)
220 | fs = f.fs
221 | path = f.path
222 | del f
223 | print(f'Put {local_src} to {path}')
224 | fs.put(local_src, path, recursive=True, create_parents=True)
225 |
226 |
227 | def save_hlo_svg(filespec: str, compiled: jax.stages.Compiled):
228 | """Saves a compiled function's HLO to an SVG file."""
229 | compiled_hlo_dot = xla_client._xla.hlo_module_to_dot_graph(compiled.runtime_executable().hlo_modules()[0])
230 | with tempfile.TemporaryDirectory() as d:
231 | with open(os.path.join(d, "hlo.dot"), "w") as f:
232 | f.write(compiled_hlo_dot)
233 | hlo_orig_svg = os.path.join(d, "hlo.original.svg")
234 | hlo_svg = os.path.join(d, "hlo.svg")
235 | os.system(f"dot -Tsvg {f.name} -o{hlo_orig_svg}")
236 | # Edit the SVG to remove everything before