├── .deepsource.toml ├── .gitignore ├── LICENSE ├── README.md ├── config.yaml ├── images ├── axial.png ├── convit.png ├── mixer.png ├── normformer.png ├── paper_cogview_sandwich.png ├── reformer.png ├── revnet.png ├── rmsnorm_speed.png ├── sandwich.png └── scalenorm.png ├── img.png ├── inference.py ├── launch_on_tensorfork.py ├── main.py ├── requirements-dev.txt ├── requirements.txt ├── run.sh ├── script ├── inference.py ├── launch_multiple_runs.py ├── quantized_vid2tfrecord.py ├── start_on_multi_host_tpu.py └── text2tfrecord.py ├── setup.sh ├── setup_dev.sh ├── share_vm.py ├── src ├── backend.py ├── constants.py ├── context.py ├── data.py ├── main.py ├── model │ ├── activate.py │ ├── conv.py │ ├── loss.py │ ├── main.py │ ├── mixer.py │ ├── moe.py │ ├── norm.py │ └── reversible.py ├── optimizer.py └── utils │ ├── checkpoint.py │ └── wandblog.py ├── sweep.yaml ├── train_watcher.py └── unittests ├── consistency └── step.py └── grad ├── activation.py ├── backend.py ├── leak.py ├── loss.py └── norm.py /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = ["unittests/**"] 4 | 5 | [[analyzers]] 6 | name = "python" 7 | enabled = true 8 | dependency_file_paths = ["requirements-dev.txt"] 9 | 10 | [analyzers.meta] 11 | runtime_version = "3.x.x" 12 | max_line_length = 120 13 | type_checker = "mypy" 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, Luke 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Olmax 2 | 3 | Optimized Language-Model (in jax) 4 | 5 | ## Techniques 6 | 7 | Olmax is a collection of various techniques otherwise rarely used in language modeling. The sections below will describe 8 | their purpose. 9 | 10 | ### Model 11 | 12 | #### Convolution 13 | 14 | Most modern language models use attention as their core building block. However, works such 15 | as [Conformer](https://arxiv.org/abs/2005.08100) and noticed that combining attention with convolution, instead of 16 | attention with feed forward layers, leads to better evaluation performance at little overhead. Later works, such 17 | as [Nyströmformer](https://arxiv.org/abs/2102.03902) explicitly introduced a convolution in the residual branch to add a 18 | hard locality bias to their model. Concurrent work, such as [ALiBi](https://arxiv.org/abs/2108.12409) 19 | , [Swin](https://arxiv.org/abs/2103.14030) and [ConViT](https://arxiv.org/abs/2103.10697) pointed out that soft 20 | inductive biases are nearly always helpful. 21 | ![convit.png](images/convit.png) 22 | However, as soft biases such as ALiBi do not leverage structured sparsity as a convolution does, as it still computes 23 | attention to tokens it will never attend to, and convolution by itself already adds a decent chunk of performance, we 24 | decided to take this as our core building block. 25 | 26 | Another advantages convolutions have is that they scale linearly with increased sequence length, which makes them 27 | perfect for deployment in long-context applications. Additionally, convolutions have more parameters at the same memory 28 | cost, giving the model more freedom without increasing memory usage. For example, a 3x3 convolution has precisely 29 | 9-times as many weights as a 1x1 convolution, which is most commonly used in a transformer, without requiring larger 30 | intermediate states.\ 31 | In summary, convolutions provide a memory-efficient way of adding a locality bias to any model. As most domains, such as 32 | text, image and video, have a strong locality bias, their usage is sensible. 33 | 34 | #### Axial MLP-Mixer 35 | 36 | Unfortunately, replacing global attention with local convolution does not yield the best performance in most cases. Some 37 | papers, such as [ConvBERT](https://arxiv.org/abs/2008.02496) illustrated that it is possible to achieve competitive 38 | performance with convolutions, but only after spending significant engineering efforts on it and using 39 | non-autoregressive attention. Therefore, a second, global mixing method needs to be used, together with convolution.\ 40 | According to recent research, [MLP-Mixer](https://arxiv.org/abs/2105.01601) scales better than attention when training 41 | large enough models on significant amounts of data. As language modeling is a domain with predominantly large models and 42 | massive datasets, it seems like a good fit. Additionally, MLP-Mixer is faster to compute and requires less memory than 43 | attention, making it a perfect fit for a memory-efficient model. 44 | ![mixer.png](images/mixer.png) 45 | Additionally, concurrent work pointed out that axial attention performs well on images in the pixel level, while also 46 | running quickly and with a small memory footprint of O(N^1.5) instead of O(N^2). 47 | ![img.png](images/axial.png) 48 | Therefore, an axial mlp-mixer needs less memory and compute than a standard transformer while providing better 49 | performance at scale. 50 | 51 | #### Reversible 52 | 53 | Most commonly, large transformers use [activation checkpointing](https://arxiv.org/abs/1604.06174v2), which saves the 54 | input to a function and recomputes its intermediate values. While activation checkpointing means not all intermediate 55 | values, such as attention maps, have to be stored simultaneously, the model still has to store the input to all of these 56 | functions. However, [recent research](https://arxiv.org/abs/1707.04585) indicates that a slight modification in model 57 | architecture can make the model fully reversible, allowing even those "checkpoints" to be recomputed as they're needed. 58 | ![revnet.png](images/revnet.png) 59 | Additionally, [Reformer](https://arxiv.org/abs/2001.04451) pointed out that reversible layers have the same loss curve 60 | as non-reversible layers when trained at scale, indicating that they fit this use-case perfectly. 61 | ![reformer.png](images/reformer.png) 62 | Using reversible layers, the network can be scaled to any depth, without increasing the memory usage except for the 63 | increased number of parameters. This way, architectures such as [DeepNarrow](https://arxiv.org/abs/2109.10686) can be 64 | used efficiently. 65 | 66 | #### Normalization 67 | 68 | For a long time, people have discussed whether the BERT-style PostNorm or GPT-style PreNorm is best. However, recent 69 | research, such as [CogView's SandwichLN](https://arxiv.org/abs/2105.13290) 70 | and [NormFormer](https://openreview.net/pdf?id=GMYWzWztDx5) showed that using both PostNorm and PreNorm improves 71 | stability and with that convergence. 72 | ![paper_cogview_sandwich.png](images/paper_cogview_sandwich.png) 73 | ![normformer.png](images/normformer.png) 74 | 75 | Testing it in this codebase gives similar results to those of NormFormer, showing that SandwichLN converges 76 | significantly better than PreNorm, reaching lower losses in less time. 77 | ![sandwich.png](images/sandwich.png) 78 | 79 | Additionally, [RMSNorm](https://arxiv.org/abs/1910.07467), as used by 80 | DeepMind's [Gopher-280B](https://arxiv.org/abs/2112.11446), decreases the loss by another 81 | 3% when comparing step by step. ![rmsnorm_loss.png](images/rmsnorm_loss.png) 82 | Additionally, RMSNorm is significantly simpler and less expensive than LayerNorm, 83 | as `RMSNorm(x, scale) = x / Sqrt(Sum(Square(x))) * scale` and 84 | `LayerNorm(x, scale, shift) = (x - Mean(x)) / (Mean(Square(x)) - Square(Mean(x))) * scale + shift`. So, even though 85 | normalization takes up only a small fraction of the total runtime, replacing it with RMSNorm yields an immediate 27% 86 | speedup for both training and inference. 87 | ![rmsnorm_speed.png](images/rmsnorm_speed.png) 88 | 89 | #### MoE 90 | 91 | ### Optimizer 92 | 93 | 94 | #### Adaptive Gradient Clipping 95 | 96 | ## Getting Started 97 | 98 | ### Tensorboard Trace 99 | 100 | After SSHing into a TPU, it's recommended to run `bash setup.sh` as root to install the correct versions of libraries.\ 101 | This isn't done via a `requirements.txt` as some libraries have to be removed while others require a very specific 102 | installation order. 103 | 104 | Once that's done, you can run `python3 model.py` to start whatever model is configured in `context.py`.\ 105 | To change the config without touching the code, you could also run `python3 model.py config.json` and perform your 106 | changes in `config.json`. 107 | 108 | It's possible to get a tensorboard trace of memory and operations by changing `ctx.training.trace.do_trace` to `true`. 109 | With that, a file in the possibly newly created folder named `ctx.training.trace.output_path` will be created containing 110 | the trace.\ 111 | Using this trace, you can start a tensorboard server and inspect the current model performance. Something as simple 112 | as `tensorboard --logdir trace --host 0.0.0.0 --port 6006` works perfectly fine.\ 113 | If the tensorboard doesn't show up, it's likely that the firewall is misconfigured. One easy way to fix this is to 114 | run `gcloud compute firewall-rules create --network=default allow-tensorboard-tcp --allow=tcp:6006`, which creates a new 115 | firewall rule allowing anyone to access, and with that force you to pay for what's hosted on, this port on the TPU.\ -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | datasets_used_per_step: 2 3 | interleaved_datasets: 2 4 | parallel_workers: 2 5 | path: gs://homebrewnlp-eu/the-char-pile/* 6 | prefetch_buffer: 2 7 | seed: 0 8 | shuffle_buffer: 0 9 | vocab_size: 256 10 | dims: 11 | batch: 16 12 | depth: 64 13 | features: 256 14 | heads: 8 15 | inner_bottleneck_features: 128 16 | inner_bottleneck_kernel: 49 17 | intermediate: 512 18 | moe_intermediate: 4096 19 | one: 1 20 | outer_bottleneck_kernel: 25 21 | pointwise_features: 512 22 | pointwise_kernel: 5 23 | sequence: 4096 24 | vocab: 256 25 | global_prefix: '' 26 | model: 27 | activation_std: 0.5893595616022745 28 | computation_dtype: bfloat16 29 | conv_scale: 4.0 30 | conv_shift: 8.0 31 | leaky_relu_slope: 0.02 32 | norm_eps: 1.0e-16 33 | qrnn_frequency: 8 34 | rezero_lr_scale: 0.01 35 | storage_dtype: float32 36 | optimizer: 37 | adam_beta1: 0.03 38 | adam_beta2: 0.003 39 | block_size: 512 40 | bottleneck_scale: 1 41 | epsilon: 1.0e-16 42 | exponential_decay: 3.0e-06 43 | gradient_clip: 0.001 44 | input_scale: 1 45 | learning_rate: 0.01 46 | moe_scale: 1 47 | momentum_beta: 0.1 48 | norm_scale: 1 49 | output_scale: 1 50 | pointwise_scale: 1 51 | preconditioning_compute_steps: 128 52 | qrnn_scale: 1 53 | skip_preconditioning_dim_size_gt: 1024 54 | start_preconditioning_step: 16 55 | statistics_compute_steps: 4 56 | warmup_end: 1024 57 | weight_decay: 0.001 58 | seed: 0 59 | training: 60 | checkpoint_interval: 2048 61 | checkpoint_load_path: "" 62 | checkpoint_path: gs://homebrewnlp-eu/homebrewnlp-checkpoint-deep 63 | device_steps: 1 64 | device_unroll: 1 65 | do_checkpoint: true 66 | early_stopping: 67 | expected_loss: 68 | exponent: -0.3642513 69 | offset: 6.165868 70 | scale: 39.08037 71 | loss_patience: 0.875 72 | maximum_spike_duration: 24 73 | maximum_spike_size: 3 74 | minimum_relative_loss_change: 0.003 75 | print_interval: 1 76 | steps: 65536 77 | trace: 78 | do_trace: false 79 | output_path: trace 80 | start_step: 16 81 | stop_step: 80 82 | z_loss: 0.01 83 | wandb: 84 | entity: homebrewnlp 85 | log_frequency: 1 86 | median_sizes: 87 | - 64 88 | - 256 89 | - 1024 90 | percentile: 25 91 | project: gpt 92 | use_wandb: true -------------------------------------------------------------------------------- /images/axial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/axial.png -------------------------------------------------------------------------------- /images/convit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/convit.png -------------------------------------------------------------------------------- /images/mixer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/mixer.png -------------------------------------------------------------------------------- /images/normformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/normformer.png -------------------------------------------------------------------------------- /images/paper_cogview_sandwich.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/paper_cogview_sandwich.png -------------------------------------------------------------------------------- /images/reformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/reformer.png -------------------------------------------------------------------------------- /images/revnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/revnet.png -------------------------------------------------------------------------------- /images/rmsnorm_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/rmsnorm_speed.png -------------------------------------------------------------------------------- /images/sandwich.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/sandwich.png -------------------------------------------------------------------------------- /images/scalenorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/images/scalenorm.png -------------------------------------------------------------------------------- /img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/Olmax/3833b72e0c312fe1b3293bf2f9b62b8452a3aa23/img.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import jax 4 | import numpy as np 5 | import uvicorn 6 | from fastapi import FastAPI, HTTPException 7 | from jax import lax, numpy as jnp, random 8 | from pydantic import BaseModel 9 | from transformers import GPT2TokenizerFast 10 | 11 | from src.backend import matmul, promote_to 12 | from src.constants import ParallelAxes 13 | from src.context import Context, WhilePredictContext 14 | from src.model.main import body_ctx 15 | from src.utils.checkpoint import read_checkpoint 16 | 17 | 18 | def one_hot(inp: jax.Array, size: int) -> jax.Array: 19 | return jnp.equal(jnp.reshape(inp, inp.shape + (1,)), jnp.reshape(jnp.arange(0, size), (1,) * inp.ndim + (size,))) 20 | 21 | 22 | def cond_fn(while_ctx_dict: typing.Dict[str, typing.Any]) -> bool: 23 | wctx = WhilePredictContext(while_ctx_dict) 24 | is_eos = wctx.data == wctx.ctx.eval.eos 25 | behind_start = wctx.start_pos.reshape(-1, 1) > jnp.arange(wctx.ctx.dims.sequence).reshape(1, -1) 26 | is_eos = jnp.logical_and(is_eos, behind_start) 27 | is_eos = jnp.cumsum(is_eos, axis=1) 28 | eos_at_seq = (is_eos > 0).sum(0) == wctx.ctx.dims.batch 29 | eos = jnp.take_along_axis(eos_at_seq.reshape(-1), wctx.current_step.reshape(-1).astype(jnp.int32), axis=0) 30 | stop = jnp.less(wctx.current_step, wctx.stop_pos) 31 | return jnp.logical_or(eos, stop).reshape(()) 32 | 33 | 34 | def body_fn(while_ctx_dict: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: 35 | wctx = WhilePredictContext(while_ctx_dict) 36 | 37 | out, wgt = body_ctx(wctx.ctx, wctx.data) 38 | out = (out * one_hot(wctx.current_step - 1, wctx.ctx.dims.sequence).reshape(1, -1, 1)).sum(1, keepdims=True) 39 | out = matmul(out, wgt).reshape(out.shape[0], 1, -1) 40 | out = promote_to(out, jnp.float32) 41 | out_token = lax.psum(out, ParallelAxes.model) 42 | 43 | key = random.PRNGKey((wctx.ctx.seed + wctx.current_step).astype(jnp.int32)) 44 | temp = random.uniform(key, out_token.shape, maxval=1, minval=1e-7, dtype=jnp.float32) 45 | temp = jnp.log(temp) 46 | temp = jnp.negative(temp) 47 | temp = jnp.log(temp) 48 | temp = temp * -wctx.temperature 49 | 50 | arange = lax.broadcasted_iota(jnp.int32, out_token.shape, dimension=2) 51 | sorted_out, argsort_out = lax.sort_key_val(out_token, arange) 52 | ranks = jnp.argsort(argsort_out, -1) 53 | top_k_mask = jnp.less(ranks, wctx.ctx.dims.vocab - wctx.max_tokens.reshape(-1, 1, 1)) # we want to not mask top k 54 | 55 | cumulative_probabilities = lax.rev(jnp.cumsum(lax.rev(jax.nn.softmax(out), (1,)), -1), (1,)) 56 | overflow = jnp.greater(cumulative_probabilities, wctx.max_probability_mass.reshape(-1, 1, 1)) 57 | overflow = jnp.concatenate([overflow[:, :, 1:], jnp.zeros_like(overflow[:, :, :1])], -1) 58 | top_p_mask = jnp.take_along_axis(overflow, ranks, axis=2) 59 | 60 | log_softmax = jax.nn.log_softmax(out_token) 61 | shifted_scores = jnp.abs((jnp.exp(log_softmax) * log_softmax).sum(-1, keepdims=True) - log_softmax) 62 | sorted_out, argsort_out = lax.sort_key_val(shifted_scores, arange) 63 | cumulative_probabilities = jnp.cumsum(jax.nn.softmax(jnp.take_along_axis(out_token, argsort_out, axis=2)), -1) 64 | overflow = jnp.less(cumulative_probabilities, wctx.typical_mass.reshape(-1, 1, 1)) 65 | overflow_at = overflow.sum(-1, keepdims=True).astype(jnp.int32) 66 | overflow = jnp.take_along_axis(sorted_out, overflow_at, axis=2) 67 | overflow = jnp.greater(sorted_out, overflow) 68 | overflow = jnp.concatenate([jnp.zeros_like(overflow[:, :, :1]), overflow[:, :, :-1]], -1) 69 | typical_mask = jnp.take_along_axis(overflow, jnp.argsort(argsort_out, -1), axis=2) 70 | 71 | # min_prob_mask ("top-p-x") and adaptive_mask ("top-a") are ideas taken from 72 | # https://github.com/BlinkDL/RWKV-LM/blob/4bbee4bb1a26059c6425d25c59e057891ae7c4c7/README.md 73 | softmax = jax.nn.softmax(out_token) 74 | min_prob_mask = softmax < wctx.max_probability_to_filter.reshape(-1, 1, 1) 75 | adaptive_filter = jnp.max(softmax, axis=2, keepdims=True) ** wctx.adaptive_filter_power * wctx.adaptive_filter_scale 76 | adaptive_mask = softmax < adaptive_filter 77 | 78 | out_token = out_token + temp + ((top_k_mask + top_p_mask + adaptive_mask) * min_prob_mask + typical_mask) * -1e9 79 | out_token = jnp.argmax(out_token, -1) 80 | wctx.data = jnp.where(one_hot(wctx.current_step, wctx.ctx.dims.sequence).reshape(1, -1), out_token, wctx.data) 81 | wctx.current_step += 1 82 | return wctx.serialize() 83 | 84 | 85 | def jitless_prediction_step(parameters: typing.Dict[str, jax.Array], data: jax.Array, 86 | temperature: jax.Array, max_tokens: jax.Array, max_probability_mass: jax.Array, 87 | typical_mass: jax.Array, max_probability_to_filter: jax.Array, 88 | adaptive_filter_power: jax.Array, adaptive_filter_scale: jax.Array, seed: jax.Array, 89 | start_pos: jax.Array, stop_pos: jax.Array) -> jax.Array: 90 | wctx = WhilePredictContext() 91 | wctx.ctx.parameters = parameters 92 | wctx.data = data 93 | wctx.temperature = temperature 94 | wctx.max_tokens = max_tokens 95 | wctx.max_probability_to_filter = max_probability_to_filter 96 | wctx.max_probability_mass = max_probability_mass 97 | wctx.adaptive_filter_power = adaptive_filter_power 98 | wctx.adaptive_filter_scale = adaptive_filter_scale 99 | wctx.typical_mass = typical_mass 100 | wctx.ctx.seed = seed 101 | wctx.start_pos = start_pos 102 | wctx.stop_pos = stop_pos 103 | wctx.current_step = jnp.min(start_pos) 104 | 105 | wctx = WhilePredictContext(lax.while_loop(cond_fn, body_fn, wctx.serialize())) 106 | 107 | return wctx.data 108 | 109 | 110 | class Inference: 111 | def __init__(self, ctx: Context): 112 | dummy_data = np.zeros((1, ctx.dims.sequence), dtype=np.int32) 113 | read_checkpoint(ctx) 114 | self.parameters = ctx.parameters 115 | 116 | partition = {k: 0 for k in ctx.parameters.keys()} 117 | self.step = jax.pmap(jitless_prediction_step, axis_name=ParallelAxes.model, 118 | in_axes=(partition, None, None, None, None, None, None, None, None, None, None, None), 119 | out_axes=None) 120 | self.ctx = ctx 121 | 122 | self.complete_jax(dummy_data, np.zeros(()), np.ones(()), np.ones(()), np.ones(()), np.ones(()), np.ones(()), 123 | np.ones(()), np.zeros(()), np.zeros(()), np.ones(())) 124 | 125 | def complete_jax(self, prompt: jnp.array, temperature: jnp.array, max_tokens: jnp.array, 126 | max_probability_mass: jnp.array, typical_mass: jax.Array, 127 | max_probability_to_filter: jax.Array, adaptive_filter_power: jax.Array, 128 | adaptive_filter_scale: jax.Array, seed: jnp.array, start_pos: jnp.array, 129 | stop_pos: jnp.array) -> jnp.array: 130 | return self.step(self.parameters, prompt, temperature, max_tokens, max_probability_mass, typical_mass, 131 | max_probability_to_filter, adaptive_filter_power, adaptive_filter_scale, seed, start_pos, 132 | stop_pos) 133 | 134 | def complete_tokens(self, prompt: jax.Array, temperature: float, max_tokens: int, max_probability_mass: float, 135 | typical_mass: float, max_probability_to_filter: float, adaptive_filter_power: float, 136 | adaptive_filter_scale: float, seed: int, length: int) -> jax.Array: 137 | tokens = jnp.pad(prompt, ((0, 0), (0, self.ctx.dims.sequence - prompt.shape[1]))) 138 | base = jnp.zeros(()) 139 | start = base + prompt.shape[1] 140 | return self.complete_jax(tokens, temperature, base + max_tokens, base + max_probability_mass, 141 | base + typical_mass, base + max_probability_to_filter, base + adaptive_filter_power, 142 | base + adaptive_filter_scale, base + seed, start, start + length) 143 | 144 | def complete(self, text: str, temperature: float = 0.5, max_tokens: int = 32, max_probability_mass: float = 0.9, 145 | typical_mass: float = 1, max_probability_to_filter: float = 1., adaptive_filter_power: float = 1, 146 | adaptive_filter_scale: float = 0, seed: int = 0, length: int = 128): 147 | tokens = jnp.asarray(np.frombuffer(text.encode(), np.uint8)).astype(jnp.int32).reshape(1, -1) 148 | out = self.complete_tokens(tokens, temperature, max_tokens, max_probability_mass, typical_mass, 149 | max_probability_to_filter, adaptive_filter_power, adaptive_filter_scale, seed, 150 | length)[0] 151 | return np.asarray(out).astype(np.uint8).tobytes().decode(errors='ignore')[len(text):len(text) + length] 152 | 153 | 154 | class Tokens(BaseModel): 155 | tokens: typing.List[int] # skipcq: PTC-W0052 156 | 157 | 158 | class TokenCompletion(BaseModel): 159 | token_completion: typing.List[int] 160 | 161 | 162 | class Completion(BaseModel): 163 | completion: str # skipcq: PTC-W0052 164 | 165 | 166 | class SanitizedTokens(BaseModel): 167 | tokens: typing.List[int] 168 | 169 | 170 | class CompletionInput(BaseModel): 171 | prompt: str = "" 172 | length: int = 16 173 | temperature: float = 1. 174 | max_tokens: int = 64 175 | max_probability_mass: float = 0.9 176 | typical_mass: float = 1 177 | max_probability_to_filter: float = 1 178 | adaptive_filter_power: float = 1 179 | adaptive_filter_scale: float = 0 180 | seed: int = 0 181 | error: bool = True 182 | 183 | 184 | class RestAPI: 185 | def __init__(self): 186 | self._ctx = Context() 187 | self._interface = Inference(self._ctx) 188 | if self._ctx.dims.vocab == 256: 189 | self._encode = lambda x: list(x.encode()) 190 | self._decode = lambda x: np.asarray(x).astype(np.uint8).tobytes().decode(errors='ignore') 191 | else: 192 | tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') 193 | self._encode = tokenizer.encode 194 | self._decode = tokenizer.decode 195 | 196 | async def check_tokens(self, tokens: typing.List[int], error: bool = True) -> SanitizedTokens: 197 | if tokens and max(tokens) > self._ctx.dims.vocab: 198 | if error: 199 | raise HTTPException(status_code=400, detail=f"Invalid tokens sent. Tokens go up to " 200 | f"{self._ctx.dims.vocab} but received {max(tokens)}.") 201 | tokens = [t for t in tokens if t < self._ctx.dims.vocab] 202 | if len(tokens) > self._ctx.dims.sequence: 203 | if error: 204 | raise HTTPException(status_code=400, detail=f"Context too big. The model supports up to " 205 | f"{self._ctx.dims.sequence} tokens but received " 206 | f"{len(tokens)}.") 207 | tokens = tokens[:self._ctx.dims.sequence] 208 | return SanitizedTokens(tokens=tokens) 209 | 210 | async def encode(self, prompt: str) -> Tokens: 211 | return Tokens(tokens=self._encode(prompt)) 212 | 213 | async def decode(self, prompt: typing.List[int]) -> Completion: 214 | return Completion(completion=self._decode(prompt)) 215 | 216 | async def token_completion(self, params: CompletionInput) -> TokenCompletion: 217 | tokens = (await self.encode(params.prompt)).tokens 218 | tokens = (await self.check_tokens(tokens, params.error)).tokens 219 | tok = self._interface.complete_tokens(jnp.array(tokens).reshape(1, -1), params.temperature, params.max_tokens, 220 | params.max_probability_mass, params.typical_mass, 221 | params.max_probability_to_filter, params.adaptive_filter_power, 222 | params.adaptive_filter_scale, params.seed, params.length) 223 | tok = tok[0, len(tokens):len(tokens) + params.length].tolist() 224 | out = [] 225 | for t in tok: 226 | if t == self._ctx.eval.eos: 227 | break 228 | out.append(t) 229 | return TokenCompletion(token_completion=out) 230 | 231 | async def completion(self, params: CompletionInput) -> Completion: 232 | return await self.decode((await self.token_completion(params)).token_completion) 233 | 234 | 235 | def main(): 236 | rest_api = RestAPI() 237 | fast_api = FastAPI() 238 | 239 | for key in dir(rest_api): 240 | if key.startswith('_') or key.endswith('_'): 241 | continue 242 | fn = getattr(rest_api, key) 243 | fast_api.post('/' + key, response_model=typing.get_type_hints(fn)["return"])(fn) 244 | 245 | uvicorn.run(fast_api, host='0.0.0.0', port=62220, log_level='info', workers=1) # skipcq: BAN-B104 246 | 247 | 248 | if __name__ == '__main__': 249 | main() 250 | -------------------------------------------------------------------------------- /launch_on_tensorfork.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pathlib 4 | 5 | import wandb 6 | import yaml 7 | 8 | from src.context import WandB 9 | 10 | CONFIGS = [("europe-west4-a", 3, 250, 1), 11 | ("us-central1-a", 3, 200, 1), 12 | ("us-central1-c", 3, 15, 1), 13 | ("us-central1-c", 3, 5, 0), 14 | ("us-central1-b", 2, 150, 1), 15 | ("us-central1-c", 2, 150, 1), 16 | ("us-central1-f", 2, 150, 1), 17 | ("us-central1-a", 2, 5, 0), 18 | ("us-central1-f", 2, 25, 0), 19 | ] 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--prefix", type=str, default="homebrew", help="Name prefix for TPUs") 25 | parser.add_argument("--us-service-account", type=str, help="EMail of the service account used for american TPUs") 26 | parser.add_argument("--eu-service-account", type=str, help="EMail of the service account used for european TPUs") 27 | parser.add_argument("--use-us", default=0, type=int, help="Whether to use TPUs from the USA") 28 | parser.add_argument("--dry", default=1, type=int, help="Whether to only show what it'd do rather than doing it.") 29 | parser.add_argument("--branch", default="main", type=str, help="Branch on github to use") 30 | parser.add_argument("--cleanup", default=0, type=int, 31 | help="Instead of running something new, kill all tpus. 1 or 0 for y/n") 32 | args = parser.parse_args() 33 | return (bool(args.use_us), bool(args.dry), args.cleanup, args.prefix, args.us_service_account, 34 | args.eu_service_account, args.branch) 35 | 36 | 37 | def main(): 38 | (use_us, dry, cleanup, base_prefix, us_service_account, eu_service_account, branch) = parse_args() 39 | 40 | if not cleanup: 41 | with open("sweep.yaml", 'r') as f: 42 | config = yaml.safe_load(f.read()) 43 | sweep = wandb.sweep(config, entity=WandB.entity, project=WandB.project) 44 | else: 45 | sweep = "" 46 | main_folder = pathlib.Path(os.path.abspath(__file__)).parent 47 | for zone, tpu_version, tpu_count, preemptible in CONFIGS: 48 | us_tpu = zone.startswith('us') 49 | if us_tpu and not use_us: 50 | continue 51 | service_account = us_service_account if us_tpu else eu_service_account 52 | prefix = zone.split('-') 53 | prefix = prefix[0][:2] + prefix[1][0] + prefix[1][-1] + prefix[2][-1] # us-central1-f -> usc1f 54 | if preemptible: 55 | prefix += "-preemptible" 56 | 57 | cmd = (f'export PYTHONPATH="{main_folder}:$PYTHONPATH" && ' 58 | f'screen -dmS "{prefix}" python3 {main_folder}/script/launch_multiple_runs.py --tpus {tpu_count} ' 59 | f'--zone {zone} --tpu-version {tpu_version} ' 60 | f'--data-path gs://homebrewnlp-{"us" if us_tpu else "eu"}/the-token-pile/ ' 61 | f'--prefix {base_prefix}-{prefix} --preemptible {preemptible} ' 62 | f'--sweep {WandB.entity}/{WandB.project}/{sweep} --cleanup {cleanup} ' 63 | f'--timeout-multiplier {len(CONFIGS)} --service-account {service_account} ' 64 | f'--branch {branch}') 65 | print(cmd) 66 | if not dry: 67 | os.system(cmd) # skipcq: BAN-B605 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from src.main import main 2 | 3 | if __name__ == '__main__': 4 | main() 5 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # as installation is sensitive to its order, it's recommended to install via the provided shell scripts 2 | jax[tpu]>=0.3.10 3 | wandb 4 | smart-open[gcs] 5 | jsonpickle 6 | tensorflow==2.8.0 7 | protobuf==3.20.1 8 | tpucare 9 | sharedutils 10 | tpunicorn 11 | google-api-python-client 12 | google-cloud-tpu 13 | redis 14 | sqlalchemy 15 | psycopg2-binary 16 | opencv-python 17 | Pillow 18 | git+https://github.com/ytdl-org/youtube-dl.git 19 | google-cloud-storage 20 | oauth2client 21 | utils 22 | scipy 23 | gdown 24 | omegaconf 25 | pyparsing==2.4.7 26 | einops 27 | pytorch-lightning 28 | fastapi 29 | uvicorn 30 | pydantic 31 | transformers -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # as installation is sensitive to its order, it's recommended to install via the provided shell scripts 2 | jax[tpu]>=0.3.10 3 | wandb 4 | smart-open[gcs] 5 | jsonpickle 6 | tensorflow==2.8.0 7 | protobuf==3.20.1 8 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 # faster malloc 2 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=60000000000 # no numpy memory warnings 3 | export TF_CPP_MIN_LOG_LEVEL=4 # no dataset warnings 4 | 5 | export XRT_TPU_CONFIG="localservice;0;localhost:51011" 6 | 7 | export JAX_ENABLE_X64=1 # allow fp64 8 | export JAX_DEFAULT_DTYPE_BITS=32 # ..but don't enforce it 9 | 10 | export WANDB_WATCH="false" # workaround to wandb crashing and killing the whole run 11 | export WANDB_START_METHOD="thread" 12 | 13 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla.proto 14 | export XLA_FLAGS="--xla_force_host_platform_device_count=1" # We don't use TPU-CPU for ML 15 | # export XLA_FLAGS="--xla_step_marker_location=1 $XLA_FLAGS" # 0 = entry; 1 = outer while 16 | 17 | /usr/bin/env python3 main.py "$@" 18 | -------------------------------------------------------------------------------- /script/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import requests 4 | 5 | URL = "https://orbscale.com/" 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--prompt", type=str, required=True) 11 | parser.add_argument("--temperature", type=float, default=1.) 12 | parser.add_argument("--mass", type=float, default=1.) 13 | parser.add_argument("--k", type=float, default=256) 14 | parser.add_argument("--max-prob", type=float, default=1) 15 | parser.add_argument("--power", type=float, default=1) 16 | parser.add_argument("--scale", type=float, default=0) 17 | parser.add_argument("--length", type=int, default=128) 18 | parser.add_argument("--seed", type=int, default=128) 19 | args = parser.parse_args() 20 | out = requests.post(URL, 21 | json={"prompt": args.prompt, "temperature": args.temperature, "max_probability_mass": args.mass, 22 | "max_tokens": args.k, "length": args.length, 'seed': args.seed, 23 | "max_probability_to_filter": args.max_prob, 24 | "adaptive_filter_power": args.power, "adaptive_filter_scale": args.scale 25 | }) 26 | response = out.json()["completion"] 27 | print(response) 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /script/launch_multiple_runs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import typing 4 | from netrc import netrc 5 | 6 | from tpucare import delete_all, exec_command, exec_on_tpu, send_to_tpu, start_multiple 7 | 8 | from src.context import DataContext 9 | 10 | _, _, wandb_key = netrc().authenticators("api.wandb.ai") 11 | OLD_DATA_PATH = DataContext.path.replace("/", "\\/")[:-1] # remove * at the end 12 | 13 | 14 | @dataclasses.dataclass 15 | class Context: 16 | zone: str 17 | host: str 18 | sweep_id: str 19 | branch: str 20 | data_path: str 21 | 22 | 23 | def start_fn(ctx: Context, worker: int): 24 | setup = f'(bash setup.sh ; sed -i "s/{OLD_DATA_PATH}/{ctx.data_path}/g" src/context.py; exit 0)' 25 | cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key, 26 | setup_command=setup, run_command=f"/home/ubuntu/.local/bin/wandb agent {ctx.sweep_id}", 27 | branch=ctx.branch) 28 | send_to_tpu(ctx.host, ctx.zone, "setup.sh", cmd, worker) 29 | exec_on_tpu(ctx.host, ctx.zone, "bash setup.sh", worker) 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--tpus", type=int, default=1, help="How many TPUs should be launched") 35 | parser.add_argument("--tpu-version", type=int, default=3, help="Which TPU version to create (v2-8 or v3-8)") 36 | parser.add_argument("--prefix", type=str, default="homebrewnlp-preemptible-tuning", help="Name prefix for TPUs") 37 | parser.add_argument("--zone", type=str, default="europe-west4-a", help="GCP Zone TPUs get created in") 38 | parser.add_argument("--data-path", type=str, default="gs://ggpt4/the-char-pile/", 39 | help="Where the data is stored. Should be changed to a bucket in the correct region") 40 | parser.add_argument("--sweep", type=str, help="ID of the Weights and Biases sweep that'll be resumed") 41 | parser.add_argument("--cleanup", default=0, type=int, 42 | help="Instead of running something new, kill all tpus. 1 or 0 for y/n") 43 | parser.add_argument("--preemptible", default=1, type=int, 44 | help="Whether to create preemptible or non-preemptible TPUs") 45 | parser.add_argument("--service-account", type=str, 46 | help="Service account that controls permissions of TPU (for example, to ensure EU TPUs won't " 47 | "use US data)") 48 | parser.add_argument("--branch", type=str, help="Branch on github to use") 49 | parser.add_argument("--slices", type=int, help="How many TPU slices each TPU should have (1=>vX-8, 4=>vX-32)") 50 | args = parser.parse_args() 51 | return (args.tpus, args.tpu_version, args.prefix, args.zone, args.sweep, args.data_path, bool(args.cleanup), 52 | bool(args.preemptible), args.service_account, args.branch, args.slices) 53 | 54 | 55 | def main(): 56 | (tpus, tpu_version, prefix, zone, sweep_id, data_path, cleanup, preemptible, 57 | service_account, branch, slices) = parse_args() 58 | if cleanup: 59 | return delete_all(prefix, zone) 60 | 61 | def creation_callback(host: str, ctx: typing.Optional[Context]) -> Context: 62 | if ctx is None: 63 | return Context(zone=zone, host=host, sweep_id=sweep_id, data_path=data_path, branch=branch) 64 | return ctx 65 | 66 | start_multiple(prefix, tpu_version, zone, preemptible, service_account, slices, start_fn, creation_callback, tpus) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /script/quantized_vid2tfrecord.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import datetime 4 | import functools 5 | import multiprocessing 6 | import os 7 | import pickle 8 | import random 9 | import shutil 10 | import sys 11 | import threading 12 | import time 13 | import traceback 14 | import typing 15 | import uuid 16 | 17 | import boto3 18 | import ffmpeg 19 | import gdown 20 | import numpy as np 21 | import requests 22 | import tensorflow as tf 23 | import torch 24 | import youtube_dl 25 | from omegaconf import OmegaConf 26 | from sharedutils import SharedEXTQueue 27 | 28 | sys.path.append("./taming-transformers") 29 | from taming.models.vqgan import GumbelVQ # skipcq: FLK-E402 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--cpu-worker", type=int, default=multiprocessing.cpu_count(), 35 | help=f"Number of workers. Default is the number of CPU cores (={multiprocessing.cpu_count()})") 36 | parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs to use") 37 | parser.add_argument("--model-base-path", type=str, default='/fsx/lucas', 38 | help="Where model and config should be dowloaded to") 39 | parser.add_argument("--bucket", type=str, help="Name of the S3 bucket") 40 | parser.add_argument("--prefix", type=str, help="Prefix in the bucket") 41 | parser.add_argument("--batch", type=int, default=128, help="Number of images processed per 'computation step'") 42 | parser.add_argument("--tmp-dir", type=str, help="Local directory for temporary storage") 43 | parser.add_argument("--urls", type=str, help="Directory filled with JSON files full of URLs") 44 | parser.add_argument("--fps", type=int, default=1, 45 | help="Number of (encoded) video frames per second of raw data (default=4)") 46 | parser.add_argument("--shared-memory", type=int, default=4, help="number of GB of shared memory") 47 | parser.add_argument("--tokens-per-file", type=int, default=2 ** 28, help="how big each file should roughly be") 48 | parser.add_argument("--video-downloaders", type=int, default=4, 49 | help="Number of parallel video _information_ downloaders. Videos are always downloaded in " 50 | "parallel, but downloading information about too many videos in parallel can lead to " 51 | "errors and slow things down.") 52 | args = parser.parse_args() 53 | return args.cpu_worker, args.bucket, args.prefix, args.tmp_dir, args.urls, args.fps, args.batch, args.gpus, \ 54 | args.model_base_path, args.shared_memory, args.tokens_per_file, args.video_downloaders 55 | 56 | 57 | def frame_encoder(frame): 58 | feature = {'text': tf.train.Feature(int64_list=tf.train.Int64List(value=frame))} 59 | features = tf.train.Features(feature=feature) 60 | proto = tf.train.Example(features=features) 61 | proto = proto.SerializeToString() 62 | return proto 63 | 64 | 65 | def try_except(fn: typing.Callable, default=None): 66 | def _fn(*args, **kwargs): 67 | try: 68 | return fn(*args, **kwargs) 69 | except Exception as exc: # skipcq: PYL-W0703 70 | print(r"IGNORED EXCEPTION \/\/\/") 71 | print(fn, exc) 72 | traceback.print_exc() 73 | print("IGNORED EXCEPTION /\\/\\/\\") 74 | 75 | return default 76 | 77 | return _fn 78 | 79 | 80 | def load_vqgan(config_path: str, ckpt_path: str): 81 | config = OmegaConf.load(config_path) 82 | model = GumbelVQ(**config.model.params) 83 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 84 | model.load_state_dict(sd, strict=False) 85 | return model.eval() 86 | 87 | 88 | @functools.partial(try_except, default=[]) 89 | def tokenize(model: GumbelVQ, frames: torch.Tensor, device: torch.device): 90 | with torch.no_grad(): 91 | batches = [model.encode(f.to(device))[2][2].detach() for f in frames] 92 | return torch.cat(batches, dim=0).flatten().cpu().tolist() 93 | 94 | 95 | @try_except 96 | def get_video_urls(youtube_getter, youtube_base: str, url: str, lock: threading.Semaphore, target_image_size: int) -> \ 97 | typing.List[dict]: 98 | # We have to lock this part because it can lead to errors if multiple thread try to scrape video Information at 99 | # the same time. 100 | with lock: 101 | info = youtube_getter.extract_info(youtube_base + url, download=False) 102 | if info is None or 'formats' not in info: 103 | return [] 104 | video_urls = [] 105 | for f in info['formats']: 106 | width = f.get('width') 107 | height = f.get('height') 108 | url = f.get('url') 109 | ext = f.get('ext') 110 | format_note = f.get('format_note') 111 | 112 | if any(x is None for x in (width, height, url, ext, format_note)): 113 | continue 114 | if any(not x for x in (width, height, url, ext)): 115 | continue 116 | if format_note == "tiny" or width <= target_image_size or height <= target_image_size: 117 | continue 118 | video_urls.append({'width': width, 'height': height, 'ext': f['ext'], 'url': f['url']}) 119 | return sorted(video_urls, key=lambda x: (x['ext'] != 'mp4', x['width'], x['height'])) 120 | 121 | 122 | def get_video_frames(video_urls: typing.List[dict], target_image_size: int, target_fps: int) -> np.ndarray: 123 | filename = uuid.uuid4() 124 | path = str(filename) 125 | for video_url in video_urls: 126 | if os.path.exists(path): 127 | os.remove(path) 128 | 129 | url = video_url["url"] 130 | path = f"{filename}.{video_url['ext']}" 131 | 132 | try: 133 | with requests.get(url, stream=True) as r, open(path, 'wb') as f: 134 | shutil.copyfileobj(r.raw, f) 135 | except Exception: # skipcq: PYL-W0703 136 | continue # Broken URL, next might work 137 | 138 | width = round(video_url["width"] * video_url["height"] / target_image_size) 139 | try: 140 | out, _ = ffmpeg.input(path) \ 141 | .filter("scale", w=width, h=target_image_size) \ 142 | .filter("crop", w=target_image_size, h=target_image_size).filter("fps", target_fps) \ 143 | .output("pipe:", format="rawvideo", pix_fmt="rgb24", loglevel="error", preset="ultrafast", 144 | threads=target_image_size // 40) \ 145 | .run(capture_stdout=True) 146 | except ffmpeg.Error: # Broken Video, next might work 147 | continue 148 | 149 | if os.path.exists(path): 150 | os.remove(path) 151 | return np.frombuffer(out, np.uint8).reshape((-1, target_image_size, target_image_size, 3)) 152 | 153 | 154 | @functools.partial(try_except, default=0) 155 | def write_tfrecords(tokens: typing.List[int], chunk_size: int, buffer_save_dir: str, save_dir: str, tfrecord_id: int, 156 | s3_bucket): 157 | path = f"{buffer_save_dir}/{save_dir.replace('/', '_')}_{tfrecord_id}.tfrecord" 158 | count = len(tokens) 159 | residual = count % chunk_size 160 | count -= residual 161 | if not count: 162 | return 0 163 | 164 | added = 0 165 | 166 | for i in range(0, count, chunk_size): 167 | with tf.io.TFRecordWriter(path) as tf_writer: 168 | tf_writer.write(frame_encoder(tokens[i:i + chunk_size])) 169 | s3_bucket.upload_file(path, f"{save_dir.rstrip('/')}/{tfrecord_id + added:07d}.tfrecord") 170 | os.remove(path) 171 | added += 1 172 | residual_tokens = tokens[-residual:] 173 | tokens.clear() 174 | tokens.extend(residual_tokens) 175 | return added 176 | 177 | 178 | def frame_worker(work: list, worker_id: int, lock: threading.Semaphore, target_image_size: int, target_fps: int, 179 | batch_size: int, queue_export): 180 | queue = SharedEXTQueue.from_export(*queue_export) 181 | youtube_base = 'https://www.youtube.com/watch?v=' 182 | youtube_getter = youtube_dl.YoutubeDL( 183 | {'writeautomaticsub': False, 'socket_timeout': 600, "quiet": True, "verbose": False, "no_warnings": True, 184 | "ignoreerrors": True 185 | }) 186 | youtube_getter.add_default_info_extractors() 187 | random.Random(worker_id).shuffle(work) 188 | 189 | for wor in work: 190 | video_urls = get_video_urls(youtube_getter, youtube_base, wor, lock, target_image_size) 191 | 192 | if not video_urls: 193 | continue 194 | 195 | frames = get_video_frames(video_urls, target_image_size, target_fps) 196 | 197 | if frames is None or not frames.size: 198 | continue 199 | 200 | frames: np.ndarray = frames 201 | frames = frames[:frames.shape[0] // batch_size * batch_size] 202 | frames = frames.transpose((0, 3, 1, 2)).reshape((-1, batch_size, 3, target_image_size, target_image_size)) 203 | queue.put(frames) 204 | 205 | 206 | def worker(model: GumbelVQ, save_dir: str, download_buffer_dir: str, bucket, device: int, 207 | queue: SharedEXTQueue, procs: typing.List[multiprocessing.Process], tokens_per_file: int, 208 | padding_token: int): 209 | save_dir = f'{save_dir.rstrip("/")}/{device}' 210 | dev_str = f'cuda:{device}' 211 | device = torch.device(dev_str) 212 | torch.set_default_tensor_type('torch.FloatTensor') 213 | model = copy.deepcopy(model) 214 | model = model.to(device) 215 | total_frames = 0 216 | tokens = [] 217 | tfrecord_id = 0 218 | start_time = time.time() 219 | start = datetime.datetime.now() 220 | token_pad = len(f'{tokens_per_file:,d}') 221 | frame_pad = len(f'{tokens_per_file // 1024:,d}') 222 | while True: 223 | print(f"{dev_str} | {datetime.datetime.now()} | Tokens: {len(tokens):{token_pad},d} - " 224 | f"Frames: {total_frames:{frame_pad},d} | " 225 | f"FramesPerSecond: {total_frames / (time.time() - start_time):5.2f} - " 226 | f"Elapsed: {datetime.datetime.now() - start}", flush=True) 227 | 228 | # wait until one element exists or run is over 229 | while not queue and any(p.is_alive() for p in procs): 230 | time.sleep(1) 231 | if not any(p.is_alive() for p in procs): 232 | break 233 | frames = queue.get() 234 | frames = torch.as_tensor(frames.astype(np.float32) / 255) 235 | total_frames += frames.size(0) * frames.size(1) 236 | if tokens: 237 | tokens.append(padding_token) 238 | tokens.extend(tokenize(model, frames, device)) 239 | tfrecord_id += write_tfrecords(tokens, tokens_per_file, download_buffer_dir, save_dir, tfrecord_id, bucket) 240 | write_tfrecords(tokens, tokens_per_file, download_buffer_dir, save_dir, tfrecord_id, bucket) 241 | 242 | 243 | def main(): 244 | workers, bucket, prefix, tmp_dir, urls, fps, batch_size, gpus, model_path, shared_memory, chunk_size, \ 245 | video_downloaders = parse_args() 246 | config_path = f'{model_path}/vqgan.gumbelf8.config.yml' 247 | model_path = f'{model_path}/sber.gumbelf8.ckpt' 248 | if not os.path.exists(config_path): 249 | gdown.download('https://drive.google.com/uc?id=1WP6Li2Po8xYcQPGMpmaxIlI1yPB5lF5m', model_path, quiet=True) 250 | if not os.path.exists(config_path): 251 | gdown.download('https://drive.google.com/uc?id=1M7RvSoiuKBwpF-98sScKng0lsZnwFebR', config_path, quiet=True) 252 | os.makedirs(tmp_dir, exist_ok=True) 253 | conf = OmegaConf.load(config_path) 254 | padding_token = conf.model.params.n_embed 255 | resolution = conf.model.params.ddconfig.resolution 256 | model = load_vqgan(config_path, model_path) 257 | 258 | shared_memory = shared_memory * 1024 ** 3 # it's in GB, we have to convert it to bytes 259 | shared_frames = shared_memory // (256 ** 2 * 3 * batch_size) 260 | queue = SharedEXTQueue.from_shape([shared_frames, batch_size, 3, 256, 256]) 261 | 262 | ids = [] 263 | for path in os.listdir(urls): 264 | with open(f'{urls}/{path}', 'rb') as f: 265 | video_ids, _ = pickle.load(f) # skipcq: BAN-B301 266 | ids.extend(video_ids) 267 | 268 | ids = [ids[int(len(ids) * i / workers):int(len(ids) * (i + 1) / workers)] for i in range(workers)] 269 | lock = multiprocessing.Semaphore(video_downloaders) 270 | procs = [multiprocessing.Process(args=(work, worker_id, lock, resolution, fps, batch_size, queue.export()), 271 | daemon=True, target=frame_worker) for worker_id, work in enumerate(ids)] 272 | for p in procs: 273 | p.start() 274 | 275 | while not queue: # "pre-wait" to get more accurate FPS counters 276 | time.sleep(1) 277 | 278 | bucket = boto3.resource("s3").Bucket(bucket) 279 | threads = [threading.Thread(target=worker, 280 | args=(model, prefix, tmp_dir, bucket, i, queue, procs, chunk_size, padding_token), 281 | daemon=True) 282 | for i in range(gpus)] 283 | 284 | for t in threads: 285 | t.start() 286 | 287 | for p in procs + threads: 288 | p.join() 289 | 290 | queue.frame_mem.unlink() 291 | queue.frame_mem.close() 292 | 293 | 294 | if __name__ == '__main__': 295 | main() 296 | -------------------------------------------------------------------------------- /script/start_on_multi_host_tpu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import netrc 4 | import os 5 | import pathlib 6 | import subprocess 7 | import threading 8 | import typing 9 | 10 | from launch_multiple_runs import all_tpus 11 | 12 | 13 | def parse_args() -> typing.Tuple[str, str, str]: 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--tpu", type=str, help="Name of the TPU to use") 16 | parser.add_argument("--zone", type=str, help="Where the TPU is") 17 | parser.add_argument("--branch", type=str, help="Git branch to use") 18 | args = parser.parse_args() 19 | return args.tpu, args.zone, args.branch 20 | 21 | 22 | def install(zone: str, name: str, worker: int): 23 | base = ["gcloud", "alpha", "compute", "tpus", "tpu-vm"] 24 | args = ["--zone", zone, "--worker", str(worker)] 25 | name = f"ubuntu@{name}" 26 | 27 | if subprocess.call(base + ["scp", "exec.sh", f"{name}:~/exec.sh"] + args): 28 | return 29 | file_path = os.path.abspath(inspect.getfile(inspect.currentframe())) 30 | if subprocess.call(base + ["scp", str(pathlib.Path(file_path).parent.parent / "config.yaml"), 31 | f"{name}:~/config.yaml"] + args): 32 | return 33 | if subprocess.call(base + ["ssh", name, "--command", "bash exec.sh"] + args): 34 | return 35 | 36 | 37 | def main(): 38 | name, zone, branch = parse_args() 39 | _, _, wandb_key = netrc.netrc().authenticators("api.wandb.ai") 40 | tpu = [tpu for tpu in all_tpus(zone) if tpu['name'].split('/')[-1] == name][0] 41 | hosts = len(tpu['networkEndpoints']) 42 | with open("exec.sh", "w") as f: 43 | f.write("git clone https://github.com/HomebrewNLP/HomebrewNLP-Jax/ ; " 44 | "cd HomebrewNLP-Jax ; " 45 | "mv ../config.yaml config.yaml ; " 46 | "git fetch ; " 47 | f"git checkout {branch} ; " 48 | "git pull ; " 49 | "bash setup.sh ; " 50 | f"/home/ubuntu/.local/bin/wandb login {wandb_key} ; " 51 | "screen -dmS model bash -c 'bash run.sh; sleep 100000'") 52 | for i in range(hosts): 53 | threading.Thread(target=install, args=(zone, name, i)).start() 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /script/text2tfrecord.py: -------------------------------------------------------------------------------- 1 | """tokenization to bpe or character embeddings of text datasets""" 2 | 3 | import argparse 4 | import io 5 | import multiprocessing 6 | import os 7 | import shutil 8 | import time 9 | 10 | import jsonlines 11 | import requests 12 | import simdjson 13 | import tensorflow as tf 14 | import zstandard 15 | from google.cloud import storage 16 | from transformers import GPT2TokenizerFast 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--name", type=str, default="text", 20 | help="Name of output files will be name_i.tfrecords where i is the number of the file") 21 | parser.add_argument("--procs", type=int, default=2, help="Number of processes in multiprocessing") 22 | parser.add_argument("--output_dir", type=str, default="gs://homebrewnlp-eu/the-token-pile/", 23 | help="Where to put tfrecords (in a bucket)") 24 | parser.add_argument("--int64", type=bool, default=True, help="Whether to encode as bytes or int64") 25 | parser.add_argument("--buffer_size", type=int, default=2 ** 29, help="This is a minimum size, not a maximum size. " 26 | "tfrecords will have this minimum size as well.") 27 | parser.add_argument("--separator", type=str, default=chr(4), 28 | help="separator to place between files in chunk mode." 29 | "Default is \x04 (chr(4)) in case of byte encodings, " 30 | "but should be changed to <|endoftext|> for BPE") 31 | 32 | 33 | def file_generator(args, pid, procs): 34 | base_url = 'http://eaidata.bmk.sh/data/pile/train/%s.jsonl.zst' 35 | splits = 30 36 | parse_fn = simdjson.Parser().parse 37 | tmp_name = f".tmp.download.{pid}" 38 | 39 | def _json_parser(x): 40 | return parse_fn(x.encode()).as_dict() 41 | 42 | for i in range(pid, splits, procs): 43 | with requests.get(base_url.replace("%s", str(i).zfill(2)), stream=True) as r, open(tmp_name, 'wb') as f: 44 | shutil.copyfileobj(r.raw, f) 45 | with open(tmp_name, 'rb') as f: 46 | for item in jsonlines.Reader(io.BufferedReader(zstandard.ZstdDecompressor().stream_reader(f)), 47 | loads=_json_parser): 48 | if isinstance(item, dict): 49 | item = item['text'] 50 | if isinstance(item, list): 51 | item = args.separator.join(item) 52 | yield item 53 | os.remove(tmp_name) 54 | 55 | 56 | def create_tfrecords(args, pid, procs): 57 | slash_idx = args.output_dir.find('/') 58 | bucket_name, output_dir = args.output_dir[:slash_idx], args.output_dir[slash_idx + 1:] 59 | bucket = storage.Client().get_bucket(bucket_name) 60 | join = args.separator.join 61 | prefix = f"{'int64' if args.int64 else 'bytes'}_{args.name}_" 62 | encode = (GPT2TokenizerFast.from_pretrained('gpt2') if args.int64 else str).encode 63 | 64 | files_processed = 0 65 | tfrecord_count = 0 66 | chunk = 0 67 | buffer_size = 0 68 | tokenized_files = [] 69 | 70 | last_write = start_time = time.time() 71 | 72 | for f in file_generator(args, pid, procs): 73 | buffer_size += len(f) 74 | tokenized_files.append(f) 75 | files_processed += 1 76 | 77 | if buffer_size > chunk * args.buffer_size // 4: 78 | print(f"Worker: {pid:{len(str(procs))}d} | Buffer: {buffer_size * 2 ** -20:.1f}MB | " 79 | f"Files: {files_processed} - TFrecords: {tfrecord_count} | " 80 | f"Wrote: {time.time() - last_write:.0f}s ago - Started: {time.time() - start_time:.0f}s ago", 81 | end='') 82 | chunk += 1 83 | 84 | if buffer_size > args.buffer_size: 85 | filename = f"{prefix}{tfrecord_count:_>6d}_{files_processed}_{buffer_size}.tfrecord" 86 | 87 | joined = encode(join(tokenized_files)) 88 | tokenized_files.clear() 89 | 90 | with tf.io.TFRecordWriter(filename) as writer: 91 | if args.int64: 92 | feature = {"text": tf.train.Feature(int64_list=tf.train.Int64List(value=joined))} 93 | else: 94 | feature = {"text": tf.train.Feature(bytes_list=tf.train.BytesList(value=[joined]))} 95 | tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) 96 | writer.write(tf_example.SerializeToString()) 97 | 98 | bucket.blob(f'{output_dir}{filename}').upload_from_filename(filename) 99 | 100 | os.remove(filename) 101 | chunk = 0 102 | buffer_size = 0 103 | tfrecord_count += 1 104 | 105 | print("") 106 | 107 | last_write = time.time() 108 | 109 | 110 | def main(): 111 | args = parser.parse_args() 112 | 113 | if not args.output_dir.endswith("/"): 114 | args.output_dir = args.output_dir + "/" 115 | if not args.output_dir.startswith("gs://"): 116 | print("Output dir isn't a cloud bucket. Exiting.") 117 | return 118 | args.output_dir = args.output_dir[len('gs://'):] 119 | processes = [multiprocessing.Process(target=create_tfrecords, args=(args, pid, args.procs)) for pid in 120 | range(args.procs)] 121 | for p in processes: 122 | p.start() 123 | for p in processes: 124 | p.join() 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | python3 -m pip install --upgrade pip 2 | python3 -m pip install --no-cache-dir --force-reinstall --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 3 | sudo python3 -m pip uninstall tensorboard tbp-nightly tb-nightly tensorboard-plugin-profile -y 4 | python3 -m pip install wandb smart-open[gcs] jsonpickle sharedutils 5 | python3 -m pip install --upgrade --force-reinstall tensorflow==2.8.0 protobuf==3.20.1 6 | -------------------------------------------------------------------------------- /setup_dev.sh: -------------------------------------------------------------------------------- 1 | sudo apt-get -o DPkg::Lock::Timeout=-1 update 2 | sudo apt-get -o DPkg::Lock::Timeout=-1 install -y libpq-dev python-dev python3-dev gcc libgl1-mesa-glx ffmpeg libgl-dev python3-pip git 3 | python3 -m pip install --upgrade pip 4 | python3 -m pip install --upgrade "jax[tpu]>=0.3.10" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 5 | sudo python3 -m pip uninstall tensorboard tbp-nightly tb-nightly tensorboard-plugin-profile -y 6 | python3 -m pip install wandb smart-open[gcs] jsonpickle tpunicorn google-api-python-client google-cloud-tpu redis sqlalchemy psycopg2-binary opencv-python Pillow git+https://github.com/ytdl-org/youtube-dl.git google-cloud-storage oauth2client utils scipy gdown omegaconf pyparsing==2.4.7 einops pytorch-lightning fastapi uvicorn pydantic transformers boto3 torch sharedutils 7 | python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 8 | python3 -m pip install --upgrade --force-reinstall tensorflow==2.8.0 protobuf==3.20.1 9 | git clone https://github.com/CompVis/taming-transformers 10 | mv taming-transformers script/ -------------------------------------------------------------------------------- /share_vm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import string 4 | import typing 5 | from netrc import netrc 6 | 7 | import namecheap 8 | import requests 9 | import shortuuid 10 | import tpucare 11 | from tpucare import delete_one_tpu, exec_on_tpu, start_single, tpu_ips 12 | 13 | tpucare.LOG_LEVEL = 0 14 | _, _, wandb_key = netrc().authenticators("api.wandb.ai") 15 | 16 | IP = requests.get("https://ipinfo.io/ip").text 17 | 18 | 19 | @dataclasses.dataclass 20 | class TPUContext: 21 | zone: str 22 | host: str 23 | ssh_key: str 24 | 25 | 26 | class Args: 27 | subdomain_prefix: str 28 | namecheap_username: str 29 | namecheap_api_key: str 30 | domain_name: str 31 | host: str 32 | tpu_version: int 33 | zone: str 34 | preemptible: bool 35 | service_account: str 36 | slices: int 37 | cleanup: int 38 | ssh_key: str 39 | 40 | 41 | def start_fn(ctx: TPUContext, worker: int): 42 | exec_on_tpu(ctx.host, ctx.zone, f"echo '{ctx.ssh_key}' >> ~/.ssh/authorized_keys", worker) 43 | 44 | 45 | def parse_args() -> Args: 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--host", type=str, help="Name of the TPU") 48 | parser.add_argument("--subdomain-prefix", type=str, help="like abc to get abc0.example.com and abc7.example.com") 49 | parser.add_argument("--namecheap-username", type=str, help="Username used for login on namecheap") 50 | parser.add_argument("--namecheap-api-key", type=str, 51 | help="See https://ap.www.namecheap.com/settings/tools/apiaccess/") 52 | parser.add_argument("--domain-name", type=str, help="example.com, including the .com") 53 | parser.add_argument("--ssh-key", type=str, help="like `ssh-rsa @`") 54 | parser.add_argument("--tpu-version", type=int, default=3, help="Which TPU version to create (v2-8 or v3-8)") 55 | parser.add_argument("--zone", type=str, default="europe-west4-a", help="GCP Zone TPUs get created in") 56 | parser.add_argument("--preemptible", default=1, type=int, 57 | help="Whether to create preemptible or non-preemptible TPUs") 58 | parser.add_argument("--service-account", type=str, 59 | help="Service account that controls permissions of TPU (for example, to ensure EU TPUs won't " 60 | "use US data)") 61 | parser.add_argument("--slices", default=1, type=int, 62 | help="How many TPU slices each TPU should have (1=>vX-8, 4=>vX-32)") 63 | parser.add_argument("--cleanup", default=0, type=int, 64 | help="Instead of running something new, kill all tpus. 1 or 0 for y/n") 65 | return parser.parse_args() 66 | 67 | 68 | def new_id(): 69 | return str(shortuuid.ShortUUID(alphabet=string.digits + string.ascii_lowercase).random(32)) 70 | 71 | 72 | class CreationCallback: 73 | def __init__(self, args: Args): 74 | self.args = args 75 | self.api = namecheap.Api(args.namecheap_username, args.namecheap_api_key, args.namecheap_username, IP, 76 | sandbox=False, debug=False) 77 | 78 | def _update_ips(self, host: str): 79 | ips = tpu_ips(host, self.args.zone) 80 | records = self.api.domains_dns_getHosts(self.args.domain_name) 81 | 82 | records.extend([{"RecordType": "A", "HostName": f"{self.args.subdomain_prefix}{i}", "Address": ip, 83 | "MXPref": 10, "TTL": 300 84 | } for i, ip in enumerate(ips)]) 85 | records = [self.api._elements_names_fix(x) for x in records] # skipcq: PYL-W0212 86 | records = list({r["HostName"]: r for r in records}.values()) # deduplicate, and take last element 87 | self.api.domains_dns_setHosts(self.args.domain_name, records) 88 | 89 | def __call__(self, host: str, ctx: typing.Optional[TPUContext]) -> TPUContext: 90 | self._update_ips(host) 91 | return TPUContext(zone=self.args.zone, host=host, ssh_key=self.args.ssh_key) 92 | 93 | 94 | def main(): 95 | args = parse_args() 96 | if args.cleanup: 97 | delete_one_tpu("", args.host, args.zone) 98 | 99 | start_single(args.host, args.tpu_version, args.zone, args.preemptible, args.service_account, args.slices, start_fn, 100 | CreationCallback(args)) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /src/backend.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | from typing import Tuple, List, Any, Optional, TypeVar, Union, Callable 3 | 4 | import jax 5 | import jax._src.util as util 6 | import numpy as np 7 | from jax import lax, numpy as jnp, random 8 | 9 | from src.constants import ParallelAxes 10 | from src.context import Context 11 | 12 | INT_OR_TUPLE = Union[int, Sequence[int]] 13 | 14 | Output = TypeVar("Output") 15 | CtxFn = TypeVar("CtxFn") 16 | 17 | PRECISION = "highest" 18 | jax.config.update("jax_default_matmul_precision", PRECISION) 19 | 20 | 21 | def square_grad(fn: Callable[[jax.Array, jax.Array], jax.Array], src: jax.Array, weight: jax.Array, 22 | weight_sq: jax.Array): 23 | @jax.custom_gradient 24 | def _fn(x: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): 25 | def _grad(dy: jax.Array): 26 | d_x, d_wgt = jax.vjp(fn, x, wgt)[1](dy) 27 | _, d_wgt_sq = jax.vjp(fn, lax.square(x), wgt)[1](lax.square(dy)) 28 | return d_x, d_wgt, d_wgt_sq * x.shape[0] 29 | 30 | return fn(x, wgt), _grad 31 | 32 | return _fn(src, weight, weight_sq) 33 | 34 | 35 | def add_sq(name: str) -> str: 36 | if name.endswith('_stacked'): 37 | return name[:-len('_stacked')] + '_sq_stacked' 38 | return name + '_sq' 39 | 40 | 41 | def promote_to(inp: jax.Array, dtype: jnp.dtype) -> jax.Array: 42 | return jnp.asarray(inp, jnp.promote_types(dtype, jnp.result_type(inp))) 43 | 44 | 45 | def with_context(count: Optional[bool] = None): 46 | def _inner(fn: CtxFn): 47 | prefix_kwargs = {"appended": fn.__name__} 48 | if count is not None: 49 | prefix_kwargs["count"] = count 50 | 51 | def _fn(ctx: Context, *args, add_to_prefix: bool = True, **kwargs): 52 | if add_to_prefix: 53 | ctx = ctx.add_to_prefix(**prefix_kwargs) 54 | return fn(ctx, *args, **kwargs) 55 | 56 | return _fn 57 | 58 | return _inner 59 | 60 | 61 | def is_main(): 62 | return jax.process_index() == 0 63 | 64 | 65 | def stable_rsqrt(inp: jax.Array, eps: float) -> jax.Array: 66 | return jnp.reciprocal(jnp.maximum(jnp.sqrt(inp), eps)) 67 | 68 | 69 | def pos_dim(inp: jax.Array, dims: Sequence[int]) -> Sequence[int]: 70 | return tuple(d % inp.ndim for d in dims) 71 | 72 | 73 | def tuple_int(obj: INT_OR_TUPLE) -> Sequence[int]: 74 | if isinstance(obj, (tuple, list)): 75 | return tuple(obj) 76 | if isinstance(obj, int): 77 | return obj, # skipcq: PYL-R1707 78 | raise ValueError 79 | 80 | 81 | def is_model(param_name: str): 82 | return "/stem:" in param_name and '/optimizer' not in param_name 83 | 84 | 85 | def is_stacked(param_name: str): 86 | return param_name.endswith('_stacked') and is_model(param_name) 87 | 88 | 89 | def conv(inp: jax.Array, weight: jax.Array, padding: List[Tuple[int, int]], groups: int): 90 | ndim = weight.ndim 91 | lhs = (0, ndim - 1) + tuple(range(1, ndim - 1)) 92 | dimension_numbers = lax.ConvDimensionNumbers(lhs, (0, ndim - 1,) + tuple(range(1, ndim - 1)), lhs) 93 | return lax.conv_general_dilated(inp, weight, (1,) * (ndim - 2), padding=padding, feature_group_count=groups, 94 | dimension_numbers=dimension_numbers, precision=PRECISION) 95 | 96 | 97 | def device_id(): 98 | return (lax.psum_scatter(jnp.arange(jax.device_count()), ParallelAxes.model) / jax.device_count()).astype(jnp.int32) 99 | 100 | 101 | def dot(left: jax.Array, right: jax.Array, left_contract_dims: INT_OR_TUPLE, right_contract_dims: INT_OR_TUPLE, 102 | left_batch_dims: INT_OR_TUPLE = (), right_batch_dims: INT_OR_TUPLE = ()) -> jax.Array: 103 | dims = ((pos_dim(left, tuple_int(left_contract_dims)), pos_dim(right, tuple_int(right_contract_dims))), 104 | (pos_dim(left, tuple_int(left_batch_dims)), pos_dim(right, tuple_int(right_batch_dims)))) 105 | return lax.dot_general(left, right, dims, PRECISION) 106 | 107 | 108 | def matmul(left: jax.Array, right: jax.Array, reduced_dims=1): 109 | return dot(left, right, tuple(range(-reduced_dims, 0)), tuple(range(reduced_dims))) 110 | 111 | 112 | def prefixed_name(ctx: Context, name: str): 113 | return ctx.add_to_prefix(name, count=False).global_prefix 114 | 115 | 116 | def assign(ctx: Context, name: str, inp: jax.Array): 117 | name = prefixed_name(ctx, name) 118 | ctx.parameters[name] = inp 119 | 120 | 121 | def normal(ctx: Context, shape: Sequence[int]): 122 | ctx.prng_key, key = random.split(ctx.prng_key) 123 | return random.normal(key, shape, ctx.model.storage_dtype) 124 | 125 | 126 | def deep_replace(d, value): 127 | if isinstance(d, dict): 128 | return {k: deep_replace(v, value) for k, v in d.items()} 129 | return value 130 | 131 | 132 | def orthogonal_init(ctx: Context, shape: List[int], column_axes=(-1,)) -> jax.Array: 133 | column_axes = tuple(column_axes) 134 | axes = tuple(shape[c] for c in column_axes) 135 | n_rows, n_cols = util.prod(shape) // util.prod(axes), util.prod(axes) 136 | matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows) 137 | out, r = jnp.linalg.qr(normal(ctx, matrix_shape)) 138 | out *= lax.broadcast_to_rank(jnp.sign(jnp.diag(r)), rank=out.ndim) 139 | if n_rows < n_cols: 140 | out = out.T 141 | return jnp.reshape(out, tuple(np.delete(shape, column_axes)) + axes).astype(ctx.model.storage_dtype) 142 | 143 | 144 | def get_param(ctx: Context, name: str, shape: Optional[List[int]] = None, 145 | std: Optional[float] = None, mean: Optional[float] = None, column_axes: int = 1, 146 | scale: float = 1., post_variance_scale: float = 1, 147 | lr_scale: float = 1, dtype: Optional[jnp.dtype] = None, 148 | init_val: Optional[jax.Array] = None, 149 | tied: bool = False, 150 | return_sq: bool = False, 151 | add_parameter_usages: bool = True) -> Union[Tuple[jax.Array, Optional[jax.Array]], jax.Array]: 152 | if return_sq: 153 | args = [shape, std, mean, column_axes, scale, post_variance_scale, lr_scale, dtype, init_val, tied, False] 154 | out0 = get_param(ctx, name, *args) 155 | if ctx.is_initializing: 156 | return out0, None 157 | return out0, get_param(ctx, add_sq(name), *args, add_parameter_usages=False) 158 | if not tied: 159 | name = name + '_stacked' 160 | add_depth = ctx.add_depth and not tied 161 | 162 | prefix_name = prefixed_name(ctx, name) 163 | 164 | if dtype is None: 165 | computation_dtype = ctx.model.computation_dtype 166 | storage_dtype = ctx.model.storage_dtype 167 | else: 168 | computation_dtype = dtype 169 | storage_dtype = dtype 170 | 171 | if add_parameter_usages: # can't inline, because += 0 would still cause a new item (with val=0) to be created 172 | ctx.parameter_usages[prefix_name] += 1 173 | if prefix_name in ctx.parameters: 174 | return ctx.parameters[prefix_name].astype(computation_dtype) 175 | 176 | if not ctx.is_initializing and ctx.fail_on_missing_parameter: 177 | raise ValueError(f"Couldn't find parameter {prefix_name}. {ctx.name_cache=}") 178 | 179 | if init_val is not None: 180 | param = init_val * scale * post_variance_scale 181 | elif std is None and mean is None: 182 | param = orthogonal_init(ctx, shape, range(len(shape) - column_axes, len(shape))) 183 | if add_depth: 184 | param = normal(ctx, [ctx.dims.depth] * add_depth + list(shape)) * param.std() + param.mean() 185 | param *= scale * post_variance_scale 186 | else: 187 | param = normal(ctx, [ctx.dims.depth] * add_depth + list(shape)) * scale 188 | if std is not None: 189 | param *= std 190 | if mean is not None: 191 | param += mean 192 | ctx.parameter_variance[prefix_name] = lr_scale * scale 193 | assign(ctx, name, param.astype(storage_dtype)) 194 | return param.astype(computation_dtype) 195 | 196 | 197 | def default(option_1, option_2): 198 | if option_1 is None: 199 | return option_2 200 | return option_1 201 | 202 | 203 | def zero_param(ctx: Context, name: str, shape: List[int], dtype: Optional[jnp.dtype]) -> jax.Array: 204 | return get_param(ctx, name, shape, 0, 0, dtype=dtype) 205 | 206 | 207 | def loop(fn: Callable, fn_input: Any, steps: int, unroll: int = 1): 208 | return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0] 209 | 210 | 211 | typevar = TypeVar("typevar") 212 | 213 | 214 | def pattern_match(gen_fn: Callable[[int], Callable[[typevar], jax.Array]], cases: int, 215 | predicate: jax.Array, base: typevar): 216 | return lax.switch(predicate.astype(jnp.int32) % cases, [gen_fn(i) for i in range(cases)], base) 217 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class MomentumType(Enum): 5 | heavyball = "heavyball" 6 | nesterov = "nesterov" 7 | debiased = "debiased" 8 | ema = "ema" 9 | 10 | 11 | class ParallelAxes(Enum): 12 | model = "model_parallel" 13 | -------------------------------------------------------------------------------- /src/context.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import os 4 | from typing import Any, Callable, Union, Dict, Optional, List 5 | 6 | import jax 7 | import yaml 8 | from jax import numpy as jnp, random 9 | 10 | 11 | class DataClass: 12 | def serialize(self): 13 | return serialize(self) 14 | 15 | 16 | def fn_if_dataclass(instance: Any, fn: Callable): 17 | return fn(instance) if isinstance(instance, (DataClass, list, tuple, dict)) else instance 18 | 19 | 20 | def serialize(instance: Union[DataClass, Dict[str, Any]]): 21 | if isinstance(instance, DataClass): 22 | attributes = {key: getattr(instance, key) for key in dir(instance) if 23 | not key.startswith('_') and not key.endswith('_')} 24 | return serialize({key: value for key, value in attributes.items() if not isinstance(value, Callable)}) 25 | if isinstance(instance, (list, tuple)): 26 | return [fn_if_dataclass(itm, serialize) for itm in instance] 27 | if isinstance(instance, dict): 28 | return {k: fn_if_dataclass(v, serialize) for k, v in instance.items()} 29 | return instance 30 | 31 | 32 | def init_class(instance: DataClass, config: Dict[str, Any]): 33 | for name in dir(instance): 34 | if name.startswith("_") or name.endswith("_") or name not in config: 35 | continue 36 | attr = getattr(instance, name) 37 | is_dataclass = isinstance(attr, DataClass) 38 | is_list = isinstance(attr, (list, tuple)) 39 | is_dict = isinstance(attr, dict) 40 | if not (is_dataclass or (is_list and isinstance(attr[0], DataClass)) or ( 41 | is_dict and isinstance(next(iter(attr.values())), DataClass))): 42 | setattr(instance, name, config[name]) 43 | continue 44 | 45 | if is_dataclass: 46 | init_class(attr, config[name]) 47 | elif is_list: 48 | setattr(instance, name, type(attr)(init_class_copy(attr[0], item) for item in config[name])) 49 | elif is_dict: 50 | base = next(iter(attr.values())) 51 | setattr(instance, name, {key: init_class_copy(base, item) for key, item in config[name].items()}) 52 | else: 53 | raise ValueError(f"Unknown type {type(attr)} with given data {config[name]}") 54 | 55 | 56 | def init_class_copy(instance: DataClass, config: Dict[str, Any]) -> DataClass: 57 | instance = copy.deepcopy(instance) 58 | init_class(instance, config) 59 | return instance 60 | 61 | 62 | class DataContext(DataClass): 63 | path: str = "gs://homebrewnlp-eu/the-char-pile/*" 64 | shuffle_buffer_gb: int = 64 65 | parallel_workers: int = 2 66 | interleaved_datasets: int = 2 67 | prefetch_buffer: int = 2 68 | seed: int = 0 69 | deterministic: bool = True 70 | datasets_used_per_step: int = 2 71 | 72 | 73 | class Dims(DataClass): 74 | batch: int = 512 75 | outer_bottleneck_kernel: int = 25 76 | inner_bottleneck_kernel: int = 49 77 | inner_bottleneck_features: int = 128 78 | pointwise_kernel: int = 5 79 | features: int = 256 80 | spatial_mixing_kernel: int = 512 81 | pointwise_features: int = 512 82 | sequence: int = 4096 83 | depth: int = 8 84 | vocab: int = 256 85 | 86 | 87 | class TensorboardTrace(DataClass): 88 | """ 89 | Defines a tensorboard profiling output (folder) on which a tensorboard can be run to measure RAM utilization and 90 | view the operation trace. 91 | """ 92 | start_step: int = 16 93 | stop_step: int = 64 + 16 94 | do_trace: bool = False 95 | output_path: str = "trace" 96 | 97 | 98 | class WandB(DataClass): 99 | group: Optional[str] = None 100 | name: Optional[str] = None 101 | id: Optional[str] = None 102 | project: str = 'gpt' 103 | entity: str = 'homebrewnlp' 104 | median_sizes: List[int] = [64, 256, 1024] 105 | 106 | 107 | class Optimizer(DataClass): 108 | momentum_dtype: str = "float32" 109 | momentum_type: str = "debiased" # see src.constants.MomentumType for options 110 | epsilon: float = 1e-16 111 | learning_rate: float = 0.01 112 | gradient_clip: float = 0.001 113 | adam_beta1: float = 0.03 114 | adam_beta2: float = 0.003 115 | adam_beta3: float = 0.001 116 | weight_decay: float = 0.01 117 | warmup_end: int = 16384 118 | exponential_decay: float = 3e-6 119 | 120 | 121 | class Normalization(DataClass): 122 | eps: float = 1e-16 123 | 124 | 125 | class Model(DataClass): 126 | norm: Normalization = Normalization() 127 | autoregressive: bool = True 128 | conv_scale: float = 4. 129 | conv_shift: float = 8. 130 | storage_dtype: str = "float32" # valid jax.numpy.dtype 131 | computation_dtype: str = "bfloat16" 132 | 133 | 134 | class Training(DataClass): 135 | debug: bool = False 136 | checkpoint_path: str = "gs://homebrewnlp-eu/homebrewnlp-checkpoint" 137 | checkpoint_load_path: str = "" 138 | checkpoint_interval: float = 16384 139 | do_checkpoint: bool = False 140 | z_loss: float = 0.01 141 | device_steps: int = 4 142 | device_unroll: int = 1 143 | steps: int = 2 ** 16 144 | trace: TensorboardTrace = TensorboardTrace() 145 | 146 | 147 | class Evaluation(DataClass): 148 | eos: int = 4 149 | 150 | 151 | class Context(DataClass): 152 | data: DataContext = DataContext() 153 | optimizer: Optimizer = Optimizer() 154 | model: Model = Model() 155 | training: Training = Training() 156 | wandb: WandB = WandB() 157 | eval: Evaluation = Evaluation() 158 | 159 | def __init__(self, config: Optional[Dict[str, Any]] = None): 160 | self.data = DataContext() 161 | self.optimizer = Optimizer() 162 | self.model = Model() 163 | self.training = Training() 164 | self.wandb = WandB() 165 | self.dims = Dims() 166 | 167 | if config is None and 'CONFIG' in os.environ: 168 | with open(os.environ['CONFIG']) as f: 169 | cfg = f.read() 170 | config = yaml.safe_load(cfg) 171 | if config is not None: 172 | init_class(self, config) 173 | 174 | self.seed = 0 175 | self.global_prefix = '' 176 | 177 | self.name_cache: Dict[str, int] = {} 178 | self.name_cache_offsets: Dict[str, int] = {} 179 | self.parameters: Dict[str, jax.Array] = {} 180 | self.parameter_variance: Dict[str, float] = {} 181 | self.parameter_usages: Dict[str, int] = collections.defaultdict(int) 182 | self.prng_key = random.PRNGKey(self.seed) 183 | self.is_initializing = False 184 | self.fail_on_missing_parameter = True 185 | self.add_depth = False 186 | self.depth = 0 187 | 188 | def add_to_prefix(self, appended="", count=True): 189 | new = copy.copy(self) 190 | if count: 191 | appended = self.incremental_name(appended) 192 | new.global_prefix = self.global_prefix + '/' + appended 193 | return new 194 | 195 | def incremental_name(self, name): 196 | if name not in self.name_cache: 197 | self.name_cache[name] = -1 198 | self.name_cache[name] += 1 199 | return f'{name}:{self.name_cache[name]:d}' 200 | 201 | def config(self) -> dict: 202 | cfg = self.__dict__.copy() 203 | del cfg['name_cache'], cfg['parameters'], cfg['prng_key'], cfg['is_initializing'] 204 | del cfg['parameter_variance'] 205 | return serialize(cfg) 206 | 207 | def __str__(self): 208 | return yaml.dump(self.config(), indent=4) 209 | 210 | 211 | class WhileContext(DataClass): 212 | def __init__(self, config: Optional[Dict[str, Any]] = None): 213 | self.ctx = Context() 214 | self.current_step = jnp.ones([], dtype=jnp.uint32) 215 | self.data: Optional[jax.Array] = None 216 | 217 | if config is not None: 218 | self.ctx.parameters = config['parameters'] 219 | self.current_step = config['current_step'] 220 | self.data = config['data'] 221 | 222 | def _serialize(self) -> dict: 223 | return {'parameters': self.ctx.parameters, 'current_step': self.current_step, 'data': self.data} 224 | 225 | @property 226 | def step(self): 227 | return int(self.current_step[0]) 228 | 229 | def __call__(self, data: jax.Array): 230 | self.data = data 231 | return self 232 | 233 | 234 | class WhileTrainContext(WhileContext): 235 | def __init__(self, config: Optional[Dict[str, Any]] = None): 236 | super().__init__(config) 237 | self.scalars = jnp.zeros([2], jnp.float64) 238 | 239 | if config is not None: 240 | self.scalars = config['scalars'] 241 | self.ctx.parameter_variance = config['parameter_variance'] 242 | 243 | def serialize(self): 244 | serialized = self._serialize() 245 | serialized['scalars'] = self.scalars 246 | serialized['parameter_variance'] = self.ctx.parameter_variance 247 | return serialized 248 | 249 | 250 | class WhilePredictContext(WhileContext): 251 | def __init__(self, config: Optional[Dict[str, Any]] = None): 252 | super().__init__(config) 253 | 254 | batch_dim_size = self.ctx.dims.batch 255 | sequence_dim_size = self.ctx.dims.sequence 256 | vocab_dim_size = self.ctx.dims.vocab 257 | 258 | self.start_pos = jnp.zeros([batch_dim_size]) 259 | self.stop_pos = jnp.array([sequence_dim_size] * batch_dim_size)[0] 260 | self.temperature = jnp.zeros([batch_dim_size]) 261 | self.max_tokens = jnp.array([vocab_dim_size] * batch_dim_size) 262 | self.max_probability_mass = jnp.array([1] * batch_dim_size) 263 | self.typical_mass = jnp.array([1] * batch_dim_size) 264 | self.seed = jnp.array([0] * batch_dim_size) 265 | self.max_probability_to_filter = jnp.array([0] * batch_dim_size) 266 | self.adaptive_filter_scale = jnp.array([0] * batch_dim_size) 267 | self.adaptive_filter_power = jnp.array([1] * batch_dim_size) 268 | 269 | if config is not None: 270 | self.start_pos = config['start_pos'] 271 | self.stop_pos = config['stop_pos'] 272 | self.temperature = config['temperature'] 273 | self.max_tokens = config['max_tokens'] 274 | self.max_probability_mass = config['max_probability_mass'] 275 | self.max_probability_to_filter = config['max_probability_to_filter'] 276 | self.adaptive_filter_scale = config['adaptive_filter_scale'] 277 | self.adaptive_filter_power = config['adaptive_filter_power'] 278 | self.typical_mass = config['typical_mass'] 279 | self.ctx.seed = config['seed'] 280 | 281 | def serialize(self): 282 | serialized = self._serialize() 283 | serialized['start_pos'] = self.start_pos 284 | serialized['stop_pos'] = self.stop_pos 285 | serialized['temperature'] = self.temperature 286 | serialized['max_tokens'] = self.max_tokens 287 | serialized['max_probability_mass'] = self.max_probability_mass 288 | serialized['max_probability_to_filter'] = self.max_probability_to_filter 289 | serialized['adaptive_filter_scale'] = self.adaptive_filter_scale 290 | serialized['adaptive_filter_power'] = self.adaptive_filter_power 291 | serialized['typical_mass'] = self.typical_mass 292 | serialized['seed'] = self.ctx.seed 293 | 294 | return serialized 295 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import Iterator 4 | 5 | import jax 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow.data.experimental import AutoShardPolicy 9 | 10 | from src.context import Context 11 | 12 | tf1 = tf.compat.v1 13 | 14 | 15 | def decoder(int_string: bool, data: tf.Tensor, seed: int, context_p1: int, deterministic: bool): 16 | """ 17 | Read a given tfrecord and windowed text dataset out of it. 18 | :param int_string: whether the entire dataset is in int64 or byte 19 | :param data: protobuf object to decode 20 | :param seed: rng seed 21 | :param context_p1: context + 1 22 | :param sub_batch: number of samples should be taken from this dataset per batch 23 | :param deterministic: whether to use sloppy interleave (fast) or deterministic interleave (slow) 24 | :return: tensorflow dataset of tokens 25 | """ 26 | 27 | def chunk(proto): 28 | if int_string: 29 | dat = tf1.parse_single_example(proto, {'text': tf1.VarLenFeature(tf.int64)}) 30 | dat = tf.cast(tf.sparse.to_dense(dat['text']), tf.int32) 31 | else: 32 | text_slice = tf1.parse_single_example(proto, {'text': tf1.FixedLenFeature([], tf.string)})['text'] 33 | dat = tf.io.decode_raw(text_slice, tf.uint8) 34 | dat = tf.reshape(dat, (-1,)) 35 | dat = tf.slice(dat, (0,), (tf.size(dat) // context_p1 * context_p1,)) 36 | dat = tf.reshape(dat, (-1, context_p1)) 37 | dat = tf.random.shuffle(dat, seed=seed) 38 | return tf.data.Dataset.from_tensor_slices(dat) 39 | 40 | return tf.data.TFRecordDataset(filenames=data).interleave(chunk, cycle_length=1, deterministic=deterministic) 41 | 42 | 43 | def debug_generator(ctx: Context) -> Iterator[np.ndarray]: 44 | rstate = np.random.RandomState(0) 45 | while True: 46 | start = rstate.uniform(1, 2 ** 30, (ctx.training.device_steps * ctx.dims.batch,)).astype(np.int64) 47 | multiplier = rstate.normal(size=(ctx.training.device_steps * ctx.dims.batch,)).astype(np.int64) 48 | out = np.arange(0, ctx.dims.sequence + 1).astype(np.int64).reshape(1, -1) 49 | out += start 50 | yield (np.sin(out) * multiplier * ctx.dims.vocab) % ctx.dims.vocab 51 | 52 | 53 | def text_dataset(ctx: Context, skipped_steps: int) -> Iterator[np.ndarray]: 54 | if ctx.training.debug: 55 | return debug_generator(ctx) 56 | 57 | filenames = tf.io.gfile.glob(ctx.data.path) 58 | 59 | rng = random.Random(ctx.data.seed) 60 | rng.shuffle(filenames) 61 | 62 | file_slice = len(filenames) / jax.process_count() 63 | filenames = filenames[int(file_slice * jax.process_index()):int(file_slice * (jax.process_index() + 1))] 64 | 65 | dset = tf.data.Dataset.from_tensor_slices(filenames).repeat() 66 | sequence_length = ctx.dims.sequence 67 | batch_size = ctx.dims.batch 68 | device_steps = ctx.training.device_steps 69 | full_batch = device_steps * batch_size 70 | sequence_length_1 = sequence_length + 1 71 | if full_batch % ctx.data.datasets_used_per_step != 0: 72 | raise ValueError(f"Can't use {full_batch=} with {ctx.data.datasets_used_per_step=} as " 73 | f"{full_batch % ctx.data.datasets_used_per_step=}. Ensure full_batch=" 74 | f"{device_steps * batch_size=} is divisible by {ctx.data.datasets_used_per_step=}") 75 | is_int64 = 'int64' in filenames[0] 76 | 77 | def _slice_target(x): 78 | """ 79 | :param x: tensor 80 | :return: Shape[Steps * Batch, Sequence + 1] 81 | """ 82 | x = tf.reshape(x, (device_steps * batch_size, sequence_length_1)) 83 | x = tf.cast(x, tf.int32) 84 | return x 85 | 86 | dset = dset.interleave(lambda x: decoder(is_int64, x, rng.randint(0, 2 ** 32), sequence_length_1, 87 | ctx.data.deterministic), 88 | cycle_length=ctx.data.interleaved_datasets, 89 | num_parallel_calls=ctx.data.parallel_workers, 90 | deterministic=ctx.data.deterministic) 91 | if ctx.data.shuffle_buffer_gb > 0: 92 | buffer_size = ctx.data.shuffle_buffer_gb * 2 ** 30 // 4 // sequence_length_1 # 4 = int32 93 | dset = dset.shuffle(buffer_size, seed=rng.randint(0, 2 ** 32)) 94 | dset = dset.batch(full_batch, deterministic=ctx.data.deterministic) 95 | dset = dset.map(_slice_target, deterministic=ctx.data.deterministic) 96 | if ctx.data.prefetch_buffer > 0: 97 | dset = dset.prefetch(ctx.data.prefetch_buffer) 98 | options = tf.data.Options() 99 | options.deterministic = ctx.data.deterministic 100 | options.experimental_optimization.apply_default_optimizations = True 101 | options.experimental_optimization.filter_fusion = True 102 | options.experimental_optimization.map_and_batch_fusion = True 103 | options.experimental_optimization.map_and_filter_fusion = True 104 | options.experimental_optimization.map_fusion = True 105 | options.experimental_optimization.map_parallelization = True 106 | options.experimental_optimization.noop_elimination = True 107 | options.experimental_optimization.parallel_batch = True 108 | options.experimental_optimization.shuffle_and_repeat_fusion = True 109 | options.threading.private_threadpool_size = os.cpu_count() 110 | options.experimental_slack = not ctx.data.deterministic 111 | options.experimental_distribute.auto_shard_policy = AutoShardPolicy.AUTO 112 | dset = dset.with_options(options) 113 | 114 | if skipped_steps: 115 | dset.skip(skipped_steps) 116 | 117 | return dset.as_numpy_iterator() 118 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import time 5 | import warnings 6 | from typing import Tuple, Dict, Any, Callable, Iterator 7 | 8 | import jax 9 | import numpy as np 10 | import wandb 11 | from jax import lax, numpy as jnp 12 | 13 | from src.backend import add_sq, deep_replace, device_id, loop 14 | from src.constants import ParallelAxes 15 | from src.context import Context, WhileTrainContext, init_class 16 | from src.data import text_dataset 17 | from src.model.main import body_ctx, compute 18 | from src.optimizer import get_current_lr, update 19 | from src.utils.checkpoint import read_train_checkpoint, write_train_checkpoint 20 | from src.utils.wandblog import WandbLog 21 | 22 | 23 | def add_zeros(params: Dict[str, jax.Array]): 24 | params.update({add_sq(k): jnp.zeros_like(v) for k, v in params.items()}) 25 | 26 | 27 | def train_step(while_ctx_dict: Dict[str, Any]) -> Dict[str, Any]: 28 | wctx = WhileTrainContext(while_ctx_dict) 29 | steps = wctx.ctx.training.device_steps * jax.process_count() 30 | grad_fn = jax.value_and_grad(compute, 0, True) 31 | data_slice = wctx.data[wctx.current_step % steps] 32 | params = {k: v for k, v in wctx.ctx.parameters.items() if '/optimizer' not in k} 33 | add_zeros(params) 34 | scalars, grads = grad_fn(params, data_slice) 35 | update(wctx.ctx, grads, wctx.current_step) 36 | wctx.scalars += jnp.stack(scalars) / steps # higher numerical accuracy if we divide before summing 37 | wctx.current_step += 1 38 | return wctx.serialize() 39 | 40 | 41 | def jitless_step(while_ctx_dict: Dict[str, Any]) -> Dict[str, Any]: 42 | wctx = WhileTrainContext(while_ctx_dict) 43 | training = wctx.ctx.training 44 | steps = training.device_steps * jax.process_count() 45 | step_batch, sequence_p1 = wctx.data.shape 46 | 47 | # "all-to-all" / "all-concat" with jax.process_count() outputs instead of jax.device_count() outputs 48 | # init sparse tensor with 0s everywhere except for local input slice 49 | data = jnp.zeros((jax.process_count(), step_batch, sequence_p1), wctx.data.dtype) 50 | data = data.at[jax.process_index(), :, :].set(wctx.data) 51 | # same value was seen `local_device_count` times, so divide to remove implicit multiplication (int32 --> accurate) 52 | data = lax.psum(data, ParallelAxes.model).astype(wctx.data.dtype) // jax.local_device_count() 53 | 54 | # interleave samples within batch by transposing steps*process_count + batch and reshaping from (x,y).t() to x,y 55 | # process_count, steps * batch, sequence 56 | # --reshape--> batch, process_count * steps, sequence ([[0, 1, 2], [3, 4, 5]] --> [[0, 1], [2, 3], [4, 5]]) 57 | # --transpose--> process_count * steps, batch, sequence ([[0, 1], [2, 3], [4, 5]] --> [[0, 2, 4], [1, 3, 5]]) 58 | data = data.reshape(wctx.ctx.dims.batch, steps, sequence_p1).transpose(1, 0, 2) 59 | wctx.data = jnp.stack([data[:, :, :-1], data[:, :, 1:]], 1) 60 | 61 | return loop(train_step, wctx.serialize(), steps, training.device_unroll) 62 | 63 | 64 | def get_parameters(ctx: Context, inp: jax.Array): 65 | def _fn(x: jax.Array): 66 | initial_seed = ctx.seed 67 | initial_prng_key = ctx.prng_key 68 | ctx.seed += device_id() 69 | ctx.prng_key = jax.random.PRNGKey(ctx.seed) 70 | body_ctx(ctx, x) 71 | params = ctx.parameters 72 | var = ctx.parameter_variance 73 | ctx.parameters = {} 74 | ctx.prng_key = initial_prng_key 75 | ctx.seed = initial_seed 76 | ctx.parameter_variance = {} 77 | return params, lax.pmean(var, ParallelAxes.model) 78 | 79 | pmapped = jax.pmap(_fn, ParallelAxes.model, in_axes=(0,), out_axes=(0, 0), donate_argnums=(0,)) 80 | ctx.parameters, variance = pmapped(inp) 81 | ctx.parameter_variance = {name: var.mean() for name, var in variance.items()} 82 | 83 | 84 | def get_optimizer_state(ctx: Context): 85 | def _fn(parameters: Dict[str, jax.Array]): 86 | new_ctx = ctx 87 | new_ctx.parameters = {} 88 | new_ctx = copy.deepcopy(new_ctx) 89 | new_ctx.parameters = parameters.copy() 90 | add_zeros(parameters) 91 | keys = jax.random.split(jax.random.PRNGKey(0), len(parameters)) 92 | grads = {name: jax.random.uniform(key, param.shape, ctx.model.computation_dtype, 1e-6, 1e-3) 93 | for key, (name, param) in zip(keys, parameters.items())} 94 | update(new_ctx, grads, jnp.ones((), dtype=new_ctx.model.computation_dtype)) 95 | return new_ctx.parameters 96 | 97 | pmapped = jax.pmap(_fn, ParallelAxes.model, in_axes=({k: 0 for k in ctx.parameters.keys()},), out_axes=0, 98 | donate_argnums=(0,)) 99 | ctx.parameters = pmapped(ctx.parameters) 100 | 101 | 102 | def timeit(text: str, fn, *args, pad=50, **kwargs): 103 | start_time = time.time() 104 | print(f'{text}..', end='', flush=True) 105 | out = fn(*args, **kwargs) 106 | print(f"{' ' * (pad - len(text))}Took:{time.time() - start_time:9.2f}s", flush=True) 107 | return out 108 | 109 | 110 | class TrainLoop: 111 | def __init__(self, wctx: WhileTrainContext, step: Callable): 112 | self.wctx = wctx 113 | self.step = step 114 | 115 | def __call__(self, dat: jax.Array) -> WhileTrainContext: 116 | wctx = self.wctx(dat) 117 | wctx.scalars = jnp.zeros_like(wctx.scalars) 118 | self.wctx = WhileTrainContext(self.step(wctx.serialize())) 119 | return self.wctx 120 | 121 | 122 | def replicate(x: Any) -> Any: 123 | return jax.device_put_replicated(x, jax.local_devices()) 124 | 125 | 126 | def init_data(ctx: Context, skipped_samples: int) -> Tuple[Iterator[np.ndarray], np.ndarray]: 127 | np_data = timeit("Initializing dataset", text_dataset, ctx, skipped_samples) 128 | 129 | data = map(replicate, np_data) 130 | inp = timeit("Enqueueing first batch", next, data)[:, :ctx.dims.batch, :ctx.dims.sequence] 131 | return data, inp 132 | 133 | 134 | def init_data_and_model(wctx: WhileTrainContext) -> Iterator[np.ndarray]: 135 | """Model gets loaded in-place into the `WhileTrainContext`""" 136 | if wctx.ctx.training.checkpoint_load_path: 137 | read_train_checkpoint(wctx, '[0]{100}') 138 | skipped_samples = math.ceil(wctx.step / jax.process_count() / wctx.ctx.training.device_steps) 139 | data, _ = init_data(wctx.ctx, skipped_samples) 140 | return data 141 | 142 | data, inp = init_data(wctx.ctx, 0) 143 | wctx.ctx.is_initializing = True 144 | timeit("Acquiring forward parameters", get_parameters, wctx.ctx, inp) 145 | timeit("Acquiring optimizer parameters", get_optimizer_state, wctx.ctx) 146 | wctx.ctx.is_initializing = False 147 | wctx.ctx.parameter_variance = replicate(wctx.ctx.parameter_variance) 148 | wctx.current_step = replicate(wctx.current_step) 149 | wctx.scalars = replicate(wctx.scalars) 150 | 151 | return data 152 | 153 | 154 | def dump_ctx(ctx: Context, run): 155 | with open("config.yaml", 'w') as f: 156 | f.write(str(ctx)) 157 | os.environ['CONFIG'] = 'config.yaml' 158 | run.config.update(ctx.config(), allow_val_change=True) 159 | 160 | 161 | def main(): 162 | warnings.filterwarnings("ignore", message=".*is an experimental feature and probably has bugs!.*") 163 | warnings.filterwarnings("ignore", message=".*Some donated buffers were not usable.*") 164 | 165 | wctx = WhileTrainContext() 166 | ctx = wctx.ctx 167 | 168 | run = wandb.init(project=ctx.wandb.project, entity=ctx.wandb.entity, config=ctx.config(), name=ctx.wandb.name, 169 | id=ctx.wandb.id, group=ctx.wandb.group) 170 | 171 | cfg = {} 172 | for param_name, param in run.config.items(): 173 | if '.' not in param_name: 174 | continue 175 | inner_cfg = cfg 176 | split_name = param_name.split(".") 177 | for s in split_name[:-1]: 178 | if s not in inner_cfg: 179 | inner_cfg[s] = {} 180 | inner_cfg = inner_cfg[s] 181 | inner_cfg[split_name[-1]] = param 182 | init_class(ctx, cfg) 183 | dump_ctx(ctx, run) 184 | 185 | wctx = WhileTrainContext() 186 | print(wctx.ctx) 187 | device_steps = wctx.ctx.training.device_steps * jax.process_count() 188 | total_steps = wctx.ctx.training.steps * device_steps 189 | tokens_processed = wctx.ctx.dims.sequence * wctx.ctx.dims.batch 190 | data = init_data_and_model(wctx) 191 | parameter_count = sum(param.size for name, param in wctx.ctx.parameters.items() if "optimizer" not in name) 192 | buffer_count = sum(param.size for name, param in wctx.ctx.parameters.items()) - parameter_count 193 | 194 | partition = deep_replace(wctx.serialize(), 0) 195 | 196 | step = jax.pmap(jitless_step, ParallelAxes.model, in_axes=(partition,), out_axes=partition, donate_argnums=(0,)) 197 | step = TrainLoop(wctx, step) 198 | 199 | print("\n") 200 | print(f"Parameters: {jax.process_count() * parameter_count:,}") 201 | print(f"Buffers: {jax.process_count() * buffer_count:,}\n\n") 202 | 203 | checkpoint_at = wctx.ctx.training.checkpoint_interval + wctx.step 204 | start_time = time.time() 205 | wblog = WandbLog(run, int(ctx.training.device_steps * jax.process_count()), parameter_count, tokens_processed) 206 | for idx, dat in enumerate(data): 207 | step_start = time.time() 208 | wctx = step(dat) 209 | current_step = int(wctx.step) 210 | lr = float(get_current_lr(wctx.ctx, wctx.current_step[0])) 211 | print(f'[{current_step:{len(str(total_steps))}d}/{total_steps}] ' 212 | f'Loss: {wctx.scalars[0, 0]:6.3f} - ' 213 | f'Accuracy: {wctx.scalars[0, 1]:8.3f} | ' 214 | f'LearningRate: {lr:.5f} | ' 215 | f'StepTime: {time.time() - step_start:10.6f}s - ' 216 | f'Rate: {tokens_processed * (current_step + 1) / (time.time() - start_time):9,.1f} Tokens/s') 217 | if wblog(wctx, current_step, lr): 218 | return 219 | if wctx.ctx.training.trace.do_trace: 220 | if idx == wctx.ctx.training.trace.start_step: 221 | jax.profiler.start_trace(wctx.ctx.training.trace.output_path) 222 | if idx == wctx.ctx.training.trace.stop_step: 223 | jax.profiler.stop_trace() 224 | if wctx.ctx.training.do_checkpoint and current_step > checkpoint_at: 225 | write_train_checkpoint(wctx) 226 | checkpoint_at += wctx.ctx.training.checkpoint_interval 227 | return 228 | 229 | 230 | if __name__ == '__main__': 231 | main() 232 | -------------------------------------------------------------------------------- /src/model/activate.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | 4 | 5 | def activate_forward(inp: jax.Array) -> jax.Array: 6 | return inp * activate_grad(inp) 7 | 8 | 9 | def activate_grad(inp: jax.Array) -> jax.Array: 10 | return jnp.where(inp < 0, 0.01, 1) 11 | 12 | 13 | def activate(inp: jax.Array) -> jax.Array: 14 | @jax.custom_gradient 15 | def _fn(x: jax.Array): 16 | return activate_forward(x), lambda dy: dy * activate_grad(x) 17 | 18 | return _fn(inp) 19 | -------------------------------------------------------------------------------- /src/model/conv.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | 4 | from src.backend import conv as lax_conv, get_param, square_grad, with_context 5 | from src.context import Context 6 | from src.model.norm import prenorm, scale_norm_act 7 | 8 | 9 | @with_context() 10 | def conv(ctx: Context, inp: jax.Array, conv_kernel: int, in_features: int, out_features: int, tied: bool = False): 11 | fan_in = jnp.arange(conv_kernel, 0, -1, dtype=ctx.model.storage_dtype) 12 | fan_in = (1 - 1 / (conv_kernel * ctx.model.conv_scale + ctx.model.conv_shift)) ** fan_in 13 | fan_in = fan_in / fan_in.sum() 14 | fan_in = fan_in.reshape(1, 1, -1) 15 | weight, weight_sq = get_param(ctx, "conv_weight", [out_features, conv_kernel, in_features], column_axes=2, 16 | lr_scale=fan_in, tied=tied, return_sq=True) 17 | if ctx.is_initializing: 18 | return jnp.zeros(inp.shape[:-1] + (out_features,), dtype=inp.dtype) 19 | 20 | def _conv(x, y): 21 | return lax_conv(x, y, [(conv_kernel - 1, 0)], 1) 22 | 23 | return square_grad(_conv, inp, weight, weight_sq) 24 | 25 | 26 | @prenorm 27 | @with_context() 28 | def dense_block(ctx: Context, inp: jax.Array) -> jax.Array: 29 | inp = conv(ctx, inp, ctx.dims.pointwise_kernel, ctx.dims.features, ctx.dims.pointwise_features) 30 | inp = scale_norm_act(ctx, inp, ctx.dims.pointwise_features) 31 | return conv(ctx, inp, ctx.dims.pointwise_kernel, ctx.dims.pointwise_features, ctx.dims.features) 32 | -------------------------------------------------------------------------------- /src/model/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import jax 5 | from jax import lax, numpy as jnp 6 | 7 | from src.backend import device_id, matmul, promote_to 8 | from src.constants import ParallelAxes 9 | from src.context import Context 10 | 11 | 12 | def cross_entropy_loss(ctx: Context, src_wgt: Tuple[jax.Array, jax.Array, jax.Array], 13 | outer_tgt: jax.Array) -> Tuple[jax.Array, jax.Array]: 14 | # Forward: logsumexp(x) - x[target] 15 | # Backward: (logsumexp(x) - x[target] + logsumexp(x)^2 * z_loss).grad 16 | # -> softmax(x) - one_hot(target) + softmax(x) * logsumexp(x) * z_loss 17 | src, param, param_sq = src_wgt 18 | devices = jax.device_count() 19 | total_items = ctx.dims.batch * ctx.dims.sequence 20 | steps = ctx.dims.vocab // 128 21 | step_batch = total_items // steps 22 | local_batch = step_batch // devices 23 | 24 | def _xent_slice(carry: Tuple[jax.Array, jax.Array, jax.Array, jax.Array], 25 | x: Tuple[jax.Array, jax.Array], wgt: jax.Array): 26 | d_wgt, d_wgt_sq, loss, acc = carry 27 | inp_slice, tgt_slice = x 28 | tmp = matmul(inp_slice, wgt) 29 | tmp = promote_to(tmp, jnp.float32) 30 | tmp = lax.psum_scatter(tmp, ParallelAxes.model).reshape(local_batch, ctx.dims.vocab) 31 | lse = jax.nn.logsumexp(promote_to(tmp, jnp.float64), 1, keepdims=True) 32 | 33 | loss = loss + (lse / total_items).sum() 34 | loss = loss - (jnp.take_along_axis(tmp, tgt_slice.reshape(*tgt_slice.shape, 1), -1) / total_items).sum() 35 | acc = acc + lax.eq(lax.argmax(tmp, 1, outer_tgt.dtype), tgt_slice).sum() / total_items 36 | 37 | dy = lax.exp(tmp - (lse + math.log(total_items))) # [LocalBatch, Vocab] 38 | zloss = dy * lse * ctx.training.z_loss * 2 39 | dy = dy.at[jnp.arange(local_batch).reshape(-1, 1), tgt_slice.reshape(-1, 1)].add(-1 / total_items) 40 | dy = dy + zloss 41 | dy = dy * jax.device_count() 42 | dy = dy.astype(src.dtype) 43 | dy = lax.all_gather(dy, ParallelAxes.model) 44 | dy = dy.reshape(step_batch, ctx.dims.vocab).transpose(1, 0) 45 | dx = matmul(wgt, dy) # [Features, Vocab] @ [Vocab, StepBatch] -> [Features, StepBatch] 46 | 47 | inp_slice = inp_slice.reshape(step_batch, ctx.dims.features) 48 | d_wgt = d_wgt + matmul(dy, inp_slice) # [Vocab, StepBatch] @ [StepBatch, Features] -> [Vocab, Features] 49 | d_wgt_sq = d_wgt_sq + matmul(lax.square(dy), lax.square(inp_slice)) 50 | return (d_wgt, d_wgt_sq, loss.astype(jnp.float64), acc.astype(jnp.float64)), dx 51 | 52 | @jax.custom_gradient 53 | def _fn(inp: jax.Array, tgt: jax.Array, wgt: jax.Array, _wgt_sq: jax.Array): 54 | inp = inp.reshape(steps, devices, local_batch, ctx.dims.features) 55 | tgt = tgt.reshape(steps, step_batch) # [Steps, StepBatch] 56 | tgt = lax.dynamic_slice_in_dim(tgt, device_id() * local_batch, local_batch, 1) # [Steps, LocalBatch] 57 | 58 | def _slice_fn(carry, x): 59 | return _xent_slice(carry, x, wgt) 60 | 61 | init = (jnp.zeros(wgt.shape[::-1]), jnp.zeros(wgt.shape[::-1]), jnp.zeros((), dtype=jnp.float64), 62 | jnp.zeros((), dtype=jnp.float64)) 63 | (d_wgt, d_wgt_sq, loss, acc), dx = lax.scan(_slice_fn, init, (inp, tgt)) 64 | 65 | dx = dx.transpose(0, 2, 1) # [Steps, Features, StepBatch] -> [Steps, StepBatch, Features] 66 | dx = dx.reshape(ctx.dims.batch, ctx.dims.sequence, ctx.dims.features) 67 | d_wgt = d_wgt.transpose(1, 0) # [Vocab, Features] -> [Features, Vocab] 68 | d_wgt_sq = d_wgt_sq.transpose(1, 0) * ctx.dims.batch 69 | 70 | def _grad(dy: Tuple[jax.Array, None]) -> Tuple[jax.Array, None, jax.Array, jax.Array]: 71 | # dy == 1 since this is the last function before the output 72 | dy, _ = dy 73 | return (dx * dy).astype(inp.dtype), None, (d_wgt * dy).astype(wgt.dtype), (d_wgt_sq * dy).astype(wgt.dtype) 74 | 75 | loss = lax.psum(loss, ParallelAxes.model) 76 | acc = lax.psum(acc, ParallelAxes.model) 77 | return (loss, acc), _grad 78 | 79 | return _fn(src, outer_tgt, param, param_sq) 80 | -------------------------------------------------------------------------------- /src/model/main.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union 2 | 3 | import jax 4 | from jax import lax, numpy as jnp 5 | 6 | from src.backend import get_param, is_model, is_stacked, square_grad, with_context 7 | from src.context import Context 8 | from src.model.conv import dense_block 9 | from src.model.loss import cross_entropy_loss 10 | from src.model.mixer import mix 11 | from src.model.moe import dense_moe 12 | from src.model.norm import scale_norm_act 13 | from src.model.reversible import FourArrays, reversible, revnet_out 14 | 15 | 16 | @with_context() 17 | def input_embed(ctx: Context, inp: jax.Array) -> jax.Array: 18 | param, param_sq = get_param(ctx, "inp_embd", [ctx.dims.vocab, ctx.dims.features], std=1 / ctx.dims.features, 19 | return_sq=True) 20 | 21 | def _fn(src, wgt): 22 | return jnp.take(wgt, src, 0) 23 | 24 | if ctx.is_initializing: 25 | return _fn(inp, param) 26 | 27 | return square_grad(_fn, inp, param, param_sq) 28 | 29 | 30 | @with_context() 31 | def block(ctx: Context, shared_params: Dict[str, jax.Array]): 32 | name_cache = ctx.name_cache 33 | 34 | def _fn(carry: FourArrays, inp: Tuple[Dict[str, jax.Array], jax.Array]): 35 | original_parameters = ctx.parameters 36 | ctx.parameters, depth = inp 37 | ctx.parameters.update(shared_params) 38 | depth = depth.reshape([]) 39 | src = [ctx.parameters] + list(carry) 40 | src = reversible(ctx, dense_block, src) 41 | src = reversible(ctx, dense_moe, src) 42 | src = reversible(ctx, dense_block, src) 43 | src = reversible(ctx, mix, src, depth) 44 | name_cache.update(ctx.name_cache) 45 | if ctx.is_initializing: 46 | return src 47 | ctx.parameters = original_parameters 48 | return src[1:], None 49 | 50 | return _fn 51 | 52 | 53 | @with_context() 54 | def stem(ctx: Context, src: FourArrays) -> FourArrays: 55 | if ctx.is_initializing: 56 | ctx.add_depth = True 57 | ctx.parameters, *src = block(ctx, {})(src, (ctx.parameters, jnp.zeros([], dtype=jnp.int32))) 58 | ctx.add_depth = False 59 | return src 60 | 61 | params = {k: v for k, v in ctx.parameters.items() if is_model(k)} 62 | shared = {k: v for k, v in params.items() if not is_stacked(k)} 63 | params = {k: v for k, v in params.items() if is_stacked(k)} 64 | src, _ = lax.scan(block(ctx, shared), src, (params, jnp.arange(ctx.dims.depth)), ctx.dims.depth) 65 | return src 66 | 67 | 68 | def body_ctx(ctx: Context, src: jax.Array) -> Union[ 69 | Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: 70 | src = input_embed(ctx, src) 71 | zero = jnp.zeros_like(src) 72 | src = stem(ctx, (src, zero, src, zero)) 73 | out = revnet_out(src) 74 | out = scale_norm_act(ctx, out, ctx.dims.features, act=False, weight=False) 75 | wgt, wgt_sq = get_param(ctx, "out_embd", [ctx.dims.features, ctx.dims.vocab], std=1, scale=1 / jax.device_count(), 76 | return_sq=True) 77 | if ctx.is_initializing: 78 | return out 79 | return out, wgt, wgt_sq 80 | 81 | 82 | def compute(params: Dict[str, jax.Array], inp: jax.Array) -> Tuple[jax.Array, jax.Array]: 83 | ctx = Context() 84 | ctx.parameters = params 85 | src, tgt = inp 86 | out = body_ctx(ctx, src) 87 | if ctx.is_initializing: 88 | return out 89 | return cross_entropy_loss(ctx, out, tgt) 90 | -------------------------------------------------------------------------------- /src/model/mixer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Sequence 3 | 4 | import jax 5 | from jax import numpy as jnp 6 | 7 | from src.backend import dot, get_param, pattern_match, square_grad, with_context 8 | from src.context import Context 9 | from src.model.norm import prenorm, scale_norm_act 10 | 11 | 12 | def dot_sq(src: jax.Array, weight: jax.Array, weight_sq: jax.Array, 13 | left_contract_dims: Sequence[int], right_contract_dims: Sequence[int]): 14 | def _dot(x, y): 15 | return dot(x, y, left_contract_dims=left_contract_dims, right_contract_dims=right_contract_dims) 16 | 17 | return square_grad(_dot, src, weight, weight_sq) 18 | 19 | 20 | @prenorm 21 | @with_context() 22 | def mix(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: 23 | weight_shape = [ctx.dims.spatial_mixing_kernel] * 2 24 | run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) 25 | wgt0, wgt0_sq = get_param(ctx, "mix_0", weight_shape, return_sq=True) 26 | wgt1, wgt1_sq = get_param(ctx, "mix_1", weight_shape, return_sq=True) 27 | scale, scale_sq = get_param(ctx, "scale", [ctx.dims.features], std=0, mean=1, dtype=run_type, return_sq=True) 28 | if ctx.is_initializing: 29 | return inp 30 | 31 | original_shape = inp.shape 32 | _batch, sequence, _features = original_shape 33 | max_dims = math.ceil(math.log(sequence, ctx.dims.spatial_mixing_kernel)) 34 | original_batch = inp.shape[0] 35 | if ctx.model.autoregressive: 36 | wgt0 = jnp.triu(wgt0) 37 | wgt1 = jnp.triu(wgt1) 38 | 39 | def _get_mix_fn(current_depth: int): 40 | def _fn(x: jax.Array): 41 | batch = max(sequence // ctx.dims.spatial_mixing_kernel ** (current_depth % max_dims + 1), 1) 42 | out = x.reshape(original_batch * batch, ctx.dims.spatial_mixing_kernel, -1) 43 | inner_batch, inner_sequence, inner_features = out.shape 44 | 45 | # Shape[Batch, Sequence, Features] * Shape[Sequence, Sequence] -> Shape[Batch, Features, Sequence] 46 | out = dot_sq(out, wgt0, wgt0_sq, left_contract_dims=(1,), right_contract_dims=(0,)) 47 | 48 | out = out.reshape(-1, ctx.dims.features, inner_sequence) 49 | out = scale_norm_act(ctx, out, ctx.dims.features, weight=(scale, scale_sq), add_to_prefix=False, dim=1) 50 | out = out.reshape(inner_batch, inner_features, inner_sequence) 51 | 52 | # Shape[Batch, Features, Sequence] * Shape[Sequence, Sequence] -> Shape[Batch, Features, Sequence] 53 | out = dot_sq(out, wgt1, wgt1_sq, left_contract_dims=(2,), right_contract_dims=(0,)) 54 | out = out.transpose(0, 2, 1) 55 | return out.reshape(original_shape) 56 | 57 | return _fn 58 | 59 | return pattern_match(_get_mix_fn, max_dims, depth, inp) 60 | -------------------------------------------------------------------------------- /src/model/moe.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import lax 3 | 4 | from src.backend import with_context 5 | from src.constants import ParallelAxes 6 | from src.context import Context 7 | from src.model.conv import conv 8 | from src.model.norm import prenorm, scale_norm_act 9 | 10 | 11 | def all_to_all(ctx: Context, x: jax.Array, split_axis: int, concat_axis: int) -> jax.Array: 12 | if ctx.is_initializing: 13 | return x 14 | 15 | @jax.custom_gradient 16 | def _fn(inp: jax.Array): 17 | def _grad(dy: jax.Array) -> jax.Array: 18 | return lax.all_to_all(dy, ParallelAxes.model, concat_axis, split_axis, tiled=True) 19 | 20 | return lax.all_to_all(inp, ParallelAxes.model, split_axis, concat_axis, tiled=True), _grad 21 | 22 | return _fn(x) 23 | 24 | 25 | @prenorm 26 | @with_context() 27 | def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: 28 | devices = jax.device_count() 29 | big_params = devices * ctx.dims.inner_bottleneck_features 30 | batch, sequence, features = inp.shape 31 | sequence_slice = sequence // devices 32 | 33 | inp = conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, features, ctx.dims.inner_bottleneck_features) 34 | 35 | # [Batch, Sequence, Features] -> [Batch, SequenceSlice, Features * Devices] 36 | # In essence, 1) Collect features from all devices + 2) Drop unused sequence elements 37 | if not ctx.is_initializing: 38 | inp = inp.reshape(batch, sequence_slice, devices, ctx.dims.inner_bottleneck_features) 39 | inp = all_to_all(ctx, inp, 2, 3) 40 | inp = inp.reshape(batch, sequence_slice, big_params) 41 | 42 | # Devices^2 more parameters than normal bottleneck block but only Devices-times more flops due to sparsity above 43 | inp = scale_norm_act(ctx, inp, big_params) 44 | inp = conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, big_params, big_params, tied=True) 45 | inp = scale_norm_act(ctx, inp, big_params) 46 | 47 | # [Batch, SequenceSlice, Features * Devices] -> [Batch, Sequence, Features] (PixelShuffle across devices) 48 | if not ctx.is_initializing: 49 | inp = inp.reshape(batch, sequence_slice, 1, big_params) 50 | inp = all_to_all(ctx, inp, 3, 2) 51 | inp = inp.reshape(batch, sequence, ctx.dims.inner_bottleneck_features) 52 | 53 | return conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, ctx.dims.inner_bottleneck_features, features) 54 | -------------------------------------------------------------------------------- /src/model/norm.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Union, Callable 2 | 3 | import jax 4 | from jax import lax, numpy as jnp 5 | 6 | from src.backend import get_param, promote_to, stable_rsqrt, with_context 7 | from src.constants import ParallelAxes 8 | from src.context import Context 9 | from src.model.activate import activate_forward, activate_grad 10 | 11 | 12 | def prenorm(fn: Callable[[Context, jax.Array], jax.Array]): 13 | def _fn(ctx: Context, inp: jax.Array, *args) -> jax.Array: 14 | ctx = ctx.add_to_prefix("prenorm") 15 | inp = scale_norm_act(ctx, inp, ctx.dims.features, act=False) 16 | out = fn(ctx, inp, *args) 17 | return scale_norm_act(ctx, out, ctx.dims.features, act=False) 18 | 19 | return _fn 20 | 21 | 22 | def all_gather(inp: jax.Array, dim: int) -> jax.Array: 23 | @jax.custom_gradient 24 | def _fn(x): 25 | def _grad(dy): 26 | return lax.psum_scatter(dy, axis_name=ParallelAxes.model, scatter_dimension=dim, tiled=True) 27 | 28 | return lax.all_gather(x, axis_name=ParallelAxes.model, axis=dim, tiled=True), _grad 29 | 30 | return _fn(inp) 31 | 32 | 33 | def norm_forward(ctx: Context, src: jax.Array, wgt: Optional[jax.Array] = None, psum: bool = False, 34 | act: bool = True, dim: int = 2): 35 | run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) 36 | src_fp64 = promote_to(src, run_type) 37 | own_sum = lax.square(src_fp64).sum(dim, keepdims=True) 38 | if psum: 39 | own_sum = lax.psum(own_sum, ParallelAxes.model) 40 | std = stable_rsqrt(own_sum, ctx.model.norm.eps) 41 | out = src_fp64 * std * wgt 42 | if act: 43 | out = activate_forward(out) 44 | out = out.astype(src.dtype) 45 | if psum: 46 | out = all_gather(out, dim) 47 | return out, std 48 | 49 | 50 | @with_context() 51 | def scale_norm_act(ctx: Context, inp: jax.Array, feature_dim: int, 52 | weight: Union[bool, None, Tuple[jax.Array, jax.Array]] = None, 53 | psum: bool = False, act: bool = True, dim: int = 2) -> jax.Array: 54 | run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) 55 | if weight is None: 56 | weight, weight_sq = get_param(ctx, "scale", [feature_dim], std=0, mean=1, dtype=run_type, return_sq=True) 57 | elif weight is False: 58 | weight_sq = weight = 1 59 | else: 60 | weight, weight_sq = weight 61 | 62 | if ctx.is_initializing: 63 | return inp 64 | 65 | @jax.custom_gradient 66 | def _fn(src: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): 67 | if isinstance(wgt, jax.Array): 68 | wgt = wgt.reshape((1,) * dim + (-1,) + (1,) * (src.ndim - 1 - dim)) 69 | 70 | out, std = norm_forward(ctx, src, wgt, psum, act, dim) 71 | 72 | def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[jax.Array, None, None]]: 73 | inner_src = lax.all_gather(src, ParallelAxes.model, axis=dim) if psum else src 74 | src_fp64 = promote_to(inner_src, run_type) 75 | norm_out = src_fp64 * std 76 | dy = promote_to(dy, run_type) 77 | if act: 78 | dy = dy * activate_grad(norm_out * wgt) 79 | d_normed = dy * wgt 80 | 81 | d_std = (d_normed * src_fp64).sum(dim, keepdims=True) # broadcast forward -> sum backward 82 | d_std *= std ** 3 # reciprocal + x^(1/pow) -> 1/std^2 * 1/std^(pow-1) * 1/pow 83 | d_std *= src_fp64 # x^pow -> pow * x^(pow-1), multiply fused with above 84 | dx = d_normed * std - d_std 85 | if psum: 86 | dx = lax.psum_scatter(dx, axis_name=ParallelAxes.model, scatter_dimension=dim, tiled=True) 87 | dx = dx.astype(src.dtype) 88 | 89 | if not isinstance(wgt, jax.Array): 90 | return dx, None, None 91 | 92 | summed = list(range(src.ndim)) 93 | del summed[dim] 94 | d_wgt = dy * norm_out 95 | d_wgt_sq = (lax.square(d_wgt).sum(summed) * ctx.dims.batch).reshape((-1,)).astype(run_type) 96 | d_wgt = d_wgt.sum(summed).reshape((-1,)).astype(run_type) 97 | return dx, d_wgt, d_wgt_sq 98 | 99 | return out, _grad 100 | 101 | return _fn(inp, weight, weight_sq) 102 | -------------------------------------------------------------------------------- /src/model/reversible.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Callable, Dict, Tuple 3 | 4 | import jax 5 | 6 | from src.context import Context 7 | 8 | REVERSIBLE_CTX = Tuple[Dict[str, jax.Array], jax.Array, jax.Array, jax.Array, jax.Array] 9 | ReversibleFn = Callable[[Context, jax.Array], jax.Array] 10 | FourArrays = Tuple[jax.Array, jax.Array, jax.Array, jax.Array] 11 | 12 | 13 | def reversible(ctx: Context, fn: ReversibleFn, src: REVERSIBLE_CTX, *args) -> REVERSIBLE_CTX: 14 | if ctx.is_initializing: 15 | params, _x00, x01, x10, x11 = src 16 | new_ctx = ctx.add_to_prefix("reversible") 17 | new_ctx.parameters = params 18 | out = fn(new_ctx, x10, *args) 19 | ctx.parameters = new_ctx.parameters 20 | ctx.name_cache = new_ctx.name_cache 21 | ctx.prng_key = new_ctx.prng_key 22 | return new_ctx.parameters, x10, x11, out, x01 23 | 24 | name_cache = copy.deepcopy(ctx.name_cache) 25 | 26 | def base(params: Dict[str, jax.Array], inp: jax.Array, *inner_args) -> jax.Array: 27 | ctx.name_cache = copy.deepcopy(name_cache) 28 | new_ctx = ctx.add_to_prefix("reversible") 29 | new_ctx.parameters = params 30 | out = fn(new_ctx, inp, *inner_args) 31 | ctx.name_cache = new_ctx.name_cache 32 | return out 33 | 34 | @jax.custom_gradient 35 | def _fn(params: Dict[str, jax.Array], x0: jax.Array, _back_x0: jax.Array, x1: jax.Array, 36 | _back_x1: jax.Array, *inner_args): 37 | def _grad(dy): 38 | d_params_old, dy0, y0, dy1, y1 = dy 39 | x0, grad_fn = jax.vjp(base, params, y0, *inner_args) 40 | d_params, dx0, *_ = grad_fn(dy1) 41 | d_params = {k: d_params_old.get(k, 0) + d_params.get(k, 0) for k in d_params.keys()} 42 | return (d_params, dy1, y1 - x0, dx0 + dy0, y0) + (None,) * len(inner_args) 43 | 44 | out = base(params, x1, *inner_args) + x0 45 | return (params, x1, x1, out, out), _grad 46 | 47 | return _fn(*src, *args) 48 | 49 | 50 | def revnet_out(src: FourArrays) -> jax.Array: 51 | @jax.custom_gradient 52 | def _fn(x0: jax.Array, _x0_back: jax.Array, x1: jax.Array, _x1_back: jax.Array): 53 | def _grad(dy) -> FourArrays: 54 | return dy, x0, dy, x1 55 | 56 | return x0 + x1, _grad 57 | 58 | return _fn(*src) 59 | -------------------------------------------------------------------------------- /src/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | import jax 4 | from jax import lax, numpy as jnp 5 | 6 | from src.backend import add_sq, assign, default, get_param, is_stacked, stable_rsqrt, with_context 7 | from src.constants import MomentumType 8 | from src.context import Context 9 | 10 | 11 | def small_parameter(param_name: str, grad: jax.Array) -> bool: 12 | param_name = param_name.lower() 13 | is_small = any(f'{k}' in param_name for k in ("norm", "rezero")) 14 | is_small |= grad.ndim < (2 + is_stacked(param_name)) 15 | return is_small 16 | 17 | 18 | @with_context() 19 | def ema(ctx: Context, inp: jax.Array, step: jax.Array, beta: float, 20 | momentum_type: Optional[MomentumType] = None) -> jax.Array: 21 | default(momentum_type, ctx.optimizer.momentum_type) 22 | state = get_param(ctx, "momentum_buffer", inp.shape, dtype=ctx.optimizer.momentum_dtype, tied=True, 23 | init_val=jnp.zeros_like(inp)) 24 | if ctx.is_initializing: 25 | return state 26 | 27 | if momentum_type != MomentumType.heavyball: 28 | inp *= 1 - beta 29 | inp = inp.astype(ctx.optimizer.momentum_dtype) 30 | new_state = state * beta + inp 31 | assign(ctx, "momentum_buffer", new_state) 32 | 33 | new_state = new_state.astype(jnp.float64) 34 | if momentum_type == MomentumType.debiased: 35 | new_state = new_state / (1 - beta ** (step + 1)) 36 | 37 | if momentum_type == MomentumType.nesterov: 38 | return new_state * beta + inp 39 | return new_state 40 | 41 | 42 | def norm(param_name: str, val: jax.Array, is_squared: bool = False): 43 | if not is_squared: 44 | val = lax.square(val) 45 | if not is_stacked(param_name): 46 | return val.sum() 47 | return val.sum(tuple(range(1, val.ndim))).reshape((-1,) + (1,) * (val.ndim - 1)) 48 | 49 | 50 | def clip_norm(param_name: str, val: jax.Array, min_norm: float, is_squared: bool = False) -> jax.Array: 51 | return jnp.maximum(jnp.sqrt(norm(param_name, val, is_squared)), min_norm) 52 | 53 | 54 | def adaptive_gradient_clipping(ctx: Context, param_name: str, grad: jax.Array, is_squared: bool) -> jax.Array: 55 | grad = grad.astype(jnp.float64) 56 | grd_norm = clip_norm(param_name, grad, ctx.optimizer.epsilon, is_squared) 57 | wgt_norm = clip_norm(param_name, ctx.parameters[param_name].astype(jnp.float64), 1e-3) 58 | grad_scale = jnp.minimum(wgt_norm / grd_norm * ctx.optimizer.gradient_clip, 1) 59 | return grad * grad_scale 60 | 61 | 62 | def graft(param_name: str, magnitude: jax.Array, direction: jax.Array) -> jax.Array: 63 | return direction * jnp.sqrt(norm(param_name, magnitude) / jnp.maximum(norm(param_name, direction), 1e-16)) 64 | 65 | 66 | def tg_adam(ctx: Context, param_name: str, grad: jax.Array, tg_grad: jax.Array, step: jax.Array) -> jax.Array: 67 | ema_g = ema(ctx, grad, step, 1 - ctx.optimizer.adam_beta1) 68 | ema_gsq = ema(ctx, grad ** 2, step, 1 - ctx.optimizer.adam_beta2) 69 | ema_tgsq = ema(ctx, tg_grad, step, 1 - ctx.optimizer.adam_beta3) 70 | 71 | if ctx.is_initializing: 72 | return grad 73 | 74 | adam_update = ema_g * stable_rsqrt(ema_gsq, ctx.optimizer.epsilon) 75 | tg_update = ema_g * stable_rsqrt(ema_tgsq, ctx.optimizer.epsilon) 76 | return graft(param_name, adam_update, tg_update) 77 | 78 | 79 | def get_current_lr(ctx: Context, step: jax.Array) -> jax.Array: 80 | opt = ctx.optimizer 81 | learning_rate = opt.learning_rate 82 | learning_rate *= jnp.minimum(step, opt.warmup_end).astype(jnp.float64) 83 | learning_rate /= opt.warmup_end 84 | learning_rate *= (1 - opt.exponential_decay) ** jax.nn.relu(step.astype(jnp.float64)) 85 | return learning_rate.astype(ctx.model.storage_dtype) 86 | 87 | 88 | def update(ctx: Context, grads: Dict[str, jax.Array], step: jax.Array): 89 | outer_ctx = ctx.add_to_prefix("optimizer") 90 | lr = -get_current_lr(ctx, step) 91 | 92 | for param_name, grad in grads.items(): 93 | if "optimizer" in param_name or param_name.endswith('_sq') or param_name.endswith('_sq_stacked'): 94 | continue 95 | ctx = outer_ctx.add_to_prefix(param_name, count=False) 96 | ctx.name_cache = {} 97 | dtype = ctx.parameters[param_name].dtype 98 | parameter_lr = lr * ctx.parameter_variance.get(param_name, 1) 99 | 100 | grad = adaptive_gradient_clipping(ctx, param_name, grad, False) 101 | grad_sq = adaptive_gradient_clipping(ctx, param_name, grads[add_sq(param_name)], True) 102 | weight_update = tg_adam(ctx, param_name, grad, grad_sq, step) * parameter_lr 103 | 104 | if ctx.is_initializing: 105 | continue 106 | 107 | param = ctx.parameters[param_name].astype(jnp.float64) 108 | if not small_parameter(param_name, grad): 109 | param *= 1 + ctx.optimizer.weight_decay * parameter_lr 110 | ctx.parameters[param_name] = (param + weight_update).astype(dtype) 111 | -------------------------------------------------------------------------------- /src/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/kingoflolz/mesh-transformer-jax/blob/0a75ca9370576ad9d247facf6cb8e9699300e690 3 | /mesh_transformer/checkpoint.py 4 | """ 5 | import datetime 6 | import functools 7 | import io 8 | import json 9 | import multiprocessing 10 | import re 11 | import subprocess 12 | import time 13 | from typing import Any 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import numpy as np 18 | from jax.tree_util import PyTreeDef 19 | from smart_open import open as smart_open 20 | 21 | from src.backend import deep_replace, is_main 22 | from src.context import Context, WhileTrainContext 23 | 24 | UPLOAD_RETRIES = 8 25 | WCTX_VALUES = ("scalars", "current_step") 26 | TMP_PATH_ADDON = "_____TEMPORARY" 27 | GSUTIL_PATH = "/opt/google-cloud-sdk/bin/gsutil" 28 | 29 | 30 | def log(arg: str, verbose: bool): 31 | if verbose: 32 | print(datetime.datetime.now(), arg) 33 | 34 | 35 | def write_shard(weights: Any, idx: int, prefix: str, filename: str, verbose: bool): 36 | path = f"{prefix}/{jax.process_index()}/{idx}/{filename}" 37 | shard = jax.device_put(jax.tree_util.tree_map(lambda i: i[idx], weights), jax.devices("cpu")[0]) 38 | log(f"Uploading {len(shard)} objects to {path}", verbose) 39 | for _ in range(UPLOAD_RETRIES): 40 | try: 41 | with smart_open(path, "wb") as f: 42 | np.savez(f, **{str(idx): tensor for idx, tensor in enumerate(shard)}) 43 | return 44 | except: # skipcq: FLK-E722 45 | log(f"Couldn't save to {path}. Retrying now.", verbose) 46 | 47 | log(f"Saving to {path} failed {UPLOAD_RETRIES} times. Skipping this checkpoint.", True) 48 | 49 | 50 | def cmd(command: str, check: bool = True): 51 | return subprocess.run(command.split(' '), check=check) 52 | 53 | 54 | def move_checkpoint(ctx: Context, new: str): 55 | cmd(f"{GSUTIL_PATH} -m rm -r {new}", False) # ignore exit code 56 | cmd(f"{GSUTIL_PATH} -m cp -r {ctx.training.checkpoint_path} {new}") 57 | cmd(f"{GSUTIL_PATH} -m rm -r {ctx.training.checkpoint_path}") 58 | 59 | 60 | def write_checkpoint(ctx: Context, verbose: bool = True): 61 | flattened, jax_structure = jax.tree_util.tree_flatten(ctx.parameters) 62 | variance, _ = jax.tree_util.tree_flatten(ctx.parameter_variance) # same structure 63 | 64 | structure = str(jax_structure) # like "PyTreeDef({'2': {'a': *}})" 65 | structure = structure.replace('PyTreeDef', '')[1:-1] # clean up "types" 66 | structure = structure.replace(': *', ': null').replace("{'", '{"').replace("':", '":') 67 | structure = structure.replace("', ", '", ').replace(", '", ', "') # to valid JSON 68 | 69 | if is_main(): 70 | log(f"Saving structure to {ctx.training.checkpoint_path}/structure.json", verbose) 71 | for _ in range(UPLOAD_RETRIES): 72 | try: 73 | with smart_open(f"{ctx.training.checkpoint_path}/structure.json", "w") as f: # skipcq: PTC-W6004 74 | f.write(structure) 75 | break 76 | except: # skipcq: FLK-E722 77 | log("Couldn't save structure. Retrying now.", verbose) 78 | 79 | for shard in range(jax.local_device_count()): 80 | for tree, suffix in ((flattened, "parameters"), (variance, "variance")): 81 | write_shard(tree, shard, ctx.training.checkpoint_path, f"{suffix}.npz", verbose) 82 | 83 | 84 | @functools.partial(jax.pmap, axis_name="i") 85 | def _sync(x): 86 | return jax.lax.psum(x, "i") == jax.device_count() 87 | 88 | 89 | def sync(): 90 | if not _sync(jnp.ones(jax.local_device_count()))[0]: 91 | raise ValueError 92 | 93 | 94 | def write_train_checkpoint(wctx: WhileTrainContext, verbose: bool = True): 95 | real_path = wctx.ctx.training.checkpoint_path 96 | wctx.ctx.training.checkpoint_path = real_path + TMP_PATH_ADDON 97 | 98 | write_checkpoint(wctx.ctx, verbose) 99 | for shard in range(jax.local_device_count()): 100 | for val in WCTX_VALUES: 101 | write_shard([getattr(wctx, val)], shard, wctx.ctx.training.checkpoint_path, val + ".npz", verbose) 102 | 103 | sync() # ensure that all nodes wrote their respective checkpoints before moving them collectively 104 | if is_main(): 105 | move_checkpoint(wctx.ctx, real_path) 106 | wctx.ctx.training.checkpoint_path = real_path 107 | 108 | 109 | def read_shard(checkpoint_dir): 110 | with smart_open(checkpoint_dir, "rb") as f: 111 | buf = f.read() 112 | f_io = io.BytesIO(buf) 113 | deserialized = list(np.load(f_io).items()) 114 | return [tensor for idx, tensor in sorted(deserialized, key=lambda x: int(x[0]))] 115 | 116 | 117 | def unshard(shards): 118 | unsharded = [] 119 | for all_shards in zip(*shards): 120 | x = np.stack(all_shards) 121 | if x.dtype == np.dtype('V2'): 122 | x.dtype = jnp.bfloat16 123 | unsharded.append(x) # manual jnp.asarray -> replicated; automatic (via jax.pmap) -> parallel (as before) 124 | return unsharded 125 | 126 | 127 | def _read_shards(path: str, structure: PyTreeDef, suffix: str): 128 | with multiprocessing.pool.ThreadPool(jax.local_device_count()) as p: 129 | start = time.time() 130 | paths = [f"{path}/{jax.process_index()}/{shard}/{suffix}.npz" for shard in range(jax.local_device_count())] 131 | shards = list(p.map(read_shard, paths)) 132 | print(f"Loading {suffix} took {time.time() - start:.2f}s") 133 | 134 | return structure.unflatten(unshard(shards)) 135 | 136 | 137 | def _overwrite(new: dict, old: dict, ignore: re.Pattern): 138 | if not old: 139 | print("No entries in old dict. Using new dict.") 140 | for key, param in new.items(): 141 | old[key] = param 142 | return 143 | 144 | print("Unknown: ", [p for p in new.keys() if p not in old and not ignore.match(p)]) 145 | print("Unfilled: ", [p for p in old.keys() if p not in new and not ignore.match(p)]) 146 | 147 | for key in old.keys(): 148 | if key in new: 149 | old[key] = new[key] 150 | 151 | 152 | def read_checkpoint(ctx: Context, ignore: str = '.*optimizer.*', load_variance: bool = False): 153 | ignore = re.compile(ignore) 154 | 155 | with smart_open(f"{ctx.training.checkpoint_load_path}/structure.json", "r") as f: 156 | structure = f.read() 157 | structure = json.loads(structure) 158 | py_structure = deep_replace(structure, jnp.zeros((1,))) 159 | _, structure = jax.tree_util.tree_flatten(py_structure) 160 | 161 | _overwrite(_read_shards(ctx.training.checkpoint_load_path, structure, "parameters"), ctx.parameters, ignore) 162 | 163 | if load_variance: 164 | py_structure = {k: v for k, v in py_structure.items() if "optimizer" not in k} # no optimizer for param-lr 165 | _, structure = jax.tree_util.tree_flatten(py_structure) 166 | _overwrite(_read_shards(ctx.training.checkpoint_load_path, structure, "variance"), ctx.parameter_variance, 167 | ignore) 168 | 169 | 170 | def read_train_checkpoint(wctx: WhileTrainContext, ignore: str = '.*optimizer.*'): 171 | _, structure = jax.tree_util.tree_flatten([jnp.zeros((1,))]) 172 | for val in WCTX_VALUES: 173 | setattr(wctx, val, _read_shards(wctx.ctx.training.checkpoint_load_path, structure, val)[0]) 174 | read_checkpoint(wctx.ctx, ignore, load_variance=True) 175 | -------------------------------------------------------------------------------- /src/utils/wandblog.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import time 3 | from typing import List 4 | 5 | import numpy as np 6 | 7 | from src.context import WhileTrainContext 8 | 9 | 10 | class WandbLog: 11 | def __init__(self, run, device_steps: int, param_count: int, tokens_per_step: int): 12 | self.start_time = time.time() 13 | self.run = run 14 | self.scalars = collections.defaultdict(list) 15 | self.device_steps = device_steps 16 | self.param_count = param_count 17 | self.tokens_per_step = tokens_per_step 18 | self.first_step = None 19 | 20 | def _log(self, prefix: str, value: float, sizes: List[int]): 21 | scalars = self.scalars[prefix] 22 | value = float(value) 23 | scalars.append(value) 24 | items = {f"{prefix}/Median{s * self.device_steps}": np.median(scalars[-s:]) for s in sizes} 25 | self.scalars[prefix] = scalars[-max(sizes):] 26 | items[f"{prefix}/Current"] = value 27 | return items 28 | 29 | def __call__(self, wctx: WhileTrainContext, step: int, current_lr) -> bool: 30 | if self.first_step is None: 31 | self.first_step = step - self.device_steps 32 | rate = (step - self.first_step) / (time.time() - self.start_time) 33 | 34 | ctx = wctx.ctx 35 | sizes = [s // self.device_steps for s in ctx.wandb.median_sizes] 36 | 37 | tokens_per_day = 3600 * 24 * rate * ctx.dims.batch * ctx.dims.sequence 38 | items = {"Optimizer/Learning Rate": current_lr, 39 | "Speed/Batches per Second": rate, 40 | "Speed/Tokens per Day": tokens_per_day, 41 | "Speed/Parameters * Tokens per Day": tokens_per_day * self.param_count, 42 | "Speed/Tokens Seen": step * self.tokens_per_step} 43 | 44 | items.update(self._log("Loss", wctx.scalars[0, 0], sizes)) 45 | items.update(self._log("Accuracy", wctx.scalars[0, 1], sizes)) 46 | 47 | self.run.log(items, step=step) 48 | 49 | return any(val in (float("nan"), float("inf"), float("-inf")) for val in wctx.scalars[0, :]) 50 | -------------------------------------------------------------------------------- /sweep.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | method: bayes 3 | metric: 4 | name: Loss/Median1024 5 | goal: minimize 6 | command: 7 | - bash 8 | - run.sh 9 | - ${args} 10 | parameters: # See https://wandb.ai/homebrewnlp/gpt/sweeps/xuwcs6i1 for initial sweep space 11 | optimizer.learning_rate: 12 | distribution: log_uniform_values 13 | min: 0.001 14 | max: 10 15 | optimizer.adam_beta2: 16 | distribution: log_uniform_values 17 | min: 0.001 18 | max: 0.1 19 | optimizer.adam_beta1: 20 | distribution: log_uniform_values 21 | min: 0.01 22 | max: 1 23 | optimizer.momentum_beta: 24 | distribution: log_uniform_values 25 | min: 0.01 26 | max: 1 27 | dims.sizes.batch: 28 | distribution: q_log_uniform_values 29 | min: 32 30 | max: 256 31 | q: 8 32 | 33 | # New parameters 34 | optimizer.gradient_clip: 35 | distribution: log_uniform_values 36 | min: 0.0001 37 | max: 1 38 | optimizer.weight_decay: 39 | distribution: log_uniform_values 40 | min: 0.0001 41 | max: 1 42 | training.z_loss: 43 | distribution: log_uniform_values 44 | min: 0.00001 45 | max: 1 -------------------------------------------------------------------------------- /train_watcher.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import string 4 | import typing 5 | from netrc import netrc 6 | 7 | import shortuuid 8 | import tpucare 9 | import wandb 10 | import yaml 11 | from tpucare import delete_one_tpu, exec_command, exec_on_tpu, send_to_tpu, start_single 12 | 13 | from src.context import Context 14 | 15 | tpucare.LOG_LEVEL = 0 16 | _, _, wandb_key = netrc().authenticators("api.wandb.ai") 17 | 18 | 19 | @dataclasses.dataclass 20 | class TPUContext: 21 | zone: str 22 | host: str 23 | config: dict 24 | branch: str 25 | 26 | 27 | class Args: 28 | host: str 29 | tpu_version: int 30 | zone: str 31 | data_path: str 32 | preemptible: bool 33 | service_account: str 34 | branch: str 35 | slices: int 36 | storage_prefix: str 37 | config_path: str 38 | cleanup: int 39 | merge_runs: bool 40 | 41 | 42 | def start_fn(ctx: TPUContext, worker: int): 43 | setup = '(bash setup.sh ; mv ~/config.yaml ~/HomebrewNLP-Jax/config.yaml ; exit 0)' 44 | send_to_tpu(ctx.host, ctx.zone, "config.yaml", yaml.dump(ctx.config), worker) 45 | cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key, 46 | setup_command=setup, run_command="CONFIG=config.yaml bash run.sh", branch=ctx.branch, 47 | install_python=False) 48 | send_to_tpu(ctx.host, ctx.zone, "setup.sh", cmd, worker) 49 | exec_on_tpu(ctx.host, ctx.zone, "bash setup.sh", worker) 50 | 51 | 52 | def parse_args() -> Args: 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--host", type=str, help="Name of the TPU") 55 | parser.add_argument("--tpu-version", type=int, default=3, help="Which TPU version to create (v2-8 or v3-8)") 56 | parser.add_argument("--zone", type=str, default="europe-west4-a", help="GCP Zone TPUs get created in") 57 | parser.add_argument("--data-path", type=str, default="gs://ggpt4/the-char-pile/", 58 | help="Where the data is stored. Should be changed to a bucket in the correct region") 59 | parser.add_argument("--preemptible", default=1, type=int, 60 | help="Whether to create preemptible or non-preemptible TPUs") 61 | parser.add_argument("--service-account", type=str, 62 | help="Service account that controls permissions of TPU (for example, to ensure EU TPUs won't " 63 | "use US data)") 64 | parser.add_argument("--branch", type=str, default="main", help="Branch on github to use") 65 | parser.add_argument("--slices", default=1, type=int, 66 | help="How many TPU slices each TPU should have (1=>vX-8, 4=>vX-32)") 67 | parser.add_argument("--storage-prefix", type=str, help="Storage prefix to use for weights on gcloud bucket") 68 | parser.add_argument("--config-path", type=str, help="Path to config.yaml") 69 | parser.add_argument("--cleanup", default=0, type=int, 70 | help="Instead of running something new, kill all tpus. 1 or 0 for y/n") 71 | parser.add_argument("--merge-runs", default=1, type=int, 72 | help="Whether to merge all WandB runs into one logstream or keep one for each host.") 73 | return parser.parse_args() 74 | 75 | 76 | def new_id(): 77 | return str(shortuuid.ShortUUID(alphabet=string.digits + string.ascii_lowercase).random(32)) 78 | 79 | 80 | class CreationCallback: 81 | def __init__(self, args: Args): 82 | self.args = args 83 | self.restarts = 0 84 | 85 | with open(args.config_path, 'r') as f: # skipcq: PTC-W6004 86 | txt = f.read() 87 | config = yaml.safe_load(txt) 88 | cfg = Context(config) 89 | cfg.training.do_checkpoint = True 90 | cfg.data.path = args.data_path 91 | cfg.wandb.group = args.host 92 | 93 | if args.merge_runs: 94 | cfg.wandb.id = new_id() 95 | 96 | cfg.training.checkpoint_path = f'{cfg.training.checkpoint_path}-{args.storage_prefix}' 97 | self.wandb_api = wandb.Api() 98 | self.cfg = cfg 99 | 100 | def _prepare_config(self): # load checkpoint if exists and avoid overwriting logs at 1000 if already up to 1500 101 | try: 102 | run = self.wandb_api.run(f'{self.cfg.wandb.entity}/{self.cfg.wandb.project}/{self.cfg.wandb.id}') 103 | start_step = int(run.summary["_step"]) 104 | except: # skipcq: FLK-E722 105 | return # no logs yet 106 | finally: 107 | self.cfg.wandb.id = new_id() 108 | self.restarts += 1 109 | if start_step < self.cfg.training.checkpoint_interval: 110 | self.cfg.training.checkpoint_load_path = "" 111 | else: 112 | self.cfg.training.checkpoint_load_path = self.cfg.training.checkpoint_path 113 | 114 | def __call__(self, host: str, ctx: typing.Optional[TPUContext]) -> TPUContext: 115 | if ctx is not None: # every call after 0th 116 | self._prepare_config() 117 | self.cfg.wandb.name = f'{self.args.host}-{self.restarts}' 118 | print(self.cfg) 119 | return TPUContext(zone=self.args.zone, host=host, config=self.cfg.config(), branch=self.args.branch) 120 | 121 | 122 | def main(): 123 | args = parse_args() 124 | if args.cleanup: 125 | delete_one_tpu("", args.host, args.zone) 126 | 127 | start_single(args.host, args.tpu_version, args.zone, args.preemptible, args.service_account, args.slices, start_fn, 128 | CreationCallback(args)) 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /unittests/consistency/step.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import jax 4 | import pytest 5 | from jax import numpy as jnp 6 | 7 | from src.constants import ParallelAxes 8 | from src.context import WhileTrainContext 9 | from src.model.main import body_ctx 10 | from src.main import add_zeros 11 | 12 | 13 | def get_wctx(config: typing.Optional[typing.Dict[str, typing.Any]] = None): 14 | wctx = WhileTrainContext(config) 15 | ctx = wctx.ctx 16 | 17 | ctx.dims.batch = 16 18 | ctx.dims.spatial_mixing_kernel = 8 19 | ctx.dims.sequence = 128 20 | ctx.dims.features = 16 21 | ctx.dims.pointwise_features = 32 22 | ctx.dims.inner_bottleneck_features = 8 23 | 24 | return wctx, ctx 25 | 26 | 27 | def replicate(x: typing.Any) -> typing.Any: 28 | return jax.device_put_replicated(x, jax.local_devices()) 29 | 30 | 31 | def pmap(config: typing.Optional[typing.Dict[str, typing.Any]]): 32 | _, ctx = get_wctx() 33 | src = replicate(jnp.zeros((ctx.dims.batch, ctx.dims.sequence), dtype=jnp.int32)) 34 | name_cache = {} 35 | parameter_usages = {} 36 | 37 | def _fn(x, cfg): 38 | wctx, ctx = get_wctx(cfg) 39 | ctx.fail_on_missing_parameter = False 40 | ctx.is_initializing = config is None 41 | add_zeros(ctx.parameters) 42 | _ = body_ctx(ctx, x) 43 | for k in list(ctx.parameters.keys()): 44 | if "optimizer" in k or k.endswith('_sq') or k.endswith('_sq_stacked'): 45 | del ctx.parameters[k] 46 | name_cache.update(ctx.name_cache) 47 | parameter_usages.update(ctx.parameter_usages) 48 | return wctx.serialize() 49 | 50 | out = jax.pmap(_fn, ParallelAxes.model)(src, config) 51 | return WhileTrainContext(out), name_cache, parameter_usages 52 | 53 | 54 | class BaseTest: 55 | def __init__(self): 56 | self.export1, self.name_cache1, self.usages1 = pmap(None) 57 | self.export2, self.name_cache2, self.usages2 = pmap(self.export1.serialize()) 58 | 59 | @staticmethod 60 | def check(dict1: typing.Dict[str, typing.Any], dict2: typing.Dict[str, typing.Any], 61 | cond: typing.Callable[[str, typing.Any, typing.Dict[str, typing.Any]], bool]): 62 | wrong_in_1 = [k for k, v in dict1.items() if cond(k, v, dict2)] 63 | wrong_in_2 = [k for k, v in dict2.items() if cond(k, v, dict1)] 64 | dict1 = {k: f'{str(v)[:10]}...' if len(str(v)) > 12 else str(v) for k, v in dict1.items()} 65 | dict2 = {k: f'{str(v)[:10]}...' if len(str(v)) > 12 else str(v) for k, v in dict2.items()} 66 | print(f"{dict1=}\n{dict2=}") 67 | print() 68 | if wrong_in_1 or wrong_in_2: 69 | raise ValueError(f"{wrong_in_1=}\n{wrong_in_2=}") 70 | 71 | def is_in(self, dict1: typing.Dict[str, typing.Any], dict2: typing.Dict[str, typing.Any]): 72 | self.check(dict1, dict2, lambda k, v, d: k not in d) 73 | 74 | def same_shape(self, dict1: typing.Dict[str, typing.Any], dict2: typing.Dict[str, typing.Any]): 75 | self.is_in(dict1, dict2) 76 | self.check(dict1, dict2, lambda k, v, d: v.shape != d[k].shape) 77 | 78 | def equal(self, dict1: typing.Dict[str, typing.Any], dict2: typing.Dict[str, typing.Any]): 79 | self.is_in(dict1, dict2) 80 | self.check(dict1, dict2, lambda k, v, d: v != d[k]) 81 | 82 | def __call__(self): 83 | raise NotImplementedError 84 | 85 | 86 | class NameCache(BaseTest): 87 | def __call__(self): 88 | self.equal(self.name_cache1, self.name_cache2) 89 | 90 | 91 | class ParameterUsage(BaseTest): 92 | def __call__(self): 93 | self.equal(self.usages1, self.usages2) 94 | 95 | 96 | class ParameterShapes(BaseTest): 97 | def __call__(self): 98 | self.same_shape(self.export1.ctx.parameters, self.export2.ctx.parameters) 99 | 100 | 101 | class ParameterVariance(BaseTest): 102 | def __call__(self): 103 | self.same_shape(self.export1.ctx.parameter_variance, self.export2.ctx.parameter_variance) 104 | 105 | 106 | classes = [NameCache, ParameterUsage, ParameterShapes, ParameterVariance] 107 | 108 | 109 | @pytest.mark.parametrize("cls", classes) 110 | def test(cls: type): 111 | cls()() 112 | 113 | 114 | def main(): 115 | for cls in classes: 116 | test(cls) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /unittests/grad/activation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from jax import numpy as jnp 3 | 4 | from src.context import Context 5 | from src.model.activate import activate, activate_forward 6 | from unittests.grad.backend import grad_fn, randn_fn, trials, sample_sizes 7 | 8 | 9 | @pytest.mark.parametrize("samples", sample_sizes) 10 | def test_grad(samples: int): # skipcq: PYL-W0640 11 | ctx = Context() 12 | ctx.is_initializing = False 13 | randn = randn_fn() 14 | for _ in range(trials): 15 | inp = randn(samples) 16 | dy = randn(samples) 17 | grad = grad_fn(dy, inp) 18 | out0, = grad(lambda x: activate(x[0])) 19 | out1, = grad(lambda x: activate_forward(x[0])) 20 | assert jnp.allclose(out0, out1) 21 | -------------------------------------------------------------------------------- /unittests/grad/backend.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import jax 4 | from jax import numpy as jnp 5 | 6 | from src.constants import ParallelAxes 7 | 8 | trials = 1 9 | sample_sizes = [2 ** 6] 10 | 11 | 12 | def randn_fn(): 13 | rng = random.Random(0) 14 | 15 | def _fn(*shape: int): 16 | seed = rng.randint(0, 2 ** 30) 17 | div = (shape[-1] * jax.device_count()) ** 0.25 18 | fn = jax.pmap( 19 | lambda x: jax.random.normal(jax.random.PRNGKey(x + seed), shape, jnp.float32).astype(jnp.float_) / div) 20 | local_devices = jax.local_device_count() 21 | seeds = jnp.arange(local_devices * jax.process_index(), local_devices * (1 + jax.process_index())) 22 | return fn(seeds) 23 | 24 | return _fn 25 | 26 | 27 | def grad_fn(dy: jax.Array, *args): 28 | def _fn(fn): 29 | return jax.pmap(jax.grad(lambda x: (fn(x) * dy).sum()), ParallelAxes.model)(args) 30 | 31 | return _fn 32 | -------------------------------------------------------------------------------- /unittests/grad/leak.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import jax 4 | import pytest 5 | import tqdm 6 | from jax import lax, numpy as jnp 7 | 8 | from src.constants import ParallelAxes 9 | from src.context import Context 10 | from src.model.main import stem 11 | from src.model.reversible import revnet_out 12 | from unittests.grad.backend import grad_fn, randn_fn, trials 13 | 14 | 15 | def mean(x: jax.Array): 16 | return (x / x.size).sum() 17 | 18 | 19 | def initialize(samples: int): 20 | ctx = Context() 21 | ctx.dims.sequence = samples // 2 22 | ctx.dims.batch = 2 23 | return ctx, randn_fn() 24 | 25 | 26 | def randn_zero(ctx: Context, randn, zero_from: int): 27 | dy = randn(ctx.dims.batch, ctx.dims.sequence, ctx.dims.features) 28 | dy = dy[:, :, :zero_from, :] 29 | 30 | def _inner_fn(x: jax.Array): 31 | zeros = jnp.zeros((ctx.dims.batch, ctx.dims.sequence - zero_from, ctx.dims.features)) 32 | return jnp.concatenate([x, zeros], 1) 33 | 34 | return jax.pmap(_inner_fn)(dy) 35 | 36 | 37 | @pytest.mark.parametrize("samples", [8, 128]) 38 | @pytest.mark.parametrize("depth", [2, 8]) 39 | def test(samples: int, depth: int): 40 | ctx, randn = initialize(samples) 41 | ctx.is_initializing = True 42 | ctx.dims.depth = depth 43 | ctx.dims.features = 8 44 | ctx.dims.inner_bottleneck_features = 4 45 | ctx.dims.pointwise_features = 16 46 | ctx.dims.spatial_mixing_kernel = ctx.dims.sequence // 2 47 | src = randn(ctx.dims.batch, ctx.dims.sequence, ctx.dims.features).astype(jnp.bfloat16) 48 | 49 | def _fn(x: jax.Array): 50 | stem(ctx, (x, jnp.zeros_like(x), x, jnp.zeros_like(x))) 51 | params = ctx.parameters 52 | ctx.parameters = {} 53 | return params 54 | 55 | params = jax.pmap(_fn, ParallelAxes.model)(src) 56 | ctx.is_initializing = False 57 | 58 | def _inner(inp: typing.Tuple[typing.Dict[str, jax.Array], jax.Array]): 59 | params, x = inp 60 | ctx.name_cache = {} 61 | ctx.parameters = params 62 | out = stem(ctx, (x, jnp.zeros_like(x), x, jnp.zeros_like(x))) 63 | ctx.parameters = {} 64 | return revnet_out(out) 65 | 66 | for _ in range(trials): 67 | for i in tqdm.tqdm(range(1, ctx.dims.sequence + 1)): 68 | dy = randn_zero(ctx, randn, i) 69 | d_src = grad_fn(dy, params, src)(_inner)[1] 70 | d_src = lax.rev(d_src, (2,)) != 0 71 | seq_grad = d_src.sum((0, 1, 3)) > 0 72 | print(seq_grad) 73 | for j, itm in enumerate(seq_grad, 1): 74 | assert itm == (j > (ctx.dims.sequence - i)) 75 | -------------------------------------------------------------------------------- /unittests/grad/loss.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import pytest 4 | from jax import lax, numpy as jnp 5 | 6 | from src.backend import is_main, matmul 7 | from src.constants import ParallelAxes 8 | from src.context import Context 9 | from src.model.loss import cross_entropy_loss 10 | from unittests.grad.backend import grad_fn, randn_fn, sample_sizes, trials 11 | 12 | 13 | def mean(x: jax.Array): 14 | return (x / x.size).sum() 15 | 16 | 17 | def naive_loss(x, y, z_loss): 18 | tmp = lax.psum(matmul(x[0], x[1]), ParallelAxes.model) 19 | lse = jax.nn.logsumexp(tmp, -1) 20 | pos = mean(lse) 21 | neg = mean(jnp.take_along_axis(tmp.reshape(-1, tmp.shape[-1]), y.reshape(-1, 1), -1)) 22 | z_loss = mean(lse ** 2 * z_loss) 23 | z_loss = z_loss - lax.stop_gradient(z_loss) 24 | return pos - neg + z_loss 25 | 26 | 27 | def initialize(z_loss: float, samples: int): 28 | ctx = Context() 29 | ctx.training.z_loss = z_loss 30 | ctx.dims.sequence = int(samples ** 0.5) 31 | ctx.dims.batch = int(samples ** 0.5) 32 | 33 | tgt = jax.random.randint(jax.random.PRNGKey(0), (ctx.dims.batch, ctx.dims.sequence), 0, ctx.dims.vocab) 34 | 35 | return ctx, tgt, randn_fn() 36 | 37 | 38 | def statistics(name: str, var: jax.Array): 39 | _fn = jax.pmap(lambda x: (lax.pmax(x.max(), "i"), lax.pmin(x.min(), "i"), lax.pmean(x.mean(), "i"), 40 | lax.pmean(jnp.square(x - x.mean()).mean(), "i")), "i") 41 | if is_main(): 42 | vmax, vmin, vmean, vstd = [float(a[0]) for a in _fn(var)] 43 | print(f"{name}: max={vmax}, min={vmin}, mean={vmean}, std={vstd}") 44 | 45 | 46 | def general_value_test(z_loss: float, samples: int, vocab: int): # skipcq: PYL-W0640 47 | ctx, tgt, randn = initialize(z_loss, samples) 48 | ctx.dims.vocab = vocab 49 | 50 | for _ in range(trials): 51 | src = randn(ctx.dims.batch, ctx.dims.sequence, ctx.dims.features) 52 | wgt = randn(ctx.dims.features, ctx.dims.vocab) 53 | 54 | grad0 = float(jax.pmap(lambda x: cross_entropy_loss(ctx, x, tgt)[0], ParallelAxes.model)((src, wgt, wgt))[0]) 55 | grad1 = float(jax.pmap(lambda x: naive_loss(x, tgt, z_loss), ParallelAxes.model)((src, wgt, wgt))[0]) 56 | assert np.isclose(grad0, grad1) 57 | 58 | 59 | def general_grad_test(z_loss: float, samples: int, vocab: int): # skipcq: PYL-W0640 60 | ctx, tgt, randn = initialize(z_loss, samples) 61 | ctx.dims.vocab = vocab 62 | 63 | for _ in range(trials): 64 | src = randn(ctx.dims.batch, ctx.dims.sequence, ctx.dims.features) 65 | wgt = randn(ctx.dims.features, ctx.dims.vocab) 66 | dy = randn(2) 67 | grad = grad_fn(dy, src, wgt) 68 | 69 | grad0 = grad(lambda x: cross_entropy_loss(ctx, x, tgt)[0]) 70 | grad1 = grad(lambda x: naive_loss(x, tgt, z_loss)) 71 | 72 | for g0, g1 in zip(grad0, grad1): 73 | statistics("Grad0", g0) 74 | statistics("Grad1", g1) 75 | statistics("abs(Grad0 - Grad1)", jax.pmap(lambda x, y: jnp.abs(x - y), "i")(g0, g1)) 76 | statistics("abs(Grad0 / Grad1)", jax.pmap(lambda x, y: jnp.abs(x / y), "i")(g0, g1)) 77 | allclose = int(jax.pmap(lambda x, y: lax.psum(jnp.allclose(x, y).astype(jnp.float32), "i"), "i")(g0, g1)[0]) 78 | if is_main(): 79 | print(f'{allclose=}/{jax.device_count()}\n') 80 | assert allclose == jax.device_count() 81 | 82 | 83 | @pytest.mark.parametrize("z_loss", [1, 0.01, 0]) 84 | @pytest.mark.parametrize("samples", sample_sizes) 85 | def test_z_loss_value(z_loss: float, samples: int): 86 | general_value_test(z_loss, samples, 65536) 87 | 88 | 89 | @pytest.mark.parametrize("vocab", [256, 65536]) 90 | def test_vocab_value(vocab: int, samples: int): 91 | general_value_test(0.01, samples, vocab) 92 | 93 | 94 | @pytest.mark.parametrize("z_loss", [1, 0.01, 0]) 95 | @pytest.mark.parametrize("samples", sample_sizes) 96 | def test_z_loss_grad(z_loss: float, samples: int): 97 | general_grad_test(z_loss, samples, 65536) 98 | 99 | 100 | @pytest.mark.parametrize("vocab", [256, 65536]) 101 | @pytest.mark.parametrize("samples", sample_sizes) 102 | def test_vocab_grad(vocab: int, samples: int): 103 | general_grad_test(0.01, samples, vocab) 104 | -------------------------------------------------------------------------------- /unittests/grad/norm.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import pytest 3 | from jax import numpy as jnp 4 | 5 | from src.context import Context 6 | from src.model.norm import norm_forward, scale_norm_act 7 | from unittests.grad.backend import grad_fn, randn_fn, sample_sizes, trials 8 | 9 | 10 | def general_test(act: bool, psum: bool, samples: int, dim: int): # skipcq: PYL-W0640 11 | ctx = Context() 12 | ctx.is_initializing = False 13 | randn = randn_fn() 14 | for trial in range(trials): 15 | src = randn(int(samples ** 0.5), int(samples ** 0.5), ctx.dims.features) 16 | multiplier = jax.device_count() if psum else 1 17 | out_shape = list(src.shape)[1:] 18 | out_shape[dim] *= multiplier 19 | wgt = randn(out_shape[dim]) 20 | wgt_sq = randn(out_shape[dim]) 21 | dy = randn(*out_shape) 22 | print(dy.shape, src.shape, wgt.shape) 23 | grad = grad_fn(dy, src, wgt, wgt_sq) 24 | 25 | print(trial) 26 | shape = (1,) * dim + (-1,) + (1,) * (src.ndim - 2 - dim) 27 | out0 = grad(lambda x: norm_forward(ctx, x[0], x[1].reshape(shape), bool(psum), act, dim)[0]) 28 | out1 = grad(lambda x: scale_norm_act(ctx, x[0], ctx.dims.features, (x[1], x[2]), bool(psum), act, dim)) 29 | 30 | assert jnp.allclose(out0[0], out1[0]) 31 | assert jnp.allclose(out0[1], out1[1]) 32 | 33 | 34 | @pytest.mark.parametrize("act", [True, False]) 35 | @pytest.mark.parametrize("samples", sample_sizes) 36 | def test_act(act: bool, samples: int): 37 | general_test(act, False, samples, 2) 38 | 39 | 40 | @pytest.mark.parametrize("psum", [False, True]) 41 | @pytest.mark.parametrize("samples", sample_sizes) 42 | def test_psum(psum: bool, samples: int): 43 | general_test(True, psum, samples, 2) 44 | 45 | 46 | @pytest.mark.parametrize("dim", [0, 1, 2]) 47 | @pytest.mark.parametrize("samples", sample_sizes) 48 | def test_dim(dim: int, samples: int): 49 | general_test(True, False, samples, dim) 50 | --------------------------------------------------------------------------------