├── .gitignore ├── LICENSE ├── README.md ├── configs ├── base.yaml ├── c4_a100x8_1b.yaml ├── c4_a100x8_270m.yaml ├── c4_a100x8_2b.yaml ├── c4_a100x8_540m.yaml ├── c4_a100x8_84m.yaml ├── c4_a100x8_base.yaml ├── c4_a100x8x4_1b.yaml ├── c4_a100x8x4_2b.yaml ├── c4_a100x8x4_540m.yaml ├── flat_tokens_c4_a100x1_84m.yaml ├── huggingface_c4_a100x1_84m.yaml └── local_test_synthetic.yaml ├── docs ├── flat-tokens.md ├── matx.svg └── pytree-zarr-checkpoint.md ├── env.py ├── input_loader.py ├── jax_extra.py ├── requirements-cpu.txt ├── shardlib ├── shardops.py └── shardtypes.py ├── synthetic_dataset.zarr ├── .zgroup ├── train │ ├── .zattrs │ ├── .zgroup │ ├── encoded_tokens │ │ ├── 0 │ │ └── .zarray │ └── seq_starts │ │ ├── 0 │ │ └── .zarray └── validation │ ├── .zattrs │ ├── .zgroup │ ├── encoded_tokens │ ├── 0 │ └── .zarray │ └── seq_starts │ ├── 0 │ └── .zarray ├── tools ├── configs │ ├── c4_en.yaml │ ├── synthetic_dataset.yaml │ └── tinystories.yaml ├── flat_tokens.py ├── huggingface_to_flat_tokens.py ├── requirements.txt └── write_synthetic_dataset.py ├── train.py └── training_io.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Hydra 2 | outputs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 MatX Inc 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # seqax = sequence modeling + JAX 6 | 7 | seqax is a codebase for small-to-medium-scale LLM pretraining research. The entire training program---including the model implementation; optimizer; multihost FSDP and tensor parallel partitioning---is [500 lines of code](/train.py), which scales well up to ~100 GPUs or TPUs[^1] and [typically achieves good MFUs of 30-50%](#performance). 8 | 9 | [^1]: Achieving good performance at larger scale requires pipeline parallelism (which we have not yet implemented). At that scale, you may also care more about using custom kernels to further improve performance at the cost of code simplicity. 10 | 11 | seqax is written in a style that makes the important information visible, rather than being hidden behind abstractions and indirections or being inferred automatically and unpredictably. This shows up in: 12 | 13 | * **Math**. seqax implements all of the training step's math, rather than calling into external libraries. If you want to understand or change the math, it's right there! 14 | 15 | * **Memory**. All tensors that go into a model checkpoint on disk are explicits. All tensors that occupy a lot of memory, including activations saved for the backwards pass, are explicit. You can straightforwardly read the memory footprint from the source code. 16 | 17 | * **Partitioning and communication**. The partitioned layout of all tensors and operations is explicit. All interchip communication is explicit. 18 | 19 | ## Getting started 20 | 21 | ### Installation 22 | 23 | 1. Install `graphviz` from your system package manager: e.g. `brew install graphviz` or `apt install graphviz`. 24 | 2. Install Python dependencies, typically inside a virtualenv: `python -m pip install -r requirements-cpu.txt`. 25 | 26 | NOTE: the `requirements-cpu.txt` is configured for CPU-based installation. For GPU or TPU installation, you may need a different install of JAX and jaxlib. Consult the [JAX install documentation](https://jax.readthedocs.io/en/latest/installation.html). If your GPU environment has a Torch-GPU installation, you may need to switch it to a Torch-CPU installation to avoid conflicts with JAX-GPU. 27 | 28 | ### Run on CPU for local development 29 | 30 | For development and testing you can run on CPU. Typically you'd use our synthetic dataset (which is [checked into this repository](/synthetic_dataset.zarr)) or the [Huggingface data loader](#data-loaders) and you'd set XLA flags to simulate multiple devices so as to test that parallelism is working as intended: 31 | 32 | ``` 33 | XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m train --config-name=local_test_synthetic +paths.model_name=synthetic_000 34 | ``` 35 | 36 | The `paths.model_name` flag specifies which subdirectory on disk (inside `/tmp`) to write model checkpoints to. You'll typically want to change this when starting a new model run. 37 | 38 | ### Run on GPUs 39 | 40 | We have configured a range of model sizes, to be trained on the C4 dataset with the Llama tokenizer. Browse the `configs/` directory to select your preferred configuration file. Each configuration file lists how to run it at the top. 41 | 42 | You typically want to set `paths.model_name` to a unique name for each distinct training run. This path specifies which subdirectory on disk to write model checkpoints to. 43 | 44 | ## Performance 45 | 46 | Recent benchmark results on A100 clusters: 47 | 48 | Single-host A100x8 49 | | Model Size | MFU | 50 | |------------|-------| 51 | | 84m | 14 | 52 | | 270m | 24 | 53 | | 540m | 35 | 54 | | 1b | 41.6 | 55 | | 2b | 50.66 | 56 | 57 | On 4 A100x8 hosts connected with infiniband 58 | | Model Size | MFU | 59 | |------------|-------| 60 | | 1b | 32.4 | 61 | | 2b | 39.0 | 62 | 63 | ## Data loaders 64 | 65 | seqax can stream training data directly from Huggingface (see [example config](/configs/huggingface_c4_a100x1_84m.yaml)), or can first convert the training data to a pre-tokenized format on disk which we call [flat-tokens](/docs/flat-tokens.md) (see [example config](/configs/flat_tokens_c4_a100x1_84m.yaml)). Streaming from Huggingface allows you to quickly experiment with different datasets, but it doesn't offer an efficient way to resume training from a checkpoint after a job is aborted, and it wastes some tokens from the dataset at batch boundaries. The flat-tokens format supports efficiently resuming training from a checkpoint, uses 100% of tokens for training, and also consumes less CPU time during training. 66 | 67 | To pre-tokenize the training data, you can run [huggingface_to_flat_tokens.py](/tools/huggingface_to_flat_tokens.py). You'll need to first install the requirements in [/tools/requirements.txt](/tools/requirements.txt), and then you can invoke the command listed at the top of [/tools/configs/c4_en.yaml](/tools/configs/c4_en.yaml). On modern CPUs this script processes about 100M tokens per minute. You can limit the number of output tokens it processes with a configuration flag. 68 | 69 | ## Expressing partitioning and communication with `shardlib` 70 | 71 | seqax ships with a new library called [shardlib](/shardlib) for expressing partitioning and communication with JAX, building on the ideas and style of [jaxtyping](https://docs.kidger.site/jaxtyping/), [einops](https://einops.rocks/), [equinox](https://docs.kidger.site/equinox/), and [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Here we demonstrate its core ideas, to implement fully sharded data parallelism (FSDP) for a simple fully connected neural network. 72 | 73 | 74 | ```python 75 | # XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m shardlib_example 76 | from shardlib.shardtypes import bf16, bool_, f32, pytree_dataclass, typed_shard_map, u32, make_shardings 77 | from shardlib import shardtypes 78 | shardtypes.register_with_typeguard() 79 | import shardlib.shardops as shardops 80 | from jax.sharding import Mesh 81 | from jax.experimental import mesh_utils 82 | import jax 83 | import jax.numpy as jnp 84 | 85 | # We set up a device mesh where 'd' refers to the "data parallel" axis. 86 | MESH = Mesh(mesh_utils.create_device_mesh([8], jax.devices()), ('d')) 87 | 88 | # At rest, weights are all sharded over the data parallel axis, making them fully sharded. 89 | # 90 | # The `hidden1/d` syntax means that second axis has size `hidden1` and is sharded over device axis `d`. 91 | # Equivalently, you can view this as saying that the per-device shape is `(in, hidden1/d)`, where `/` 92 | # indicates division. 93 | @pytree_dataclass 94 | class Weights: 95 | w1: f32['in hidden1/d'] 96 | w2: f32['hidden1 hidden2/d'] 97 | w3: f32['hidden2/d'] 98 | 99 | with MESH: 100 | # Create dummy weights. 101 | w = Weights( 102 | w1=jnp.zeros((8, 8), dtype=jnp.float32), 103 | w2=jnp.zeros((8, 8), dtype=jnp.float32), 104 | w3=jnp.zeros((8,), dtype=jnp.float32), 105 | ) 106 | # Apply sharding to weights. The sharding specs are inferred from the type annotations on the Weights class. 107 | w = jax.tree.map(jax.device_put, w, make_shardings(Weights)) 108 | 109 | # We use `typed_shard_map` to allow us to write per-device code with explicit communication. 110 | # 111 | # Compared to untyped `jax.shard_map`, the `in_specs` and `out_specs` do not need to be specified: 112 | # they're inferred from the sharding on the function's signature. 113 | @typed_shard_map 114 | def forward_pass(x: f32[b'batch/d in'], w: Weights) -> f32[b'batch/d']: 115 | # Weights are all-gathered just prior to their use. (This is the core idea of fully-sharded data parallelism.) 116 | # The `in hidden1/d -> in hidden1` syntax expresses what this all-gather operation should do: it removes the 117 | # `d` sharding on the `hidden1` axis, resulting in a fully replicated output. 118 | w1 = shardops.all_gather('in hidden1/d -> in hidden1', w.w1) 119 | # The `einsum_unreduced` operation is a chip-local einsum. Unlike `jnp.einsum`, it supports sharding syntax, 120 | # and it performs shape checking using the current typing environment, so it will raise an error if for example 121 | # you use `batch` in two different ways within a function. 122 | # 123 | # We call this einsum "unreduced", because it does not do any cross-chip reductions, even if they are necessary. 124 | # For example, in an `a b/d, b/d c -> a c` einsum, a cross-chip reduction over the `d` sharding axis is required, 125 | # and it is the caller's responsibility to perform this reduction. 126 | y = jax.nn.relu(shardops.einsum_unreduced('batch/d in, in hidden1 -> batch/d hidden1', x, w1)) 127 | w2 = shardops.all_gather('hidden1 hidden2/d -> hidden1 hidden2', w.w2) 128 | z = jax.nn.relu(shardops.einsum_unreduced('batch/d hidden1, hidden1 hidden2 -> batch/d hidden2', y, w2)) 129 | w3 = shardops.all_gather('hidden2/d -> hidden2', w.w3) 130 | return shardops.einsum_unreduced('batch/d hidden2, hidden2 -> batch/d', z, w3) 131 | 132 | x = forward_pass(jnp.zeros((32, 8), dtype=jnp.float32), w) 133 | assert(x.shape == (32,)) 134 | ``` 135 | 136 | There are several other APIs exported by shardlib in addition to the ones demonstrated here. [Browse the code](/shardlib/) to see the full list. 137 | 138 | ## Expressing activation checkpointing using `save_for_backward` 139 | 140 | Which intermediate computations in the forwards pass are saved to HBM for later use in the backwards pass? [The default answer](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html) is: JAX saves _all_ intermediates for use in the backwards pass, but in JIT mode the XLA compiler optimizes many of these away so as to save memory. 141 | 142 | While JAX provides many sophisticated policies for making these choices, we offer a very simple one: calling `save_for_backward` causes its argument to be saved for the backwards pass. Here is an example: 143 | 144 | ```python 145 | from jax_extra import explicit_activation_checkpointing, save_for_backward 146 | 147 | # The @explicit_activation_checkpointing switches JAX from its default 148 | # policy of saving all intermediates, and instead only saves the 149 | # arguments to the annotated function, plus any intermediates marked 150 | # with `save_for_backward`. 151 | @explicit_activation_checkpointing 152 | def forward_pass(x, w1, w2): 153 | # save_for_backward marks `y` as being saved. 154 | y = save_for_backward(x @ w1) 155 | # `z` is not saved for the backwards pass. 156 | z = jax.nn.relu(z) 157 | return z @ w2 158 | ``` 159 | 160 | ## Profiling 161 | 162 | Every training run gathers and reports performance information: 163 | * the time for two training steps (including data fetching in between them). This is written to stdout. 164 | * model FLOPS utilization (MFU) efficiency for these steps. This is written to stdout. 165 | * an XLA performance profile. This is written into the model directory at `/plugins/profile//perfetto_trace.json.gz` 166 | * an rendered SVG of the optimized XLA computation graph. This is written into the model directory at `/training_step_optimized_hlo_.svg`. 167 | 168 | 169 | ## File formats 170 | 171 | We write checkpoints and datasets in simple file formats based on [zarr](https://zarr.dev/). See our file format specifications: 172 | * [our checkpoint format](/docs/pytree-zarr-checkpoint.md) 173 | * [our dataset format](/docs/flat-tokens.md) 174 | 175 | ## Contact 176 | 177 | `seqax` is developed by the [MatX team](https://matx.com). If you're interested in working with us, you can reach us at [founders@matx.com](mailto:founders@matx.com). 178 | 179 | 180 | ## Citing seqax 181 | 182 | To cite this repository: 183 | 184 | ``` 185 | @software{seqax2024github, 186 | author = {Reiner Pope and Vaclav Cvicek and Daniel Heinlein and Akshay Mishra and Mahdi Nazemi and Sanjit Neelam and Rachit Tibrewal}, 187 | title = {seqax = sequence modeling + {JAX}}, 188 | url = {https://github.com/MatX-inc/seqax}, 189 | year = {2024}, 190 | } 191 | ``` 192 | 193 | 194 | ## Acknowledgements 195 | 196 | seqax's implementation style was substantially inspired by [jaxtyping](https://docs.kidger.site/jaxtyping/), [einops](https://einops.rocks/), [equinox](https://docs.kidger.site/equinox/), and [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). 197 | 198 | Thanks to [MaxText](https://github.com/google/maxtext) for demonstrating good practices for production LLM use of JAX. 199 | 200 | Thanks to the [JAX](https://github.com/google/jax) team for ongoing support and advice. 201 | 202 | Thanks to the [Google TPU Research Cloud](https://sites.research.google/trc/about/), which partially supported this work. 203 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # AdamW optimizer parameters 3 | # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 4 | adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. 5 | adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients. 6 | adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. 7 | adam_eps_root: 0. # A small constant applied to denominator inside the square root. 8 | weight_decay: 0.1 # AdamW Weight decay 9 | # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 10 | # Learning rate schedule has two parts: 11 | # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] 12 | # 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps 13 | learning_rate: 3.e-5 14 | cosine_learning_rate_final_fraction: 0.1 15 | seed: 0 16 | 17 | io: 18 | max_io_threads: 1024 -------------------------------------------------------------------------------- /configs/c4_a100x8_1b.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=c4_a100x8_1b +paths.model_name=1b 2 | defaults: 3 | - c4_a100x8_base 4 | - _self_ 5 | 6 | training: 7 | warmup_steps: 37000 8 | steps: 370000 9 | steps_for_lr: 370000 10 | learning_rate: 1.0e-5 11 | tokens: 12 | batch: 64 13 | 14 | training_data: 15 | streams: 1 16 | 17 | model: 18 | d_model: 2048 19 | n_q_per_kv: 1 20 | n_kv: 16 21 | d_head: 128 22 | layers: 8 23 | d_ff: 16384 24 | vocab: 32768 25 | rope_max_timescale: 10000 26 | 27 | checkpoint_interval: 10000 -------------------------------------------------------------------------------- /configs/c4_a100x8_270m.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=c4_a100x8_270m +paths.model_name=270m 2 | defaults: 3 | - c4_a100x8_base 4 | - _self_ 5 | 6 | training: 7 | warmup_steps: 9200 8 | steps: 92000 9 | steps_for_lr: 92000 10 | learning_rate: 3.0e-4 11 | 12 | model: 13 | d_model: 1024 14 | n_q_per_kv: 1 15 | n_kv: 16 16 | d_head: 128 17 | layers: 8 18 | d_ff: 8192 19 | vocab: 32768 20 | rope_max_timescale: 10000 21 | 22 | checkpoint_interval: 9200 23 | -------------------------------------------------------------------------------- /configs/c4_a100x8_2b.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=c4_a100x8_2b +paths.model_name=2b 2 | defaults: 3 | - c4_a100x8_base 4 | - _self_ 5 | 6 | 7 | training: 8 | warmup_steps: 74000 9 | steps: 740000 10 | steps_for_lr: 740000 11 | learning_rate: 1.0e-5 12 | tokens: 13 | batch: 64 14 | 15 | training_data: 16 | streams: 4 17 | 18 | model: 19 | d_model: 4096 20 | n_q_per_kv: 1 21 | n_kv: 16 22 | d_head: 128 23 | layers: 8 24 | d_ff: 16384 25 | vocab: 32768 26 | rope_max_timescale: 10000 27 | 28 | checkpoint_interval: 10000 -------------------------------------------------------------------------------- /configs/c4_a100x8_540m.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=c4_a100x8_540m +paths.model_name=540m 2 | defaults: 3 | - c4_a100x8_base 4 | - _self_ 5 | 6 | training: 7 | warmup_steps: 18500 8 | steps: 185000 9 | steps_for_lr: 185000 10 | learning_rate: 3.0e-4 11 | 12 | model: 13 | d_model: 2048 14 | n_q_per_kv: 1 15 | n_kv: 16 16 | d_head: 128 17 | layers: 8 18 | d_ff: 8192 19 | vocab: 32768 20 | rope_max_timescale: 10000 21 | 22 | checkpoint_interval: 4000 23 | -------------------------------------------------------------------------------- /configs/c4_a100x8_84m.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=c4_a100x8_84m +paths.model_name=84m 2 | defaults: 3 | - c4_a100x8_base 4 | - _self_ 5 | 6 | training: 7 | warmup_steps: 2600 8 | steps: 26000 9 | steps_for_lr: 26000 10 | learning_rate: 3.0e-3 11 | 12 | model: 13 | d_model: 512 14 | n_q_per_kv: 1 15 | n_kv: 8 16 | d_head: 128 17 | layers: 8 18 | d_ff: 4096 19 | vocab: 32768 20 | rope_max_timescale: 10000 21 | 22 | checkpoint_interval: 2600 23 | -------------------------------------------------------------------------------- /configs/c4_a100x8_base.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | seed: 0 3 | tokens: 4 | batch: 64 5 | len: 1024 6 | 7 | # AdamW optimizer parameters 8 | # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 9 | adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. 10 | adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients. 11 | adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. 12 | adam_eps_root: 0. # A small constant applied to denominator inside the square root. 13 | weight_decay: 0.1 # AdamW Weight decay 14 | # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 15 | # Learning rate schedule has two parts: 16 | # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] 17 | # 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps 18 | 19 | # Learning rate is not yet tuned. 20 | learning_rate: 3.e-4 21 | cosine_learning_rate_final_fraction: 0.1 22 | 23 | flat_tokens: 24 | # filespec can also be a path on a local filesystem 25 | filespec: 'gcs://path/to/your/dataset' # Fill in 26 | streams: 1 27 | read_blocks_per_shuffle_buffer: 128 28 | sequences_per_read_block: 1024 29 | seed: 0 30 | sequence_packing: true 31 | 32 | paths: 33 | # root_working_dir can also be a path on a local filesystem 34 | root_working_dir: 'gcs://path/to/your/outputs' 35 | 36 | num_hosts: 1 37 | 38 | mesh: 39 | d: 8 40 | t: 1 41 | 42 | io: 43 | max_io_threads: 1024 -------------------------------------------------------------------------------- /configs/c4_a100x8x4_1b.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=c4_a100x8x4_1b +paths.model_name=1b 2 | defaults: 3 | - c4_a100x8_base 4 | - _self_ 5 | 6 | num_hosts: 4 7 | 8 | mesh: 9 | d: 32 10 | t: 1 11 | 12 | training: 13 | warmup_steps: 9250 14 | steps: 92500 15 | steps_for_lr: 92500 16 | learning_rate: 1.0e-5 17 | tokens: 18 | batch: 256 19 | 20 | training_data: 21 | streams: 4 22 | 23 | model: 24 | d_model: 2048 25 | n_q_per_kv: 1 26 | n_kv: 16 27 | d_head: 128 28 | layers: 8 29 | d_ff: 16384 30 | vocab: 32768 31 | rope_max_timescale: 10000 32 | 33 | checkpoint_interval: 2500 -------------------------------------------------------------------------------- /configs/c4_a100x8x4_2b.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=c4_a100x8x4_2b +paths.model_name=2b 2 | defaults: 3 | - c4_a100x8_base 4 | - _self_ 5 | 6 | num_hosts: 4 7 | 8 | mesh: 9 | d: 32 10 | t: 1 11 | 12 | training: 13 | warmup_steps: 18500 14 | steps: 185000 15 | steps_for_lr: 185000 16 | learning_rate: 1.0e-5 17 | tokens: 18 | batch: 256 19 | 20 | model: 21 | d_model: 4096 22 | n_q_per_kv: 1 23 | n_kv: 16 24 | d_head: 128 25 | layers: 8 26 | d_ff: 16384 27 | vocab: 32768 28 | rope_max_timescale: 10000 29 | 30 | checkpoint_interval: 2500 -------------------------------------------------------------------------------- /configs/c4_a100x8x4_540m.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=c4_a100x8x4_540m +paths.model_name=540m 2 | defaults: 3 | - c4_a100x8_base 4 | - _self_ 5 | 6 | num_hosts: 4 7 | 8 | mesh: 9 | d: 32 10 | t: 1 11 | 12 | training: 13 | warmup_steps: 4625 14 | steps: 46250 15 | steps_for_lr: 46250 16 | learning_rate: 3.0e-4 17 | tokens: 18 | batch: 256 19 | 20 | model: 21 | d_model: 2048 22 | n_q_per_kv: 1 23 | n_kv: 16 24 | d_head: 128 25 | layers: 16 26 | d_ff: 8192 27 | vocab: 32768 28 | rope_max_timescale: 10000 29 | 30 | checkpoint_interval: 2500 -------------------------------------------------------------------------------- /configs/flat_tokens_c4_a100x1_84m.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=huggingface_c4_a100x1_84m +paths.model_name=flat_token_84m 2 | training: 3 | seed: 0 4 | tokens: 5 | batch: 64 6 | len: 1024 7 | 8 | # AdamW optimizer parameters 9 | # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 10 | adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. 11 | adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients. 12 | adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. 13 | adam_eps_root: 0. # A small constant applied to denominator inside the square root. 14 | weight_decay: 0.1 # AdamW Weight decay 15 | # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 16 | # Learning rate schedule has two parts: 17 | # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] 18 | # 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps 19 | warmup_steps: 2600 20 | steps: 26000 21 | steps_for_lr: 26000 22 | learning_rate: 3.0e-4 23 | 24 | cosine_learning_rate_final_fraction: 0.1 25 | 26 | 27 | model: 28 | d_model: 512 29 | n_q_per_kv: 1 30 | n_kv: 8 31 | d_head: 128 32 | layers: 8 33 | d_ff: 4096 34 | vocab: 32768 35 | rope_max_timescale: 10000 36 | 37 | paths: 38 | # can also be a path to GCS. IE 'gcs://your_bucket/your_output_path' 39 | root_working_dir: '~/seqax_outputs' 40 | 41 | num_hosts: 1 42 | 43 | io: 44 | max_io_threads: 1024 45 | 46 | # Define either hf_dataset or flat_tokens. Do not use both. 47 | # flat_tokens requires more setup, but is better tested and doesn't waste tokens. 48 | # Using flat_tokens requires setting up a flat_tokens dataset using the script in tools/huggingface_to_flat_tokens.py 49 | flat_tokens: 50 | filespec: 'gcs://path/to/your/dataset' # can be a path to a gcs directory, or local copy of dataset. 51 | streams: 1 52 | read_blocks_per_shuffle_buffer: 128 53 | sequences_per_read_block: 1024 54 | seed: 0 55 | sequence_packing: true 56 | 57 | 58 | mesh: 59 | d: 1 60 | t: 1 61 | 62 | checkpoint_interval: 100 -------------------------------------------------------------------------------- /configs/huggingface_c4_a100x1_84m.yaml: -------------------------------------------------------------------------------- 1 | # python -m train --config-name=huggingface_c4_a100x1_84m +paths.model_name=hf_84m 2 | training: 3 | seed: 0 4 | tokens: 5 | batch: 64 6 | len: 1024 7 | 8 | # AdamW optimizer parameters 9 | # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 10 | adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. 11 | adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients. 12 | adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. 13 | adam_eps_root: 0. # A small constant applied to denominator inside the square root. 14 | weight_decay: 0.1 # AdamW Weight decay 15 | # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 16 | # Learning rate schedule has two parts: 17 | # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] 18 | # 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps 19 | warmup_steps: 2600 20 | steps: 26000 21 | steps_for_lr: 26000 22 | learning_rate: 3.0e-4 23 | 24 | cosine_learning_rate_final_fraction: 0.1 25 | 26 | 27 | model: 28 | d_model: 512 29 | n_q_per_kv: 1 30 | n_kv: 8 31 | d_head: 128 32 | layers: 8 33 | d_ff: 4096 34 | vocab: 32768 35 | rope_max_timescale: 10000 36 | 37 | paths: 38 | # can also be a path to GCS. IE 'gcs://your_bucket/your_output_path' 39 | root_working_dir: '~/seqax_outputs' 40 | 41 | num_hosts: 1 42 | 43 | io: 44 | max_io_threads: 1024 45 | 46 | # Define either hf_dataset or flat_tokens. Do not use both. 47 | # hf_dataset should work as long as you have have the necessary permissions to access 48 | # the dataset and tokenizer. 49 | hf_dataset: 50 | path: allenai/c4 51 | name: en 52 | num_workers: 64 53 | tokenizer: mistralai/Mistral-7B-v0.1 # may require huggingface-cli login 54 | sequences_packed_per_batch: 120 55 | 56 | 57 | 58 | mesh: 59 | d: 1 60 | t: 1 61 | 62 | checkpoint_interval: 100 -------------------------------------------------------------------------------- /configs/local_test_synthetic.yaml: -------------------------------------------------------------------------------- 1 | # Command to run on your CPU: 2 | # XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m train --config-name=local_test_synthetic +paths.model_name=synthetic_000 3 | 4 | defaults: 5 | - base 6 | - _self_ 7 | 8 | training: 9 | warmup_steps: 10 10 | steps: 50 11 | steps_for_lr: 100 12 | tokens: 13 | batch: 64 14 | len: 64 15 | 16 | model: 17 | d_model: 256 18 | n_q_per_kv: 2 19 | n_kv: 2 20 | d_head: 32 21 | layers: 2 22 | vocab: 1280 23 | d_ff: 1024 24 | rope_max_timescale: 256 25 | 26 | paths: 27 | root_working_dir: '/tmp' 28 | 29 | checkpoint_interval: 10 30 | num_hosts: 1 31 | 32 | mesh: 33 | d: 4 34 | t: 2 35 | 36 | flat_tokens: 37 | filespec: 'synthetic_dataset.zarr' 38 | streams: 2 39 | read_blocks_per_shuffle_buffer: 8 40 | sequences_per_read_block: 16 41 | seed: 0 42 | sequence_packing: true 43 | -------------------------------------------------------------------------------- /docs/flat-tokens.md: -------------------------------------------------------------------------------- 1 | # `flat-tokens` data format 2 | 3 | ## Introduction 4 | 5 | The `flat-tokens` data format is a very simple data format for storing language model training data. 6 | Unlike some other dataset libraries, it supports efficient seeking after job restarts. It also 7 | supports batch size, sequence length, and "sequence packing vs not" being selected at training 8 | time. 9 | 10 | It is based on the simplest possible design: a concatenation of all tokens in the dataset, together 11 | with start indices of each sequence. 12 | 13 | ## Specification 14 | 15 | ### Flat-tokens array 16 | 17 | A *flat-tokens array* is a [`zarr` Group](https://zarr.readthedocs.io/en/stable/) of the following format: 18 | 19 | ``` 20 | arrays: { 21 | "encoded_tokens": uint32[token_count], 22 | "seq_starts": uint64[seq_count + 1], 23 | } 24 | attributes: { 25 | "max_token_id": int32 26 | } 27 | ``` 28 | 29 | That is, it has two arrays, named `encoded_tokens`, `seq_starts`. 30 | 31 | 1. The `encoded_tokens` array is a concatenation of all sequences in the dataset into a long array of tokens. 32 | There are no padding, beginning-of-sequence, or end-of-sequence tokens included. Tokens are encoded 33 | as `token_id*2+1` if they are the start of a new sequence, or `token_id*2` if not. The maximum supported `token_id` is `2^31-1`. 34 | 2. The `seq_starts` array lists (in increasing order) the indices of the `tokens` array where each 35 | sequence starts, plus one final index which equals `token_count`, indicating the end of the final 36 | sequence. 37 | 38 | Additionally, it has one attribute, named `max_token_id`. All decoded `token_id` values in `encoded_tokens` 39 | must be `<= max_token_id`. (This is intended to allow readers to quickly check that their vocabulary size is 40 | large enough for the dataset.) 41 | 42 | ### Flat-tokens dataset 43 | 44 | A *flat-tokens dataset* is a `zarr` Group with entries "train", "validation", each of which are flat-tokens arrays. 45 | 46 | ## Example 47 | 48 | The token sequences `[[1, 2], [3, 4, 5], [6, 7, 8]]` are represented in a flat-tokens array as: 49 | 50 | ``` 51 | arrays: { 52 | "tokens": [3, 4, 7, 8, 10, 13, 14, 16], 53 | "seq_starts": [0, 2, 5, 8], 54 | } 55 | attributes: { 56 | "max_token_id": 8 57 | } 58 | ``` 59 | 60 | ## Discussion 61 | 62 | This is the simplest possible format supporting the following features: 63 | * Batch size and sequence length can be chosen at training time. They are not "baked into" the format. 64 | * Data loading can be done with or without sequence packing: 65 | * Without sequence packing, we consult `seq_starts` to locate the tokens of a particular sequence, e.g. `tokens[seq_starts[1]:seq_starts[2]]` is `[7, 8, 10]`, corresponding to the tokens of sequence 1. 66 | * With sequence packing, we bypass `seq_starts` and directly consult `tokens`, e.g. for packed sequence length 4, sequence 1 is `tokens[4:8]`, i.e. `[10, 13, 14, 16]`. 67 | * O(1) random access to any sequence, packed or not. 68 | * This allows you to restart your training job and continue where you left off in the dataset, without retaining any state except for the step or sequence index where you left off. 69 | * This allows arbitrary shuffling at runtime. 70 | * Minimal disk seeks ("IO operations" on public clouds) per random access: just one disk seek for sequence-packed random access; just two disk seeks for non-packed random access. 71 | 72 | The sequence packing is designed such that no loss masking is required: every single token can be used as a target token. In the above example, if we used packed sequence length 8 (i.e. the whole dataset as one packed sequence), 73 | at training time we'd expand the tokens into the following input and target tokens: 74 | 75 | ``` 76 | { 77 | "inputs": [0, 1, 0, 3, 4, 0, 6, 7], 78 | "targets": [1, 2, 3, 4, 5, 6, 7, 8], 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /docs/matx.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/pytree-zarr-checkpoint.md: -------------------------------------------------------------------------------- 1 | # PyTree-zarr checkpoint format 2 | 3 | For `seqax` we write checkpoints of JAX PyTrees, in a simple format documented here. 4 | 5 | ## Specification 6 | 7 | The *zarr of a PyTree* is a a [zarr Group](https://zarr.readthedocs.io/en/stable/api/hierarchy.html) with the following elements: 8 | * for each `path, array` in the [flattened PyTree](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_flatten_with_path.html#jax.tree_util.tree_flatten_with_path), the zarr Group contains `array` as a child array, with path equal to `jax.tree_util.keystr(path)` 9 | * additionally there is a zarr [attribute](https://zarr.readthedocs.io/en/stable/api/attrs.html) by name `write_completed` and value `True`. 10 | 11 | The zarr of a PyTree may be written to disk with any compression and chunk size settings. 12 | 13 | ## Discussion 14 | 15 | We use `zarr` to support parallel writers from different hosts in a fully-sharded training setup. (Parallel writers in this scenario must choose a chunk size that divides the data size per host, so as to avoid zarr race conditions during writing.) Readers of the checkpoint format do not need to be aware that it was written in parallel, as this is hidden by the zarr abstraction. 16 | 17 | We use the `write_completed` attribute to allow parallel writers to support a "two phase commit" protocol: all writers write their data chunks, then wait for a global barrier, then the "leader" writer sets the `write_completed` attribute. This protects readers from reading partially-written checkpoints. 18 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def set_variables(): 4 | os.environ['XLA_FLAGS'] = ( 5 | os.environ.get('XLA_FLAGS', '') + ' ' 6 | '--xla_gpu_enable_async_collectives=true ' 7 | '--xla_gpu_enable_latency_hiding_scheduler=true ' 8 | ) 9 | os.environ.update({ 10 | "NCCL_LL128_BUFFSIZE": "-2", 11 | "NCCL_LL_BUFFSIZE": "-2", 12 | "NCCL_PROTO": "SIMPLE,LL,LL128", 13 | }) 14 | os.environ["LIBTPU_INIT_ARGS"] = "--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 15 | -------------------------------------------------------------------------------- /input_loader.py: -------------------------------------------------------------------------------- 1 | """Input data loading from `flat-tokens` data format. 2 | 3 | See `docs/flat-tokens.md` for details on the format. 4 | 5 | We support shuffling of the input data, by the following algorithm: 6 | * there are N independent "streams" of data, each of which has disjoint data and is 7 | shuffled independently. 8 | * within each stream, we fetch a "shuffle buffer" consisting of many "read blocks" of 9 | data. We shuffle the entire buffer in memory. 10 | * the "read blocks" attached to each shuffle buffer are themselves selected randomly. 11 | 12 | This is the standard shuffling used by e.g. Huggingface Datasets. Unlike them, we run 13 | this algorithm _after_ tokenization, so we know exactly at which step number each new 14 | shuffle buffer starts at, allowing us to do instant resumes after job restarts. In our 15 | default recommended configuration, we also recommend a much larger shuffle buffer size 16 | than Huggingface Datasets, which allows for more thorough shuffling, taking advantage 17 | of the fact that a single sequence of tokens uses very little memory compared to e.g. 18 | a single image. 19 | 20 | Mosaic's StreamingDatasets library uses a similar algorithm as us, which they call py1b: 21 | https://docs.mosaicml.com/projects/streaming/en/stable/fundamentals/shuffling.html. 22 | """ 23 | 24 | from concurrent.futures import ThreadPoolExecutor 25 | import functools 26 | from typing import Tuple, Union, Optional 27 | 28 | from typeguard import typechecked 29 | from shardlib.shardtypes import bool_, pytree_dataclass, u32 30 | import shardlib.shardtypes as shardtypes 31 | import zarr 32 | from dataclasses import dataclass 33 | import jax 34 | import numpy as np 35 | from jax.sharding import PartitionSpec as P 36 | import datetime 37 | import jax 38 | 39 | # imports for hf dataloader 40 | import numpy as onp 41 | from transformers import AutoTokenizer 42 | from torch.utils.data import DataLoader 43 | from datasets import load_dataset 44 | 45 | @dataclass(frozen=True) 46 | class TokenBatchParams: 47 | """The shape of a token batch.""" 48 | len: int 49 | batch: int 50 | 51 | 52 | @pytree_dataclass 53 | class TokenBatch: 54 | """A batch of tokens, which are typically the input to training.""" 55 | targets: u32['batch/d len'] 56 | is_seq_start: bool_['batch/d len'] 57 | 58 | 59 | 60 | 61 | @dataclass(frozen=True) 62 | class FlatTokensParams: 63 | filespec: str 64 | 65 | # A "stream" is what's attached to one independent shuffle buffer. There may be multiple 66 | # independent shuffle buffers, allowing parallelism. 67 | # 68 | # A "minipoch" (mini-epoch) is the set of sequences visited by one global refill of shuffle 69 | # buffers. The last minipoch may be shorter than others, but each stream in the last minipoch 70 | # must have the same number of read blocks, which must also be an integer. 71 | # 72 | # (To minimize discarded data on very small training sets, set streams=1 and make 73 | # sequences_per_read_block small.) 74 | # 75 | # Shuffling transforms the uint32[num_tokens] into uint32[streams, sequences, len], the 76 | # "shuffled tokens". We then form batches by a transformation on [streams, sequences]. 77 | 78 | streams: int # Recommended: maximum number of hosts you expect to use. 79 | read_blocks_per_shuffle_buffer: int # Recommended: 1 << 10. 4GiB (uncompressed) shuffle buffer. 80 | sequences_per_read_block: int # Recommended: (1 << 20) / len. 1MiB (compressed) read block. 81 | seed: int 82 | sequence_packing: bool 83 | 84 | 85 | @dataclass 86 | class _ShuffleBuffer: 87 | minipoch: int 88 | buffer: u32['Buflen len'] 89 | 90 | 91 | class ShufflingLoader: 92 | def __init__(self, split: str, params: FlatTokensParams, token_batch_params: TokenBatchParams): 93 | self.params = params 94 | self.token_batch_params = token_batch_params 95 | self.root = zarr.open_group(params.filespec, mode="r") 96 | assert split in ["train", "validation"], "Invalid split" 97 | self.encoded_tokens = self.root[split]["encoded_tokens"] 98 | self.seq_starts = self.root[split]["seq_starts"] 99 | self.max_token_id = self.root[split].attrs["max_token_id"] 100 | assert len(self.encoded_tokens.shape) == 1, "Expected 1D zarr" 101 | assert self.encoded_tokens.dtype == np.uint32, "Expected uint32 zarr" 102 | assert len(self.seq_starts.shape) == 1, "Expected 1D zarr" 103 | assert self.seq_starts.dtype == np.uint64, "Expected uint64 zarr" 104 | 105 | token_count = self.encoded_tokens.shape[0] 106 | if params.sequence_packing: 107 | self.seq_count = token_count // token_batch_params.len 108 | else: 109 | self.seq_count = self.seq_starts.shape[0] - 1 110 | 111 | # Count read blocks. Round it down to a multiple of streams 112 | read_block_count = self.seq_count // params.sequences_per_read_block 113 | read_block_count = (read_block_count // params.streams) * params.streams 114 | self.read_block_count = read_block_count 115 | assert read_block_count > 0, "Must have at least one read block per stream. Try shrinking streams and sequences_per_read_block." 116 | self.step_count = (read_block_count * params.sequences_per_read_block) // token_batch_params.batch 117 | # Count minipochs 118 | self.minipoch_count = _div_up(read_block_count, params.streams * params.read_blocks_per_shuffle_buffer) 119 | self.seq_indices_per_shuffle_buffer = params.read_blocks_per_shuffle_buffer * params.sequences_per_read_block 120 | # Calculate batch->stream mapping. 121 | self.batch_indices_per_stream = _div_exact(token_batch_params.batch, params.streams) 122 | # Calculate which streams and which batch indices this host is responsible for, based on the sharding. 123 | self.sharding = shardtypes.make_shardings(TokenBatch).targets 124 | streams = set() 125 | batch_indices = set() 126 | for batch_slices, _ in self.sharding.addressable_devices_indices_map((token_batch_params.batch, token_batch_params.len)).values(): 127 | batch_lo, batch_hi, batch_step = batch_slices.indices(token_batch_params.batch) 128 | for b in range(batch_lo, batch_hi, batch_step): 129 | batch_indices.add(b) 130 | streams.add(b // self.batch_indices_per_stream) 131 | self.shuffle_buffers_by_stream = {stream_index: None for stream_index in streams} 132 | self.batch_indices = sorted(batch_indices) 133 | # Shuffle read blocks 134 | assert read_block_count < 1 << 32, "Too many read blocks. Try growing sequences_per_read_block." 135 | self.read_block_ordering = _random_permutation(params.seed, read_block_count) 136 | 137 | 138 | def load(self, step: int) -> TokenBatch: 139 | assert step < self.step_count, f"Requested step {step} but dataset only supports {self.step_count} steps at batch size {self.token_batch_params.batch}." 140 | # Conceptually, we remap IDs as follows: 141 | # 1. (step, batch_index) -> (stream, seq_index_in_stream) 142 | # 2. seq_index_in_stream -> (minipoch, seq_index_in_shuffle_buffer) 143 | # 144 | # We visit all batch_indices in increasing order. Since the map batch_index->(stream, minipoch) 145 | # is monotonic (non-decreasing), we can reload the shuffle buffer for a stream whenever 146 | # we cross to a new minipoch without thrashing back and forth between adjacent minipochs. 147 | seq_by_batch_index = {} 148 | for batch_index in self.batch_indices: 149 | # 1. (step, batch_index) -> (stream, seq_index_in_stream) 150 | stream = batch_index // self.batch_indices_per_stream 151 | seq_index_in_stream = step * self.batch_indices_per_stream + (batch_index % self.batch_indices_per_stream) 152 | # 2. seq_index_in_stream -> (minipoch, seq_index_in_shuffle_buffer) 153 | minipoch = seq_index_in_stream // self.seq_indices_per_shuffle_buffer 154 | seq_index_in_shuffle_buffer = seq_index_in_stream % self.seq_indices_per_shuffle_buffer 155 | shuffle_buffer = self._get_shuffle_buffer(stream, minipoch) 156 | seq_by_batch_index[batch_index] = shuffle_buffer[seq_index_in_shuffle_buffer] 157 | 158 | def get_shard(indexing: Tuple[slice]) -> jax.Array: 159 | seqlen_slice = indexing[1] 160 | examples = [] 161 | for batch_index in range(*indexing[0].indices(self.token_batch_params.batch)): 162 | examples.append(seq_by_batch_index[batch_index][seqlen_slice]) 163 | return np.stack(examples) 164 | 165 | shape = (self.token_batch_params.batch, self.token_batch_params.len) 166 | encoded_tokens = jax.make_array_from_callback(shape, self.sharding, get_shard) 167 | return _decode(encoded_tokens) 168 | 169 | 170 | def _get_shuffle_buffer(self, stream: int, minipoch: int) -> _ShuffleBuffer: 171 | if self.shuffle_buffers_by_stream[stream] is None or self.shuffle_buffers_by_stream[stream].minipoch != minipoch: 172 | self.shuffle_buffers_by_stream[stream] = None # Free the underlying memory 173 | blocks_in_shuffle_buffer = self.params.read_blocks_per_shuffle_buffer 174 | if minipoch == self.minipoch_count - 1: 175 | blocks_in_shuffle_buffer = (self.read_block_count // self.params.streams) - self.params.read_blocks_per_shuffle_buffer * minipoch 176 | # We form a mapping: 177 | # (stream, minipoch, read_block_in_minipoch) -> sequential_read_block 178 | # then we map 179 | # sequential_read_block -> shuffled_read_block 180 | # using self.shuffled_read_blocks. 181 | shuffled_read_block_indices = [] 182 | for read_block_in_minipoch in range(blocks_in_shuffle_buffer): 183 | sequential_read_block = (minipoch * self.params.read_blocks_per_shuffle_buffer + read_block_in_minipoch) * self.params.streams + stream 184 | shuffled_read_block = self.read_block_ordering[sequential_read_block] 185 | shuffled_read_block_indices.append(shuffled_read_block) 186 | 187 | # Now load all of the read blocks in parallel. 188 | def load_read_block(read_block_index: int) -> u32['Buflen len']: 189 | start_seq = read_block_index * self.params.sequences_per_read_block 190 | end_seq = start_seq + self.params.sequences_per_read_block 191 | block_shape = (self.params.sequences_per_read_block, self.token_batch_params.len) 192 | if self.params.sequence_packing: 193 | flat_tokens = self.encoded_tokens[start_seq * self.token_batch_params.len : end_seq * self.token_batch_params.len] 194 | return flat_tokens.reshape(block_shape) 195 | else: 196 | seq_starts = self.seq_starts[start_seq : end_seq + 1] 197 | flat_tokens = self.encoded_tokens[seq_starts[0] : seq_starts[-1]] 198 | # Read the ragged array into a (padded) dense array. 199 | # 200 | # We pad with 1s, which decode to (0, new_sequence=true). 201 | result = np.ones(block_shape, dtype=np.uint32) 202 | for i in range(self.params.sequences_per_read_block): 203 | start = seq_starts[i] 204 | end = seq_starts[i + 1] 205 | result[i, :end - start] = flat_tokens[start:end] 206 | return result 207 | 208 | print(f'[{datetime.datetime.now()}] Loading shuffle buffer') 209 | # Loading a read block is IO-dominated work, with very little CPU time involved, so we can afford 210 | # to run a huge number of these in parallel with little concern about thrashing the CPU by having 211 | # excessively many threads doing CPU-intensive work. At the recommended read block sizing of 1MiB, 212 | # the memory footprint of a read block is typically bigger than the memory footprint of a CPU thread, 213 | # so we're also unlikely to waste a significant fraction of memory by having too many threads. In 214 | # net, allow a lot of threads, potentially way more than we have CPUs! Other overheads will 215 | # bite us before thread overheads do. 216 | with ThreadPoolExecutor(max_workers=len(shuffled_read_block_indices)) as executor: 217 | shuffled_read_blocks = list(executor.map(load_read_block, shuffled_read_block_indices)) 218 | shuffle_buffer = np.concatenate(shuffled_read_blocks, axis=0) 219 | print(f'[{datetime.datetime.now()}] Finished loading shuffle buffer, {shuffle_buffer.size * 4:_} bytes') 220 | 221 | # Actually shuffle it. 222 | sequences_in_shuffle_buffer = blocks_in_shuffle_buffer * self.params.sequences_per_read_block 223 | assert shuffle_buffer.shape == (sequences_in_shuffle_buffer, self.token_batch_params.len) 224 | shuffle_seed = self.params.seed + 1 + minipoch * self.params.streams + stream 225 | permutation = _random_permutation(shuffle_seed, sequences_in_shuffle_buffer) 226 | shuffle_buffer = shuffle_buffer[permutation, :] 227 | self.shuffle_buffers_by_stream[stream] = _ShuffleBuffer(minipoch, shuffle_buffer) 228 | 229 | return self.shuffle_buffers_by_stream[stream].buffer 230 | 231 | def _div_up(a: int, b: int) -> int: 232 | return (a + b - 1) // b 233 | 234 | def _div_exact(a: int, b: int) -> int: 235 | assert a % b == 0 236 | return a // b 237 | 238 | @functools.partial(jax.jit, donate_argnums=(0,)) 239 | @typechecked 240 | def _decode(encoded_tokens: u32[b'batch/d len']) -> TokenBatch: 241 | # encoded_tokens encoding: 242 | # 2*id+1 for the first token in a sequence 243 | # 2*id for other tokens in the sequence 244 | return TokenBatch( 245 | targets = encoded_tokens >> 1, 246 | is_seq_start = (encoded_tokens & 1) == 1, 247 | ) 248 | 249 | def _random_permutation(seed: int, n: int) -> u32['N']: 250 | """Same as `np.random.Generator.permutation`, but with a guarantee that it will always produce the same results for a given seed.""" 251 | assert n < 1 << 32 252 | # We do a Fisher-Yates shuffle using the Philox BitGenerator. Unlike the rest of np.random, 253 | # which is documented as potentially changing between numpy versions or even platforms on 254 | # the same version, the Philox BitGenerator is documented as stable. Likewise, we also promise 255 | # not to change the following implementation of the Fisher-Yates shuffle. 256 | # 257 | # We calculate the random numbers using `random_uint64() % n` rather than using rejection 258 | # sampling to generate numbers in range `[0, n)`. (Rejection sampling is more complicated, 259 | # because we don't know up front how many random numbers we'll need.) Our approach 260 | # introduces some bias, but it's small: since n<2^32, the bias is at most 2^-32 for each 261 | # random number generated. We're fine with this. 262 | randoms = np.random.Philox(seed).random_raw(n) % (np.arange(n, dtype=np.uint64) + 1) 263 | result = np.arange(n, dtype=np.uint32) 264 | for i in reversed(range(n)): 265 | j = randoms[i] 266 | tmp = result[i] 267 | result[i] = result[j] 268 | result[j] = tmp 269 | return result 270 | 271 | 272 | @dataclass(frozen=True) 273 | class HuggingFaceDataParams: 274 | path: str 275 | tokenizer: str 276 | num_workers: int 277 | sequences_packed_per_batch: int 278 | name: Optional[str] = None 279 | 280 | class HuggingFaceDataLoader: 281 | """ 282 | The HuggingFaceDataLoader is provided for convenience and ease of setup, 283 | but the flat tokens dataloader is recommended for production use. 284 | This dataset does not require running the tools/huggingface_to_flat_tokens.py 285 | to create a flat tokens dataset, and instead streams directly from huggingface. 286 | 287 | This datalaoder will waste tokens if you pack too many sequences into a batch, 288 | and does not support instant resume to an arbitrary step. 289 | """ 290 | def __init__(self, split, config: HuggingFaceDataParams, token_batch_params: TokenBatchParams): 291 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer) 292 | self.batch_size = token_batch_params.batch 293 | self.max_seq_len = token_batch_params.len 294 | self.sharding = shardtypes.make_shardings(TokenBatch).targets 295 | self.max_token_id = self.tokenizer.vocab_size-1 296 | assert 0 in self.tokenizer.all_special_ids, "Tokenizer must have a special 0 token" 297 | 298 | # setup an iterator over the dataset 299 | tokenize = functools.partial(self.tokenizer, padding=False, truncation=False, max_length=None, add_special_tokens=False, return_token_type_ids=False, return_attention_mask=False, return_tensors="np") 300 | dataset = load_dataset(config.path, config.name, streaming=True, split=split) 301 | tokenized = dataset.select_columns(["text"]).map(tokenize, input_columns=['text'], remove_columns=["text"]) 302 | dataloader = DataLoader(tokenized, num_workers=config.num_workers, collate_fn=self.collate, drop_last=True, batch_size=config.sequences_packed_per_batch) 303 | self.iterator = iter(dataloader) 304 | 305 | def collate(self, sequences): 306 | flat_batch = onp.zeros(self.batch_size * self.max_seq_len, onp.uint32) 307 | flat_is_start = onp.zeros(self.batch_size * self.max_seq_len, onp.bool_) 308 | start = 0 309 | for seq in sequences: 310 | seq = seq['input_ids'][0] 311 | end = min(start + len(seq), len(flat_batch)) 312 | flat_is_start[start] = True 313 | flat_batch[start:end] = seq[:end-start] 314 | start += len(seq) 315 | if start >= len(flat_batch): 316 | break 317 | shape = (self.batch_size, self.max_seq_len) 318 | return flat_batch.reshape(shape), flat_is_start.reshape(shape) 319 | 320 | def load(self, step): 321 | shape = (self.batch_size, self.max_seq_len) 322 | batch, is_start = next(self.iterator) 323 | def get_shard(x: jax.Array, indexing: Tuple[slice]) -> jax.Array: 324 | shard = x[indexing] 325 | return shard 326 | tokens = jax.make_array_from_callback(shape, self.sharding, functools.partial(get_shard, batch)) 327 | is_start = jax.make_array_from_callback(shape, self.sharding, functools.partial(get_shard, is_start)) 328 | return TokenBatch(tokens, is_start) 329 | 330 | def get_loader(split: str, config: Union[FlatTokensParams, HuggingFaceDataParams], token_batch_params: TokenBatchParams): 331 | if isinstance(config, FlatTokensParams): 332 | return ShufflingLoader(split, config, token_batch_params) 333 | elif isinstance(config, HuggingFaceDataParams): 334 | return HuggingFaceDataLoader(split, config, token_batch_params) 335 | else: 336 | raise ValueError(f"Unknown config type {type(config)}") -------------------------------------------------------------------------------- /jax_extra.py: -------------------------------------------------------------------------------- 1 | """Extra utilities for JAX and Python.""" 2 | import jax 3 | import hashlib 4 | import jax 5 | import jax.ad_checkpoint 6 | import dataclasses 7 | from typing import Union, get_args 8 | from dataclasses import fields, is_dataclass 9 | 10 | 11 | def fold_in_str(key: jax.Array, string: str) -> jax.Array: 12 | """Returns a PRNG key derived from an initial PRNG key and a string input. 13 | 14 | Args: 15 | key: The initial PRNG key. 16 | string: The string input (e.g., 'pretrain', 'query', etc.). 17 | 18 | Returns: 19 | A PRNG key derived from the initial PRNG key and the string input. 20 | """ 21 | return jax.random.fold_in( 22 | key, int(hashlib.md5(string.encode()).hexdigest()[:8], base=16) 23 | ) 24 | 25 | def _convert(value, target_type): 26 | if value is None and target_type is not type(None): 27 | raise ValueError(f"Cannot convert None to {target_type}") 28 | elif value is None and target_type is type(None): 29 | return None 30 | elif is_dataclass(target_type): 31 | return make_dataclass_from_dict(target_type, value) 32 | else: 33 | return target_type(value) 34 | 35 | def _handle_union(name, field_value, union_types): 36 | for type_option in union_types: 37 | try: 38 | return _convert(field_value, type_option) 39 | except (TypeError, ValueError, AssertionError): 40 | continue 41 | raise ValueError(f'could not convert Union type {name} to any of {union_types}.') 42 | 43 | def make_dataclass_from_dict(cls, data): 44 | """Recursively instantiate a dataclass from a dictionary.""" 45 | if data is None: 46 | raise ValueError(f'Expected a {cls.__name__}, got None instead.') 47 | field_data = {} 48 | for field in fields(cls): 49 | field_value = data.get(field.name) 50 | if hasattr(field.type, '__origin__') and field.type.__origin__ is Union: 51 | field_data[field.name] = _handle_union(field.name, field_value, get_args(field.type)) 52 | else: 53 | try: 54 | field_data[field.name] = _convert(field_value, field.type) 55 | except (TypeError, ValueError, AssertionError): 56 | raise ValueError(f'Expected {field.type} for {cls.__name__}.{field.name}, got {type(field_value)} instead.') 57 | return cls(**field_data) 58 | 59 | def explicit_activation_checkpointing(f): 60 | """Annotates a function f to be used with save_for_backward(). 61 | 62 | Example: 63 | 64 | ``` 65 | @explicit_activation_checkpointing 66 | def foo(W1, W2, W3, x): 67 | x = jax.nn.relu(save_for_backward(W1 @ x)) 68 | x = jax.nn.relu(save_for_backward(W2 @ x)) 69 | x = W3 @ x 70 | ``` 71 | 72 | This causes the pre-ReLU activations to be saved for the backwards pass. 73 | """ 74 | # We save everything that is named. 75 | return jax.ad_checkpoint.checkpoint(f, policy=jax.checkpoint_policies.save_any_names_but_these()) 76 | 77 | def save_for_backward(x): 78 | """Saves a value for the backwards pass in a function annotated with explicit_activation_checkpointing().""" 79 | # The actual name isn't important, just the fact that it _is_ named, so that 80 | # the save_any_names_but_these() policy causes it to be saved. 81 | return jax.ad_checkpoint.checkpoint_name(x, name='seqax_save_for_backward') -------------------------------------------------------------------------------- /requirements-cpu.txt: -------------------------------------------------------------------------------- 1 | zarr 2 | fsspec[gcs] 3 | jax[cpu]==0.4.26 4 | einops 5 | hydra-core 6 | clearml 7 | clearml-agent 8 | typeguard==4.1.5 9 | transformers # For Huggingface data loader 10 | datasets # For Huggingface data loader 11 | torch[cpu] # For Huggingface data loader -------------------------------------------------------------------------------- /shardlib/shardops.py: -------------------------------------------------------------------------------- 1 | import shardlib.shardtypes as shardtypes 2 | from jax import lax 3 | import jax.numpy as jnp 4 | import jax 5 | 6 | def all_gather(spec: str, x): 7 | """String-specified all-gather operation. 8 | 9 | For example: 10 | all_gather('A/x/y B/z C/w -> A B C/w', x) 11 | """ 12 | before, after = spec.split('->') 13 | before = shardtypes.ShapeSpec.parse(before) 14 | after = shardtypes.ShapeSpec.parse(after) 15 | shardtypes.check(x.dtype, before, x) 16 | for i, (before_dim, after_dim) in enumerate(zip(before.dims, after.dims)): 17 | # Check that after_dim.sharding is a prefix of before_dim.sharding 18 | after_n = len(after_dim.sharding) 19 | if before_dim.shape != after_dim.shape or before_dim.sharding[:after_n] != after_dim.sharding: 20 | raise ValueError(f'Cannot all-gather {before_dim} into {after_dim}') 21 | if len(before_dim.sharding) == after_n: 22 | continue 23 | x = lax.all_gather(x, tuple(before_dim.sharding[after_n:]), axis=i, tiled=True) 24 | shardtypes.check(x.dtype, after, x) 25 | return x 26 | 27 | def psum_scatter(spec: str, x): 28 | """String-specified reduce-scatter operation. 29 | 30 | For example: 31 | psum_scatter('A B C/w -> A/x/y B/z C/w', x) 32 | """ 33 | before, after = spec.split('->') 34 | before = shardtypes.ShapeSpec.parse(before) 35 | after = shardtypes.ShapeSpec.parse(after) 36 | shardtypes.check(x.dtype, before, x) 37 | for i, (before_dim, after_dim) in enumerate(zip(before.dims, after.dims)): 38 | # Check that before_dim.sharding is a prefix of after_dim.sharding 39 | before_n = len(before_dim.sharding) 40 | if before_dim.shape != after_dim.shape or after_dim.sharding[:before_n] != before_dim.sharding: 41 | raise ValueError(f'Cannot reduce-scatter {before_dim} into {after_dim}') 42 | if len(after_dim.sharding) == before_n: 43 | continue 44 | x = lax.psum_scatter(x, tuple(after_dim.sharding[before_n:]), scatter_dimension=i, tiled=True) 45 | shardtypes.check(x.dtype, after, x) 46 | return x 47 | 48 | def einsum_unreduced(spec: str, x, y, **kwargs): 49 | """Ordinary chip-local einsum, but with sharding-aware typechecking. 50 | 51 | Note that this function does not do any chip-to-chip communication. If the inputs are 52 | sharded over the contraction dimensions, the caller is responsible for reducing the result 53 | over those dimensions. For example: 54 | 55 | c = einsum_unreduced('A/x B/y, B/y C/z -> A/x/z', a, b) 56 | # c still needs to be reduced over the y axis. 57 | d = psum_scatter('A/x/z -> A/x/z/y', c) 58 | # Now the post-einsum reduction is complete. 59 | """ 60 | tmp, result = spec.split('->') 61 | lhs, rhs = tmp.split(',') 62 | lhs = shardtypes.ShapeSpec.parse(lhs) 63 | rhs = shardtypes.ShapeSpec.parse(rhs) 64 | result = shardtypes.ShapeSpec.parse(result) 65 | shardtypes.check(x.dtype, lhs, x) 66 | shardtypes.check(y.dtype, rhs, y) 67 | # Convert to jax einsum syntax, with single-letter variables. 68 | jaxspec = '' 69 | 70 | vars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 71 | var_i = 0 72 | dim_table = {} 73 | def map_var(dim): 74 | if dim in dim_table: 75 | return dim_table[dim] 76 | nonlocal var_i 77 | if var_i >= len(vars): 78 | raise ValueError('Too many dimensions in einsum, we ran out of variables') 79 | var = vars[var_i] 80 | var_i += 1 81 | dim_table[dim] = var 82 | return var 83 | 84 | for dim in lhs.dims: 85 | jaxspec += map_var(dim) 86 | jaxspec += ',' 87 | for dim in rhs.dims: 88 | jaxspec += map_var(dim) 89 | jaxspec += '->' 90 | for dim in result.dims: 91 | jaxspec += map_var(dim) 92 | r = jnp.einsum(jaxspec, x, y, **kwargs) 93 | shardtypes.check(r.dtype, result, r) 94 | return r 95 | 96 | def index_unreduced(spec: str, table, indices): 97 | """String-specified sharded table lookup operation. 98 | 99 | For example: 100 | index_unreduced(table, indices, 'A [B/x/y] C/z, D/w A -> C/z A D/w') 101 | 102 | In this example, the integers in `indices` are used as lookup addresses into the 103 | `B` dimension of `table`, and all other dimensions (`A`, `C`, `D`) are vmapped over. 104 | 105 | This operation does not do any chip-to-chip communication, even though the table 106 | may be sharded. If the axis inside square brackets is sharded, corresponding to 107 | different table indices on different shards, a table lookup will be performed on each 108 | shard, but only one shard will return a nonzero result: the other shards, where the 109 | index is out of bounds, will return zero. The caller is required to reduce the output 110 | over the axes specified by the square brackets: in the above example, the caller must 111 | reduce over `x` and `y` axes. 112 | """ 113 | tmp, result = spec.split('->') 114 | lhs, rhs = tmp.split(',') 115 | lhs_dims = lhs.split() 116 | index_axis = None 117 | for i, dim in enumerate(lhs_dims): 118 | if dim.startswith('['): 119 | index_axis = i 120 | if not dim.endswith(']'): 121 | raise ValueError(f'Expected closing bracket in {dim}') 122 | lhs_dims[i] = dim[1:-1] 123 | break 124 | if index_axis is None: 125 | raise ValueError(f'Expected an index axis in {lhs}') 126 | 127 | lhs_dims = [shardtypes.DimSpec.parse(dim) for dim in lhs_dims] 128 | lhs_spec = shardtypes.ShapeSpec(lhs_dims) 129 | rhs_spec = shardtypes.ShapeSpec.parse(rhs) 130 | result_spec = shardtypes.ShapeSpec.parse(result) 131 | shardtypes.check(table.dtype, lhs_spec, table) 132 | shardtypes.check(indices.dtype, rhs_spec, indices) 133 | 134 | # Do the base operation on scalars, then do a sequence of vmap operations to bring it up 135 | # to the desired shape. 136 | def base_op(table, index): 137 | len_per_chip = table.shape[0] 138 | lower_bound = len_per_chip * lax.axis_index(lhs_dims[index_axis].sharding) 139 | upper_bound = lower_bound + len_per_chip 140 | in_bounds = (lower_bound <= index) & (index < upper_bound) 141 | return jnp.where(in_bounds, table[jnp.where(in_bounds, index - lower_bound, 0)], 0) 142 | 143 | op = base_op 144 | 145 | lhs_dims_handled = [False] * len(lhs_dims) 146 | lhs_dims_handled[index_axis] = True 147 | rhs_dims_handled = [False] * len(rhs_spec.dims) 148 | for dim in reversed(result_spec.dims): 149 | try: 150 | lhs_index = lhs_dims.index(dim) 151 | lhs_vmap_axis = sum(lhs_dims_handled[:lhs_index]) 152 | assert not lhs_dims_handled[lhs_index] 153 | lhs_dims_handled[lhs_index] = True 154 | except ValueError: 155 | lhs_index = None 156 | lhs_vmap_axis = None 157 | 158 | try: 159 | rhs_index = rhs_spec.dims.index(dim) 160 | rhs_vmap_axis = sum(rhs_dims_handled[:rhs_index]) 161 | assert not rhs_dims_handled[rhs_index] 162 | rhs_dims_handled[rhs_index] = True 163 | except ValueError: 164 | rhs_index = None 165 | rhs_vmap_axis = None 166 | 167 | op = jax.vmap(op, in_axes=(lhs_vmap_axis, rhs_vmap_axis), out_axes=0) 168 | 169 | assert all(lhs_dims_handled) 170 | assert all(rhs_dims_handled) 171 | 172 | result = op(table, indices) 173 | shardtypes.check(result.dtype, result_spec, result) 174 | return result 175 | 176 | def axis_size(name: str) -> int: 177 | """Return the size of the axis with the given name.""" 178 | return jax.lax.psum(1, name) -------------------------------------------------------------------------------- /shardlib/shardtypes.py: -------------------------------------------------------------------------------- 1 | """Type annotations for JAX arrays with sharding information. 2 | 3 | # Shape checking 4 | 5 | Example: 6 | 7 | ``` 8 | import jax 9 | shardtypes.register_with_typeguard() 10 | from shardlib.shardtypes import f32 11 | from typeguard import typechecked 12 | 13 | @typechecked 14 | def center_channels(x: f32[b'batch/d channels']) -> f32[b'batch/d channels']: 15 | return x - jax.numpy.mean(x, axis=-1, keepdims=True) 16 | ``` 17 | 18 | The type syntax is `[]`, where `dtype` is imported from `shardlib.shardtypes`, 19 | and `` is a space-separated list of dimensions. Each dimension consists of a dimension 20 | name (e.g. `batch`), optionally followed by slashes and sharding axis names, e.g. `batch/d` indicates 21 | that the `batch` tensor dimension is sharded over the `d` device axis. Sharding over multiple axes 22 | is indicated by multiple axis names, e.g. `batch/d/e`. 23 | 24 | The shape string may be either a string ('foo') or a bytes object (b'foo'). Strings have special 25 | meaning in Python type annotations (they are used for forward references, and are eval'ed by typeguard), 26 | so the bytes object b'foo' is a workaround to prevent this eval'ing. 27 | 28 | Shape checking proceeds by maintaining a table of the sizes of all dimension names in a context 29 | variable, known as the shape checking scope. The first time a dimension name is encountered, 30 | its size is recorded in the current scope. Subsequent uses of the same dimension name must have 31 | the same size. Device axes (e.g. `/d`) are looked up in the currently configured JAX device mesh, 32 | to determine the size of the axis. 33 | 34 | For calls into functions or libraries, it can be useful to clear the shape checking scope, so caller 35 | and callee can use the same variable name to mean different things. This can be done with the `@scope` 36 | function decorator or the `with Scope():` context manager. 37 | 38 | # Using type annotations 39 | 40 | In addition to driving shape checking, type annotations can be used to drive sharding in JAX functions. 41 | See for example `typed_shard_map`, which is a simplification of JAX's `shard_map` by taking advantage 42 | of sharding in type signatures. 43 | """ 44 | import inspect 45 | import typing 46 | from collections.abc import Sequence 47 | from contextvars import ContextVar 48 | from enum import IntEnum 49 | from typing import Any, Union 50 | from typing import get_args, get_origin 51 | from typeguard import check_type_internal, typechecked 52 | import jax 53 | import jax.numpy as jnp 54 | from types import GenericAlias 55 | from typeguard import TypeCheckError, TypeCheckerCallable 56 | import dataclasses 57 | from dataclasses import dataclass, make_dataclass 58 | from typeguard import checker_lookup_functions 59 | 60 | 61 | #### State 62 | # ContextVar(dict[str, int]) 63 | _VARS = ContextVar('shardtypes._VARS', default={}) 64 | 65 | class Scope: 66 | """Context manager that clears the shape checking scope.""" 67 | def __enter__(self): 68 | self.token = _VARS.set({}) 69 | 70 | def __exit__(self, type, value, traceback): 71 | _VARS.reset(self.token) 72 | 73 | def scope(f): 74 | """Function decorator that clears the shape checking scope.""" 75 | def wrapper(*args, **kwargs): 76 | with Scope(): 77 | return f(*args, **kwargs) 78 | return wrapper 79 | 80 | 81 | def check_size(name: str, size: int): 82 | """Checks that a dimension has the expected size.""" 83 | try: 84 | value = int(name) 85 | if value != size: 86 | raise TypeCheckError(f'explicit dimension {value}: actually was {size}') 87 | except ValueError: 88 | v = _VARS.get() 89 | if name in v: 90 | if v[name] != size: 91 | raise TypeCheckError(f'dimension {name}: expected {v[name]}, got {size}') 92 | else: 93 | v[name] = size 94 | 95 | 96 | #### Shape specs 97 | @dataclass(frozen=True) 98 | class DimSpec: 99 | """Parsed result of a dimension in a shape string.""" 100 | shape: str 101 | sharding: Sequence[str] 102 | 103 | @staticmethod 104 | def parse(spec: str) -> 'DimSpec': 105 | pieces = spec.split('/') 106 | shape = pieces[0] 107 | sharding = tuple(pieces[1:]) 108 | return DimSpec(shape, sharding) 109 | 110 | def __str__(self): 111 | return '/'.join([self.shape] + list(self.sharding)) 112 | 113 | @dataclass 114 | class ShapeSpec: 115 | """Parsed result of a shape string.""" 116 | dims: Sequence[DimSpec] 117 | 118 | @staticmethod 119 | def parse(spec: Union[bytes, str]) -> 'ShapeSpec': 120 | if isinstance(spec, bytes): 121 | spec = spec.decode('utf-8') 122 | if not isinstance(spec, str): 123 | print(spec) 124 | raise ValueError('Expected a string') 125 | dims = spec.split() # Split on spaces, trimming excess space 126 | result = [] 127 | for dim in dims: 128 | result.append(DimSpec.parse(dim)) 129 | return ShapeSpec(result) 130 | 131 | def partition_spec(self) -> jax.sharding.PartitionSpec: 132 | result = [] 133 | for dim_spec in self.dims: 134 | if len(dim_spec.sharding) == 0: 135 | result.append(None) 136 | elif len(dim_spec.sharding) == 1: 137 | result.append(dim_spec.sharding[0]) 138 | else: 139 | result.append(tuple(dim_spec.sharding)) 140 | return jax.sharding.PartitionSpec(*result) 141 | 142 | def __str__(self): 143 | return ' '.join(str(dim) for dim in self.dims) 144 | 145 | #### Shape checking 146 | def _partition_spec_equiv(lhs: jax.sharding.PartitionSpec, rhs: jax.sharding.PartitionSpec) -> bool: 147 | if len(lhs) < len(rhs): 148 | lhs, rhs = rhs, lhs 149 | if any(l is not None for l in lhs[len(rhs):]): 150 | return False 151 | return lhs[:len(rhs)] == rhs[:] 152 | 153 | 154 | def check(dtype, shape_spec: ShapeSpec, value): 155 | """Checks that a value has the expected dtype and shape.""" 156 | if not isinstance(value, jax.Array): 157 | raise TypeCheckError('is not a jax.Array') 158 | if value.dtype != dtype: 159 | raise TypeCheckError(f'is {value.dtype}, but expected {dtype}') 160 | shape = value.shape 161 | if len(shape) != len(shape_spec.dims): 162 | raise TypeCheckError(f'has shape {shape}, but expected shape {str(shape_spec)}') 163 | mesh = None 164 | 165 | axis_env = jax._src.core.thread_local_state.trace_state.axis_env 166 | if axis_env: 167 | # We're in a shard_map/pmap/xmap context. Multiply sizes by sharding, then check sizes. 168 | # We don't actually check the sharding, because that information is lost inside a 169 | # shard_map/pmap/xmap context, but we do check the unsharded sizes are correct. 170 | mesh = {axis.name: axis.size for axis in axis_env} 171 | for orig_dim, dim_spec in zip(shape, shape_spec.dims): 172 | dim = orig_dim 173 | for axis in dim_spec.sharding: 174 | if axis not in mesh: 175 | raise TypeCheckError(f'has unknown mesh axis {axis}') 176 | axis_size = mesh[axis] 177 | dim *= axis_size 178 | check_size(dim_spec.shape, dim) 179 | else: 180 | # Check sizes 181 | for dim, dim_spec in zip(shape, shape_spec.dims): 182 | check_size(dim_spec.shape, dim) 183 | 184 | # Check sharding 185 | expected_spec = shape_spec.partition_spec() 186 | def cb(actual): 187 | if isinstance(actual, jax.sharding.SingleDeviceSharding): 188 | if any(dim_spec.sharding for dim_spec in shape_spec.dims): 189 | raise TypeCheckError(f'is fully replicated, but expected {expected_spec} is not') 190 | elif not isinstance(actual, jax.sharding.NamedSharding): 191 | if isinstance(actual, jax.sharding.Sharding): 192 | raise TypeCheckError(f'is SPMD-sharded but no axis names are available. Use `with Mesh(...):` to provide axis names for type checking.') 193 | else: 194 | raise TypeCheckError(f': unexpected object when checking sharding: {actual}') 195 | elif not _partition_spec_equiv(actual.spec, expected_spec): 196 | # TODO: when an axis size is None, recovering the NamedSharding from the PositionalSharding 197 | # is ambiguous, and JAX often takes a different approach than the user does. 198 | # 199 | # We could fix this with a more precise _partition_spec_equiv, but for now we'll just ignore it. 200 | # raise TypeCheckError(f'has sharding spec {actual.spec}, but expected {expected_spec} from {str(shape_spec)}') 201 | pass 202 | # Use tracing as a proxy for whether we're in a jit context 203 | is_tracing = jax._src.core.thread_local_state.trace_state.trace_stack 204 | if is_tracing: 205 | jax.debug.inspect_array_sharding(value, callback=cb) 206 | else: 207 | cb(value.sharding) 208 | 209 | 210 | 211 | 212 | #### Typeguard 213 | def register_with_typeguard(): 214 | """Registers the shardtypes module with typeguard. Call this at the beginning of your program.""" 215 | def check_array(value, origin, args, memo): 216 | if len(args) != 1 or (type(args[0]) is not str and type(args[0]) is not bytes): 217 | raise TypeCheckError(f'has bad type signature; expected {origin.__name__}[], got {origin.__name__}{args}') 218 | check(origin.dtype, ShapeSpec.parse(args[0]), value) 219 | 220 | def check_pytree_dataclass(value, origin, args, memo): 221 | if not isinstance(value, origin): 222 | raise TypeCheckError(f'is not an instance of {origin}') 223 | for field in dataclasses.fields(origin): 224 | check_type_internal(getattr(value, field.name), field.type, memo) 225 | 226 | def lookup( 227 | origin, args, extras 228 | ) -> TypeCheckerCallable | None: 229 | if isinstance(origin, type) and issubclass(origin, number): 230 | return check_array 231 | if origin in _PYTREE_DATACLASSES: 232 | return check_pytree_dataclass 233 | return None 234 | 235 | checker_lookup_functions.append(lookup) 236 | 237 | #### Array types 238 | class number: 239 | def __class_getitem__(cls, x): 240 | if isinstance(x, str): 241 | x = x.encode('utf-8') 242 | return GenericAlias(cls, x) 243 | 244 | class bool_(number): 245 | dtype = jnp.bool_ 246 | pass 247 | 248 | class bf16(number): 249 | dtype = jnp.bfloat16 250 | pass 251 | 252 | class f32(number): 253 | dtype = jnp.float32 254 | pass 255 | 256 | class i32(number): 257 | dtype = jnp.int32 258 | pass 259 | 260 | class u32(number): 261 | dtype = jnp.uint32 262 | pass 263 | 264 | class i8(number): 265 | dtype = jnp.int8 266 | pass 267 | 268 | class u8(number): 269 | dtype = jnp.uint8 270 | pass 271 | 272 | 273 | _PYTREE_DATACLASSES = set() 274 | 275 | 276 | def pytree_dataclass(cls): 277 | """Decorator that declares a dataclass that JAX recognizes as a PyTree.""" 278 | cls = dataclass(cls) 279 | 280 | def flatten_with_keys(value): 281 | return [(k.name, getattr(value, k.name)) for k in dataclasses.fields(cls)], () 282 | 283 | def unflatten(_aux, fields): 284 | return cls(*fields) 285 | 286 | jax.tree_util.register_pytree_with_keys(cls, flatten_with_keys, unflatten) 287 | _PYTREE_DATACLASSES.add(cls) 288 | return cls 289 | 290 | class Array: 291 | """If `cls` is an array type or a `pytree_dataclass` of array types, 292 | `Array[axes, cls]` will extend `cls` with leading axes `axes`. 293 | For example, `Array['layers', f32['batch d_model']] returns f32['layers batch d_model`]`. 294 | """ 295 | def __class_getitem__(cls, x): 296 | axes, input_cls = x 297 | if isinstance(axes, str): 298 | axes = axes.encode('utf-8') 299 | elif isinstance(axes, bytes): 300 | pass 301 | else: 302 | raise ValueError(f"input axes to {cls} must be Union[bytes, str]") 303 | 304 | if dataclasses.is_dataclass(input_cls): 305 | extended_fields = [] 306 | for fld in dataclasses.fields(input_cls): 307 | extended_type = Array[axes, fld.type] 308 | extended_fields.append((fld.name, extended_type)) 309 | 310 | extended_cls = make_dataclass(input_cls.__name__, extended_fields, bases=(input_cls,)) 311 | pytree_dataclass(extended_cls) 312 | return extended_cls 313 | else: 314 | number_type, shape = get_origin(input_cls), get_args(input_cls) 315 | extended_shape = (axes + b' ' + shape[0],) 316 | return GenericAlias(number_type, extended_shape) 317 | 318 | def make_partition_specs(cls): 319 | """Instantiates a pytree dataclass with a PartitionSpec at array type.""" 320 | # Check for a tuple type: 321 | origin = typing.get_origin(cls) 322 | args = typing.get_args(cls) 323 | if origin is tuple: 324 | return tuple(make_partition_specs(arg) for arg in args) 325 | elif origin is not None and issubclass(origin, number): 326 | if len(args) != 1 or (type(args[0]) is not str and type(args[0]) is not bytes): 327 | raise ValueError(f'Type annotation {cls} should be [], got {cls}') 328 | spec = ShapeSpec.parse(args[0]) 329 | return spec.partition_spec() 330 | elif dataclasses.is_dataclass(cls): 331 | values = [] 332 | for field in dataclasses.fields(cls): 333 | values.append(make_partition_specs(field.type)) 334 | return cls(*values) 335 | 336 | raise ValueError(f'Unsupported type {cls} is not a array, dataclass, or tuple type') 337 | 338 | 339 | def make_shardings(cls): 340 | """Instantiates a pytree dataclass with NamedSharding at array type.""" 341 | mesh = jax._src.mesh.thread_resources.env.physical_mesh 342 | return jax.tree_map(lambda spec: jax.sharding.NamedSharding(mesh, spec), make_partition_specs(cls)) 343 | 344 | 345 | def typed_shard_map(f, **kwargs): 346 | """jax.shard_map, but which does not require specifying in_specs and out_specs. 347 | 348 | Instead, the function signature is used to infer the partitioning of the inputs and outputs. 349 | 350 | For example: 351 | @typed_shard_map 352 | def f(x: f32[b'batch/d len'], y: f32[b'e/d f/t']) -> f32[b'batch/d f/t']: 353 | ... 354 | 355 | """ 356 | sig = inspect.signature(f) 357 | 358 | def wrapped(*args): 359 | mesh = jax._src.mesh.thread_resources.env.physical_mesh 360 | in_specs = tuple(make_partition_specs(param.annotation) for param in sig.parameters.values()) 361 | out_specs = make_partition_specs(sig.return_annotation) 362 | return jax.experimental.shard_map.shard_map(typechecked(f), in_specs=in_specs, out_specs=out_specs, mesh=mesh, **kwargs)(*args) 363 | 364 | return wrapped 365 | 366 | def is_fully_sharded(spec: jax.sharding.PartitionSpec): 367 | """Returns True if the spec is fully sharded, i.e. every device axis is used in the partition spec.""" 368 | axis_count = 0 369 | for axis in spec: 370 | if axis is None: 371 | continue 372 | elif isinstance(axis, str): 373 | axis_count += 1 374 | elif isinstance(axis, tuple): 375 | axis_count += len(axis) 376 | else: 377 | raise ValueError(f'Unknown axis type {axis}') 378 | return axis_count == len(jax._src.core.thread_local_state.trace_state.axis_env) 379 | -------------------------------------------------------------------------------- /synthetic_dataset.zarr/.zgroup: -------------------------------------------------------------------------------- 1 | { 2 | "zarr_format": 2 3 | } -------------------------------------------------------------------------------- /synthetic_dataset.zarr/train/.zattrs: -------------------------------------------------------------------------------- 1 | { 2 | "max_token_id": 1017 3 | } -------------------------------------------------------------------------------- /synthetic_dataset.zarr/train/.zgroup: -------------------------------------------------------------------------------- 1 | { 2 | "zarr_format": 2 3 | } -------------------------------------------------------------------------------- /synthetic_dataset.zarr/train/encoded_tokens/.zarray: -------------------------------------------------------------------------------- 1 | { 2 | "chunks": [ 3 | 4194304 4 | ], 5 | "compressor": { 6 | "blocksize": 0, 7 | "clevel": 5, 8 | "cname": "lz4", 9 | "id": "blosc", 10 | "shuffle": 2 11 | }, 12 | "dtype": " self.group.attrs["max_token_id"]: 78 | self.group.attrs["max_token_id"] = chunk.max_token_id 79 | # In parallel: 80 | with concurrent.futures.ThreadPoolExecutor() as executor: 81 | executor.submit(lambda: self.encoded_tokens.append(chunk.encoded_tokens)) 82 | executor.submit(lambda: self.seq_starts.append(num_tokens + chunk.seq_starts[1:])) 83 | 84 | 85 | -------------------------------------------------------------------------------- /tools/huggingface_to_flat_tokens.py: -------------------------------------------------------------------------------- 1 | """Tokenizes a Huggingface dataset and writes it to `flat-tokens` format. 2 | 3 | See `docs/flat-tokens.md` for details on the format. 4 | See `configs/c4_en.yaml` for an instructions on running. 5 | 6 | TODO: we could make this much faster by sharding over multiple CPUs. Rough approach: 7 | 1) Make this script read from a shard of the Huggingface dataset. 8 | 2) At the end of this script, wait for all shards to complete, and then concatenate the zarr data. 9 | """ 10 | 11 | import hydra 12 | 13 | from typing import Optional, Dict 14 | import time 15 | import numpy as np 16 | from dataclasses import dataclass 17 | from transformers import AutoTokenizer 18 | from hydra.core.config_store import ConfigStore 19 | from datasets import load_dataset 20 | from concurrent.futures import ThreadPoolExecutor 21 | import flat_tokens 22 | 23 | @dataclass 24 | class Config: 25 | output: str 26 | tokenizer: str 27 | dataset: str 28 | variant: Optional[str] 29 | max_tokens: Optional[int] 30 | write_buffer_size_in_sequences: int 31 | flat_tokens_config: flat_tokens.Config 32 | _target_: str = __name__ + ".Config" 33 | 34 | 35 | # Registering the Config class with the name 'config'. 36 | ConfigStore.instance().store(name="config_schema", node=Config) 37 | 38 | 39 | @hydra.main(config_path="configs", version_base=None) 40 | def main(config): 41 | # Create tokenizer 42 | if config.tokenizer == "bytes_utf8": 43 | def tokenize(texts): 44 | return [np.uint32(np.frombuffer(text.encode('utf-8'), np.uint8)) + 1 for text in texts] 45 | else: 46 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer) 47 | assert 0 in tokenizer.all_special_ids, "Tokenizer must have 0 as a special id" 48 | assert tokenizer.vocab_size < 1 << 31, "Tokenizer vocab size too large for uint31" 49 | def tokenize(texts): 50 | return tokenizer(texts, add_special_tokens=False)["input_ids"] 51 | 52 | def tokenize_and_concat(batch): 53 | chunk = flat_tokens.Chunk.from_ragged(tokenize(batch["text"])) 54 | # dataset.map() API requires us to return numpy tensors of the appropriate shape... 55 | return { 56 | "encoded_tokens": chunk.encoded_tokens[np.newaxis, :], 57 | "seq_starts": chunk.seq_starts[np.newaxis, :], 58 | "max_token_id": np.array(chunk.max_token_id, np.uint32)[np.newaxis], 59 | } 60 | 61 | executor = ThreadPoolExecutor() 62 | 63 | for split, mode in [("validation", "w-"), ("train", "r+")]: 64 | dataset = load_dataset( 65 | config.dataset, 66 | config.variant, 67 | streaming=True, 68 | split=split, 69 | ) 70 | dataset = dataset.select_columns(["text"]) 71 | dataset = dataset.map(tokenize_and_concat, batched=True, batch_size=config.write_buffer_size_in_sequences, remove_columns=["text"]) 72 | 73 | # Open output 74 | dst = flat_tokens.Writer(config.output, flat_tokens.Split(split), mode, config.flat_tokens_config) 75 | dst_flush = executor.submit(lambda: None) 76 | 77 | # Write in batches 78 | flush_elapsed = 0 79 | start_time = time.time() 80 | next_update = 0 81 | seq_count = 0 82 | token_count = 0 83 | for batch in dataset: 84 | chunk = flat_tokens.Chunk(encoded_tokens=batch["encoded_tokens"], seq_starts=batch["seq_starts"], max_token_id=batch["max_token_id"]) 85 | seq_count += len(chunk.seq_starts) - 1 86 | token_count += len(chunk.encoded_tokens) 87 | 88 | flush_start = time.time() 89 | dst_flush.result() 90 | dst_flush = executor.submit(dst.write, chunk) 91 | flush_elapsed += time.time() - flush_start 92 | elapsed = time.time() - start_time 93 | if elapsed > next_update: 94 | total_mib = token_count * 4 // (1024 * 1024) 95 | speed_mib_per_s = total_mib / elapsed 96 | print(f"[{int(elapsed):_}s] Processed {seq_count:_} examples, {token_count:_} tokens, {total_mib:_} MiB, {speed_mib_per_s:.2f} MiB/s. Flush time: {flush_elapsed:.2f}s") 97 | next_update = elapsed + 60 98 | 99 | if token_count >= config.max_tokens: 100 | break 101 | 102 | # Final flush 103 | dst_flush.result() 104 | 105 | print(f"Done with split '{split}': {seq_count:_} examples, {token_count:_} tokens") 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /tools/requirements.txt: -------------------------------------------------------------------------------- 1 | zarr 2 | fsspec[gcs] 3 | transformers 4 | datasets 5 | hydra-core 6 | -------------------------------------------------------------------------------- /tools/write_synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | # To run: 2 | # 3 | # ``` 4 | # python write_synthetic_dataset.py --config-name=synthetic_dataset +output=synthetic_dataset.zarr 5 | # ``` 6 | # 7 | # Synthetic tasks: see Section 4 of https://arxiv.org/abs/2002.09402 for some ideas. 8 | # 9 | # We do: 10 | # * Task 1: [theirs] fixed-distance copy. Requires attention position queries. 11 | # * Task 2: [theirs] fixed-distance reverse. Requires attention position queries. 12 | # * Task 3: [ours] random-distance (specified) copy. Requires variable-length position queries. 13 | # * Task 4: [ours] random-distance (not specified) copy. Requires equality matching for a prefix. 14 | # * Task 5: [ours] gaussians sampling. Requires model to learn which IDs are close to each other (in numerical order). 15 | # 16 | # Sequences begin with task ID, then have task-specific data. We avoid index 0, which indicates padding. 17 | 18 | from functools import partial 19 | import hydra 20 | from jaxtyping import Float, Int, jaxtyped, UInt32 21 | import numpy as np 22 | from typeguard import typechecked as typechecker 23 | from dataclasses import dataclass 24 | from hydra.core.config_store import ConfigStore 25 | import flat_tokens 26 | 27 | @dataclass 28 | class Config: 29 | output: str 30 | seed: int 31 | seq_len: int 32 | examples: int 33 | flat_tokens_config: flat_tokens.Config 34 | _target_: str = __name__ + '.Config' 35 | 36 | 37 | 38 | @jaxtyped(typechecker=typechecker) 39 | def copy(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']: 40 | seq = gen.integers(1, 11, (examples, (seq_len + 1) // 2), dtype=np.uint32) 41 | return np.append(seq, seq, axis=1)[:, :seq_len] 42 | 43 | @jaxtyped(typechecker=typechecker) 44 | def reverse(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']: 45 | seq = gen.integers(1, 11, (examples, (seq_len + 1) // 2), dtype=np.uint32) 46 | return np.append(seq, np.flip(seq, axis=1), axis=1)[:, :seq_len] 47 | 48 | @jaxtyped(typechecker=typechecker) 49 | def random_known_distance_copy(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']: 50 | distance = gen.integers(max(1, seq_len // 4), seq_len, (examples,), dtype=np.uint32) 51 | seq = gen.integers(1, 11, (examples, seq_len), dtype=np.uint32) 52 | indices = np.arange(seq_len - 1)[np.newaxis, :] % distance[:, np.newaxis] 53 | full_seq = seq[np.arange(examples)[:, np.newaxis], indices] 54 | assert full_seq.shape == (examples, seq_len - 1) 55 | return np.append(distance[:, np.newaxis], full_seq, axis=1) 56 | 57 | @jaxtyped(typechecker=typechecker) 58 | def random_unknown_distance_copy(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']: 59 | return random_known_distance_copy(seq_len + 1, examples, gen)[:, 1:] 60 | 61 | @jaxtyped(typechecker=typechecker) 62 | def mixture_of_gaussians(seq_len: int, examples: int, gen: np.random.Generator) -> UInt32[np.ndarray, 'batch seqlen']: 63 | centers = gen.uniform(0, 100, (examples, 3)).astype(np.float32) 64 | stddevs = gen.uniform(1, 4, (examples, 3)).astype(np.float32) 65 | sample_cluster_ids = gen.integers(0, 3, (examples, seq_len), dtype=np.uint32) 66 | batch_ids = np.arange(examples)[:, np.newaxis] 67 | sample_centers = centers[batch_ids, sample_cluster_ids] 68 | sample_stddevs = stddevs[batch_ids, sample_cluster_ids] 69 | floats = gen.normal(0, 1, (examples, seq_len)).astype(np.float32) * sample_stddevs + sample_centers 70 | return np.clip(np.round(floats).astype(np.uint32), 1, 100) 71 | 72 | @jaxtyped(typechecker=typechecker) 73 | def synthetic_task(config: Config, gen: np.random.Generator) -> list[UInt32[np.ndarray, '...']]: 74 | task_seq_len = config.seq_len - 1 75 | examples = config.examples 76 | copy_data = copy(task_seq_len, examples, gen) 77 | reverse_data = reverse(task_seq_len, examples, gen) 78 | random_known_distance_copy_data = random_known_distance_copy(task_seq_len, examples, gen) 79 | random_unknown_distance_copy_data = random_unknown_distance_copy(task_seq_len, examples, gen) 80 | mixture_of_gaussians_data = mixture_of_gaussians(task_seq_len, examples, gen) 81 | tasks = np.asarray([copy_data, reverse_data, random_known_distance_copy_data, random_unknown_distance_copy_data, mixture_of_gaussians_data]) 82 | task_id = gen.integers(1, 6, (examples,), dtype=np.uint32) 83 | targets = np.append(task_id[:, np.newaxis], tasks[task_id - 1, np.arange(examples)], axis=1) 84 | lengths = gen.integers(1, config.seq_len + 1, (examples,), dtype=np.uint32) 85 | ragged_targets = [targets[i, :lengths[i]] for i in range(examples)] 86 | return ragged_targets 87 | 88 | 89 | # Registering the Config class with the name 'config'. 90 | ConfigStore.instance().store(name="config_schema", node=Config) 91 | 92 | 93 | @hydra.main(config_path="configs", version_base=None) 94 | def main(config): 95 | config = hydra.utils.instantiate(config) 96 | gen = np.random.Generator(np.random.PCG64(config.seed)) 97 | 98 | for split, mode in [(flat_tokens.Split.VALIDATION, "w-"), (flat_tokens.Split.TRAIN, "r+")]: 99 | dst = flat_tokens.Writer(config.output, split, mode, config.flat_tokens_config) 100 | examples = synthetic_task(config, gen) 101 | dst.write(flat_tokens.Chunk.from_ragged(examples)) 102 | 103 | if __name__ == "__main__": 104 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Main training loop, including the model, loss function, and optimizer.""" 2 | import operator 3 | import os 4 | import time 5 | 6 | import env 7 | env.set_variables() 8 | import shardlib.shardtypes as shardtypes 9 | shardtypes.register_with_typeguard() 10 | import gcsfs # Needed for clearml setup 11 | 12 | import datetime 13 | from functools import cached_property, partial 14 | from typing import Any, Optional, Tuple, Union 15 | import hydra 16 | from typeguard import typechecked 17 | from dataclasses import dataclass 18 | import jax 19 | from jax import lax 20 | from jax.sharding import PartitionSpec 21 | import jax.numpy as jnp 22 | import math 23 | from input_loader import FlatTokensParams, HuggingFaceDataParams, TokenBatch, TokenBatchParams, get_loader 24 | from shardlib.shardtypes import bf16, bool_, f32, pytree_dataclass, u32, make_shardings, Array 25 | import shardlib.shardops as shardops 26 | P = PartitionSpec 27 | import einops 28 | import jax_extra 29 | from jax_extra import fold_in_str, explicit_activation_checkpointing, save_for_backward 30 | import os 31 | import training_io 32 | from clearml import Task 33 | from jax.experimental import mesh_utils 34 | from jax.sharding import Mesh 35 | from jax.tree_util import tree_leaves 36 | 37 | PRNGKey = Any 38 | 39 | @dataclass(frozen=True) 40 | class Hparams: 41 | d_model: int 42 | n_q_per_kv: int 43 | n_kv: int 44 | d_head: int 45 | layers: int 46 | vocab: int 47 | d_ff: int 48 | rope_max_timescale: int 49 | 50 | @pytree_dataclass 51 | class TransformerLayer: 52 | ln1: f32['d_model/t/d'] 53 | ln2: f32['d_model/t/d'] 54 | w_q: f32['d_model/d n_q_per_kv n_kv/t d_head'] 55 | w_kv: f32['2 d_model/d n_kv/t d_head'] 56 | w_o: f32['d_model/d n_q_per_kv n_kv/t d_head'] 57 | w_gate: f32['d_model/d d_ff/t'] 58 | w_up: f32['d_model/d d_ff/t'] 59 | w_down: f32['d_model/d d_ff/t'] 60 | 61 | Transformer = Array['layers', TransformerLayer] 62 | 63 | @pytree_dataclass 64 | class Model: 65 | embed: f32['vocab/t d_model/d'] 66 | unembed: f32['vocab/t d_model/d'] 67 | transformer: Transformer 68 | final_layer_norm: f32['d_model/d/t'] 69 | 70 | @staticmethod 71 | @typechecked 72 | def init(h: Hparams, rng: PRNGKey) -> 'Model': 73 | embed = jax.random.normal(jax_extra.fold_in_str(rng, 'embed'), (h.vocab, h.d_model), dtype=jnp.float32) 74 | # https://github.com/google/jax/issues/20390 for ones_like with sharding. 75 | ln1 = jnp.ones((h.layers, h.d_model), dtype=jnp.float32) 76 | ln2 = jnp.ones((h.layers, h.d_model), dtype=jnp.float32) 77 | final_layer_norm = jnp.ones((h.d_model,), dtype=jnp.float32) 78 | 79 | # All of wi/wq/wo/wo/w_kv use truncated_normal initializers with 'fan_in' scaling, 80 | # i.e. variance set to 1.0/fan_in. 81 | # The constant is stddev of standard normal truncated to (-2, 2) 82 | truncated_normal_stddev = .87962566103423978 83 | 84 | # scale for tensors with d_model fan_in and truncated normal truncated to (-2, 2) 85 | d_model_scale = 1 / (math.sqrt(h.d_model) * truncated_normal_stddev) 86 | 87 | w_kv_scale = d_model_scale 88 | w_q_scale = d_model_scale / math.sqrt(h.d_head) 89 | total_head_dim = h.n_q_per_kv * h.n_kv * h.d_head 90 | w_o_scale = 1 / (math.sqrt(total_head_dim) * truncated_normal_stddev) 91 | w_up_scale = d_model_scale 92 | w_down_scale = 1 / (math.sqrt(h.d_ff) * truncated_normal_stddev) 93 | unembed_scale = d_model_scale 94 | 95 | w_q_shape = (h.layers, h.d_model, h.n_q_per_kv, h.n_kv, h.d_head) 96 | w_q = w_q_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_q'), -2, 2, w_q_shape, dtype=jnp.float32) 97 | w_kv_shape = (h.layers, 2, h.d_model, h.n_kv, h.d_head) 98 | w_kv = w_kv_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_kv'), -2, 2, w_kv_shape, dtype=jnp.float32) 99 | w_o_shape = w_q_shape 100 | w_o = w_o_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_o'), -2, 2, w_o_shape, dtype=jnp.float32) 101 | 102 | ff_shape = (h.layers, h.d_model, h.d_ff) 103 | w_gate = w_up_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_gate'), -2, 2, ff_shape, dtype=jnp.float32) 104 | w_up = w_up_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_up'), -2, 2, ff_shape, dtype=jnp.float32) 105 | w_down = w_down_scale * jax.random.truncated_normal(fold_in_str(rng, 'w_down'), -2, 2, ff_shape, dtype=jnp.float32) 106 | 107 | unembed = unembed_scale * jax.random.truncated_normal(fold_in_str(rng, 'unembed'), -2, 2, (h.vocab, h.d_model), dtype=jnp.float32) 108 | arrays = Model( 109 | embed=embed, 110 | unembed=unembed, 111 | transformer=Transformer( 112 | ln1=ln1, 113 | ln2=ln2, 114 | w_q=w_q, 115 | w_kv=w_kv, 116 | w_o=w_o, 117 | w_gate=w_gate, 118 | w_up=w_up, 119 | w_down=w_down, 120 | ), 121 | final_layer_norm=final_layer_norm, 122 | ) 123 | shardings = make_shardings(Model) 124 | return jax.tree.map(lax.with_sharding_constraint, arrays, shardings) 125 | 126 | 127 | @typechecked 128 | def forward_pass(self, h: Hparams, ids: u32[b'B/d L'], is_seq_start: bool_[b'B/d L']) -> f32[b'B/d L V/t']: 129 | ##### Initial embedding lookup. 130 | embed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(self.embed)) 131 | x = shardops.index_unreduced('[V/t] M, B/d L -> B/d L M', embed, ids) 132 | x = shardops.psum_scatter('B/d L M -> B/d L M/t', x) 133 | 134 | L = ids.shape[1] 135 | segment_ids = jnp.cumsum(is_seq_start, axis=1) 136 | segment_mask: bool_[b'B/d L L'] = segment_ids[:, :, jnp.newaxis] == segment_ids[:, jnp.newaxis, :] 137 | segment_mask: bool_[b'B/d L L 1 1'] = segment_mask[..., jnp.newaxis, jnp.newaxis] # add axes for q_per_k, num_kv_heads dimensions 138 | causal_mask: bool_[b'1 L L 1 1'] = jnp.tril(jnp.ones((L, L), dtype=jnp.bool_), 0)[jnp.newaxis, ..., jnp.newaxis, jnp.newaxis] 139 | causal_mask: bool_[b'B/d L L 1 1'] = jnp.logical_and(segment_mask, causal_mask) 140 | 141 | rope_table = RopeTable.create(L, h) 142 | 143 | ##### Transformer blocks. 144 | @explicit_activation_checkpointing 145 | @typechecked 146 | def loop_body(x: bf16[b'B/d L M/t'], layer_weights: TransformerLayer) -> Tuple[bf16[b'B/d L M/t'], Tuple[()]]: 147 | # Pre-attention RMSNorm 148 | ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln1)) 149 | gx = shardops.all_gather('B/d L M/t -> B/d L M', x) 150 | nx = jnp.bfloat16(rms_norm(gx) * ln1) 151 | 152 | # Attention, using Grouped Query Attention and RoPE position embeddings. 153 | w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_q)) 154 | q = save_for_backward(shardops.einsum_unreduced('B/d L M, M Q K/t D -> B/d L Q K/t D', nx, w_q)) 155 | q = rope_table.apply('L D -> 1 L 1 1 D', q) 156 | w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(layer_weights.w_kv)) 157 | k, v = shardops.einsum_unreduced('B/d L M, k_v M K/t D -> k_v B/d L K/t D', nx, w_kv) 158 | k = save_for_backward(k) 159 | v = save_for_backward(v) 160 | k = rope_table.apply('L d -> 1 L 1 d', k) 161 | logits = shardops.einsum_unreduced( 162 | 'B/d Qlen Q K/t D, B/d Klen K/t D -> B/d Qlen Klen Q K/t', q, k, preferred_element_type=jnp.float32) 163 | logits = jnp.where(causal_mask, logits, -1e10) 164 | probs = jnp.bfloat16(jax.nn.softmax(logits, axis=2)) 165 | attn_out = shardops.einsum_unreduced( 166 | 'B/d Qlen Klen Q K/t, B/d Klen K/t D -> B/d Qlen Q K/t D', probs, v) 167 | w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_o)) 168 | attn_out = shardops.einsum_unreduced('B/d Qlen Q K/t D, M Q K/t D -> B/d Qlen M', attn_out, w_o) 169 | attn_out = shardops.psum_scatter('B/d Qlen M -> B/d Qlen M/t', attn_out) 170 | x = save_for_backward(x + attn_out) 171 | 172 | # Pre-FFN RMSNorm 173 | ln2 = save_for_backward(shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln2))) 174 | gx = shardops.all_gather('B/d L M/t -> B/d L M', x) 175 | nx = jnp.bfloat16(rms_norm(gx) * ln2) 176 | 177 | # FFN, using SwiGLU 178 | w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_gate)) 179 | gate_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_gate)) 180 | w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_up)) 181 | up_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_up)) 182 | y = jax.nn.swish(gate_proj) * up_proj 183 | w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_down)) 184 | ffn_out = shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', y, w_down) 185 | ffn_out = shardops.psum_scatter('B/d L M -> B/d L M/t', ffn_out) 186 | 187 | return jnp.bfloat16(x + ffn_out), () 188 | 189 | x, () = jax.lax.scan(loop_body, jnp.bfloat16(x), self.transformer) 190 | 191 | ##### Final layernorm and output projection. 192 | x = shardops.all_gather('B/d L M/t -> B/d L M', x) 193 | ln = shardops.all_gather('M/t/d -> M', jnp.float32(self.final_layer_norm)) 194 | x = jnp.bfloat16(rms_norm(x) * ln) 195 | unembed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(self.unembed)) 196 | logits = shardops.einsum_unreduced('B/d L M, V/t M -> B/d L V/t', x, unembed, preferred_element_type=jnp.float32) 197 | 198 | return logits 199 | 200 | 201 | @typechecked 202 | def loss(self, h: Hparams, batch: TokenBatch) -> f32[b'']: 203 | # Given sequence-packed targets: 204 | # [[1, 2], [3, 4, 5], [6, 7, 8, 9]] 205 | # we want inputs: 206 | # [[0, 1], [0, 3, 4], [0, 6, 7, 8]] 207 | # which we get by shifting the targets right by 1 and 208 | # masking sequence-start tokens to 0. 209 | inputs = jnp.pad(batch.targets[:, :-1], pad_width=((0, 0), (1, 0))) 210 | is_seq_start: bool_[b'batch/d len'] = batch.is_seq_start 211 | inputs: u32[b'batch/d len'] = jnp.where(is_seq_start, 0, inputs) 212 | 213 | logits: f32[b'batch/d len V/t'] = self.forward_pass(h, inputs, is_seq_start) 214 | max_logits: f32[b'batch/d len 1'] = lax.pmax(jnp.max(lax.stop_gradient(logits), axis=-1, keepdims=True), 't') 215 | logits = logits - max_logits 216 | sum_logits = lax.psum(jnp.sum(jnp.exp(logits), axis=-1, keepdims=True), 't') 217 | logsumexp = jnp.log(sum_logits) 218 | logprobs: f32[b'batch/d len V/t'] = logits - logsumexp 219 | logprobs_at_targets = shardops.index_unreduced('batch/d len [V/t], batch/d len -> batch/d len', logprobs, batch.targets) 220 | logprobs_at_targets = shardops.psum_scatter('batch/d len -> batch/d len/t', logprobs_at_targets) 221 | tokens_in_global_batch = logprobs_at_targets.size * jax.lax.psum(1, ('d', 't')) 222 | return -jnp.sum(logprobs_at_targets) / jnp.float32(tokens_in_global_batch) 223 | 224 | 225 | @pytree_dataclass 226 | class RopeTable: 227 | sin: f32['len d_head2'] 228 | cos: f32['len d_head2'] 229 | 230 | @staticmethod 231 | def create(max_len: int, hparams: Hparams) -> 'RopeTable': 232 | rope_max_timescale = hparams.rope_max_timescale 233 | d_head = hparams.d_head 234 | d = d_head // 2 235 | # endpoint=False is equivalent to what MaxText does. endpoint=True would be more natural, though. 236 | timescale = jnp.logspace(0, jnp.log10(jnp.float32(rope_max_timescale)), d, endpoint=False) 237 | position = jnp.arange(max_len, dtype=jnp.int32) 238 | sinusoid_inp = jnp.float32(position[:, jnp.newaxis]) / timescale[jnp.newaxis, :] 239 | sin = jnp.sin(sinusoid_inp) 240 | cos = jnp.cos(sinusoid_inp) 241 | return RopeTable(sin=sin, cos=cos) 242 | 243 | def apply(self, rearrange_spec, x): 244 | x1, x2 = jnp.split(x, 2, axis=-1) 245 | sin = einops.rearrange(self.sin, rearrange_spec) 246 | cos = einops.rearrange(self.cos, rearrange_spec) 247 | r1 = x1 * cos - x2 * sin 248 | r2 = x2 * cos + x1 * sin 249 | return jnp.append(r1, r2, axis=-1) 250 | 251 | 252 | @typechecked 253 | def rms_norm(x: bf16[b'batch/d len M']) -> bf16[b'batch/d len M']: 254 | mean2 = save_for_backward(jnp.mean(jax.lax.square(jnp.float32(x)), axis=-1, keepdims=True)) 255 | return jnp.bfloat16(x * jax.lax.rsqrt(mean2 + 1e-6)) 256 | 257 | 258 | @pytree_dataclass 259 | class Metrics: 260 | loss: f32[b''] 261 | learning_rate: f32[b''] 262 | grad_norm: f32[b''] 263 | raw_grad_norm: f32[b''] 264 | 265 | 266 | @dataclass(frozen=True) 267 | class TrainingHparams: 268 | adam_b1: float 269 | adam_b2: float 270 | adam_eps: float 271 | adam_eps_root: float 272 | weight_decay: float 273 | warmup_steps: int 274 | steps: int 275 | steps_for_lr: int 276 | cosine_learning_rate_final_fraction: float 277 | learning_rate: float 278 | tokens: TokenBatchParams 279 | seed: int 280 | queue: Optional[str] = None 281 | 282 | @pytree_dataclass 283 | class State: 284 | weights: Model 285 | adam_mu: Model 286 | adam_nu: Model 287 | 288 | @staticmethod 289 | def init(hparams: Hparams, rng: PRNGKey) -> 'State': 290 | weights = Model.init(hparams, rng) 291 | adam_mu = jax.tree.map(lambda p: p * 0.0, weights) 292 | adam_nu = jax.tree.map(lambda p: p * 0.0, weights) 293 | return State(weights=weights, adam_mu=adam_mu, adam_nu=adam_nu) 294 | 295 | @partial(jax.jit, static_argnums=(2, 3), donate_argnums=(0,)) 296 | def training_step(state: State, step: u32[b''], h: Hparams, hparams: TrainingHparams, batch: TokenBatch) -> Tuple[Any, Metrics]: 297 | @partial(shardtypes.typed_shard_map, check_rep=False) # check_rep=False for https://github.com/google/jax/issues/20335 298 | def sharded_step(state: State, step: u32[b''], batch: TokenBatch) -> Tuple[State, Metrics]: 299 | loss, grad = jax.value_and_grad(lambda weights: weights.loss(h, batch))(state.weights) 300 | # Gradients have already been reduced across chips because the gradient of the weight `all_gather` 301 | # is weight-gradient `psum_scatter`. Loss, on the other hand, hasn't been reduced across chips: if we 302 | # did that inside the autodiff, we'd be double-reducing the loss, effectively multiplying it by the 303 | # amount of data parallelism. 304 | # 305 | # So we reduce the loss across chips _outside_ the autodiff. 306 | loss = jax.lax.psum(loss, ('d', 't')) 307 | 308 | # Other than global-norm of gradients, no other communication is needed during the weight update, 309 | # because weights and grads are already fully sharded, as checked below. 310 | 311 | # Calculate learning rate from step number. 312 | # We use linear warmup then cosine decay. See https://arxiv.org/pdf/2307.09288.pdf section 2.2 313 | warmup_lr = (jnp.float32(step) / jnp.float32(hparams.warmup_steps)) * hparams.learning_rate 314 | cosine = jnp.cos(jnp.pi * (jnp.float32(step - hparams.warmup_steps) / jnp.float32(hparams.steps_for_lr - hparams.warmup_steps))) 315 | cosine_lr = hparams.learning_rate * (hparams.cosine_learning_rate_final_fraction + (1 - hparams.cosine_learning_rate_final_fraction) * (cosine * .5 + .5)) 316 | lr = jnp.where(step < hparams.warmup_steps, warmup_lr, cosine_lr) 317 | 318 | # AdamW optimizer with global gradient clipping. 319 | grad_leaves, grad_treedef = jax.tree_util.tree_flatten(grad) 320 | global_norm_square = jnp.float32(0.0) 321 | for g in grad_leaves: 322 | assert g.dtype == jnp.float32 323 | global_norm_square += jnp.sum(jax.lax.square(g)) 324 | global_norm_square = jax.lax.psum(global_norm_square, ('d', 't')) 325 | global_norm = jnp.sqrt(global_norm_square) 326 | rescale = jnp.minimum(1.0, 1.0 / global_norm) 327 | 328 | new_ps = [] 329 | new_mus = [] 330 | new_nus = [] 331 | for p, g, mu, nu, spec in zip(tree_leaves(state.weights), grad_leaves, tree_leaves(state.adam_mu), tree_leaves(state.adam_nu), tree_leaves(shardtypes.make_partition_specs(State))): 332 | assert shardtypes.is_fully_sharded(spec), 'Weight update is only correctly scaled for fully sharded weights.' 333 | # Gradient clipping 334 | g = g * rescale 335 | # Adam scaling 336 | mu = (1 - hparams.adam_b1) * g + hparams.adam_b1 * mu 337 | nu = (1 - hparams.adam_b2) * jax.lax.square(g) + hparams.adam_b2 * nu 338 | # We need step numbers to start at 1, not 0. Otherwise the bias correction produces NaN. 339 | completed_steps = step + 1 340 | mu_hat = mu / (1 - jnp.float32(hparams.adam_b1)**completed_steps) 341 | nu_hat = nu / (1 - jnp.float32(hparams.adam_b2)**completed_steps) 342 | g = mu_hat / (jnp.sqrt(nu_hat + hparams.adam_eps_root) + hparams.adam_eps) 343 | # Weight decay 344 | g += hparams.weight_decay * p 345 | # Learning rate 346 | g *= lr 347 | 348 | # Apply update 349 | new_ps.append(p - g) 350 | new_mus.append(mu) 351 | new_nus.append(nu) 352 | 353 | new_state = State( 354 | weights=jax.tree_util.tree_unflatten(grad_treedef, new_ps), 355 | adam_mu=jax.tree_util.tree_unflatten(grad_treedef, new_mus), 356 | adam_nu=jax.tree_util.tree_unflatten(grad_treedef, new_nus), 357 | ) 358 | metrics = Metrics( 359 | loss=loss, 360 | learning_rate=lr, 361 | grad_norm=global_norm * rescale, 362 | raw_grad_norm=global_norm, 363 | ) 364 | return new_state, metrics 365 | 366 | return sharded_step(state, step, batch) 367 | 368 | 369 | @dataclass(frozen=True) 370 | class Paths: 371 | root_working_dir: str 372 | model_name: str 373 | 374 | @dataclass(frozen=True) 375 | class MeshConfig: 376 | d: int 377 | t: int 378 | 379 | 380 | @dataclass(frozen=True) 381 | class Config: 382 | model: Hparams 383 | training: TrainingHparams 384 | paths: Paths 385 | num_hosts: int 386 | checkpoint_interval: int 387 | mesh: MeshConfig 388 | io: training_io.IOConfig 389 | flat_tokens: Optional[FlatTokensParams] = None 390 | hf_dataset: Optional[HuggingFaceDataParams] = None 391 | 392 | def __post_init__(self): 393 | assert self.flat_tokens is not None or self.hf_dataset is not None, 'Must provide either flat_tokens or hf_dataset.' 394 | assert not (self.flat_tokens is not None and self.hf_dataset is not None), 'Should not specify both flat_tokens and hf_dataset.' 395 | 396 | @cached_property 397 | def training_data(self) -> Union[FlatTokensParams, HuggingFaceDataParams]: 398 | return self.flat_tokens or self.hf_dataset 399 | 400 | def main_contained(config, logger): 401 | """Main program, which does not access external services except as specified by config.paths or logger.""" 402 | # Use partitionable (and hopefully fusable!) RNG. 403 | # 404 | # This is slower in compute time than 'unsafe_rbg' with flag '--xla_tpu_spmd_rng_bit_generator_unsafe=true', 405 | # but hopefully faster in memory time because it's fusable. 406 | # TODO: check this is true and if not, provide our own that actually is fusable. 407 | jax.config.update('jax_threefry_partitionable', True) 408 | with Mesh(mesh_utils.create_device_mesh([config.mesh.d, config.mesh.t], jax.devices()), ('d', 't')): 409 | root_rng = jax.random.PRNGKey(config.training.seed) 410 | 411 | loader = get_loader('train', config.training_data, config.training.tokens) 412 | assert config.model.vocab > loader.max_token_id, f"{config.model.vocab} vs {loader.max_token_id}" 413 | 414 | model_dir = os.path.join(config.paths.root_working_dir, config.paths.model_name) 415 | training_io.mkdir(model_dir) 416 | state = jax.jit(partial(State.init, config.model))(fold_in_str(root_rng, 'init')) 417 | state, start_step = training_io.load_checkpoint_if_it_exists(model_dir, state, config.io) 418 | 419 | # Explicitly compile training step, to record XLA HLO graph. 420 | # See https://bnikolic.co.uk/blog/python/jax/2022/02/22/jax-outputgraph-rev 421 | c_training_step = training_step.lower(state, jnp.uint32(0), config.model, config.training, loader.load(0)).compile() 422 | date = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 423 | training_io.save_hlo_svg(os.path.join(model_dir, f'training_step_optimized_hlo_{date}.svg'), c_training_step) 424 | 425 | for step in range(start_step, config.training.steps): 426 | if step % config.checkpoint_interval == 0 and step > start_step: 427 | training_io.save_checkpoint(model_dir, step, state, config.io) 428 | 429 | # We profile on the second step, because the first step has a long pause for XLA 430 | # compilation and initial shuffle buffer loading. 431 | if jax.process_index() == 0 and step == start_step + 1: 432 | jax.block_until_ready(state) 433 | training_io.start_profile() 434 | profile_start = time.time() 435 | 436 | state, output = c_training_step(state, jnp.uint32(step), loader.load(step)) 437 | 438 | # Run profile for two steps, to include data loading time in between them. 439 | if jax.process_index() == 0 and step == start_step + 2: 440 | jax.block_until_ready(state) 441 | profile_duration = time.time() - profile_start 442 | training_io.stop_profile(model_dir) 443 | 444 | # Print MFU, including (one step of) data loading time. 445 | print(f"Profile time: {profile_duration}s for 2 steps.") 446 | model_params = jax.tree.reduce(operator.add, jax.tree.map(lambda w: w.size, state.weights)) 447 | tokens = loader.load(step).targets.size 448 | print(f'Model params: {model_params:_}') 449 | print(f'Tokens: {tokens:_}') 450 | device_flops = training_io.get_flops_per_device() 451 | num_devices = jax.device_count() 452 | print(f'MFU (projections only): {100 * (2 * 6 * model_params * tokens / (num_devices * profile_duration)) / device_flops:.2f}% MFU') 453 | 454 | training_io.log(step, logger, output) 455 | 456 | 457 | @hydra.main(config_path='configs', version_base=None) 458 | def main(config): 459 | config = jax_extra.make_dataclass_from_dict(Config, config) 460 | if config.training.queue: 461 | task = Task.init(project_name='testing', task_name=config.paths.model_name) 462 | logger = task.get_logger() 463 | task.execute_remotely(queue_name=config.training.queue) 464 | task.launch_multi_node(config.num_hosts, wait=True) 465 | if int(os.environ['RANK']) > 0: 466 | task.set_system_tags((task.get_system_tags() or []) + ['hidden']) 467 | jax.distributed.initialize(os.environ['MASTER_ADDR'] + ':' + os.environ['MASTER_PORT'], 468 | num_processes=int(os.environ['WORLD_SIZE']), 469 | process_id=int(os.environ['RANK'])) 470 | else: 471 | logger = None 472 | main_contained(config, logger) 473 | 474 | 475 | if __name__ == "__main__": 476 | main() 477 | -------------------------------------------------------------------------------- /training_io.py: -------------------------------------------------------------------------------- 1 | """Provides IO support for training: 2 | * checkpoint save and load 3 | * metrics logging 4 | * profiling of XLA computations 5 | * reporting FLOPs per device 6 | """ 7 | import jax 8 | import jax.numpy as jnp 9 | from jax.experimental import multihost_utils 10 | from typing import Tuple, Any 11 | from dataclasses import dataclass 12 | import os 13 | import fsspec 14 | import zarr 15 | from numcodecs import blosc 16 | from clearml import Logger 17 | import numpy as np 18 | import datetime 19 | import concurrent 20 | import jax.profiler 21 | import tempfile 22 | import shutil 23 | from jax.lib import xla_client 24 | 25 | PyTree = Any 26 | 27 | @dataclass 28 | class IOConfig: 29 | # Max number of threads to use for IO-bound tasks like saving and loading checkpoints. 30 | # Recommendation: about 1MiB/thread is typical, so 1024 thread is reasonable for 1GiB of overhead. 31 | # Since this work is IO-bound rather than CPU-bound, it is fine to have many more threads than 32 | # CPU cores. 33 | max_io_threads: int 34 | 35 | def log(step: int, logger: Logger, output: PyTree): 36 | """Logs the output of a training step. The output must be a PyTree of f32 arrays.""" 37 | if jax.process_index() == 0: 38 | metrics_dict = {} 39 | for path, arr in jax.tree_util.tree_leaves_with_path(output): 40 | path = jax.tree_util.keystr(path) 41 | arr = jax.device_get(arr) 42 | if arr.shape == () and arr.dtype == jnp.float32: 43 | if logger: 44 | logger.report_scalar( 45 | title=path, series=path, value=arr, iteration=step) 46 | metrics_dict[path] = float(arr) 47 | elif arr.dtype == jnp.float32: 48 | if logger: 49 | logger.report_histogram( 50 | title=path, series=path, values=arr, iteration=step) 51 | else: 52 | raise ValueError(f"Output {path} has unsupported shape {arr.shape} and dtype {arr.dtype}.") 53 | now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 54 | print(f"[{now}] Step {step}: {metrics_dict}") 55 | 56 | 57 | def load_checkpoint_if_it_exists(checkpoint_dir: str, state: PyTree, config: IOConfig) -> Tuple[PyTree, int]: 58 | """Loads the latest checkpoint if it exists, otherwise return the initial state. 59 | 60 | In either case, uses the sharding and PyTree structure of `state` to produce the output. 61 | 62 | Since the state may occupy a large amount of memory, this function makes sure to delete `state` 63 | before loading the checkpoint. To facilitate this, callers should ensure not to hold on to any 64 | additional references to `state` when calling this function. 65 | 66 | Returns state and step number. Step 0 is the initial state, which may or may not have been loaded 67 | from a checkpoint. 68 | """ 69 | blosc.use_threads = False # Blindly following recommendation from https://zarr.readthedocs.io/en/stable/tutorial.html#parallel-computing-and-synchronization 70 | checkpoint_dir_pseudofile = fsspec.open(checkpoint_dir) 71 | fs = checkpoint_dir_pseudofile.fs 72 | checkpoint_dir_path = checkpoint_dir_pseudofile.path 73 | del checkpoint_dir_pseudofile 74 | 75 | # Check working_dir for checkpoint files. 76 | # Process index 0 selects the checkpoint, then broadcasts it to everyone else. 77 | selected_checkpoint = -1 78 | if jax.process_index() == 0: 79 | if fs.exists(checkpoint_dir_path): 80 | # fs.mkdir(checkpoint_dir, create_parents=False) 81 | checkpoint_dirs = fs.ls(checkpoint_dir_path) 82 | for c in reversed(sorted(checkpoint_dirs)): 83 | try: 84 | checkpoint_number = int(os.path.basename(c)) 85 | except ValueError: 86 | continue 87 | root = zarr.open_group(zarr.storage.FSStore(c, fs=fs)) 88 | if "write_completed" not in root.attrs: 89 | print(f"zarr 'write_completed' marker is missing in checkpoint {c}; skipping.") 90 | continue 91 | selected_checkpoint = checkpoint_number 92 | break 93 | selected_checkpoint = multihost_utils.broadcast_one_to_all(jnp.int32(selected_checkpoint)) 94 | 95 | if selected_checkpoint == -1: 96 | print(f"No checkpoints found in {checkpoint_dir_path}, starting from initial state.") 97 | return state, 0 98 | 99 | print(f'Found checkpoint {selected_checkpoint} in {checkpoint_dir_path}, starting from there.') 100 | return load_zarr(os.path.join(checkpoint_dir, step_to_str(selected_checkpoint)), state, config), selected_checkpoint 101 | 102 | 103 | def save_checkpoint(checkpoint_dir: str, step: int, state: PyTree, config: IOConfig): 104 | """Saves a checkpoint for the specified step number. 105 | 106 | See docs/pytree-zarr-checkpoint.md for the checkpoint format. 107 | """ 108 | blosc.use_threads = False 109 | checkpoint_file = os.path.join(checkpoint_dir, step_to_str(step)) 110 | if jax.process_index() == 0: 111 | # If there's already a checkpoint at this step, delete it. It might have been a partially 112 | # written checkpoint from a previous run. 113 | f = fsspec.open(checkpoint_dir) 114 | checkpoint_path = os.path.join(f.path, step_to_str(step)) 115 | if f.fs.exists(checkpoint_path): 116 | f.fs.rm(checkpoint_path, recursive=True) 117 | 118 | print(f"[{datetime.datetime.now()}] Saving checkpoint {step} to {checkpoint_file}.") 119 | save_zarr(checkpoint_file, state, config) 120 | print(f"[{datetime.datetime.now()}] Finished saving checkpoint {step} to {checkpoint_file}.") 121 | 122 | 123 | def load_zarr(filename: str, state: PyTree, config: IOConfig) -> PyTree: 124 | """Loads a zarr checkpoint from disk. 125 | 126 | See docs/pytree-zarr-checkpoint.md for the checkpoint format. 127 | """ 128 | root = zarr.open_group(filename, mode="r") 129 | if "write_completed" not in root.attrs: 130 | raise ValueError(f"zarr 'write_completed' marker is missing. Should not have selected this checkpoint to load from.") 131 | 132 | def load_one(path: Tuple, prev: jax.Array) -> jax.Array: 133 | path = jax.tree_util.keystr(path) 134 | shape = prev.shape 135 | sharding = prev.sharding 136 | arr = root[path] 137 | assert arr.shape == shape, f'Expected shape {shape} but got {arr.shape} for {path} in {filename}' 138 | assert arr.dtype == prev.dtype, f'Expected dtype {prev.dtype} but got {arr.dtype} for {path} in {filename}' 139 | del prev # Deallocate memory before loading its replacement! 140 | return jax.make_array_from_callback(shape, sharding, lambda shard_index: arr[shard_index]) 141 | 142 | state, treedef = jax.tree_util.tree_flatten_with_path(state) 143 | with concurrent.futures.ThreadPoolExecutor(max_workers=config.max_io_threads) as executor: 144 | state_futures = [executor.submit(load_one, path, shape) for (path, shape) in state] 145 | states = [f.result() for f in state_futures] 146 | return jax.tree_util.tree_unflatten(treedef, states) 147 | 148 | 149 | def save_zarr(filename: str, state: PyTree, config: IOConfig): 150 | """Saves a zarr checkpoint to disk. 151 | 152 | See docs/pytree-zarr-checkpoint.md for the checkpoint format. 153 | """ 154 | state, _treedef = jax.tree_util.tree_flatten_with_path(state) 155 | 156 | if jax.process_index() == 0: 157 | # Create the zarr file and all the arrays. 158 | try: 159 | root = zarr.open_group(filename, mode='w-') 160 | except zarr.errors.ContainsGroupError: 161 | raise ValueError(f"Checkpoint {filename} already exists.") 162 | for path, arr in state: 163 | path = jax.tree_util.keystr(path) 164 | chunk_shape = arr.sharding.shard_shape(arr.shape) 165 | root.empty(path, shape=arr.shape, chunks=chunk_shape, dtype=arr.dtype) 166 | multihost_utils.sync_global_devices("save_zarr_begin") 167 | 168 | root = zarr.open_group(filename, mode='r+') 169 | 170 | def save_shard(dst: zarr.Array, shard: jax.Array, index: Tuple[int, ...]): 171 | dst[index] = np.asarray(shard) 172 | 173 | with concurrent.futures.ThreadPoolExecutor(max_workers=config.max_io_threads) as executor: 174 | for path, arr in state: 175 | path = jax.tree_util.keystr(path) 176 | dst = root[path] 177 | assert dst.chunks == arr.sharding.shard_shape(arr.shape) 178 | for shard in arr.addressable_shards: 179 | if shard.replica_id == 0: 180 | executor.submit(save_shard, dst, shard.data, shard.index) 181 | 182 | multihost_utils.sync_global_devices("save_zarr_end") 183 | if jax.process_index() == 0: 184 | root.attrs["write_completed"] = True 185 | multihost_utils.sync_global_devices("save_zarr_committed") 186 | 187 | def step_to_str(step: int) -> str: 188 | """Converts a step number to a string with leading zeros. 189 | 190 | We pad up to 10 digits so that lexicographic order matches numerical. 1e10 training steps 191 | should be enough for anyone: the biggest runs as of 2024 are probably around 1e7 tokens/batch, 192 | 1e13 tokens total, so 1e6 training steps total. 193 | """ 194 | return str(step).zfill(10) 195 | 196 | _PROFILE_DIR = None 197 | 198 | def start_profile(): 199 | """Starts gathering a JAX profile.""" 200 | # Get fresh temporary directory 201 | global _PROFILE_DIR 202 | _PROFILE_DIR = tempfile.mkdtemp() 203 | print(f'[{datetime.datetime.now()}] Starting profile, saving to {_PROFILE_DIR}') 204 | jax.profiler.start_trace(_PROFILE_DIR, create_perfetto_trace=True) 205 | 206 | def stop_profile(working_dir: str): 207 | """Stops gathering the JAX profile and saves it to a file.""" 208 | global _PROFILE_DIR 209 | jax.profiler.stop_trace() 210 | print(f'[{datetime.datetime.now()}] Finished profile, copying to {working_dir}') 211 | fsspec_put(_PROFILE_DIR + '/', working_dir + '/') 212 | shutil.rmtree(_PROFILE_DIR) 213 | print(f'[{datetime.datetime.now()}] Finished copying profile to {working_dir}') 214 | _PROFILE_DIR = None 215 | 216 | 217 | def fsspec_put(local_src: str, remote_dst: str): 218 | """Copies a file from local disk to a remote location specified by a fsspec path.""" 219 | f = fsspec.open(remote_dst) 220 | fs = f.fs 221 | path = f.path 222 | del f 223 | print(f'Put {local_src} to {path}') 224 | fs.put(local_src, path, recursive=True, create_parents=True) 225 | 226 | 227 | def save_hlo_svg(filespec: str, compiled: jax.stages.Compiled): 228 | """Saves a compiled function's HLO to an SVG file.""" 229 | compiled_hlo_dot = xla_client._xla.hlo_module_to_dot_graph(compiled.runtime_executable().hlo_modules()[0]) 230 | with tempfile.TemporaryDirectory() as d: 231 | with open(os.path.join(d, "hlo.dot"), "w") as f: 232 | f.write(compiled_hlo_dot) 233 | hlo_orig_svg = os.path.join(d, "hlo.original.svg") 234 | hlo_svg = os.path.join(d, "hlo.svg") 235 | os.system(f"dot -Tsvg {f.name} -o{hlo_orig_svg}") 236 | # Edit the SVG to remove everything before . There's a bunch of hover CSS that massively slows down 237 | # rendering in Chrome and adds little value: it just highlights edges when you hover over them. 238 | with open(hlo_orig_svg, "r") as f: 239 | svg = f.read() 240 | svg = svg[svg.index("