├── .github └── workflows │ └── lint_python.yml ├── .gitignore ├── CITATION.bib ├── LICENSE.txt ├── README.md ├── benchmarks.md ├── colab_demo.ipynb ├── configs ├── 6B_roto_256.json └── example_config.json ├── create_finetune_tfrecords.py ├── data ├── example.train.index ├── openwebtext2_new_inputs.train.index ├── openwebtext2_new_inputs.val.index ├── openwebtext2_new_inputs_shuffled.train.index ├── pile.train.index ├── pile.val.index ├── wit_dalle.train.index └── wit_dalle.val.index ├── device_sample.py ├── device_serve.py ├── device_train.py ├── docker ├── .env ├── Dockerfile ├── README.md ├── compose-proxy.yaml ├── docker-compose.yaml ├── main.py ├── nginx_proxyvm.conf ├── ops.py ├── payloads.py └── start.sh ├── eval_harness.py ├── howto_finetune.md ├── mesh_transformer ├── TPU_cluster.py ├── __init__.py ├── build_model.py ├── checkpoint.py ├── layers.py ├── sampling.py ├── train_actor.py ├── transformer_shard.py └── util.py ├── ray_tpu.py ├── requirements.txt ├── resharding_example.py ├── scripts ├── create_serve_tpu.sh ├── deploy_server.sh ├── init_ray.sh ├── init_ray_v2.sh └── init_serve.sh ├── setup.py ├── slim_model.py ├── tasks ├── __init__.py ├── eval_harness.py └── util.py ├── tfrecord_loader.py ├── to_hf_weights.py └── train.py /.github/workflows/lint_python.yml: -------------------------------------------------------------------------------- 1 | name: lint_python 2 | on: [pull_request, push] 3 | jobs: 4 | lint_python: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v2 8 | - uses: actions/setup-python@v2 9 | - run: pip install bandit black codespell flake8 isort mypy pytest pyupgrade safety 10 | - run: bandit --recursive --skip B101 . || true # B101 is assert statements 11 | - run: black --check . || true 12 | - run: codespell 13 | - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 14 | - run: flake8 . --count --exit-zero --max-complexity=19 --max-line-length=245 --show-source --statistics 15 | - run: isort --check-only --profile black . || true 16 | - run: pip install -r requirements.txt 17 | - run: mypy --ignore-missing-imports . || true 18 | - run: pytest . || true 19 | - run: pytest --doctest-modules . || true 20 | - run: shopt -s globstar && pyupgrade --py36-plus **/*.py || true 21 | - run: safety check 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pprof 3 | /ckpt/ 4 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @misc{mesh-transformer-jax, 2 | author = {Wang, Ben}, 3 | title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}}, 4 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, 5 | year = 2021, 6 | month = May 7 | } 8 | 9 | @misc{gpt-j, 10 | author = {Wang, Ben and Komatsuzaki, Aran}, 11 | title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}}, 12 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, 13 | year = 2021, 14 | month = May 15 | } 16 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | https://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2021 Ben Wang 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | https://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | 1. [Mesh Transformer JAX](#mesh-transformer-jax) 3 | 1. [Updates](#updates) 4 | 2. [Pretrained Models](#pretrained-models) 5 | 1. [GPT-J-6B](#gpt-j-6b) 6 | 1. [Links](#links) 7 | 2. [Acknowledgments](#acknowledgments) 8 | 3. [License](#license) 9 | 4. [Model Details](#model-details) 10 | 5. [Zero-Shot Evaluations](#zero-shot-evaluations) 11 | 3. [Architecture and Usage](#architecture-and-usage) 12 | 1. [Fine-tuning](#fine-tuning) 13 | 2. [JAX Dependency](#jax-dependency) 14 | 4. [TODO](#todo) 15 | 16 | # Mesh Transformer JAX 17 | 18 | A haiku library using the `xmap`/`pjit` operators in JAX for model parallelism of transformers. 19 | 20 | The parallelism scheme is similar to the [original Megatron-LM](https://arxiv.org/abs/1909.08053), which is efficient 21 | on TPUs due to the high speed 2d mesh network. There is also an experimental model version which implements [ZeRo style 22 | sharding](https://arxiv.org/abs/1910.02054). 23 | 24 | This library is designed for scalability up to approximately 40B parameters on TPUv3s, beyond which different 25 | parallelism strategies should be used. See other implementations such as 26 | [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) or [DeepSpeed](https://github.com/microsoft/DeepSpeed) for that. 27 | 28 | One future direction for research is integrating this codebase with 29 | [swarm-jax](https://github.com/kingoflolz/swarm-jax), to achieve further scalability with pipeline parallelism. 30 | 31 | ## Updates 32 | 33 | **12-07-21**: Added [guide to fine tuning](howto_finetune.md) 34 | 35 | # Pretrained Models 36 | 37 | ## GPT-J-6B 38 | 39 | A 6 billion parameter, autoregressive text generation model trained on [The Pile](https://pile.eleuther.ai/). 40 | 41 | ### Links 42 | 43 | [Download slim weights (bf16 weights only, for inference, 9GB)](https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd) 44 | 45 | [Download full weights (including optimizer params, 61GB)](https://the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd) 46 | 47 | [Partially trained checkpoints](https://the-eye.eu/public/AI/GPT-J-6B/) 48 | 49 | [Colab demo](http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb) 50 | 51 | [Web demo](https://6b.eleuther.ai/) 52 | 53 | [Aran's blog post](https://arankomatsuzaki.wordpress.com/2021/06/04/gpt-j/) 54 | 55 | ### Acknowledgments 56 | 57 | This project would not have been possible without compute generously provided by the 58 | [TPU Research Cloud](https://sites.research.google/trc/) with assistance from [EleutherAI](https://eleuther.ai/). 59 | 60 | Thanks to the Cloud TPU team at Google for providing early access to the Cloud TPU VM alpha 61 | ([now publicly available!](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms)) 62 | 63 | Thanks to everyone who have helped out one way or another (listed alphabetically): 64 | - [Aran Komatsuzaki](https://twitter.com/arankomatsuzaki) for advice with experiment design and writing the blog posts. 65 | - [James Bradbury](https://twitter.com/jekbradbury) for valuable assistance with debugging JAX issues. 66 | - [Janko Prester](https://github.com/jprester) for creating the web demo frontend. 67 | - [Laurence Golding](https://github.com/researcher2) for adding some features to the web demo. 68 | - [Leo Gao](https://twitter.com/nabla_theta) for running zero shot evaluations for the baseline models for the table. 69 | 70 | ### License 71 | The weights of GPT-J-6B are licensed under version 2.0 of the Apache License. 72 | 73 | ### Model Details 74 | 75 | | Hyperparameter | Value | 76 | |-------------------|--------| 77 | | n_parameters | 6,053,381,344 | 78 | | n_layers | 28* | 79 | | d_model | 4,096 | 80 | | d_ff | 16,384 | 81 | | n_heads | 16 | 82 | | d_head | 256 | 83 | | n_ctx | 2,048 | 84 | | n_vocab | 50,257 (same tokenizer as GPT-2/3) | 85 | | position encoding | [Rotary position encodings (RoPE)](https://arxiv.org/abs/2104.09864) | 86 | | RoPE dimensions | [64](https://github.com/kingoflolz/mesh-transformer-jax/blob/f2aa66e0925de6593dcbb70e72399b97b4130482/mesh_transformer/layers.py#L223) | 87 | 88 | `*` each layer consists of one feedforward block and one self attention block 89 | 90 | The model consists of 28 layers with a model dimension of 4096, and a feedforward dimension of 16384. The model 91 | dimension is split into 16 heads, each with a dimension of 256. Rotary position encodings (RoPE) was applied to 64 92 | dimensions of each head. The model is trained with a tokenization vocabulary of 50257, using the same set of BPEs as 93 | GPT-2/GPT-3. 94 | 95 | ### Zero-Shot Evaluations 96 | 97 | Models roughly sorted by performance, or by FLOPs if not available. 98 | 99 | | Model | Weights | Training FLOPs | LAMBADA PPL ↓ | LAMBADA Acc ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Dataset Size (GB) | 100 | |-----------------|---------|----------------|--- |--- |--- |--- |--- |-------------------| 101 | | Chance | ✔ | 0 | ~a lot | ~0% | 50% | 25% | 25% | 0 | 102 | | GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- | 103 | | GPT-2-1.5B | ✔ | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 | 104 | | GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 825 | 105 | | Megatron-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 | 106 | | GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 825 | 107 | | GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 | 108 | | GPT-3-Babbage‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- | 109 | | Megatron-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 | 110 | | GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 | 111 | | Megatron-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 | 112 | | **GPT-J-6B**‡ | ✔ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 825 | 113 | | GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 | 114 | | GPT-3-Curie‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- | 115 | | GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 | 116 | | GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 | 117 | | GPT-3-Davinci‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- | 118 | | Gopher 230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 1344 | 119 | | MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- | 120 | 121 | `*` represents evaluation numbers reported by their respective authors, all other numbers are provided by 122 | running the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/) either with the released 123 | weights or with API access. Due to subtle implementation differences as well as different zero shot task framing, these 124 | might not be directly comparable. See [this blog post](https://www.eleuther.ai/research-log/gpt3-model-sizes/) for more 125 | details. 126 | 127 | `†` The Megatron-11B model provides no comparable metrics, and several implementations using the released weights do not 128 | reproduce the generation quality and evaluations. (see [1](https://github.com/huggingface/transformers/pull/10301) 129 | [2](https://github.com/pytorch/fairseq/issues/2358) [3](https://github.com/pytorch/fairseq/issues/2719)) 130 | Thus, evaluation was not attempted. 131 | 132 | `‡` These models have been trained with data which contains possible test set contamination. The OpenAI GPT-3 models 133 | failed to deduplicate training data for certain test sets, while the GPT-Neo models as well as this one is 134 | trained on The Pile, which has not been deduplicated against any test sets. 135 | 136 | # Architecture and Usage 137 | 138 | Most scripts in this repository are designed to be run on TPUs, which under the 139 | [TPU-VM architecture](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) are virtual machines 140 | which can run arbitrary code. Most scripts are designed to spin up a TPU, SSH into it to set up the dependencies 141 | and copy code over from the local directory, and then start a [Ray](https://github.com/ray-project/ray.git) worker 142 | which can accept RPC calls. 143 | 144 | The TPUVMs handles running model training steps and evaluation, checkpoint save and loading, while the driver python 145 | program handles data loading and general orchestration (such as when to save checkpoints etc). 146 | 147 | This means that most scripts (`train.py`, `eval_harness.py` etc) expect to be running on a GCE virtual machine in the 148 | same region as the TPUs, to minimize RPC latency and data transfer cost. Other scripts 149 | (usually ones which don't take a `--tpu` argument, such as `device_sample.py`, `device_serve.py` or `device_train.py`) 150 | expect to be run directly on a TPUVM. The device_* scripts **only work on a v3-8** and not on larger pods. 151 | 152 | Furthermore, there is an example (`resharding_example.py`) of how to convert the provided checkpoints (which have 8 153 | shards in the case of GPT-J-6B) down to a smaller number, such as for when running on GPU(s). 154 | 155 | ### Fine-tuning 156 | 157 | To fine-tune the model, run `device_train.py` on a TPU VM. Using a TPU v3-8, you can fine-tune at a rate of ~5000 158 | tokens/second, which should be sufficient for small-to-medium-size datasets. 159 | 160 | Please read the [step by step guide](howto_finetune.md) for thorough fine-tuning instructions. 161 | 162 | ### JAX Dependency 163 | 164 | Note this library has some specific requirements for JAX version. Specifically, to use the v1 models (including 165 | GPT-J 6B), `jax==0.2.12` is required. This in turn depends on `jaxlib==0.1.68`. **If this is not done, you will get 166 | cryptic xmap errors** 167 | 168 | However, to use the v2 model code (no publicly released weights), the newest JAX version can be used. 169 | # Citation 170 | 171 | To cite this repository: 172 | ``` 173 | @misc{mesh-transformer-jax, 174 | author = {Wang, Ben}, 175 | title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}}, 176 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, 177 | year = 2021, 178 | month = May 179 | } 180 | ``` 181 | 182 | To cite the weights of GPT-J-6B: 183 | ``` 184 | @misc{gpt-j, 185 | author = {Wang, Ben and Komatsuzaki, Aran}, 186 | title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}}, 187 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, 188 | year = 2021, 189 | month = May 190 | } 191 | ``` 192 | 193 | If you use this repository or any of the pretrained weights to do something cool, we would love to hear about it. 194 | Feel free to open a github issue or reach out over email (in profile). 195 | -------------------------------------------------------------------------------- /benchmarks.md: -------------------------------------------------------------------------------- 1 | Note that everything on this page is quite outdated, and are only roughly accurate when considering new features 2 | such as RoPE 3 | 4 | # Benchmarks (v3-8) 5 | 6 | (see `tpuv38_example.py`): 7 | 8 | ## ~2.7B model 9 | ``` 10 | Initialized in 121.842s 11 | Total parameters: 2722382080 12 | Compiled in 49.0534s 13 | it: 0, loss: 20.311113357543945 14 | 15 | it: 90, loss: 3.987450361251831 16 | 100 steps in 109.385s 17 | effective flops (not including attn): 2.4466e+14 18 | ``` 19 | 20 | ## ~4.8B model 21 | ``` 22 | Initialized in 101.016s 23 | Total parameters: 4836720896 24 | Compiled in 52.7404s 25 | it: 0, loss: 4.632925987243652 26 | 27 | it: 40, loss: 3.2406811714172363 28 | 50 steps in 102.559s 29 | effective flops (not including attn): 2.31803e+14 30 | ``` 31 | 32 | ## 10B model 33 | ``` 34 | Initialized in 152.762s 35 | Total parameters: 10073579776 36 | Compiled in 92.6539s 37 | it: 0, loss: 5.3125 38 | 39 | it: 40, loss: 3.65625 40 | 50 steps in 100.235s 41 | effective flops (not including attn): 2.46988e+14 42 | ``` 43 | 44 | # Benchmarks (v3-32) 45 | 46 | (see `eval.py`): 47 | 48 | ## 6B model 49 | ``` 50 | "layers": 28, 51 | "d_model": 4096, 52 | "n_heads": 16, 53 | "n_vocab": 50400, 54 | 55 | "seq": 2048, 56 | "cores_per_replica": 8, 57 | "per_replica_batch": 1, 58 | "gradient_accumulation_steps": 8, 59 | 60 | params: 6053381856 61 | 32 iters done in 107.935s, avg 3.37298s, 0.296473/s 62 | effective flops (not including attn): 7.05692e+14 63 | MXU flops: 1.04523e+15 64 | ``` 65 | 66 | ## Note that the below models do not currently work 67 | They require a larger degree of model parallelism than is currently implemented, but benchmark numbers should be 68 | reasonably representative. 69 | 70 | ## 13B model 71 | ``` 72 | "layers": 28, 73 | "d_model": 6144, 74 | "n_heads": 32, 75 | "n_vocab": 50400, 76 | 77 | "seq": 2048, 78 | "cores_per_replica": 16, 79 | "per_replica_batch": 1, 80 | "gradient_accumulation_steps": 16, 81 | 82 | params: 13312183008 83 | 32 iters done in 250.86s, avg 7.83937s, 0.127561/s 84 | effective flops (not including attn): 6.67727e+14 85 | MXU flops: 9.80066e+14 86 | ``` 87 | 88 | ## 23B model 89 | ``` 90 | "layers": 28, 91 | "d_model": 8192, 92 | "n_heads": 32, 93 | "n_vocab": 50400, 94 | 95 | "seq": 2048, 96 | "cores_per_replica": 32, 97 | "per_replica_batch": 1, 98 | "gradient_accumulation_steps": 32, 99 | 100 | params: 23398107360 101 | 16 iters done in 221.33s, avg 13.8331s, 0.0722902/s 102 | effective flops (not including attn): 6.65107e+14 103 | MXU flops: 9.88548e+14 104 | ``` -------------------------------------------------------------------------------- /configs/6B_roto_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": 28, 3 | "d_model": 4096, 4 | "n_heads": 16, 5 | "n_vocab": 50400, 6 | "norm": "layernorm", 7 | "pe": "rotary", 8 | "pe_rotary_dims": 64, 9 | 10 | "seq": 2048, 11 | "cores_per_replica": 8, 12 | "per_replica_batch": 1, 13 | "gradient_accumulation_steps": 16, 14 | 15 | "warmup_steps": 3000, 16 | "anneal_steps": 300000, 17 | "lr": 1.2e-4, 18 | "end_lr": 1.2e-5, 19 | "weight_decay": 0.1, 20 | "total_steps": 350000, 21 | 22 | "tpu_size": 256, 23 | 24 | "bucket": "neo-models", 25 | "model_dir": "mesh_jax_pile_6B_rotary", 26 | 27 | "train_set": "pile.train.index", 28 | "val_set": { 29 | "pile": "pile.val.index", 30 | "owt": "openwebtext2_new_inputs.val.index" 31 | }, 32 | 33 | "eval_harness_tasks": [ 34 | "lambada", 35 | "piqa", 36 | "hellaswag", 37 | "winogrande", 38 | "mathqa", 39 | "pubmedqa" 40 | ], 41 | 42 | "val_batches": 100, 43 | "val_every": 500, 44 | "ckpt_every": 500, 45 | "keep_every": 10000, 46 | 47 | "name": "GPT3_6B_pile_rotary", 48 | "wandb_project": "mesh-transformer-jax", 49 | "comment": "" 50 | } -------------------------------------------------------------------------------- /configs/example_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": 28, 3 | "d_model": 4096, 4 | "n_heads": 16, 5 | "n_vocab": 50400, 6 | "norm": "layernorm", 7 | "pe": "rotary", 8 | "pe_rotary_dims": 64, 9 | 10 | "seq": 2048, 11 | "cores_per_replica": 8, 12 | "per_replica_batch": 1, 13 | "gradient_accumulation_steps": 16, 14 | 15 | "warmup_steps": 7, 16 | "anneal_steps": 65, 17 | "lr": 5e-5, 18 | "end_lr": 1e-5, 19 | "weight_decay": 0.1, 20 | "total_steps": 72, 21 | 22 | "tpu_size": 8, 23 | 24 | "bucket": "your-bucket", 25 | "model_dir": "finetune_dir", 26 | 27 | "train_set": "example.train.index", 28 | "val_set": {}, 29 | 30 | "eval_harness_tasks": [ 31 | ], 32 | 33 | "val_batches": 0, 34 | "val_every": 80, 35 | "ckpt_every": 72, 36 | "keep_every": 72, 37 | 38 | "name": "example_model", 39 | "wandb_project": "mesh-transformer-jax", 40 | "comment": "" 41 | } -------------------------------------------------------------------------------- /create_finetune_tfrecords.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import random 5 | 6 | from pathlib import Path 7 | from typing import List 8 | 9 | import ftfy 10 | import tensorflow as tf 11 | from lm_dataformat import Reader 12 | from transformers import GPT2TokenizerFast 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description=""" 18 | Converts a text dataset into the training data format expected by the model. 19 | 20 | Adapted from the script create_tfrecords.py in the gpt-neo repo. 21 | 22 | - Your text dataset: 23 | - can be provided as .txt files, or as an archive (.tar.gz, .xz, jsonl.zst). 24 | - can be one file or multiple 25 | - using a single large file may use too much memory and crash - if this occurs, split the file up into a few files 26 | - the model's end-of-text separator is added between the contents of each file 27 | - if the string '<|endoftext|>' appears inside a file, it is treated as the model's end-of-text separator (not the actual string '<|endoftext|>') 28 | - this behavior can be disabled with --treat-eot-as-text 29 | 30 | This script creates a single .tfrecords file as output 31 | - Why: the model's data loader ignores "trailing" data (< 1 batch) at the end of a .tfrecords file 32 | - this causes data loss if you have many .tfrecords files 33 | - This is probably not appropriate for very large datasets 34 | """, formatter_class=argparse.RawTextHelpFormatter) 35 | parser.add_argument( 36 | "input_path", 37 | type=str, 38 | help="Path to an input file, or a directory that contains the input files.", 39 | ) 40 | parser.add_argument("name", type=str, 41 | help="Name of output file will be {name}_{seqnum}.tfrecords, where seqnum is total sequence count") 42 | parser.add_argument("--output-dir", type=str, default="", help="Output directory (default: current directory)") 43 | 44 | cleaning_args = parser.add_argument_group('data cleaning arguments') 45 | 46 | cleaning_args.add_argument("--normalize-with-ftfy", action="store_true", help="Normalize text with ftfy") 47 | cleaning_args.add_argument("--normalize-with-wikitext-detokenize", 48 | action="store_true", help="Use wikitext detokenizer") 49 | minu_help = "Exclude repetitive documents made up of < MIN_UNIQUE_TOKENS unique tokens. These can produce large gradients." 50 | minu_help += " Set <= 0 to disable. If enabled, 200 is a good default value. (Default: 0)" 51 | cleaning_args.add_argument("--min-unique-tokens", type=int, default=0, 52 | help=minu_help) 53 | 54 | shuffle_pack_args = parser.add_argument_group('data shuffling/packing arguments') 55 | repack_ep_help = "Repeat the data N_REPACK_EPOCHS times, shuffled differently in each repetition. Recommended for multi-epoch training (set this to your intended number of epochs)." 56 | shuffle_pack_args.add_argument("--n-repack-epochs", 57 | type=int, default=1, 58 | help=repack_ep_help 59 | ) 60 | shuffle_pack_args.add_argument("--seed", type=int, default=10, 61 | help="random seed for shuffling data (default: 10)") 62 | shuffle_pack_args.add_argument("--preserve-data-order", 63 | default=False, action="store_true", 64 | help="Disables shuffling, so the input and output data have the same order.") 65 | 66 | misc_args = parser.add_argument_group('miscellaneous arguments') 67 | misc_args.add_argument("--verbose", 68 | default=False, action="store_true", 69 | help="Prints extra information, such as the text removed by --min-unique-tokens") 70 | 71 | args = parser.parse_args() 72 | 73 | # convert input_path to pathy 74 | args.input_path = Path(args.input_path) 75 | 76 | return args 77 | 78 | 79 | def get_files(input_path: Path) -> List[str]: 80 | supported_file_types = ["jsonl.zst", ".txt", ".xz", ".tar.gz"] 81 | if input_path.is_dir(): 82 | # get all files with supported file types 83 | files = [list(Path(input_path).glob(f"*{ft}")) for ft in supported_file_types] 84 | # flatten list 85 | files = [f for sublist in files for f in sublist] 86 | assert files, f"No files with supported types found in directory: {input_path}" 87 | elif input_path.is_file(): 88 | assert any( 89 | str(input_path).endswith(f_type) for f_type in supported_file_types 90 | ), f"Input file type must be one of: {supported_file_types}" 91 | files = [input_path] 92 | else: 93 | raise FileNotFoundError(f"No such file or directory: {input_path=}") 94 | 95 | return [str(f) for f in files] 96 | 97 | 98 | def wikitext_detokenizer(string): 99 | # contractions 100 | string = string.replace("s '", "s'") 101 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) 102 | # number separators 103 | string = string.replace(" @-@ ", "-") 104 | string = string.replace(" @,@ ", ",") 105 | string = string.replace(" @.@ ", ".") 106 | # punctuation 107 | string = string.replace(" : ", ": ") 108 | string = string.replace(" ; ", "; ") 109 | string = string.replace(" . ", ". ") 110 | string = string.replace(" ! ", "! ") 111 | string = string.replace(" ? ", "? ") 112 | string = string.replace(" , ", ", ") 113 | # double brackets 114 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) 115 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) 116 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) 117 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) 118 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) 119 | # miscellaneous 120 | string = string.replace("= = = =", "====") 121 | string = string.replace("= = =", "===") 122 | string = string.replace("= =", "==") 123 | string = string.replace(" " + chr(176) + " ", chr(176)) 124 | string = string.replace(" \n", "\n") 125 | string = string.replace("\n ", "\n") 126 | string = string.replace(" N ", " 1 ") 127 | string = string.replace(" 's", "'s") 128 | 129 | return string 130 | 131 | 132 | def _int64_feature(value): 133 | """ 134 | Returns an int64_list from a bool / enum / int / uint. 135 | """ 136 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 137 | 138 | 139 | def write_to_file(writer, data): 140 | """ 141 | writes data to tfrecord file 142 | """ 143 | feature = { 144 | "text": _int64_feature(data) 145 | } 146 | tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) 147 | writer.write(tf_example.SerializeToString()) 148 | 149 | 150 | def write_tfrecord(sequences, fp): 151 | with tf.io.TFRecordWriter(fp) as writer: 152 | for seq in sequences: 153 | write_to_file(writer, seq) 154 | 155 | 156 | def split_list(l, n): 157 | # splits list/string into n size chunks 158 | return [l[i:i + n] for i in range(0, len(l), n)] 159 | 160 | 161 | def enforce_min_unique(seqs, min_unique_tokens, enc, verbose=False): 162 | for seq in tqdm(seqs, mininterval=1, smoothing=0, desc="enforce_min_unique_tokens"): 163 | if len(set(seq)) >= min_unique_tokens: 164 | yield seq 165 | elif verbose: 166 | text = enc.decode(seq) 167 | print(f"excluding with {len(set(seq))} unique tokens:\n\n{repr(text)}\n\n") 168 | 169 | 170 | def eot_splitting_generator(string_iterable, encoder): 171 | """ 172 | Given strings, splits them internally on <|endoftext|> and yields (generally more) strings 173 | """ 174 | for doc in string_iterable: 175 | for d in doc.split(encoder.eos_token): 176 | if len(d) > 0: 177 | yield d 178 | 179 | 180 | def prep_and_tokenize_generator(string_iterable, encoder, normalize_with_ftfy, normalize_with_wikitext_detokenize): 181 | """ 182 | Given strings, does data cleaning / tokenization and yields arrays of tokens 183 | """ 184 | for doc in string_iterable: 185 | if normalize_with_ftfy: # fix text with ftfy if specified 186 | doc = ftfy.fix_text(doc, normalization='NFKC') 187 | if normalize_with_wikitext_detokenize: 188 | doc = wikitext_detokenizer(doc) 189 | tokens = encoder.encode(doc) + [encoder.eos_token_id] 190 | yield tokens 191 | 192 | 193 | def file_to_tokenized_docs_generator(file_path, encoder, args): 194 | """ 195 | Given a file path, reads the file and tokenizes the contents 196 | 197 | Yields token arrays of arbitrary, unequal length 198 | """ 199 | reader = Reader(file_path) 200 | string_iterable = reader.stream_data(threaded=False) 201 | string_iterable = eot_splitting_generator(string_iterable, encoder) 202 | 203 | token_list_gen = prep_and_tokenize_generator(string_iterable, 204 | encoder, 205 | normalize_with_ftfy=args.normalize_with_ftfy, 206 | normalize_with_wikitext_detokenize=args.normalize_with_wikitext_detokenize 207 | ) 208 | return token_list_gen 209 | 210 | 211 | def read_files_to_tokenized_docs(files, args, encoder): 212 | docs = [] 213 | 214 | if args.preserve_data_order: 215 | files = sorted(files) 216 | else: 217 | random.shuffle(files) 218 | 219 | for f in tqdm(files, mininterval=10, smoothing=0, desc="reading/tokenizing files"): 220 | docs.extend(file_to_tokenized_docs_generator(f, encoder, args)) 221 | 222 | if not args.preserve_data_order: 223 | # shuffle at individual document level 224 | random.shuffle(docs) 225 | 226 | return docs 227 | 228 | 229 | def arrays_to_sequences(token_list_iterable, sequence_length=2049): 230 | """ 231 | Given token arrays of arbitrary lengths, concats/splits them into arrays of equal length 232 | 233 | Returns equal-length token arrays, followed by a a final array of trailing tokens (which may be shorter) 234 | """ 235 | accum = [] 236 | for l in token_list_iterable: 237 | accum.extend(l) 238 | 239 | if len(accum) > sequence_length: 240 | chunks = split_list(accum, sequence_length) 241 | yield from chunks[:-1] 242 | accum = chunks[-1] 243 | 244 | if len(accum) > 0: 245 | yield accum 246 | 247 | 248 | def chunk_and_finalize(arrays, args, encoder): 249 | sequences = list(arrays_to_sequences(arrays)) 250 | 251 | full_seqs, trailing_data = sequences[:-1], sequences[-1] 252 | 253 | if args.min_unique_tokens > 0: 254 | full_seqs = list(enforce_min_unique(full_seqs, args.min_unique_tokens, encoder, args.verbose)) 255 | 256 | if not args.preserve_data_order: 257 | random.shuffle(full_seqs) 258 | 259 | return full_seqs, trailing_data 260 | 261 | 262 | def create_tfrecords(files, args): 263 | GPT2TokenizerFast.max_model_input_sizes['gpt2'] = 1e20 # disables a misleading warning 264 | encoder = GPT2TokenizerFast.from_pretrained('gpt2') 265 | 266 | random.seed(args.seed) 267 | 268 | all_sequences_across_epochs = [] 269 | 270 | docs = read_files_to_tokenized_docs(files, args, encoder) 271 | 272 | full_seqs, trailing_data = chunk_and_finalize(docs, args, encoder) 273 | 274 | all_sequences_across_epochs.extend(full_seqs) 275 | 276 | # ep 2+ 277 | for ep_ix in range(1, args.n_repack_epochs): 278 | # re-shuffle 279 | if not args.preserve_data_order: 280 | random.shuffle(docs) 281 | full_seqs, trailing_data = chunk_and_finalize(docs, args, encoder) 282 | else: 283 | # if we're preserving data order, we can still "repack" by shifting everything 284 | # with the trailing data of the last epoch at the beginning 285 | seqs_with_prefix = [trailing_data] + full_seqs 286 | full_seqs, trailing_data = chunk_and_finalize(seqs_with_prefix, args, encoder) 287 | 288 | all_sequences_across_epochs.extend(full_seqs) 289 | 290 | # final 291 | print(f"dropped {len(trailing_data)} tokens of trailing data") 292 | 293 | total_sequence_len = len(all_sequences_across_epochs) 294 | 295 | fp = os.path.join(args.output_dir, f"{args.name}_{total_sequence_len}.tfrecords") 296 | write_tfrecord(all_sequences_across_epochs, fp) 297 | 298 | 299 | if __name__ == "__main__": 300 | args = parse_args() 301 | 302 | if args.output_dir: 303 | os.makedirs(args.output_dir, exist_ok=True) 304 | files = get_files(args.input_path) 305 | print(f"Creating TFRecords from files: {files}") 306 | 307 | results = create_tfrecords(files, args) 308 | -------------------------------------------------------------------------------- /data/example.train.index: -------------------------------------------------------------------------------- 1 | gs://your-bucket/datasets/your.tfrecords -------------------------------------------------------------------------------- /data/openwebtext2_new_inputs.train.index: -------------------------------------------------------------------------------- 1 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_0_100000.tfrecords 2 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_1_100000.tfrecords 3 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_2_100000.tfrecords 4 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_3_100000.tfrecords 5 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_4_100000.tfrecords 6 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_5_100000.tfrecords 7 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_6_100000.tfrecords 8 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_72398.tfrecords 9 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_0_7_100000.tfrecords 10 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_10_0_100000.tfrecords 11 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_10_1_100000.tfrecords 12 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_10_2_100000.tfrecords 13 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_10_3_100000.tfrecords 14 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_10_4_100000.tfrecords 15 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_10_5_100000.tfrecords 16 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_10_6_100000.tfrecords 17 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_10_7_100000.tfrecords 18 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_11_0_100000.tfrecords 19 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_11_1_100000.tfrecords 20 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_11_2_100000.tfrecords 21 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_11_3_100000.tfrecords 22 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_11_4_100000.tfrecords 23 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_11_5_100000.tfrecords 24 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_11_6_100000.tfrecords 25 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_11_7_100000.tfrecords 26 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_12_0_100000.tfrecords 27 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_12_1_100000.tfrecords 28 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_12_2_100000.tfrecords 29 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_12_3_100000.tfrecords 30 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_12_4_100000.tfrecords 31 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_12_5_100000.tfrecords 32 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_12_6_100000.tfrecords 33 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_12_7_100000.tfrecords 34 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_13_0_100000.tfrecords 35 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_13_1_100000.tfrecords 36 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_13_2_100000.tfrecords 37 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_13_3_100000.tfrecords 38 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_13_4_100000.tfrecords 39 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_13_5_100000.tfrecords 40 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_13_6_100000.tfrecords 41 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_13_7_100000.tfrecords 42 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_14_0_100000.tfrecords 43 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_14_1_100000.tfrecords 44 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_14_2_100000.tfrecords 45 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_14_3_100000.tfrecords 46 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_14_4_100000.tfrecords 47 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_14_5_100000.tfrecords 48 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_14_6_100000.tfrecords 49 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_14_7_100000.tfrecords 50 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_15_0_100000.tfrecords 51 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_15_1_100000.tfrecords 52 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_15_2_100000.tfrecords 53 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_15_3_100000.tfrecords 54 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_15_4_100000.tfrecords 55 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_15_5_100000.tfrecords 56 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_15_6_100000.tfrecords 57 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_15_7_100000.tfrecords 58 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_16_0_100000.tfrecords 59 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_16_1_100000.tfrecords 60 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_16_2_100000.tfrecords 61 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_16_3_100000.tfrecords 62 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_16_4_100000.tfrecords 63 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_16_5_100000.tfrecords 64 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_16_6_100000.tfrecords 65 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_16_7_100000.tfrecords 66 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_17_0_100000.tfrecords 67 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_17_1_100000.tfrecords 68 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_17_2_100000.tfrecords 69 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_17_3_100000.tfrecords 70 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_17_4_100000.tfrecords 71 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_17_5_100000.tfrecords 72 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_17_6_100000.tfrecords 73 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_17_7_100000.tfrecords 74 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_18_0_100000.tfrecords 75 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_18_1_100000.tfrecords 76 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_18_2_100000.tfrecords 77 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_18_3_100000.tfrecords 78 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_18_4_100000.tfrecords 79 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_18_5_100000.tfrecords 80 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_18_6_100000.tfrecords 81 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_18_7_100000.tfrecords 82 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_19_57675.tfrecords 83 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_19_65092.tfrecords 84 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_19_65145.tfrecords 85 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_19_65193.tfrecords 86 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_19_69021.tfrecords 87 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_19_69832.tfrecords 88 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_19_85363.tfrecords 89 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_19_85399.tfrecords 90 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_1_0_100000.tfrecords 91 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_1_1_100000.tfrecords 92 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_1_2_100000.tfrecords 93 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_1_3_100000.tfrecords 94 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_1_4_100000.tfrecords 95 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_1_5_100000.tfrecords 96 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_1_6_100000.tfrecords 97 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_1_7_100000.tfrecords 98 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_2_0_100000.tfrecords 99 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_2_1_100000.tfrecords 100 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_2_2_100000.tfrecords 101 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_2_3_100000.tfrecords 102 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_2_4_100000.tfrecords 103 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_2_5_100000.tfrecords 104 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_2_6_100000.tfrecords 105 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_2_7_100000.tfrecords 106 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_3_0_100000.tfrecords 107 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_3_1_100000.tfrecords 108 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_3_2_100000.tfrecords 109 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_3_3_100000.tfrecords 110 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_3_4_100000.tfrecords 111 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_3_5_100000.tfrecords 112 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_3_6_100000.tfrecords 113 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_3_7_100000.tfrecords 114 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_4_0_100000.tfrecords 115 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_4_1_100000.tfrecords 116 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_4_2_100000.tfrecords 117 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_4_3_100000.tfrecords 118 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_4_4_100000.tfrecords 119 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_4_5_100000.tfrecords 120 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_4_6_100000.tfrecords 121 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_4_7_100000.tfrecords 122 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_5_0_100000.tfrecords 123 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_5_1_100000.tfrecords 124 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_5_2_100000.tfrecords 125 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_5_3_100000.tfrecords 126 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_5_4_100000.tfrecords 127 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_5_5_100000.tfrecords 128 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_5_6_100000.tfrecords 129 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_5_7_100000.tfrecords 130 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_6_0_100000.tfrecords 131 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_6_1_100000.tfrecords 132 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_6_2_100000.tfrecords 133 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_6_3_100000.tfrecords 134 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_6_4_100000.tfrecords 135 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_6_5_100000.tfrecords 136 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_6_6_100000.tfrecords 137 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_6_7_100000.tfrecords 138 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_7_0_100000.tfrecords 139 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_7_1_100000.tfrecords 140 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_7_2_100000.tfrecords 141 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_7_3_100000.tfrecords 142 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_7_4_100000.tfrecords 143 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_7_5_100000.tfrecords 144 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_7_6_100000.tfrecords 145 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_7_7_100000.tfrecords 146 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_8_0_100000.tfrecords 147 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_8_1_100000.tfrecords 148 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_8_2_100000.tfrecords 149 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_8_3_100000.tfrecords 150 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_8_4_100000.tfrecords 151 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_8_5_100000.tfrecords 152 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_8_6_100000.tfrecords 153 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_8_7_100000.tfrecords 154 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_9_0_100000.tfrecords 155 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_9_1_100000.tfrecords 156 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_9_2_100000.tfrecords 157 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_9_3_100000.tfrecords 158 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_9_4_100000.tfrecords 159 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_9_5_100000.tfrecords 160 | gs://neo-datasets/openwebtext2_new_inputs/train/openwebtext_9_6_100000.tfrecords -------------------------------------------------------------------------------- /data/openwebtext2_new_inputs.val.index: -------------------------------------------------------------------------------- 1 | gs://neo-datasets/openwebtext2_new_inputs/eval/openwebtext_9_7_100000.tfrecords -------------------------------------------------------------------------------- /data/openwebtext2_new_inputs_shuffled.train.index: -------------------------------------------------------------------------------- 1 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_0.tfrecords 2 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_1.tfrecords 3 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_10.tfrecords 4 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_100.tfrecords 5 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_101.tfrecords 6 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_102.tfrecords 7 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_103.tfrecords 8 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_104.tfrecords 9 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_105.tfrecords 10 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_106.tfrecords 11 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_107.tfrecords 12 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_108.tfrecords 13 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_109.tfrecords 14 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_11.tfrecords 15 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_110.tfrecords 16 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_111.tfrecords 17 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_112.tfrecords 18 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_113.tfrecords 19 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_114.tfrecords 20 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_115.tfrecords 21 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_116.tfrecords 22 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_117.tfrecords 23 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_118.tfrecords 24 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_119.tfrecords 25 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_12.tfrecords 26 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_120.tfrecords 27 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_121.tfrecords 28 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_122.tfrecords 29 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_123.tfrecords 30 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_124.tfrecords 31 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_125.tfrecords 32 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_126.tfrecords 33 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_127.tfrecords 34 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_128.tfrecords 35 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_129.tfrecords 36 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_13.tfrecords 37 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_130.tfrecords 38 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_131.tfrecords 39 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_132.tfrecords 40 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_133.tfrecords 41 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_134.tfrecords 42 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_135.tfrecords 43 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_136.tfrecords 44 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_137.tfrecords 45 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_138.tfrecords 46 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_139.tfrecords 47 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_14.tfrecords 48 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_140.tfrecords 49 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_141.tfrecords 50 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_142.tfrecords 51 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_143.tfrecords 52 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_144.tfrecords 53 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_145.tfrecords 54 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_146.tfrecords 55 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_147.tfrecords 56 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_148.tfrecords 57 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_149.tfrecords 58 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_15.tfrecords 59 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_150.tfrecords 60 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_151.tfrecords 61 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_152.tfrecords 62 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_153.tfrecords 63 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_154.tfrecords 64 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_155.tfrecords 65 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_156.tfrecords 66 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_157.tfrecords 67 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_16.tfrecords 68 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_17.tfrecords 69 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_18.tfrecords 70 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_19.tfrecords 71 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_2.tfrecords 72 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_20.tfrecords 73 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_21.tfrecords 74 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_22.tfrecords 75 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_23.tfrecords 76 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_24.tfrecords 77 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_25.tfrecords 78 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_26.tfrecords 79 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_27.tfrecords 80 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_28.tfrecords 81 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_29.tfrecords 82 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_3.tfrecords 83 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_30.tfrecords 84 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_31.tfrecords 85 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_32.tfrecords 86 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_33.tfrecords 87 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_34.tfrecords 88 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_35.tfrecords 89 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_36.tfrecords 90 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_37.tfrecords 91 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_38.tfrecords 92 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_39.tfrecords 93 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_4.tfrecords 94 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_40.tfrecords 95 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_41.tfrecords 96 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_42.tfrecords 97 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_43.tfrecords 98 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_44.tfrecords 99 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_45.tfrecords 100 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_46.tfrecords 101 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_47.tfrecords 102 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_48.tfrecords 103 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_49.tfrecords 104 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_5.tfrecords 105 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_50.tfrecords 106 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_51.tfrecords 107 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_52.tfrecords 108 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_53.tfrecords 109 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_54.tfrecords 110 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_55.tfrecords 111 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_56.tfrecords 112 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_57.tfrecords 113 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_58.tfrecords 114 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_59.tfrecords 115 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_6.tfrecords 116 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_60.tfrecords 117 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_61.tfrecords 118 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_62.tfrecords 119 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_63.tfrecords 120 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_64.tfrecords 121 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_65.tfrecords 122 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_66.tfrecords 123 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_67.tfrecords 124 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_68.tfrecords 125 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_69.tfrecords 126 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_7.tfrecords 127 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_70.tfrecords 128 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_71.tfrecords 129 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_72.tfrecords 130 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_73.tfrecords 131 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_74.tfrecords 132 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_75.tfrecords 133 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_76.tfrecords 134 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_77.tfrecords 135 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_78.tfrecords 136 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_79.tfrecords 137 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_8.tfrecords 138 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_80.tfrecords 139 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_81.tfrecords 140 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_82.tfrecords 141 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_83.tfrecords 142 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_84.tfrecords 143 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_85.tfrecords 144 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_86.tfrecords 145 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_87.tfrecords 146 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_88.tfrecords 147 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_89.tfrecords 148 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_9.tfrecords 149 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_90.tfrecords 150 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_91.tfrecords 151 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_92.tfrecords 152 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_93.tfrecords 153 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_94.tfrecords 154 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_95.tfrecords 155 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_96.tfrecords 156 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_97.tfrecords 157 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_98.tfrecords 158 | gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_99.tfrecords -------------------------------------------------------------------------------- /data/pile.val.index: -------------------------------------------------------------------------------- 1 | gs://neo-datasets/pile_new_inputs_shuffled/val/pile_0.tfrecords -------------------------------------------------------------------------------- /data/wit_dalle.val.index: -------------------------------------------------------------------------------- 1 | gs://neo-datasets/wikipedia_image_text/embedding/wit_0.tfrecords.zstd -------------------------------------------------------------------------------- /device_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | import jax 6 | import numpy as np 7 | import optax 8 | 9 | from mesh_transformer import util 10 | from mesh_transformer.checkpoint import read_ckpt 11 | from mesh_transformer.sampling import nucleaus_sample 12 | from mesh_transformer.transformer_shard import CausalTransformer 13 | import transformers 14 | from smart_open import open 15 | 16 | from mesh_transformer.util import clip_by_global_norm 17 | 18 | 19 | def parse_args(): 20 | # Parse command line arguments 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--config", type=str, default=None, help="Config file location") 23 | 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | if __name__ == "__main__": 29 | args = parse_args() 30 | params = json.load(open(args.config)) 31 | 32 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 33 | per_replica_batch = params["per_replica_batch"] 34 | cores_per_replica = params["cores_per_replica"] 35 | 36 | assert cores_per_replica <= 8 37 | 38 | bucket = params["bucket"] 39 | model_dir = params["model_dir"] 40 | layers = params["layers"] 41 | d_model = params["d_model"] 42 | n_heads = params["n_heads"] 43 | n_vocab = params["n_vocab"] 44 | seq = params["seq"] 45 | norm = params["norm"] 46 | 47 | params["sampler"] = nucleaus_sample 48 | opt = optax.chain( 49 | optax.scale(1 / gradient_accumulation_steps), 50 | clip_by_global_norm(1), 51 | optax.scale_by_adam(), 52 | optax.additive_weight_decay(0), 53 | optax.scale(-1), 54 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)) 55 | ) 56 | 57 | params["optimizer"] = opt 58 | 59 | start = time.time() 60 | print(f"jax devices: {jax.device_count()}") 61 | print(f"jax runtime initialized in {time.time() - start:.06}s") 62 | 63 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) 64 | devices = np.array(jax.devices()).reshape(mesh_shape) 65 | 66 | with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f: 67 | meta = json.load(f) 68 | 69 | ckpt_step = meta["checkpoints"][-1] 70 | print(f"using checkpoint {ckpt_step}") 71 | 72 | total_batch = per_replica_batch * jax.device_count() // cores_per_replica 73 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 74 | network = CausalTransformer(params) 75 | 76 | start = time.time() 77 | network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1]) 78 | print(f"network loaded in {time.time() - start:.06}s") 79 | 80 | local_shards = max(jax.local_device_count() // mesh_shape[1], 1) 81 | del network.state["opt_state"] 82 | network.state = network.move_xmap(network.state, np.zeros(local_shards)) 83 | 84 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 85 | 86 | while True: 87 | context = input("Type input:") 88 | tokens = tokenizer.encode(context) 89 | 90 | start = time.time() 91 | 92 | provided_ctx = len(tokens) 93 | pad_amount = seq - provided_ctx 94 | 95 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32) 96 | batched_tokens = np.array([padded_tokens] * total_batch) 97 | length = np.ones(total_batch, dtype=np.uint32) * len(tokens) 98 | 99 | output = network.generate(batched_tokens, length, 512, {"top_p": np.ones(total_batch) * 0.9, 100 | "temp": np.ones(total_batch) * 0.75}) 101 | 102 | for idx, o in enumerate(output[1][0][:, :, 0]): 103 | print(f"sample {idx}: {repr(tokenizer.decode(o))}") 104 | 105 | print(f"completion done in {time.time() - start:06}s") 106 | -------------------------------------------------------------------------------- /device_serve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import threading 4 | import time 5 | from queue import Queue, Empty 6 | 7 | import jax 8 | import numpy as np 9 | import optax 10 | 11 | from mesh_transformer import util 12 | from mesh_transformer.checkpoint import read_ckpt 13 | from mesh_transformer.sampling import nucleaus_sample 14 | from mesh_transformer.transformer_shard import CausalTransformer 15 | import transformers 16 | from smart_open import open 17 | 18 | from mesh_transformer.util import clip_by_global_norm 19 | 20 | from flask import Flask, request, make_response, jsonify 21 | app = Flask(__name__) 22 | 23 | requests_queue = Queue() 24 | 25 | """ 26 | curl --header "Content-Type: application/json" \ 27 | --request POST \ 28 | --data '{"context":"eleutherai", "top_p": 0.9, "temp": 0.75}' \ 29 | http://localhost:5000/complete 30 | """ 31 | 32 | 33 | def _build_cors_prelight_response(): 34 | response = make_response() 35 | response.headers.add("Access-Control-Allow-Origin", "*") 36 | response.headers.add('Access-Control-Allow-Headers', "*") 37 | response.headers.add('Access-Control-Allow-Methods', "*") 38 | return response 39 | 40 | 41 | def _corsify_actual_response(response): 42 | response.headers.add("Access-Control-Allow-Origin", "*") 43 | return response 44 | 45 | 46 | @app.route('/complete', methods=['POST', 'OPTIONS']) 47 | def complete(): 48 | if request.method == "OPTIONS": # CORS preflight 49 | return _build_cors_prelight_response() 50 | elif request.method == "POST": # The actual request following the preflight 51 | content = request.json 52 | 53 | if requests_queue.qsize() > 100: 54 | return {"error": "queue full, try again later"} 55 | 56 | response_queue = Queue() 57 | 58 | requests_queue.put(({ 59 | "context": content["context"], 60 | "top_p": float(content["top_p"]), 61 | "temp": float(content["temp"]) 62 | }, response_queue)) 63 | 64 | return _corsify_actual_response(jsonify({"completion": response_queue.get()})) 65 | else: 66 | raise RuntimeError("Weird - don't know how to handle method {}".format(request.method)) 67 | 68 | 69 | def parse_args(): 70 | # Parse command line arguments 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument("--config", type=str, default=None, help="Config file location") 73 | 74 | args = parser.parse_args() 75 | return args 76 | 77 | 78 | if __name__ == "__main__": 79 | threading.Thread(target=app.run, kwargs={"port": 5000, "host": "0.0.0.0"}).start() 80 | 81 | args = parse_args() 82 | params = json.load(open(args.config)) 83 | 84 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 85 | per_replica_batch = params["per_replica_batch"] 86 | cores_per_replica = params["cores_per_replica"] 87 | 88 | assert cores_per_replica <= 8 89 | 90 | bucket = params["bucket"] 91 | model_dir = params["model_dir"] 92 | layers = params["layers"] 93 | d_model = params["d_model"] 94 | n_heads = params["n_heads"] 95 | n_vocab = params["n_vocab"] 96 | seq = params["seq"] 97 | norm = params["norm"] 98 | 99 | params["sampler"] = nucleaus_sample 100 | opt = optax.chain( 101 | optax.scale(1 / gradient_accumulation_steps), 102 | clip_by_global_norm(1), 103 | optax.scale_by_adam(), 104 | optax.additive_weight_decay(0), 105 | optax.scale(-1), 106 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)) 107 | ) 108 | 109 | params["optimizer"] = opt 110 | 111 | start = time.time() 112 | print(f"jax devices: {jax.device_count()}") 113 | print(f"jax runtime initialized in {time.time() - start:.06}s") 114 | 115 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) 116 | devices = np.array(jax.devices()).reshape(mesh_shape) 117 | 118 | with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f: 119 | meta = json.load(f) 120 | 121 | ckpt_step = meta["checkpoints"][-1] 122 | print(f"using checkpoint {ckpt_step}") 123 | 124 | total_batch = per_replica_batch * jax.device_count() // cores_per_replica * 8 125 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 126 | network = CausalTransformer(params) 127 | 128 | start = time.time() 129 | network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1]) 130 | print(f"network loaded in {time.time() - start:.06}s") 131 | 132 | local_shards = max(jax.local_device_count() // mesh_shape[1], 1) 133 | del network.state["opt_state"] 134 | network.state = network.move_xmap(network.state, np.zeros(local_shards)) 135 | 136 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 137 | 138 | while True: 139 | all_ctx = [] 140 | all_top_p = [] 141 | all_temp = [] 142 | all_q = [] 143 | while len(all_ctx) < total_batch: 144 | try: 145 | o, q = requests_queue.get(block=False) 146 | all_ctx.append(o["context"]) 147 | all_top_p.append(o["top_p"]) 148 | all_temp.append(o["temp"]) 149 | all_q.append(q) 150 | except Empty: 151 | if len(all_ctx): 152 | break 153 | else: 154 | time.sleep(0.01) 155 | 156 | start = time.time() 157 | while len(all_ctx) < total_batch: 158 | all_ctx.append("whatever") 159 | all_top_p.append(1) 160 | all_temp.append(1) 161 | 162 | all_tokenized = [] 163 | all_length = [] 164 | for ctx in all_ctx: 165 | padded_tokens = np.zeros(seq).astype(np.uint32) 166 | length = 0 167 | 168 | try: 169 | tokens = tokenizer.encode(ctx) 170 | provided_ctx = len(tokens) 171 | pad_amount = seq - provided_ctx 172 | 173 | pad_amount = max(pad_amount, 0) 174 | 175 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)[-seq:] 176 | length = len(tokens) 177 | except: 178 | print("oops exception") 179 | 180 | all_tokenized.append(padded_tokens) 181 | all_length.append(length) 182 | 183 | output = network.generate(np.array(all_tokenized), 184 | np.array(all_length), 185 | 256, 186 | { 187 | "top_p": np.array(all_top_p), 188 | "temp": np.array(all_temp) 189 | }) 190 | 191 | for o, q in zip(output[1][0][:, :, 0], all_q): 192 | q.put(tokenizer.decode(o)) 193 | 194 | print(f"completion done in {time.time() - start:06}s") 195 | -------------------------------------------------------------------------------- /device_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | import jax 6 | import numpy as np 7 | import optax 8 | 9 | import wandb 10 | from tqdm import tqdm 11 | 12 | 13 | from mesh_transformer import util 14 | from mesh_transformer.checkpoint import read_ckpt, write_ckpt 15 | from mesh_transformer.transformer_shard import CausalTransformer 16 | from tfrecord_loader import TFRecordNewInputs 17 | from smart_open import open 18 | from google.cloud import storage 19 | from google.cloud.exceptions import NotFound 20 | 21 | from mesh_transformer.util import clip_by_global_norm, additive_weight_decay 22 | 23 | 24 | def parse_args(): 25 | # Parse command line arguments 26 | parser = argparse.ArgumentParser(description=""" 27 | To use, download the full checkpoint archive, extract and upload to a GCS bucket, and set that as --tune-model-path 28 | Modify the config file: 29 | - set `model_dir` to where the checkpoints should be written during training 30 | - set `train_set`, `val_set` to index files for your data 31 | - set `tpu_size` to 8 (if on a v3-8) 32 | - set `warmup_steps`, `anneal_steps`, `lr`, `end_lr` to the lr schedule for your finetuning run 33 | - the global step will reset to 0, keep that in mind when writing your lr schedule 34 | - set `name` to specify the name of the Weights & Biases run 35 | - set `wandb_project` to specify the Weights & Biases project to log to 36 | To prepare data in the expected data format: 37 | - use the script `create_finetune_tfrecords.py` in this repo to create data in the expected format 38 | - upload the .tfrecords files to GCS 39 | - save their GCS paths to a index file under `data/`, see existing files for examples 40 | """, 41 | formatter_class=argparse.RawTextHelpFormatter) 42 | parser.add_argument("--config", type=str, default=None, help="Config file location") 43 | parser.add_argument("--tune-model-path", type=str, default=None, help="Base model to finetune") 44 | parser.add_argument("--fresh-opt", default=False, action="store_true", help="Use a newly initialized optimizer, ignoring any optimizer state saved in the base checkpoint") 45 | 46 | args = parser.parse_args() 47 | return args 48 | 49 | 50 | def save(network, step, bucket, path, mp, aux=None, keep_n=3, delete_old=True): 51 | assert path 52 | client = storage.Client() 53 | 54 | if aux is None: 55 | aux = {} 56 | 57 | try: 58 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f: 59 | meta = json.load(f) 60 | except: 61 | # create metadata file 62 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f: 63 | json.dump({ 64 | "step": 0, 65 | "checkpoints": [], 66 | "aux": {} 67 | }, f) 68 | 69 | # do sharded checkpoint writing 70 | start = time.time() 71 | res = [] 72 | for shard_id in range(mp): 73 | write_ckpt(network.state, f"gs://{bucket}/{path}/step_{step}/", shard_id) 74 | 75 | print(f"Wrote checkpoint in {time.time() - start:.06}s") 76 | 77 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f: 78 | meta = json.load(f) 79 | 80 | meta["step"] = step 81 | meta["checkpoints"].append(step) 82 | all_aux = meta.get("aux", {}) 83 | 84 | while len(meta["checkpoints"]) > keep_n: 85 | ckpt_to_delete = meta["checkpoints"].pop(0) 86 | 87 | try: 88 | del all_aux[str(ckpt_to_delete)] 89 | except: 90 | print(f"failed to delete the aux state for {step}") 91 | 92 | if delete_old: 93 | print(f"deleting checkpoint {ckpt_to_delete}") 94 | for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"): 95 | # print(f"deleting {blob.name}") 96 | assert path in blob.name 97 | blob.delete() 98 | else: 99 | print(f"keeping checkpoint {ckpt_to_delete}") 100 | 101 | all_aux[step] = aux 102 | meta["aux"] = all_aux 103 | 104 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f: 105 | json.dump(meta, f) 106 | 107 | 108 | def train_step(network, data): 109 | inputs = { 110 | "obs": data[:, :, :-1], 111 | "target": data[:, :, 1:], 112 | } 113 | 114 | loss, last_loss, grad_norm, grad_norm_micro = network.train(inputs) 115 | 116 | return ( 117 | np.array(loss).mean(), 118 | np.array(last_loss).mean(), 119 | np.array(grad_norm).mean(), 120 | np.array(grad_norm_micro).mean(), 121 | ) 122 | 123 | 124 | def eval_step(network, data): 125 | inputs = { 126 | "obs": data[:, :-1], 127 | "target": data[:, 1:], 128 | } 129 | 130 | out = network.eval(inputs) 131 | loss = out["loss"] 132 | 133 | return np.array(loss).mean() 134 | 135 | 136 | if __name__ == "__main__": 137 | args = parse_args() 138 | params = json.load(open(args.config)) 139 | 140 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 141 | per_replica_batch = params["per_replica_batch"] 142 | cores_per_replica = params["cores_per_replica"] 143 | 144 | assert cores_per_replica <= 8 145 | 146 | bucket = params["bucket"] 147 | model_dir = params["model_dir"] 148 | layers = params["layers"] 149 | d_model = params["d_model"] 150 | n_heads = params["n_heads"] 151 | n_vocab = params["n_vocab"] 152 | seq = params["seq"] 153 | norm = params["norm"] 154 | 155 | val_batches = params["val_batches"] 156 | val_every = params["val_every"] 157 | ckpt_every = params["ckpt_every"] 158 | keep_every = params["keep_every"] 159 | eval_tasks = params["eval_harness_tasks"] 160 | total_steps = params["total_steps"] 161 | 162 | pe = params["pe"] 163 | assert pe in ["fixed", "rotary", "t5"] 164 | 165 | warmup_steps = params["warmup_steps"] 166 | anneal_steps = params["anneal_steps"] 167 | lr = params["lr"] 168 | end_lr = params["end_lr"] 169 | weight_decay = params["weight_decay"] 170 | 171 | # alpha parameter for the exponential moving averages used to compute B_simple 172 | noise_scale_alpha = params.get("noise_scale_alpha", 0.01) 173 | 174 | scheduler = util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr) 175 | 176 | opt = optax.chain( 177 | optax.scale(1 / gradient_accumulation_steps), 178 | clip_by_global_norm(1), 179 | optax.scale_by_adam(), 180 | additive_weight_decay(weight_decay), 181 | optax.scale(-1), 182 | optax.scale_by_schedule(scheduler) 183 | ) 184 | 185 | params["optimizer"] = opt 186 | 187 | start = time.time() 188 | tpu_size = jax.device_count() 189 | if tpu_size < cores_per_replica: 190 | msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})" 191 | raise ValueError(msg) 192 | print(f"jax devices: {tpu_size}") 193 | print(f"jax runtime initialized in {time.time() - start:.06}s") 194 | 195 | mesh_shape = (tpu_size // cores_per_replica, cores_per_replica) 196 | devices = np.array(jax.devices()).reshape(mesh_shape) 197 | 198 | # pick initial ckpt - based on tuning vs train from scratch 199 | 200 | step = 0 201 | initial_ckpt_state_path = None 202 | train_loader = None 203 | 204 | if args.tune_model_path: 205 | print('`--tune_model_path` passed: we are beginning a fine-tuning run') 206 | fine_tuning = True 207 | initial_ckpt_state_path = args.tune_model_path 208 | else: 209 | print('`--tune_model_path` not passed: we are continuing a fine-tuning run from a checkpoint (or we are not fine-tuning)') 210 | fine_tuning = False 211 | initial_ckpt_model_dir = model_dir 212 | initial_ckpt_path = f"gs://{bucket}/{initial_ckpt_model_dir}" 213 | meta_path = f"{initial_ckpt_path}/meta.json" 214 | 215 | try: 216 | with open(meta_path, "r") as f: 217 | meta = json.load(f) 218 | ckpt_step = meta["checkpoints"][-1] 219 | initial_ckpt_state_path = f"{initial_ckpt_path}/step_{ckpt_step}/" 220 | print(f"state will be restored from checkpoint {ckpt_step}") 221 | 222 | step = ckpt_step 223 | train_loader = meta['aux'][str(ckpt_step)].get("train_loader", None) 224 | except NotFound: 225 | # no checkpoint, start at zero 226 | print(f"No checkpoint to load at {initial_ckpt_path}. Training from scratch.") 227 | 228 | if initial_ckpt_state_path: 229 | print(f"path to load checkpoint from: {initial_ckpt_state_path}") 230 | else: 231 | print("not loading from a checkpoint") 232 | 233 | # set up datasets 234 | print("setting up datasets") 235 | 236 | train_dataset = TFRecordNewInputs(f"data/{params['train_set']}", 237 | batch_size=( 238 | gradient_accumulation_steps, 239 | per_replica_batch * tpu_size // cores_per_replica), 240 | sample_size=params['seq'], 241 | restore_state=train_loader) 242 | 243 | global_val_batch = per_replica_batch * tpu_size // cores_per_replica 244 | 245 | val_sets = {} 246 | 247 | for k, v in params["val_set"].items(): 248 | val_sets[k] = TFRecordNewInputs( 249 | f"data/{v}", batch_size=(global_val_batch,), sample_size=seq 250 | ) 251 | 252 | # tok/sec metrics 253 | sequences_per_step = gradient_accumulation_steps * (per_replica_batch * tpu_size // cores_per_replica) 254 | tokens_per_step = params['seq'] * sequences_per_step 255 | 256 | # load + run 257 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 258 | print("initializing network") 259 | network = CausalTransformer(params) 260 | 261 | if initial_ckpt_state_path: 262 | print("loading network") 263 | if fine_tuning: 264 | # get the scheduler step stored in the just-initialized optimizer 265 | # should be zero 266 | init_sched_state = network.state["opt_state"][-1] 267 | 268 | start = time.time() 269 | network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1], load_opt=(not args.fresh_opt)) 270 | 271 | if fine_tuning: 272 | # overwrite the loaded scheduler step with zeros 273 | # this makes fine-tuning use the lr schedule in 274 | network.state["opt_state"][-1] = init_sched_state 275 | 276 | print(f"network loaded in {time.time() - start:.06}s") 277 | 278 | print('compiling train fn') 279 | start = time.time() 280 | loss, last_loss, grad_norm, grad_norm_micro = train_step( 281 | network, train_dataset.get_samples() 282 | ) 283 | step += 1 284 | print(f"Train fn compiled in {time.time() - start:.06}s") 285 | 286 | print('compiling eval fn') 287 | start = time.time() 288 | for val_set in val_sets.values(): 289 | eval_step(network, val_set.get_samples()) 290 | val_set.reset() 291 | print(f"Eval fn compiled in {time.time() - start:.06}s") 292 | 293 | project = params.get("wandb_project", "mesh-transformer-jax") 294 | wandb.init(project=project, name=params["name"], config=params) 295 | 296 | G_noise_avg = None 297 | S_noise_avg = None 298 | 299 | while True: 300 | if (step % ckpt_every == 1) or step == total_steps: 301 | print(f"saving a checkpoint for step {step}") 302 | save(network, step, bucket, model_dir, 303 | mp=cores_per_replica, 304 | aux={"train_loader": train_dataset.get_state()}, 305 | delete_old=True, 306 | ) 307 | 308 | if step % val_every == 1: # 1 because we've already taken a step to compile train fn 309 | for name, val_set in val_sets.items(): 310 | val_loss = [] 311 | for i, _ in tqdm(zip(val_set.sample_once(), range(val_batches)), 312 | desc=f"validation for step {step}, set {name}", 313 | total=val_batches): 314 | val_loss.append(eval_step(network, i)) 315 | val_set.reset() 316 | 317 | val_loss = np.array(val_loss).mean() 318 | print(f"validation loss for step {step}, set {name}: {val_loss}") 319 | 320 | wandb.log({f'val/loss_{name}': float(val_loss)}, step) 321 | 322 | if step == total_steps: 323 | print("training completed!") 324 | exit() 325 | 326 | start = time.time() 327 | loss, last_loss, grad_norm, grad_norm_micro = train_step( 328 | network, train_dataset.get_samples() 329 | ) 330 | step += 1 331 | 332 | steps_per_sec = 1 / (time.time() - start) 333 | tokens_per_sec = tokens_per_step * steps_per_sec 334 | sequences_processed = sequences_per_step * step 335 | tokens_processed = tokens_per_step * step 336 | 337 | ### compute summary stats about the gradient 338 | 339 | # converts from grads-summed-over-microbatch (what `CasualTransformer.train` computes) 340 | # to grads-averaged-over-microbatch (what we want) 341 | # 342 | # (when taking gradient steps, the same conversion happens inside the optimizer 343 | # via optax.scale(1 / gradient_accumulation_steps)) 344 | grad_norm = grad_norm / gradient_accumulation_steps 345 | 346 | # compute G_noise and S_noise 347 | # from "An Empirical Model of Large-Batch Training" Appendix A.1 348 | # here, B_big = gradient_accumulation_steps, and B_small = 1 for convenience 349 | gbsmall = grad_norm_micro ** 2 350 | gbbig = grad_norm ** 2 351 | G_noise = (gradient_accumulation_steps * gbbig - gbsmall) / ( 352 | gradient_accumulation_steps - 1 353 | ) 354 | S_noise = (gbsmall - gbbig) / (1 - 1 / gradient_accumulation_steps) 355 | 356 | noise_scale_stats = { 357 | "noise/G_noise": G_noise, 358 | "noise/S_noise": S_noise, 359 | } 360 | 361 | # heuristic to avoid reporting G_noise in very early training when gradients are large 362 | # (these take a long time to wash out of the moving average that defines B_simple) 363 | use_step_in_noise_avgs = gbbig < 2 364 | 365 | if use_step_in_noise_avgs: 366 | # compute moving averages of G_noise and S_noise, for B_simple 367 | if G_noise_avg is None: 368 | G_noise_avg = G_noise 369 | else: 370 | G_noise_avg = (1 - noise_scale_alpha) * G_noise_avg + noise_scale_alpha * G_noise 371 | 372 | if S_noise_avg is None: 373 | S_noise_avg = S_noise 374 | else: 375 | S_noise_avg = (1 - noise_scale_alpha) * S_noise_avg + noise_scale_alpha * S_noise 376 | 377 | B_simple = S_noise_avg / G_noise_avg 378 | 379 | noise_scale_stats.update( 380 | { 381 | "noise/G_noise_avg": G_noise_avg, 382 | "noise/S_noise_avg": S_noise_avg, 383 | "noise/B_simple": B_simple, 384 | } 385 | ) 386 | 387 | wandb_stats = { 388 | "train/loss": loss, 389 | "train/last_loss": last_loss, 390 | "train/steps_per_sec": steps_per_sec, 391 | "train/tokens_per_sec": tokens_per_sec, 392 | "train/grad_norm": grad_norm, 393 | "train/learning_rate": float(scheduler(network.state["opt_state"][-1].count[0].item())), 394 | "sequences_processed": sequences_processed, 395 | "tokens_processed": tokens_processed, 396 | } 397 | wandb_stats.update(noise_scale_stats) 398 | 399 | wandb.log(wandb_stats, step) 400 | -------------------------------------------------------------------------------- /docker/.env: -------------------------------------------------------------------------------- 1 | DRYRUN=false 2 | DOMAIN=yourdomain.com 3 | SUBDOMAIN=gptj.api 4 | NGINX_PATH=./nginx_proxyvm.conf 5 | CERTS_PATH=/path/to/certs/ 6 | DNS_KEYS=/path/to/keys/ 7 | MODEL_DIR=/your/path/to/model/step_383500 8 | TPU_NAME=YOUR_TPU_NAME -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Have tested with a custom Ubuntu-1804 / Python 3.7 / Tensorflow 2.5.0 Base Image 2 | # Not tested with this image. 3 | FROM tensorflow/tensorflow:2.5.0 4 | RUN apt update && \ 5 | apt-get install git -y 6 | 7 | WORKDIR /app/ 8 | COPY . /app/ 9 | RUN git clone https://github.com/kingoflolz/mesh-transformer-jax && \ 10 | pip install -r mesh-transformer-jax/requirements.txt && \ 11 | pip install mesh-transformer-jax/ jax==0.2.12 && \ 12 | pip install fastapi uvicorn requests aiofiles aiohttp && \ 13 | ln -s /app/start.sh /start.sh 14 | 15 | ENV PYTHONPATH /app:/app/mesh-transformer-jax:/usr/local/bin/python3 16 | ENV PATH $PYTHONPATH:$PATH 17 | ENV TOKENIZERS_PARALLELISM=true 18 | EXPOSE 80 19 | 20 | CMD ["/start.sh"] 21 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Deploying GPT-J (TPU VM) in a Docker Container 2 | 3 | This PR enables you to run GPT-J for Inference in a Docker Container, and has run stable over the past week. 4 | 5 | The function currently requires 2 calls, the first is a POST call to /model/predict with the params 6 | 7 | ``` 8 | context: str 9 | temp: Optional[float] = 1.0 10 | top_p: Optional[float] = 0.9 11 | top_k: Optional[int] = 50 12 | length: Optional[int] = 256 13 | ``` 14 | 15 | This returns an ID (int) that is used to retrieve the prediction once the prediction is complete, rather than waiting until the prediction is done. 16 | 17 | The second function call is a POST call to /model/get_prediction with `qid=id` that was returned. 18 | 19 | ## Getting Started 20 | 21 | You will need the following: 22 | 1) TPU VM: A TPU (v2-8 works fine) 23 | 2) Proxy VM: a second VM within the same zone. This VM will only serve to proxy all requests to the TPUVM, so it does not require a lot of resources. 24 | 3) [Docker + Docker Compose](https://docs.docker.com/compose/install/) Installed on both VMs 25 | 4) TPU VM and Proxy VM on the same GCP Network (default for most) 26 | 5) Ports 80 and 443 exposed on the Proxy VM 27 | 6) Your DNS records pointing to the External IP for Proxy VM. 28 | 29 | The GPTJ Checkpoint should be downloaded locally and the volume should be set as the variable MODEL_DIR in the .env file. This mounts the volume to the container, preventing multiple re-downloads. 30 | 31 | ## Files to Modify 32 | 33 | `.env` 34 | 35 | ```bash 36 | # Modify these values in .env 37 | DRYRUN=false # Set to True to prevent SSL Lets Encrypt Domain validation when testing 38 | DOMAIN=yourdomain.com # Set to your domain 39 | SUBDOMAIN=gptj.api # Set to your subdomain. The endpoint will be gptj.api.yourdomain.com 40 | NGINX_PATH=./nginx_proxyvm.conf # Modify this conf file with the Internal IP of the TPU VM. This value will not change. 41 | CERTS_PATH=/path/to/certs/ # Set this to a path on the Proxy VM that be used to store SSL Certs. This path will automatically be created. 42 | DNS_KEYS=/path/to/keys/ # Set this to a path on the Proxy VM that be used to store DNS Keys. This path will automatically be created. 43 | MODEL_DIR=/your/path/to/model/step_383500 # Set this to the path on TPU VM that the model is stored in 44 | TPU_NAME=YOUR_TPU_NAME # Set this to your TPU Name as created 45 | ``` 46 | 47 | `nginx_proxyvm.conf` 48 | 49 | Modify the IP to your TPU VM's *Internal* IP Address 50 | 51 | 52 | ## Running the Docker Container 53 | 54 | Assuming you are in this directory as the working directory. 55 | 56 | On TPU VM: 57 | 58 | `docker-compose up -d` 59 | 60 | On Proxy VM: 61 | 62 | `docker-compose -f compose-proxy.yaml up` 63 | 64 | ## Final Notes 65 | 66 | The reason you can't directly serve from the TPU VM is because there is no control over the TPU VM's port firewalls. I go into more details about [Dockerization within TPU VMs here](https://trisongz.medium.com/accessing-your-tpus-in-docker-containers-with-tpu-vm-e944f5909dd4). 67 | 68 | The API is _very_ barebones and stripped down for simplicity. It's meant as a starting point and can be further extended for your needs. 69 | 70 | -------------------------------------------------------------------------------- /docker/compose-proxy.yaml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | services: 3 | gptjapi: 4 | image: ghcr.io/linuxserver/swag 5 | container_name: gptjapi 6 | cap_add: 7 | - NET_ADMIN 8 | environment: 9 | - PUID=1000 10 | - PGID=1000 11 | - TZ=America/Chicago 12 | - URL=${DOMAIN} 13 | - SUBDOMAINS=${SUBDOMAIN}, 14 | - VALIDATION=http 15 | - STAGING=${DRYRUN} 16 | - ONLY_SUBDOMAINS=true 17 | volumes: 18 | - ${NGINX_PATH}:/config/nginx/proxy-confs/app.subdomain.conf 19 | - ${CERTS_PATH}:/config/etc/letsencrypt/ 20 | - ${DNS_KEYS}:/config/keys/ 21 | ports: 22 | - 443:443 23 | - 80:80 24 | restart: unless-stopped 25 | networks: 26 | - production 27 | networks: 28 | production: 29 | -------------------------------------------------------------------------------- /docker/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | services: 3 | gptjserver: 4 | image: gptjserver 5 | container_name: gptjserver 6 | cap_add: 7 | - ALL 8 | environment: 9 | - TPU_NAME=${TPU_NAME} 10 | - TF_CPP_MIN_LOG_LEVEL=0 11 | - XRT_TPU_CONFIG="localservice;0;localhost:51011" 12 | - TF_XLA_FLAGS=--tf_xla_enable_xla_devices 13 | build: 14 | context: . 15 | dockerfile: Dockerfile 16 | ports: 17 | - 8080:80 18 | volumes: 19 | - ${MODEL_DIR}:/app/model 20 | - /var/run/docker.sock:/var/run/docker.sock 21 | - /usr/share/tpu/:/usr/share/tpu/ 22 | - /lib/libtpu.so:/lib/libtpu.so 23 | privileged: true 24 | restart: unless-stopped 25 | devices: 26 | - "/dev:/dev" 27 | networks: 28 | - production 29 | networks: 30 | production: 31 | -------------------------------------------------------------------------------- /docker/main.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import uvicorn 4 | import traceback 5 | import logging 6 | 7 | from fastapi import FastAPI 8 | from starlette.middleware.cors import CORSMiddleware 9 | from .payloads import CompletionPayload, CompletionResponse, QueueRequest, QueueResponse 10 | from .ops import get_gptj_model 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | app = FastAPI() 15 | app.add_middleware( 16 | CORSMiddleware, 17 | allow_origins=["*"], 18 | allow_credentials=True, 19 | allow_methods=["*"], 20 | allow_headers=["*"], 21 | ) 22 | 23 | 24 | # globals 25 | MODEL_API = None 26 | 27 | 28 | @app.on_event("startup") 29 | async def startup_event(): 30 | global MODEL_API 31 | try: 32 | MODEL_API = get_gptj_model() 33 | MODEL_API.load_model() 34 | MODEL_API.start_background() 35 | except Exception as e: 36 | logger.debug(f"Model could not be loaded: {str(e)}") 37 | traceback.print_exc() 38 | 39 | 40 | @app.post("/model/get_prediction") 41 | def get_prediction(payload: QueueRequest) -> CompletionResponse: 42 | res = MODEL_API.wait_for_queue(payload.qid) 43 | return CompletionResponse(**res) 44 | 45 | 46 | @app.post("/model/predict") 47 | def model_prediction(payload: CompletionPayload) -> QueueResponse: 48 | res = MODEL_API.add_to_queue(payload) 49 | return QueueResponse(qid=res['qid']) 50 | 51 | 52 | if __name__ == "__main__": 53 | uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) 54 | -------------------------------------------------------------------------------- /docker/nginx_proxyvm.conf: -------------------------------------------------------------------------------- 1 | server { 2 | listen 443; 3 | server_name gptj.api.yourdomain.com; 4 | include /config/nginx/ssl.conf; 5 | client_max_body_size 0; 6 | location / { 7 | include /config/nginx/proxy.conf; 8 | # The below IP should be your TPUVM Internal IP 9 | proxy_pass http://10.000.0.1:8080; 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /docker/ops.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import jax 5 | import optax 6 | import threading 7 | import numpy as np 8 | import logging 9 | from queue import Queue, Empty 10 | from jax.experimental import maps 11 | from transformers import GPT2TokenizerFast 12 | 13 | from mesh_transformer.checkpoint import read_ckpt 14 | from mesh_transformer.sampling import nucleaus_sample 15 | from mesh_transformer.transformer_shard import CausalTransformer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | # This prevents fastapi from creating multiple models in multi-worker mode 20 | # leading to OOM crashes 21 | gptj_model = None 22 | gptj_model_lock = threading.Lock() 23 | 24 | def compile_model(): 25 | global gptj_model 26 | with gptj_model_lock: 27 | if gptj_model: 28 | return 29 | gptj_model = GPTJ() 30 | 31 | def get_gptj_model(): 32 | compile_model() 33 | return gptj_model 34 | 35 | def timer(start_time=None): 36 | if not start_time: 37 | return time.time() 38 | return time.time() - start_time 39 | 40 | 41 | class GPTJ: 42 | def __init__(self): 43 | self.params = { 44 | "layers": 28, 45 | "d_model": 4096, 46 | "n_heads": 16, 47 | "n_vocab": 50400, 48 | "norm": "layernorm", 49 | "pe": "rotary", 50 | "pe_rotary_dims": 64, 51 | "seq": 2048, 52 | "cores_per_replica": 8, 53 | "per_replica_batch": 1, 54 | "sampler": nucleaus_sample, 55 | "optimizer": optax.scale(0) 56 | } 57 | self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') 58 | self.queue_ids = {} 59 | self.qidx = 0 60 | self.queue = Queue() 61 | self.network = None 62 | self.lock = threading.Lock() 63 | self._alive_time = timer() 64 | 65 | def load_model(self): 66 | if self.network: 67 | logger.info('Attempting to reload model when model is loaded. Returning') 68 | return 69 | with self.lock: 70 | logger.info('Loading Model') 71 | start = timer() 72 | logger.info(f"JAX Devices: {jax.device_count()}") 73 | logger.info(f"JAX Runtime Initialized in {timer(start):.06} secs") 74 | mesh_shape = (jax.device_count() // self.params['cores_per_replica'], self.params['cores_per_replica']) 75 | self.devices = np.array(jax.devices()).reshape(mesh_shape) 76 | self.total_batch = self.params['per_replica_batch'] * jax.device_count() // self.params['cores_per_replica'] * 8 77 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(self.devices, ('dp', 'mp'))) 78 | network = CausalTransformer(self.params) 79 | logger.info(f'Loading Checkpoint') 80 | network.state = read_ckpt(network.state, "/app/model/", self.devices.shape[1]) 81 | logger.info(f"GPTJ Network loaded in {timer(start):.06} secs. Total Batch Size: {self.total_batch}") 82 | del network.state["opt_state"] 83 | network.state = network.move_xmap(network.state, np.zeros(self.params['cores_per_replica'])) 84 | self.network = network 85 | 86 | 87 | def start_background(self): 88 | with self.lock: 89 | t = threading.Thread(target=self.background) 90 | t.start() 91 | 92 | def prepare_item(self, context, length=256): 93 | tokens = self.tokenizer.encode(context) 94 | logger.info(tokens) 95 | token_length = len(tokens) 96 | pad_amount = self.params['seq'] - token_length 97 | pad_amount = max(pad_amount, 0) 98 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)[-self.params['seq']:] 99 | return {'tokens': padded_tokens, 'length': token_length} 100 | 101 | # Single Item - Not Tested 102 | def infer(self, context, top_p=0.9, top_k=40, temp=1.0, length=256, **kwargs): 103 | item = self.prepare_item(context, length) 104 | batched_tokens = np.array([item['tokens']] * self.total_batch) 105 | batched_lengths = np.array([item['length']] * self.total_batch) 106 | start = timer() 107 | output = self.network.generate( 108 | batched_tokens, batched_lengths, length, 109 | { 110 | "top_p": np.ones(self.total_batch) * top_p, 111 | "top_k": np.ones(self.total_batch) * top_k, 112 | "temp": np.ones(self.total_batch) * temp, 113 | } 114 | ) 115 | samples = [] 116 | decoded_tokens = output[1][0] 117 | end_time = timer(start) 118 | for o in decoded_tokens[:, :, 0]: 119 | res = { 120 | 'context': context, 121 | 'completion': self.tokenizer.decode(o), 122 | 'time': end_time 123 | } 124 | samples.append(res) 125 | logger.info(f"Completion done in {end_time:06} secs") 126 | return samples 127 | 128 | def infer_batch(self, batch, **kwargs): 129 | logger.info(f'Starting Inference on Batch') 130 | batch_items = {'tokens': [], 'lengths': [], 'top_p': [], 'top_k': [], 'temp': []} 131 | max_lengths, contexts = [], [] 132 | for req in batch: 133 | req = self.to_data(req) 134 | item = self.prepare_item(req['context'], req['length']) 135 | batch_items['tokens'].append(item['tokens']) 136 | batch_items['lengths'].append(item['length']) 137 | batch_items['top_p'].append(req['top_p']) 138 | batch_items['top_k'].append(req['top_k']) 139 | batch_items['temp'].append(req['temp']) 140 | max_lengths.append(req['length']) 141 | contexts.append(req['context']) 142 | 143 | max_length = max(max_lengths) 144 | for key, vals in batch_items.items(): 145 | batch_items[key] = np.array(vals) 146 | start = timer() 147 | logger.info(f'Completed Preparing Batch') 148 | output = self.network.generate( 149 | batch_items['tokens'], batch_items['lengths'], max_length, 150 | { 151 | "top_p": batch_items['top_p'], 152 | "top_k": batch_items['top_k'], 153 | "temp": batch_items['temp'], 154 | } 155 | ) 156 | logger.info(f'Completed Generation') 157 | samples = [] 158 | end_time = timer(start) 159 | for pred, ctx in zip(output[1][0][:, :, 0], contexts): 160 | res = { 161 | 'context': ctx, 162 | 'completion': self.tokenizer.decode(pred), 163 | 'time': end_time 164 | } 165 | samples.append(res) 166 | logger.info(f"Completion done in {end_time:06} secs") 167 | return samples 168 | 169 | def add_to_queue(self, item): 170 | self.qidx += 1 171 | self.queue.put({'item': self.to_data(item), 'qidx': self.qidx}) 172 | self.queue_ids[self.qidx] = Queue() 173 | return {'qid': self.qidx} 174 | 175 | def wait_for_queue(self, qid): 176 | if not self.queue_ids.get(qid): 177 | return {'Error': 'QID not found'} 178 | return self.queue_ids[qid].get() 179 | 180 | def background(self): 181 | logger.info(f'Init Background') 182 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(self.devices, ('dp', 'mp'))) 183 | while True: 184 | batch, qids = [], [] 185 | while len(batch) <= self.total_batch: 186 | try: 187 | req = self.queue.get(block=False) 188 | logger.info(f'Got Queue Item: {req}') 189 | batch.append(req['item']) 190 | qids.append(req['qidx']) 191 | 192 | except Empty: 193 | if len(batch): 194 | break 195 | else: 196 | time.sleep(0.01) 197 | batch_size = len(batch) 198 | logger.info(f'Working on Batch: {batch_size} - {qids}') 199 | while len(batch) < self.total_batch: 200 | batch.append(self.placeholder_item) 201 | start = timer() 202 | results = self.infer_batch(batch) 203 | for res, qid in zip(results, qids): 204 | self.queue_ids[qid].put(res) 205 | logger.info(f'Completed Current Batch of {batch_size} Items in {timer(start):.2f} secs') 206 | 207 | @property 208 | def placeholder_item(self): 209 | return {'context': 'nada', 'top_p': 0.9, 'top_k': 40, 'temp': 1.0, 'length': 1} 210 | 211 | def to_data(self, item): 212 | try: 213 | return {'context': item.context, 'top_p': item.top_p, 'top_k': item.top_k, 'temp': item.temp, 'length': item.length} 214 | except: 215 | return {'context': item.get('context', ''), 'top_p': item.get('top_p', 0.9), 'top_k': item.get('top_k', 40), 'temp': item.get('temp', 1.0), 'length': item.get('length', 256)} 216 | 217 | @property 218 | def alive_time(self): 219 | return timer(self._alive_time) -------------------------------------------------------------------------------- /docker/payloads.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Dict, Optional, Any, List 3 | 4 | 5 | class ModelPayload(BaseModel): 6 | inputs: Optional[Any] = None 7 | params: Optional[Dict]= {} 8 | 9 | class CompletionPayload(BaseModel): 10 | context: str 11 | temp: Optional[float] = 1.0 12 | top_p: Optional[float] = 0.9 13 | top_k: Optional[int] = 50 14 | length: Optional[int] = 256 15 | 16 | class CompletionResponse(BaseModel): 17 | context: str 18 | completion: str 19 | time: float 20 | 21 | class QueueResponse(BaseModel): 22 | qid: int 23 | 24 | class QueueRequest(BaseModel): 25 | qid: int 26 | -------------------------------------------------------------------------------- /docker/start.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env sh 2 | set -e 3 | 4 | if [ -f /app/app/main.py ]; then 5 | DEFAULT_MODULE_NAME=app.main 6 | elif [ -f /app/main.py ]; then 7 | DEFAULT_MODULE_NAME=main 8 | fi 9 | 10 | MODULE_NAME=${MODULE_NAME:-$DEFAULT_MODULE_NAME} 11 | VARIABLE_NAME=${VARIABLE_NAME:-app} 12 | export APP_MODULE=${APP_MODULE:-"$MODULE_NAME:$VARIABLE_NAME"} 13 | 14 | HOST=${HOST:-0.0.0.0} 15 | PORT=${PORT:-80} 16 | LOG_LEVEL=${LOG_LEVEL:-info} 17 | 18 | # If there's a prestart.sh script in the /app directory or other path specified, run it before starting 19 | PRE_START_PATH=${PRE_START_PATH:-/app/prestart.sh} 20 | echo "Checking for script in $PRE_START_PATH" 21 | if [ -f $PRE_START_PATH ] ; then 22 | echo "Running script $PRE_START_PATH" 23 | . "$PRE_START_PATH" 24 | else 25 | echo "There is no script $PRE_START_PATH" 26 | fi 27 | 28 | # Start Uvicorn with live reload 29 | exec uvicorn --reload --reload-dir /app/mesh-transformer-jax --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" --access-log -------------------------------------------------------------------------------- /eval_harness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from lm_eval import evaluator, tasks 5 | 6 | from mesh_transformer.build_model import build_model 7 | from tasks import EvalHarnessAdaptor 8 | 9 | 10 | def parse_args(): 11 | # Parse command line arguments 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--tpu", type=str, help="Name of TPU to train on.") 14 | parser.add_argument("--tpu_region", type=str, help="Region of TPU to train on.") 15 | parser.add_argument("--preemptible", action="store_true") 16 | 17 | parser.add_argument("--config", type=str, default=None, help="Config file location") 18 | 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == "__main__": 24 | args = parse_args() 25 | params = json.load(open(args.config)) 26 | 27 | tpu_name = args.tpu 28 | region = args.tpu_region 29 | preemptible = args.preemptible 30 | 31 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 32 | per_replica_batch = params["per_replica_batch"] 33 | tpu_size = params["tpu_size"] 34 | cores_per_replica = params["cores_per_replica"] 35 | 36 | bucket = params["bucket"] 37 | model_dir = params["model_dir"] 38 | layers = params["layers"] 39 | d_model = params["d_model"] 40 | n_heads = params["n_heads"] 41 | n_vocab = params["n_vocab"] 42 | seq = params["seq"] 43 | norm = params["norm"] 44 | pe = params["pe"] 45 | 46 | total_batch = per_replica_batch * tpu_size // cores_per_replica * 4 47 | 48 | t = build_model(params, tpu_name, region, preemptible) 49 | adaptor = EvalHarnessAdaptor(t, seq, total_batch, shrink=pe != "fixed") 50 | 51 | step, aux = t.load(bucket, model_dir) 52 | t.move() 53 | 54 | results = evaluator.evaluate(adaptor, tasks.get_task_dict(["lambada", 55 | "piqa", 56 | "hellaswag", 57 | "winogrande", 58 | "mathqa", 59 | "pubmedqa", 60 | # "boolq", 61 | # "cb", 62 | # "copa", 63 | # "multirc", 64 | # "record", 65 | # "wic", 66 | # "wsc", 67 | ]), False, 0, None) 68 | dumped = json.dumps(results, indent=2) 69 | print(dumped) 70 | 71 | results = evaluator.evaluate(adaptor, tasks.get_task_dict(["lambada_cloze", 72 | ]), False, 15, None) 73 | 74 | dumped = json.dumps(results, indent=2) 75 | print(dumped) -------------------------------------------------------------------------------- /howto_finetune.md: -------------------------------------------------------------------------------- 1 | # How to Fine-Tune GPT-J - The Basics 2 | 3 | Before anything else, you'll likely want to apply for access to the TPU Research Cloud (TRC). Combined with a Google Cloud free trial, that should allow you to do everything here for free. Once you're in TRC, you need to create a project, then with the name of the new project fill out the form that was emailed to you. Use the script `create_finetune_tfrecords.py` to prepare your data as tfrecords; I might do a separate guide on that. Another thing you might want to do is fork the mesh-transformer-jax repo to make it easier to add and modify the config files. 4 | 5 | 0. [Install the Google Cloud SDK](https://cloud.google.com/sdk/docs/install). We'll need it later. 6 | 7 | 1. If you didn't make a project and activate TPU access through TRC yet (or if you plan on paying out of pocket), [make one now](https://console.cloud.google.com/projectcreate). 8 | 9 | 2. TPUs use Google Cloud buckets for storage, go ahead and [create one now](https://console.cloud.google.com/storage/create-bucket). Make sure it's in the region the TPU VM will be; the email from TRC will tell you which region(s) you can use free TPUs in. 10 | 11 | 3. You'll need the full pretrained weights in order to fine-tune the model. [Download those here](https://mystic.the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd). 12 | 13 | Now that you have a bucket on the cloud and the weights on your PC, you need to upload the weights to the bucket in two steps: 14 | 15 | 4. Decompress and extract `GPT-J-6B/step_383500.tar.zstd` so you're left with the uncompressed folder containing the sharded checkpoint. 16 | 17 | 5. Open the Google Cloud SDK and run the following command, replacing the path names as appropriate: `gsutil -m cp -R LOCAL_PATH_TO/step_383500 gs://YOUR-BUCKET`. If that works, the console will show the files being uploaded. *Note: Took about 12 hours for me, uploading to the Netherlands from California; hopefully you'll have a better geographic situation than I did! I also initially made the mistake of uploading the still-packed .tar. Don't do that, TPU VMs don't have enough local storage for you to unpack it. To avoid needing to re-upload, I had to unpack it in Colab.* 18 | 19 | You'll want to upload tfrecords of your data as well, you can do that here or through the web interface, but trust me when I say you don't want to upload the nearly 70GB weights through the web interface. 20 | 21 | Note that steps 6 and 7, preparing the index and config files, can be done later on by editing the base repo in the VM's text editor. It's more efficient to instead make these changes to your own fork of the repo as follows: 22 | 23 | 6. In the data folder, create a new file `foo.train.index`, replace foo with whatever you want to refer to your dataset as. For each tfrecord in your bucket that you intend to train with, add the path as a line in the index. Make `foo.val.index` and do the same for your validation dataset (if you have one). See the existing files for examples. 24 | 25 | 7. Duplicate the config file `6B_roto_256.json`, rename it to something appropriate for your project. Open it up and make these edits: 26 | - `tpu_size`: Change from `256` to `8` 27 | - `bucket`: Change to your bucket 28 | - `model_dir`: Change to the directory you'd like to save your checkpoints in 29 | - `train_set` and `val_set`: Change to the index files from the last step 30 | - `eval_harness_tasks`: Can be removed if you don't plan on using the eval harness 31 | - `val_every` & `ckpt_every` & `keep_every`: Usage should be intuitive. Don't set the `foo_every` values to 0 though or you'll get a divide by zero error. If you don't have a `val_set`, just set `val_every` to something higher than `total_steps`. 32 | - `val_batches`: This should equal the number of sequences in your val dataset. You can find this number at the end of the .tfrecords file produced by `create_finetune_tfrecords.py` 33 | - `name`: Change to a name for your model. 34 | - `warmup_steps`, `lr`, `val_batches`, etc.: see the *Learning Rate Notes* section at the end of the guide. 35 | 36 | 37 | 8. Push the changes to your GitHub repo. 38 | 39 | 9. Follow [this guide](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) up to and including the step **"Connect to your Cloud TPU VM"**. 40 | 41 | At this point you should have remote access to the TPU VM! 42 | 43 | 10. In the new VM terminal, type `git clone https://github.com/kingoflolz/mesh-transformer-jax` (or, preferably, your own fork, after pushing the config and index files) 44 | 45 | 11. Move to the new directory with `cd mesh-transformer-jax` and run `pip install -r requirements.txt`. Since the requirements.txt file doesn't pin the exact jax version required for finetuning, run `pip install jax==0.2.12` and you'll be all set. 46 | 47 | 12. Finally, run `python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/`. If everything is set up correctly this will begin the fine-tuning process. First the model has to be loaded into memory; when `loading network` displayed on the console it took about 10-15 minutes before the next step, setting up WandB for logging. Option 3 allows you to skip that if you aren't using WandB. A step 1 checkpoint will save, and the real training will start. If you have a small dataset, this will go by quickly; TPU VMs can train at a rate of ~5000 tokens/second. 48 | 49 | 13. You did it! Now don't forget any clean up steps you need to take like shutting down your TPU VM or removing unneeded data in buckets, so that you don't have any unexpected charges from Google later. 50 | 51 | ## Now what? 52 | 53 | This guide is labeled "The Basics", anything we haven't covered so far is out of scope, but go check out the rest of the repository! Try `python3 device_sample.py --config=configs/YOUR_CONFIG.json` for a basic sampling interface. Use `slim_model.py` to prepare an easier-to-deploy slim version of your new weights for inference. Experiment! 54 | 55 | ### Running with HuggingFace 56 | To use the model in HuggingFace's `transformer` library using pytorch, you'll need to transfer the weights 57 | into a format that it recognizes. This can be done using `to_hf_weights.py`. It's recommended that you use `slim_model.py` before attempting to move the weights to a pytorch/transformer format. Use `python to_hf_weights.py --help` to see usage details. 58 | 59 | *note: as of 9/1/2021, GPT-J has been merged into the `main` branch of `transformers` but has not yet been put into production. Run `pip install git+https://github.com/huggingface/transformers#transformers` to install the current `main` branch. 60 | 61 | ## Learning Rate Notes 62 | 63 | **Thanks to nostalgebraist for talking about this!** They're the one who explained this part on Discord, I'm just paraphrasing really: 64 | 65 | The first thing you want to determine is how long a training epoch will be. `gradient_accumulation_steps` is your batch size, it defaults to `16`, nostalgebraist recommends `32`. Your .tfrecord files should have a number in the file name indicating how many sequences are in the dataset. Divide that number by the batch size and the result is how many steps are in an epoch. Now we can write the schedule. 66 | 67 | `lr` is recommended to be between `1e-5` and `5e-5`, with `end_lr` set to 1/5 or 1/10 of `lr`. 68 | `weight_decay` can remain `0.1`. `total_steps` should be at least one epoch, possibly longer if you have a validation 69 | set to determine your training loss with. 70 | `warmup_steps` should be 5-10% of total, and finally `anneal_steps` should be `total_steps - warmup_steps`. 71 | (The `lr` is set to `end_lr` after `warmup_steps+anneal_steps` and then keeps training until `total_steps`, 72 | but usually you should stop after annealing is done) 73 | 74 | To illustrate: I have a small dataset that tokenized into 1147 sequences as a .tfrecord. Dividing by `gradient_accumulation_steps` set to `16`, rounding up to ensure I use all the data, equals 72 steps per epoch. I'll set `lr` to `5e-5`, `end_lr` to a fifth of that, `1e-5`; that may be too much, it's on the high end of the recommended range. I'll set `total_steps` to `72` for one epoch, since I don't have a validation set. Then I'll set `anneal_steps` to `65` and `warmup_steps` to `7`. Simple as that, but you may need to fiddle with the specifics on your own. 75 | -------------------------------------------------------------------------------- /mesh_transformer/TPU_cluster.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import time 4 | 5 | import ray 6 | 7 | from typing import Callable 8 | import numpy as np 9 | 10 | from mesh_transformer.train_actor import NetworkRunner 11 | from google.cloud import storage 12 | from smart_open import open 13 | from func_timeout import func_set_timeout 14 | 15 | 16 | class TPUCluster: 17 | @func_set_timeout(1200) 18 | def __init__(self, 19 | mesh_shape, 20 | node_count, 21 | model: Callable, 22 | version=1): 23 | assert ray.is_initialized() # needs a valid ray cluster to start 24 | self.nodes = [] 25 | self.node_count = node_count 26 | self.dp, self.mp = mesh_shape 27 | self.version = version 28 | 29 | start = time.time() 30 | 31 | for i in range(node_count): 32 | self.nodes.append(NetworkRunner.options(max_concurrency=2).remote(mesh_shape, model)) 33 | 34 | for n in self.nodes: 35 | n.run.remote() 36 | 37 | params = [] 38 | for n in self.nodes: 39 | params.append(n.get_params.remote()) 40 | 41 | self.param_count = ray.get(params)[0] 42 | print(f"Ray actors created in {time.time() - start:.06}s") 43 | 44 | @func_set_timeout(600) 45 | def train(self, data): 46 | data_chunks = np.array_split(data, len(self.nodes), axis=1) 47 | 48 | res = [] 49 | for n, d in zip(self.nodes, data_chunks): 50 | res.append(n.train.remote({ 51 | "obs": d[:, :, :-1], 52 | "target": d[:, :, 1:], 53 | })) 54 | 55 | res = ray.get(res) 56 | 57 | loss = [] 58 | last_loss = [] 59 | 60 | for r in res: 61 | loss.append(r[0]) 62 | last_loss.append(r[1]) 63 | 64 | return np.array(loss).mean(), np.array(last_loss).mean() 65 | 66 | @func_set_timeout(600) 67 | def eval(self, data): 68 | if isinstance(data, dict): 69 | data_chunked = [{} for _ in self.nodes] 70 | for k, v in data.items(): 71 | v_chunks = np.array_split(v, len(self.nodes), axis=0) 72 | for idx, v_chunk in enumerate(v_chunks): 73 | data_chunked[idx][k] = v_chunk 74 | 75 | res = [] 76 | for n, d in zip(self.nodes, data_chunked): 77 | res.append(n.eval.remote(d)) 78 | 79 | total = 0 80 | correct = 0 81 | last_correct = 0 82 | 83 | total_last_loss = 0 84 | mask_loss = [] 85 | each_correct = [] 86 | 87 | for input, output in zip(data_chunked, ray.get(res)): 88 | correct_and_valid = np.logical_and(output["correct"], input["eval_mask"]) 89 | 90 | correct_tokens_count = np.sum(correct_and_valid, -1) 91 | valid_tokens_count = np.sum(input["eval_mask"], -1) 92 | 93 | correct_example = np.logical_and(valid_tokens_count == correct_tokens_count, valid_tokens_count > 0) 94 | valid_example = valid_tokens_count > 0 95 | last_correct_example = correct_and_valid[:, -1] 96 | 97 | each_correct += correct_example.tolist() 98 | 99 | total += sum(valid_example) 100 | correct += sum(correct_example) 101 | last_correct += sum(last_correct_example) 102 | total_last_loss += sum(valid_example * output["last_loss"]) 103 | 104 | valid_loss = np.sum(output["all_loss"] * input["eval_mask"], -1) 105 | mask_loss += valid_loss.tolist() 106 | 107 | return { 108 | "total": total, 109 | "correct": correct, 110 | "last_correct": last_correct, 111 | "last_loss": total_last_loss, 112 | "mask_loss": np.array(mask_loss), 113 | "each_correct": np.array(each_correct) 114 | } 115 | else: 116 | data_chunks = np.array_split(data, len(self.nodes), axis=0) 117 | 118 | res = [] 119 | for n, d in zip(self.nodes, data_chunks): 120 | res.append(n.eval.remote({ 121 | "obs": d[:, :-1], 122 | "target": d[:, 1:], 123 | })) 124 | 125 | return np.array([i["loss"] for i in ray.get(res)]).mean() 126 | 127 | @func_set_timeout(600) 128 | def generate(self, context, ctx_length, gen_len): 129 | context = np.array_split(context, len(self.nodes), axis=0) 130 | ctx_length = np.array_split(ctx_length, len(self.nodes), axis=0) 131 | 132 | res = [] 133 | for n, ctx, l in zip(self.nodes, context, ctx_length): 134 | res.append(n.generate.remote(( 135 | ctx, 136 | np.ones(len(ctx), dtype=np.uint32) * l, 137 | gen_len 138 | ))) 139 | 140 | return np.concatenate([i[1][0][:, :, 0] for i in ray.get(res)], axis=0) 141 | 142 | @func_set_timeout(600) 143 | def move(self): 144 | start = time.time() 145 | res = [] 146 | for node in self.nodes: 147 | res.append(node.move_params.remote()) 148 | ray.get(res) 149 | 150 | print(f"Moved weights to TPU in {time.time() - start:.06}s") 151 | 152 | @func_set_timeout(1800) 153 | def load(self, bucket, path): 154 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f: 155 | meta = json.load(f) 156 | 157 | ckpt_step = meta["checkpoints"][-1] 158 | 159 | # do replicated checkpoint reading 160 | start = time.time() 161 | res = [] 162 | for node in self.nodes: 163 | res.append(node.load_ckpt.remote(f"gs://{bucket}/{path}/step_{ckpt_step}/")) 164 | 165 | # make sure they all read from the same checkpoint 166 | step = np.array(ray.get(res)) 167 | assert (step[0] == step).all() 168 | step = int(step[0]) 169 | 170 | print(f"Checkpoint@step{step} restored in {time.time() - start:.06}s") 171 | return step, meta["aux"][str(ckpt_step)] 172 | 173 | @func_set_timeout(600) 174 | def save(self, step, bucket, path, aux=None, init=False, overwrite=False, keep_n=3, delete_old=True): 175 | assert path 176 | client = storage.Client() 177 | 178 | if aux is None: 179 | aux = {} 180 | 181 | if init: 182 | # check existing checkpoint folder does not exist, and delete it if it does 183 | for blob in client.list_blobs(bucket, prefix=f"{path}/"): 184 | assert overwrite 185 | # print(f"deleting {blob.name}") 186 | assert path in blob.name 187 | blob.delete() 188 | 189 | # create metadata file 190 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f: 191 | json.dump({ 192 | "step": 0, 193 | "checkpoints": [], 194 | "aux": {} 195 | }, f) 196 | 197 | # do sharded checkpoint writing 198 | start = time.time() 199 | res = [] 200 | 201 | if self.version == 1: 202 | for shard_id, node in zip(range(self.mp), itertools.cycle(self.nodes)): 203 | res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}/", shard_id)) 204 | elif self.version == 2: 205 | for node in self.nodes: 206 | res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}", 0)) 207 | 208 | ray.get(res) 209 | print(f"Wrote checkpoint in {time.time() - start:.06}s") 210 | 211 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f: 212 | meta = json.load(f) 213 | 214 | meta["step"] = step 215 | meta["checkpoints"].append(step) 216 | all_aux = meta.get("aux", {}) 217 | 218 | while len(meta["checkpoints"]) > keep_n: 219 | ckpt_to_delete = meta["checkpoints"].pop(0) 220 | 221 | try: 222 | del all_aux[str(ckpt_to_delete)] 223 | except: 224 | print(f"failed to delete the aux state for {step}") 225 | 226 | if delete_old: 227 | print(f"deleting checkpoint {ckpt_to_delete}") 228 | for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"): 229 | # print(f"deleting {blob.name}") 230 | assert path in blob.name 231 | blob.delete() 232 | else: 233 | print(f"keeping checkpoint {ckpt_to_delete}") 234 | 235 | all_aux[step] = aux 236 | meta["aux"] = all_aux 237 | 238 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f: 239 | json.dump(meta, f) 240 | -------------------------------------------------------------------------------- /mesh_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kingoflolz/mesh-transformer-jax/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/__init__.py -------------------------------------------------------------------------------- /mesh_transformer/build_model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing 3 | 4 | import optax 5 | import ray 6 | 7 | from mesh_transformer import util 8 | from mesh_transformer.TPU_cluster import TPUCluster 9 | from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerV2 10 | from mesh_transformer.util import clip_by_global_norm, additive_weight_decay 11 | from ray_tpu import create_tpu, wait_til, get_connection, start_ray 12 | 13 | 14 | def build_model(params, tpu_name, region, preemptible, version=1): 15 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 16 | cores_per_replica = params["cores_per_replica"] 17 | tpu_size = params["tpu_size"] 18 | 19 | warmup_steps = params["warmup_steps"] 20 | anneal_steps = params["anneal_steps"] 21 | lr = params["lr"] 22 | end_lr = params["end_lr"] 23 | weight_decay = params["weight_decay"] 24 | 25 | assert tpu_size in [8, 32, 128, 256, 512] 26 | 27 | create_tpu(tpu_name, region, f"v3-{tpu_size}", preemptible) 28 | assert wait_til(tpu_name, region, {'state': 'READY', 'health': 'HEALTHY'}) 29 | 30 | conns = get_connection(tpu_name, region) 31 | 32 | assert len(conns) * 8 == tpu_size, "wrong size TPU for config" 33 | 34 | head_info = ray.init(include_dashboard=False, object_store_memory=10**9) 35 | address = head_info['redis_address'] 36 | 37 | with multiprocessing.pool.ThreadPool(processes=len(conns)) as p: 38 | p.map(functools.partial(start_ray, address=address, version=version), conns) 39 | 40 | opt = optax.chain( 41 | optax.scale(1 / gradient_accumulation_steps), 42 | clip_by_global_norm(1, use_psum=(version == 1)), 43 | optax.scale_by_adam(), 44 | additive_weight_decay(weight_decay), 45 | optax.scale(-1), 46 | optax.scale_by_schedule(util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)) 47 | ) 48 | 49 | params["optimizer"] = opt 50 | 51 | if version == 2: 52 | model_fn = functools.partial(CausalTransformerV2, params) 53 | elif version == 1: 54 | model_fn = functools.partial(CausalTransformer, params) 55 | else: 56 | raise Exception(f"Version {version} does not exist") 57 | 58 | t = TPUCluster((tpu_size // cores_per_replica, cores_per_replica), len(conns), model_fn, version=version) 59 | return t 60 | -------------------------------------------------------------------------------- /mesh_transformer/checkpoint.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import json 4 | import time 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import multiprocessing 10 | 11 | import ray 12 | from smart_open import open 13 | 14 | from mesh_transformer.util import head_print 15 | 16 | pieces = 16 # how many files to split each shard across 17 | 18 | 19 | def fix_dtype(pytree): 20 | def fix(x): 21 | if x.dtype == np.dtype('V2'): 22 | x.dtype = jnp.bfloat16 23 | return jnp.asarray(x) 24 | 25 | return jax.tree_map(fix, pytree) 26 | 27 | 28 | @functools.partial(jax.jit, backend="cpu") 29 | def index_weights(weights, idx): 30 | cpu_device = jax.devices("cpu")[0] 31 | return jax.device_put(jax.tree_map(lambda i: i[idx], weights), cpu_device) 32 | 33 | 34 | def write(x, ckpt_dir): 35 | # start = time.time() 36 | idx, i = x 37 | file_path = ckpt_dir + f"{idx}.npz" 38 | for _ in range(3): 39 | try: 40 | with open(file_path, "wb") as f: 41 | np.savez(f, *i) 42 | # cloudpickle.dump(i, f) 43 | # print(f"written {idx} in {time.time() - start:.06}s") 44 | return 45 | except: 46 | print("save failed, trying again") 47 | 48 | print("save failed 3 times, exiting") 49 | raise Exception("save failed") 50 | 51 | 52 | def split(a, n): 53 | k, m = divmod(len(a), n) 54 | return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)) 55 | 56 | 57 | def write_ckpt(pytree, dir, shard): 58 | # ckpt_dir = Path(dir) 59 | # ckpt_dir.mkdir(parents=True, exist_ok=True) 60 | 61 | flattened, structure = jax.tree_flatten(pytree) 62 | 63 | start = time.time() 64 | # cpu_flattened = jax.device_put(flattened, cpu_device) 65 | cpu_flattened = index_weights(flattened, shard) 66 | # print(f"Moved indexed in {time.time() - start:.06}s") 67 | 68 | cpu_flattened_chunked = split(cpu_flattened, pieces) 69 | 70 | # start = time.time() 71 | # cpu_float = move_weights(cpu_flattened) 72 | # print(f"changed weight types in {time.time() - start:.06}s") 73 | 74 | with multiprocessing.pool.ThreadPool(pieces) as p: 75 | write_fn = functools.partial(write, ckpt_dir=f"{dir}shard_{shard}/") 76 | 77 | start = time.time() 78 | list((p.imap_unordered(write_fn, enumerate(cpu_flattened_chunked)))) 79 | # print(f"written to gcs in {time.time() - start:.06}s") 80 | 81 | 82 | def read_shard(ckpt_dir): 83 | out = [] 84 | for idx in range(16): 85 | file_path = ckpt_dir + f"{idx}.npz" 86 | with open(file_path, "rb") as f: 87 | buf = f.read() 88 | f_io = io.BytesIO(buf) 89 | deserialized = np.load(f_io) 90 | for i in deserialized: 91 | out.append(deserialized[i]) 92 | return out 93 | 94 | 95 | def reshard(x, old_shape): 96 | if len(x.shape) == 1: 97 | # print("epoch") 98 | # print(x) 99 | out = x[0:1] 100 | 101 | elif len(x.shape) == 2: 102 | # print(f"LN/bias {x.shape}") 103 | # print(x[:, :16]) 104 | 105 | if (x[1:] == x[-1]).all(): 106 | # print("LN") 107 | if (x[1:] == 0).all() or (x[1:] == 1).all(): 108 | out = x[0:1] 109 | else: 110 | # print("shard bias") 111 | out = x[0:1] * x.shape[0] / old_shape[0] 112 | else: 113 | # print("bias") 114 | out = x.reshape(old_shape) 115 | 116 | print(out[:, :16]) 117 | 118 | elif len(x.shape) == 3: 119 | # print(f"weight {x.shape}") 120 | if x.shape[0] * x.shape[2] == old_shape[2]: 121 | # print("case 1") 122 | out = jnp.transpose(x, (1, 0, 2)).reshape(old_shape) 123 | elif x.shape[0] * x.shape[1] == old_shape[1]: 124 | # print("case 2") 125 | out = x.reshape(old_shape) 126 | else: 127 | raise Exception(f"unimplemented, {x.shape}, {old_shape}") 128 | else: 129 | raise Exception(f"unimplemented, {x}") 130 | 131 | return out 132 | 133 | 134 | def read_ckpt(pytree, dir, shards_in, shards_out=None, load_opt=True): 135 | if shards_out is None: 136 | shards_out = shards_in 137 | 138 | old_flattened, structure = jax.tree_flatten(pytree) 139 | 140 | original_opt_state = pytree["opt_state"] 141 | 142 | # TODO: figure out how to use a process pool here for more speed 143 | with multiprocessing.pool.ThreadPool(shards_in) as p: 144 | start = time.time() 145 | shards = list((p.imap(read_shard, [f"{dir}shard_{i}/" for i in range(shards_in)]))) 146 | print(f"read from disk/gcs in {time.time() - start:.06}s") 147 | 148 | def _unshard(shards, old_flattened): 149 | unsharded = [] 150 | 151 | for old, *all_shards in zip(old_flattened, *shards): 152 | x = np.stack(all_shards) 153 | # No idea why this is V2...? 154 | if x.dtype == np.dtype('V2'): 155 | x.dtype = jnp.bfloat16 156 | 157 | if shards_out != shards_in: 158 | x = reshard(x, old.shape) 159 | unsharded.append(x) 160 | 161 | assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}" 162 | return unsharded 163 | try: 164 | unsharded = _unshard(shards, old_flattened) 165 | except AssertionError: 166 | load_opt = False # no opt to load in ckpt 167 | del pytree['opt_state'] 168 | old_flattened, structure = jax.tree_flatten(pytree) 169 | unsharded = _unshard(shards, old_flattened) 170 | 171 | loaded_pytree = jax.tree_unflatten(structure, unsharded) 172 | 173 | if not load_opt: 174 | loaded_pytree['opt_state'] = original_opt_state 175 | return loaded_pytree 176 | 177 | 178 | def read_ckpt_lowmem(pytree, dir, shards_in, shards_out=None, load_opt=True): 179 | if shards_out is None: 180 | shards_out = shards_in 181 | 182 | old_flattened, structure = jax.tree_flatten(pytree) 183 | 184 | original_opt_state = pytree["opt_state"] 185 | 186 | def _unshard(): 187 | start = time.time() 188 | unsharded = [] 189 | devices = jax.devices() 190 | device_count = len(devices) 191 | device_index = 0 192 | 193 | for file_index in range(pieces): 194 | array_keys = [*np.load(f"{dir}shard_0/{file_index}.npz").keys()] 195 | for array_index in range(len(array_keys)): 196 | unstacked = [] 197 | for shard_index in range(shards_in): 198 | npz = np.load(f"{dir}shard_{shard_index}/{file_index}.npz") 199 | array = npz[array_keys[array_index]] 200 | if array.dtype == 'V2': 201 | array.dtype = jnp.bfloat16 202 | unstacked.append(array) 203 | 204 | x = jax.device_put(jnp.stack(unstacked), device=devices[device_index % device_count]) 205 | 206 | if shards_out != shards_in: 207 | x = reshard(x, old_flattened[device_index].shape) 208 | unsharded.append(x) 209 | 210 | assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}" 211 | device_index += 1 212 | 213 | print(f"read from disk/gcs in {time.time() - start:.06}s") 214 | return unsharded 215 | 216 | try: 217 | unsharded = _unshard() 218 | except AssertionError: 219 | load_opt = False # no opt to load in ckpt 220 | del pytree['opt_state'] 221 | old_flattened, structure = jax.tree_flatten(pytree) 222 | unsharded = _unshard() 223 | 224 | loaded_pytree = jax.tree_unflatten(structure, unsharded) 225 | 226 | if not load_opt: 227 | loaded_pytree['opt_state'] = original_opt_state 228 | return loaded_pytree 229 | 230 | 231 | def parallel_write(arrays, fname): 232 | # TODO: make this actually parallel 233 | with open(fname, "wb") as f: 234 | np.savez(f, *arrays) 235 | 236 | 237 | def parallel_read(old, fname, validate=True): 238 | old_vals, treedef = jax.tree_flatten(old) 239 | 240 | if "gs://" in fname: 241 | # TODO: make this actually parallel 242 | with open(fname, "rb") as f: 243 | buf = f.read() 244 | f_io = io.BytesIO(buf) 245 | loaded = np.load(f_io) 246 | else: 247 | loaded = np.load(fname, mmap_mode='r') 248 | 249 | new_vals = [] 250 | for i in loaded: 251 | new_vals.append(loaded[i]) 252 | 253 | assert len(new_vals) == len(old_vals), "Incompatible checkpoint" 254 | 255 | for o, n in zip(new_vals, old_vals): 256 | if validate: 257 | assert o.shape == n.shape, "Incompatible checkpoint" 258 | 259 | return jax.tree_unflatten(treedef, fix_dtype(new_vals)) 260 | 261 | 262 | def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id): 263 | id_to_name = {} 264 | if getattr(pytree, "items", None): 265 | for k, v in pytree.items(): 266 | k_path = f"{path}/{k}" 267 | if is_leaf(v): 268 | id_to_name[to_id(v)] = k_path 269 | else: 270 | id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path)} 271 | elif getattr(pytree, "__getitem__", None): 272 | for v in pytree: 273 | if is_leaf(v): 274 | id_to_name[to_id(v)] = path 275 | else: 276 | id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=path)} 277 | else: 278 | id_to_name[to_id(pytree)] = path 279 | return id_to_name 280 | 281 | 282 | def tree_leaves_with_names(pytree, to_id=id): 283 | leaves = jax.tree_leaves(pytree) 284 | is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [to_id(x) for x in leaves] 285 | return tree_flatten_with_names(pytree, is_leaf) 286 | 287 | 288 | def write_ckpt_v2(model_state, dir): 289 | start = time.time() 290 | if jax.host_id() == 0: 291 | param_map = tree_leaves_with_names(model_state["params"]) 292 | opt_map = tree_leaves_with_names(model_state["opt_state"]) 293 | 294 | meta = { 295 | "total_hosts": jax.host_count(), 296 | "step": int(model_state["step"]), 297 | "param_order": [param_map[id(i)] for i in jax.tree_leaves(model_state["params"])], 298 | "opt_order": [opt_map[id(i)] for i in jax.tree_leaves(model_state["opt_state"])] 299 | } 300 | 301 | print("step:", model_state["step"]) 302 | with open(dir + "/meta.json", "w") as f: 303 | json.dump(meta, f) 304 | print(f"meta written in {time.time() - start:.06}s") 305 | 306 | start = time.time() 307 | parallel_write(jax.tree_flatten(model_state["params"])[0], dir + f"/params/shard_{jax.host_id()}.npz") 308 | head_print(f"params written in {time.time() - start:.06}s") 309 | 310 | start = time.time() 311 | parallel_write(jax.tree_flatten(model_state["opt_state"])[0], dir + f"/opt_state/shard_{jax.host_id()}.npz") 312 | head_print(f"opt_state written in {time.time() - start:.06}s") 313 | 314 | 315 | def read_sharded_v2(state, dir, checkpoint_hosts, state_shard): 316 | files_per_host = checkpoint_hosts // jax.host_count() 317 | 318 | assert files_per_host >= 1, "can't restore model to larger pod than was trained on (yet)" 319 | assert jax.host_count() * files_per_host == checkpoint_hosts, "weird host count" 320 | 321 | if files_per_host == 1: 322 | head_print("using fast path of checkpoint restore (save shards == read shards)") 323 | parallel_read(state, dir + f"/shard_{jax.host_id()}.npz") 324 | 325 | @ray.remote 326 | def read_remote(old, fname): 327 | return parallel_read(old, fname, validate=False) 328 | 329 | start_idx = files_per_host * jax.host_id() 330 | 331 | skeleton = jax.tree_map(lambda x: jnp.zeros_like(x, shape=()), state) # a full pytree just to carry dtypes 332 | 333 | refs = [ 334 | read_remote.remote(skeleton, f"{dir}/shard_{i}.npz") 335 | for i in range(start_idx, start_idx + files_per_host) 336 | ] 337 | 338 | values = ray.get(refs) 339 | 340 | def all_array_equal(iterator): 341 | try: 342 | iterator = iter(iterator) 343 | first = next(iterator) 344 | return all(jnp.array_equal(first, rest) for rest in iterator) 345 | except StopIteration: 346 | return True 347 | 348 | def reshard_v2(old, shard_strategy, *new_values): 349 | rep_dim_count = shard_strategy.count(None) 350 | total_dim_count = len(shard_strategy) 351 | 352 | # head_print("old.shape", old.shape) 353 | # head_print("shard_strategy", shard_strategy) 354 | 355 | assert len(old.shape) == total_dim_count 356 | 357 | if rep_dim_count == total_dim_count: 358 | # fully replicated 359 | assert all_array_equal(new_values) 360 | return fix_dtype(new_values[0]) 361 | 362 | shard_dim = [idx for idx, dim in enumerate(shard_strategy) if dim is not None and "mp" in dim] 363 | 364 | # only support sharding in 1d for now 365 | assert len(shard_dim) == 1 366 | shard_dim = shard_dim[0] 367 | 368 | ret_val = jnp.concatenate(fix_dtype(new_values), axis=shard_dim) 369 | assert old.shape == ret_val.shape 370 | 371 | return jax.device_put(ret_val, jax.devices("cpu")[0]) 372 | 373 | # head_print("state", jax.tree_structure(state)) 374 | # head_print("state_shard", jax.tree_structure(state_shard)) 375 | # head_print("values", jax.tree_structure(values[0])) 376 | 377 | return jax.tree_multimap(reshard_v2, *([state, state_shard] + values)) 378 | 379 | 380 | def load_ckpt_v2(model_state, dir, state_shard, load_opt): 381 | start = time.time() 382 | with open(dir + "meta.json", "r") as f: 383 | meta = json.load(f) 384 | 385 | ckpt_hosts = meta["total_hosts"] 386 | 387 | head_print(f"meta loaded in {time.time() - start:.06}s") 388 | 389 | new_state = { 390 | "step": np.array([meta["step"]]), 391 | } 392 | 393 | start = time.time() 394 | new_state["params"] = read_sharded_v2(model_state["params"], 395 | dir + "params", 396 | ckpt_hosts, 397 | state_shard["params"]) 398 | head_print(f"params loaded in {time.time() - start:.06}s") 399 | 400 | if not load_opt: 401 | return new_state 402 | 403 | start = time.time() 404 | new_state["opt_state"] = read_sharded_v2(model_state["opt_state"], 405 | dir + "opt_state", 406 | ckpt_hosts, 407 | state_shard["opt_state"]) 408 | head_print(f"opt_state loaded in {time.time() - start:.06}s") 409 | 410 | return new_state 411 | -------------------------------------------------------------------------------- /mesh_transformer/sampling.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | # takes in a logit distribution, softmax and then sample 6 | def softmax_sample(key, logits, _, temp=1): 7 | return jax.random.categorical(key, logits/temp, -1).astype(jnp.uint32), None 8 | 9 | 10 | def nucleaus_filter(logits, top_p=0.9, top_k=None): 11 | sorted_logits = jnp.sort(logits)[:, ::-1] # sort descending 12 | sorted_indices = jnp.argsort(logits)[:, ::-1] 13 | cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1) 14 | 15 | if top_k is not None: 16 | # Keep only top_k tokens 17 | indices_range = jnp.arange(len(sorted_indices[0])) 18 | indices_range = jnp.stack([indices_range] * len(sorted_indices), axis=0) 19 | 20 | sorted_indices_to_remove = jnp.where(indices_range >= top_k, sorted_indices, 0) 21 | 22 | _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove) 23 | 24 | logit_mask = 1e10 * indices_to_remove 25 | 26 | logits -= logit_mask 27 | 28 | # Remove tokens with cumulative probability above a threshold 29 | sorted_indices_to_remove = cumulative_probs > top_p 30 | sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove), axis=-1)[:, :-1] 31 | 32 | _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove) 33 | 34 | logit_mask = 1e10 * indices_to_remove 35 | 36 | logits -= logit_mask 37 | 38 | return logits 39 | 40 | 41 | def nucleaus_sample(key, logits, _, top_p=0.9, temp=1, top_k=None): 42 | logits = nucleaus_filter(logits, top_p, top_k=top_k) 43 | 44 | return softmax_sample(key, logits, None, temp=temp) 45 | 46 | 47 | if __name__ == "__main__": 48 | import numpy as np 49 | logits = np.array([[-2, -1, 0, 0.8, 0, 0.1, 0.3, 0.4, 0.5, 0.6, 0.7, -3]]) 50 | print(nucleaus_filter(logits)) -------------------------------------------------------------------------------- /mesh_transformer/train_actor.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import time 3 | import numpy as np 4 | from queue import Queue 5 | 6 | from mesh_transformer.util import head_print 7 | 8 | 9 | @ray.remote(resources={"tpu": 1}) 10 | class NetworkRunner(object): 11 | def __init__(self, mesh_shape, network_builder): 12 | self.mesh_shape = mesh_shape 13 | self.network_builder = network_builder 14 | 15 | self.input_q = Queue(maxsize=1) 16 | self.output_q = Queue(maxsize=1) 17 | 18 | def run(self): 19 | print(f"jax runtime initialization starting") 20 | import jax 21 | from jax.experimental.maps import thread_resources, ResourceEnv, Mesh 22 | import haiku as hk 23 | # jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True 24 | 25 | thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ()) 26 | 27 | start = time.time() 28 | jax.devices() 29 | 30 | import warnings 31 | warnings.filterwarnings("ignore") 32 | warnings.filterwarnings("ignore", category=ResourceWarning) 33 | 34 | if jax.host_id() == 0: 35 | warnings.filterwarnings("default") 36 | 37 | head_print(f"jax devices: {jax.device_count()}") 38 | head_print(f"jax runtime initialized in {time.time() - start:.06}s") 39 | devices = np.array(jax.devices()).reshape(self.mesh_shape) 40 | 41 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 42 | start = time.time() 43 | network = self.network_builder() 44 | head_print(f"Initialized in {time.time() - start:.06}s") 45 | 46 | while True: 47 | operation, input = self.input_q.get() 48 | if operation == "train": 49 | self.output_q.put(network.train(input)) 50 | elif operation == "eval": 51 | self.output_q.put(network.eval(input)) 52 | elif operation == "generate": 53 | self.output_q.put(network.generate(*input)) 54 | elif operation == "write_ckpt": 55 | path, shard = input 56 | network.write_ckpt(path, shard) 57 | self.output_q.put(None) 58 | elif operation == "load_ckpt": 59 | network.load_ckpt(input) 60 | self.output_q.put(network.state["step"][0]) 61 | elif operation == "get_params": 62 | self.output_q.put(hk.data_structures.tree_size(network.state['params'])) 63 | elif operation == "move_params": 64 | # only needed for inference, otherwise first train step does this 65 | local_shards = max(jax.local_device_count() // self.mesh_shape[1], 1) 66 | 67 | # delete the optimizer states otherwise it OOMs for some reason 68 | # TODO: use ShardedDeviceArray or something to get around this for bigger models 69 | del network.state["opt_state"] 70 | network.state = network.move_xmap(network.state, np.zeros(local_shards)) 71 | self.output_q.put(None) 72 | else: 73 | raise Exception("Not implemented") 74 | 75 | def get_params(self): 76 | self.input_q.put(("get_params", None)) 77 | return self.output_q.get() 78 | 79 | def train(self, sample): 80 | self.input_q.put(("train", sample)) 81 | return self.output_q.get() 82 | 83 | def eval(self, sample): 84 | self.input_q.put(("eval", sample)) 85 | return self.output_q.get() 86 | 87 | def generate(self, input): 88 | self.input_q.put(("generate", input)) 89 | return self.output_q.get() 90 | 91 | def write_ckpt(self, path, shard): 92 | self.input_q.put(("write_ckpt", (path, shard))) 93 | return self.output_q.get() 94 | 95 | def load_ckpt(self, path): 96 | self.input_q.put(("load_ckpt", path)) 97 | return self.output_q.get() 98 | 99 | def move_params(self): 100 | self.input_q.put(("move_params", None)) 101 | return self.output_q.get() 102 | -------------------------------------------------------------------------------- /mesh_transformer/util.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.experimental.pjit import with_sharding_constraint 4 | from optax import AdditiveWeightDecayState, GradientTransformation, OptState 5 | 6 | 7 | # same as with_sharding_constraint but doesn't fail if run outside of pjit/mesh context 8 | def maybe_shard(x, resource): 9 | try: 10 | return with_sharding_constraint(x, resource) 11 | except ValueError as e: 12 | print(e) 13 | return x 14 | 15 | 16 | def gpt3_schedule(warmup_steps, 17 | total_steps, 18 | peak_lr, 19 | end_lr): 20 | def sch(step): 21 | warmup_pct = jnp.clip(step, 0, warmup_steps) / warmup_steps 22 | anneal_pct = jnp.clip(step - warmup_steps, 0, total_steps) / total_steps 23 | 24 | return warmup_pct * peak_lr - (peak_lr - end_lr) * (1 - jnp.cos(jnp.pi * anneal_pct)) / 2 25 | 26 | return sch 27 | 28 | 29 | def global_norm(updates, use_psum=True): 30 | pre_sqrt = sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(updates)]) 31 | if use_psum: 32 | pre_sqrt = jax.lax.psum(pre_sqrt, "shard") 33 | return jnp.sqrt(pre_sqrt) 34 | 35 | 36 | class ClipByGlobalNormState(OptState): 37 | """The `clip_by_global_norm` transformation is stateless.""" 38 | 39 | 40 | def clip_by_global_norm(max_norm, use_psum=True) -> GradientTransformation: 41 | """Clip updates using their global norm. 42 | 43 | References: 44 | [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) 45 | 46 | Args: 47 | max_norm: the maximum global norm for an update. 48 | 49 | Returns: 50 | An (init_fn, update_fn) tuple. 51 | """ 52 | 53 | def init_fn(_): 54 | return ClipByGlobalNormState() 55 | 56 | def update_fn(updates, state, params=None): 57 | del params 58 | g_norm = global_norm(updates, use_psum=use_psum) 59 | trigger = g_norm < max_norm 60 | updates = jax.tree_map( 61 | lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates) 62 | return updates, state 63 | 64 | return GradientTransformation(init_fn, update_fn) 65 | 66 | 67 | def additive_weight_decay(weight_decay: float = 0.0) -> GradientTransformation: 68 | """Add parameter scaled by `weight_decay`, to all parameters with more than one dim (i.e. exclude ln, bias etc) 69 | 70 | Args: 71 | weight_decay: a scalar weight decay rate. 72 | 73 | Returns: 74 | An (init_fn, update_fn) tuple. 75 | """ 76 | 77 | def init_fn(_): 78 | return AdditiveWeightDecayState() 79 | 80 | def update_fn(updates, state, params): 81 | updates = jax.tree_multimap(lambda g, p: g + weight_decay * p * (len(g.shape) > 1), updates, params) 82 | return updates, state 83 | 84 | return GradientTransformation(init_fn, update_fn) 85 | 86 | 87 | def to_f32(t): 88 | return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) 89 | 90 | 91 | def to_bf16(t): 92 | return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) 93 | 94 | 95 | def to_f16(t): 96 | return jax.tree_map(lambda x: x.astype(jnp.float16) if x.dtype == jnp.float32 else x, t) 97 | 98 | 99 | # identity in forward pass, psum in backward 100 | @jax.custom_vjp 101 | def f_psum(x): 102 | return x 103 | 104 | 105 | def f_psum_fwd(x): 106 | return f_psum(x), None 107 | 108 | 109 | def f_psum_bwd(_, g): 110 | return jax.lax.psum(g, "shard"), 111 | 112 | 113 | f_psum.defvjp(f_psum_fwd, f_psum_bwd) 114 | 115 | 116 | # identity in forward pass, pmean in backward 117 | @jax.custom_vjp 118 | def f_pmean(x): 119 | return x 120 | 121 | 122 | def f_pmean_fwd(x): 123 | return f_psum(x), None 124 | 125 | 126 | def f_pmean_bwd(_, g): 127 | return jax.lax.pmean(g, "shard"), 128 | 129 | 130 | f_pmean.defvjp(f_pmean_fwd, f_pmean_bwd) 131 | 132 | 133 | # psum in forward pass, identity in backward 134 | @jax.custom_vjp 135 | def g_psum(x): 136 | return jax.lax.psum(x, "shard") 137 | 138 | 139 | def g_psum_fwd(x): 140 | return g_psum(x), None 141 | 142 | 143 | def g_psum_bwd(_, g): 144 | return g, 145 | 146 | 147 | g_psum.defvjp(g_psum_fwd, g_psum_bwd) 148 | 149 | 150 | def shard_axis(x, axis_size, axis_name): 151 | # in_shape = x.shape 152 | assert x.shape[0] % axis_size == 0 153 | 154 | x = x.reshape((axis_size, -1) + x.shape[1:]) 155 | 156 | x = x[jax.lax.axis_index(axis_name)] 157 | # print("shard out", x.shape, "in", in_shape) 158 | 159 | # assert np.prod(x.shape) * axis_size == np.prod(in_shape) 160 | 161 | return x 162 | 163 | 164 | def unshard_axis(x, axis_name): 165 | # in_shape = x.shape 166 | x = jax.lax.all_gather(x, axis_name) 167 | 168 | x = x.reshape((-1, ) + x.shape[2:]) 169 | 170 | # assert x.shape[-1] == 4096 171 | # print("unshard out", x.shape, "in", in_shape) 172 | return x 173 | 174 | 175 | # print but only on the first node 176 | def head_print(*args, **kwargs): 177 | if jax.host_id() == 0: 178 | print(*args, **kwargs) 179 | 180 | 181 | if __name__ == "__main__": 182 | sch = gpt3_schedule(1_000, 20_000, 1e-4, 1e-5) 183 | 184 | for i in range(150): 185 | i = i * 200 186 | print(i, sch(i)) 187 | -------------------------------------------------------------------------------- /ray_tpu.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import subprocess 4 | import time 5 | 6 | import glob 7 | import requests 8 | from fabric import Connection 9 | 10 | 11 | @functools.lru_cache() 12 | def get_bearer(): 13 | return subprocess.check_output("gcloud auth print-access-token", shell=True).decode("utf-8").strip() 14 | 15 | 16 | @functools.lru_cache() 17 | def get_project(): 18 | return subprocess.check_output("gcloud config list --format 'value(core.project)'", shell=True).decode( 19 | "utf-8").strip() 20 | 21 | 22 | def create_tpu( 23 | name, 24 | zone, 25 | type, 26 | preemptible, 27 | ): 28 | headers = { 29 | 'Authorization': f'Bearer {get_bearer()}', 30 | 'Content-Type': 'application/json', 31 | } 32 | 33 | try: 34 | status = check_tpu(name, zone) 35 | 36 | if status["state"] not in ["CREATING", "READY"]: 37 | print("deleting TPU") 38 | delete_tpu(name, zone) 39 | 40 | while True: 41 | try: 42 | print("deleting check") 43 | print(check_tpu(name, zone)["state"]) 44 | 45 | time.sleep(1) 46 | except: 47 | break 48 | except: 49 | pass 50 | 51 | params = ( 52 | ('node_id', name), 53 | ) 54 | 55 | data = {"accelerator_type": 56 | type, 57 | "runtime_version": 58 | 'v2-alpha', 59 | "network_config": 60 | {"enable_external_ips": True}, 61 | } 62 | 63 | if preemptible: 64 | data["schedulingConfig"] = {"preemptible": True} 65 | 66 | response = requests.post(f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes', 67 | headers=headers, params=params, json=data) 68 | 69 | print(response.json()) 70 | 71 | return response.status_code == 200 72 | 73 | 74 | def check_tpu(name, zone): 75 | headers = { 76 | 'Authorization': f'Bearer {get_bearer()}', 77 | } 78 | 79 | response = requests.get( 80 | f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}', 81 | headers=headers) 82 | 83 | return response.json() 84 | 85 | 86 | def delete_tpu(name, zone): 87 | headers = { 88 | 'Authorization': f'Bearer {get_bearer()}', 89 | } 90 | 91 | response = requests.delete( 92 | f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}', 93 | headers=headers) 94 | 95 | return response.json() 96 | 97 | 98 | def wait_til(name, zone, state): 99 | while True: 100 | ret = check_tpu(name, zone) 101 | 102 | print("wait_til check") 103 | print(ret) 104 | 105 | matches = True 106 | for k, expected_v in state.items(): 107 | if k not in ret: 108 | matches = False 109 | continue 110 | if ret[k] != expected_v: 111 | matches = False 112 | 113 | if "error" in ret: 114 | return False 115 | 116 | if ret["state"] == "TERMINATED": 117 | return False 118 | 119 | if matches: 120 | return True 121 | 122 | time.sleep(1) 123 | 124 | 125 | def get_connection( 126 | name, 127 | zone, 128 | ): 129 | info = check_tpu(name, zone) 130 | outputs = [] 131 | for i in info["networkEndpoints"]: 132 | outputs.append(Connection(i["ipAddress"], 133 | connect_kwargs={ 134 | "key_filename": os.path.expanduser('~/.ssh/google_compute_engine'), })) 135 | return outputs 136 | 137 | 138 | def start_ray(conn, address, version=1): 139 | conn.sudo('rm -rf *.py') 140 | conn.sudo('rm -rf mesh_transformer') 141 | 142 | for i in glob.glob("*.py"): 143 | conn.put(i, "") 144 | 145 | conn.run("mkdir mesh_transformer -p") 146 | 147 | for i in glob.glob("mesh_transformer/*.py"): 148 | conn.put(i, "mesh_transformer/") 149 | 150 | conn.sudo('python3 setup.py install', hide=True) 151 | 152 | if version == 2: 153 | conn.put("scripts/init_ray_v2.sh", "/tmp/ray-tpu.sh") 154 | else: 155 | conn.put("scripts/init_ray.sh", "/tmp/ray-tpu.sh") 156 | conn.sudo('chmod +x /tmp/ray-tpu.sh', hide=True) 157 | conn.sudo('/tmp/ray-tpu.sh', hide=True) 158 | try: 159 | conn.run('ray stop -f', hide=True) 160 | except: 161 | pass 162 | 163 | time.sleep(1) 164 | 165 | conn.run(f"TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD={32 * 1024**3} ray start --address={address} --resources='" + '{"tpu": 1}\' --include-dashboard False', hide=True) 166 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.19.5 2 | tqdm>=4.45.0 3 | wandb>=0.11.2 4 | einops~=0.3.0 5 | requests~=2.25.1 6 | fabric~=2.6.0 7 | optax==0.0.9 8 | dm-haiku==0.0.5 9 | git+https://github.com/EleutherAI/lm-evaluation-harness/ 10 | ray[default]==1.4.1 11 | jax~=0.2.12 12 | Flask~=1.1.2 13 | cloudpickle~=1.3.0 14 | tensorflow-cpu~=2.6.0 15 | google-cloud-storage~=1.36.2 16 | transformers 17 | smart_open[gcs] 18 | func_timeout 19 | ftfy 20 | fastapi 21 | uvicorn 22 | lm_dataformat 23 | pathy 24 | -------------------------------------------------------------------------------- /resharding_example.py: -------------------------------------------------------------------------------- 1 | # This was tested with an RTX 3090, peak memory usage is approximately 22.4GB during inference, and 19GB when loading the model 2 | # The following environment variables were also used: XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform 3 | 4 | import time 5 | 6 | import jax 7 | from jax.experimental import maps 8 | import numpy as np 9 | import optax 10 | import transformers 11 | 12 | from mesh_transformer.checkpoint import read_ckpt 13 | from mesh_transformer.sampling import nucleaus_sample 14 | from mesh_transformer.transformer_shard import CausalTransformer 15 | 16 | params = { 17 | "layers": 28, 18 | "d_model": 4096, 19 | "n_heads": 16, 20 | "n_vocab": 50400, 21 | "norm": "layernorm", 22 | "pe": "rotary", 23 | "pe_rotary_dims": 64, 24 | "early_cast": True, 25 | "seq": 2048, 26 | "cores_per_replica": 1, # only running on one GPU 27 | "per_replica_batch": 1, 28 | } 29 | 30 | per_replica_batch = params["per_replica_batch"] 31 | cores_per_replica = params["cores_per_replica"] 32 | seq = params["seq"] 33 | 34 | 35 | params["sampler"] = nucleaus_sample 36 | 37 | # here we "remove" the optimizer parameters from the model (as we don't need them for inference) 38 | params["optimizer"] = optax.scale(0) 39 | 40 | devices = np.array([jax.devices()[0]]).reshape((1, 1)) 41 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) 42 | 43 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 44 | 45 | network = CausalTransformer(params) 46 | 47 | start = time.time() 48 | 49 | # here we load a checkpoint which was written with 8 shards into 1 shard 50 | network.state = read_ckpt(network.state, "step_383500/", 8, shards_out=cores_per_replica) 51 | 52 | # move the state to CPU/system memory so it's not duplicated by xmap 53 | network.state = jax.device_put(network.state, jax.devices("cpu")[0]) 54 | 55 | def infer(context, top_k=40, top_p=0.9, temp=1.0, gen_len=512): 56 | tokens = tokenizer.encode(context) 57 | 58 | provided_ctx = len(tokens) 59 | pad_amount = seq - provided_ctx 60 | 61 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32) 62 | batched_tokens = np.array([padded_tokens] * per_replica_batch) 63 | length = np.ones(per_replica_batch, dtype=np.uint32) * len(tokens) 64 | 65 | start = time.time() 66 | output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp}) 67 | 68 | samples = [] 69 | decoded_tokens = output[1][0] 70 | 71 | for o in decoded_tokens[:, :, 0]: 72 | samples.append(tokenizer.decode(o)) 73 | 74 | print(f"completion done in {time.time() - start:06}s") 75 | return samples 76 | 77 | 78 | infer("EleutherAI is") 79 | -------------------------------------------------------------------------------- /scripts/create_serve_tpu.sh: -------------------------------------------------------------------------------- 1 | gcloud alpha compute tpus tpu-vm create "$1" --zone europe-west4-a --accelerator-type v3-8 --version v2-alpha 2 | sleep 120 3 | gcloud alpha compute tpus tpu-vm ssh "$1" --zone europe-west4-a --command connor@sparse:~/kindiana/mesh-transformer-jax --command="$(< scripts/deploy_server.sh)" -------------------------------------------------------------------------------- /scripts/deploy_server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | rm -r mesh-transformer-jax || true 5 | 6 | git clone https://github.com/kingoflolz/mesh-transformer-jax 7 | pip install -r mesh-transformer-jax/requirements.txt 8 | pip install mesh-transformer-jax/ jax==0.2.12 9 | 10 | pushd mesh-transformer-jax || exit 11 | screen -d -m python3 device_serve.py --config configs/6B_roto_256.json -------------------------------------------------------------------------------- /scripts/init_ray.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | sudo /usr/bin/docker-credential-gcr configure-docker 5 | 6 | sudo docker rm libtpu || true 7 | sudo docker create --name libtpu gcr.io/cloud-tpu-v2-images/libtpu:libtpu_20210518_RC00 "/bin/bash" && sudo docker cp libtpu:libtpu.so /lib 8 | 9 | # this locks the python executable down to hopefully stop if from being fiddled with... 10 | screen -d -m python -c 'import time; time.sleep(999999999)' 11 | 12 | # initializes jax and installs ray on cloud TPUs 13 | sudo pip install --upgrade jaxlib==0.1.67 jax==0.2.12 ray[default]==1.5.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout -------------------------------------------------------------------------------- /scripts/init_ray_v2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # this locks the python executable down to hopefully stop if from being fiddled with... 5 | screen -d -m python -c 'import time; time.sleep(999999999)' 6 | 7 | # initializes jax and installs ray on cloud TPUs 8 | sudo pip install "jax[tpu]>=0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 9 | sudo pip install --upgrade ray[default]==1.5.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout -------------------------------------------------------------------------------- /scripts/init_serve.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # initializes jax and installs ray on cloud TPUs 5 | sudo pip install --upgrade jaxlib jax==0.2.12 ray==1.2.0 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout transformers flask -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='mesh_transformer', 5 | version='0.0.0', 6 | packages=find_packages(include=['mesh_transformer', 'mesh_transformer.*']) 7 | ) 8 | -------------------------------------------------------------------------------- /slim_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | import jax 6 | import numpy as np 7 | import optax 8 | 9 | from mesh_transformer import util 10 | from mesh_transformer.checkpoint import read_ckpt, write_ckpt 11 | from mesh_transformer.transformer_shard import CausalTransformer 12 | from smart_open import open 13 | 14 | from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16 15 | 16 | 17 | def parse_args(): 18 | # Parse command line arguments 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--config", type=str, default=None, help="Config file location") 21 | parser.add_argument("--ckpt-step", type=int, default=-1, help="Step number of the checkpoint to convert (if not specified, converts the most recent checkpoint)") 22 | parser.add_argument("--f16", default=False, action="store_true", help="Convert to float16 (instead of bfloat16)") 23 | 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | if __name__ == "__main__": 29 | args = parse_args() 30 | params = json.load(open(args.config)) 31 | convert_fn = to_f16 if args.f16 else to_bf16 32 | 33 | cores_per_replica = params["cores_per_replica"] 34 | 35 | assert cores_per_replica <= 8 36 | 37 | bucket = params["bucket"] 38 | model_dir = params["model_dir"] 39 | 40 | params["optimizer"] = optax.chain( 41 | optax.scale(1), 42 | clip_by_global_norm(1), 43 | optax.scale_by_adam(), 44 | optax.additive_weight_decay(0), 45 | optax.scale(-1), 46 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)) 47 | ) 48 | 49 | start = time.time() 50 | print(f"jax devices: {jax.device_count()}") 51 | print(f"jax runtime initialized in {time.time() - start:.06}s") 52 | 53 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) 54 | devices = np.array(jax.devices()).reshape(mesh_shape) 55 | 56 | with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f: 57 | meta = json.load(f) 58 | 59 | if args.ckpt_step > -1: 60 | ckpt_step = args.ckpt_step 61 | else: 62 | ckpt_step = meta["checkpoints"][-1] 63 | print(f"using checkpoint {ckpt_step}") 64 | 65 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 66 | network = CausalTransformer(params) 67 | 68 | start = time.time() 69 | network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1]) 70 | print(f"network loaded in {time.time() - start:.06}s") 71 | 72 | start = time.time() 73 | del network.state["opt_state"] 74 | 75 | network.state["params"] = convert_fn(network.state["params"]) 76 | print(f"network converted in {time.time() - start:.06}s") 77 | 78 | suffix = "_slim_f16" if args.f16 else "_slim" 79 | 80 | for i in range(cores_per_replica): 81 | write_ckpt(network.state, f"gs://{bucket}/{model_dir}{suffix}/step_{ckpt_step}/", i) 82 | print(f"written shard {i}") 83 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from tasks.eval_harness import EvalHarnessAdaptor -------------------------------------------------------------------------------- /tasks/eval_harness.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import transformers 4 | from lm_eval.base import LM 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | from tasks.util import sample_batch, shrink_seq 9 | import multiprocessing 10 | import ftfy 11 | 12 | tokenizer = None 13 | 14 | 15 | def process_init(): 16 | global tokenizer 17 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 18 | tokenizer.model_max_length = int(1e30) 19 | tokenizer.pad_token = "<|endoftext|>" 20 | 21 | assert tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] 22 | 23 | 24 | def process_request(x, seq): 25 | global tokenizer 26 | 27 | ctx, cont = x 28 | 29 | ctx_tokens = tokenizer.encode("<|endoftext|>" + ftfy.fix_text(ctx, normalization="NFKC")) 30 | cont_tokens = tokenizer.encode(ftfy.fix_text(cont, normalization="NFKC")) 31 | 32 | all_tokens = ctx_tokens + cont_tokens 33 | all_tokens = np.array(all_tokens)[-seq:] # truncate sequence at seq length 34 | 35 | provided_ctx = len(all_tokens) - 1 36 | pad_amount = seq - provided_ctx 37 | 38 | return { 39 | "obs": np.pad(all_tokens[:-1], ((0, pad_amount),), constant_values=50256), 40 | "target": np.pad(all_tokens[1:], ((0, pad_amount),), constant_values=50256), 41 | "ctx_length": seq, 42 | "eval_mask": np.logical_and( 43 | np.arange(0, seq) > len(all_tokens) - len(cont_tokens) - 2, 44 | np.arange(0, seq) < len(all_tokens) - 1 45 | ), 46 | } 47 | 48 | 49 | class EvalHarnessAdaptor(LM): 50 | def greedy_until(self, requests): 51 | raise Exception("unimplemented") 52 | 53 | def loglikelihood_rolling(self, requests): 54 | raise Exception("unimplemented") 55 | 56 | def __init__(self, tpu_cluster, seq, batch, shrink, min_seq=None): 57 | super().__init__() 58 | self.tpu = tpu_cluster 59 | self.seq = seq 60 | self.batch = batch 61 | self.shrink = shrink 62 | self.min_seq = min_seq 63 | 64 | self.pool = multiprocessing.Pool(initializer=process_init) 65 | process_init() 66 | 67 | def convert_requests(self, requests): 68 | return self.pool.imap(partial(process_request, seq=self.seq), requests) 69 | 70 | def loglikelihood(self, requests): 71 | output = [] 72 | 73 | r = self.convert_requests(requests) 74 | zero_example = process_request(requests[0], self.seq) 75 | 76 | for b in tqdm(sample_batch(r, self.batch, zero_example), 77 | desc="LM eval harness", 78 | total=len(requests) // self.batch): 79 | 80 | if self.shrink: 81 | b = shrink_seq(b, min_seq=self.min_seq) 82 | 83 | out = self.tpu.eval(b) 84 | 85 | for loss, correct in zip(out["mask_loss"], out["each_correct"]): 86 | output.append((float(-loss), bool(correct))) 87 | 88 | return output 89 | -------------------------------------------------------------------------------- /tasks/util.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | 3 | import numpy as np 4 | 5 | 6 | def grouper(n, iterable, fillvalue): 7 | "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx" 8 | args = [iter(iterable)] * n 9 | return zip_longest(fillvalue=fillvalue, *args) 10 | 11 | 12 | # divide the seq length by 2 until it would truncate actual context 13 | def shrink_seq(examples, min_seq=None): 14 | length = examples["obs"].shape[-1] 15 | 16 | new_length = length // 2 17 | 18 | if min_seq is not None: 19 | if new_length < min_seq: 20 | return examples 21 | 22 | max_length = np.max(examples["eval_mask"] * np.arange(0, length)) + 1 23 | 24 | if max_length < new_length: 25 | examples["obs"] = examples["obs"][:, :new_length] 26 | examples["target"] = examples["target"][:, :new_length] 27 | examples["eval_mask"] = examples["eval_mask"][:, :new_length] 28 | 29 | return shrink_seq(examples, min_seq=min_seq) 30 | else: 31 | return examples 32 | 33 | 34 | def sample_batch(examples, bs, zero_example_shape): 35 | zero_example = { 36 | "obs": np.zeros_like(zero_example_shape["obs"]), 37 | "target": np.zeros_like(zero_example_shape["target"]), 38 | "eval_mask": np.zeros_like(zero_example_shape["eval_mask"]), 39 | "ctx_length": 0, 40 | } 41 | 42 | for batch in grouper(bs, examples, zero_example): 43 | batch_flattened = { 44 | "obs": [], 45 | "target": [], 46 | "eval_mask": [], 47 | "ctx_length": [], 48 | } 49 | 50 | for sample in batch: 51 | batch_flattened["obs"].append(sample["obs"]) 52 | batch_flattened["target"].append(sample["target"]) 53 | batch_flattened["eval_mask"].append(sample["eval_mask"]) 54 | batch_flattened["ctx_length"].append(sample["ctx_length"]) 55 | 56 | batch_flattened["obs"] = np.array(batch_flattened["obs"]) 57 | batch_flattened["target"] = np.array(batch_flattened["target"]) 58 | batch_flattened["eval_mask"] = np.array(batch_flattened["eval_mask"]) 59 | batch_flattened["ctx_length"] = np.array(batch_flattened["ctx_length"]) 60 | 61 | yield batch_flattened 62 | -------------------------------------------------------------------------------- /tfrecord_loader.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import tensorflow as tf 3 | import numpy as np 4 | from transformers import GPT2TokenizerFast 5 | import itertools 6 | 7 | 8 | class TFRecordLoader: 9 | def __init__(self, index_fname, batch_size, parse_fn, map_fn=None, restore_state=None): 10 | if restore_state is not None: 11 | self.file_idx = restore_state["file_idx"] 12 | self.file_idx_init = False 13 | self.used = restore_state["used"] 14 | else: 15 | self.file_idx = 0 16 | self.file_idx_init = True 17 | self.used = [] 18 | 19 | self.index = open(index_fname).read().splitlines() 20 | self.clean_index = list(filter(lambda x: x not in self.used, self.index)) 21 | self.bs = batch_size 22 | # self.seq = sample_size 23 | self.parse_fn = parse_fn 24 | 25 | if map_fn: 26 | self.map_fn = map_fn 27 | else: 28 | self.map_fn = lambda x: x 29 | 30 | self.sample_fn = self.sample_once() 31 | 32 | def reset(self): 33 | self.file_idx = 0 34 | self.file_idx_init = True 35 | self.used = [] 36 | 37 | self.clean_index = list(filter(lambda x: x not in self.used, self.index)) 38 | self.sample_fn = self.sample_once() 39 | 40 | def sample_once(self): 41 | for i in self.clean_index: 42 | compression = "ZLIB" if "zstd" in i else "" 43 | 44 | file = tf.data.TFRecordDataset(i, compression_type=compression).map(self.parse_fn, num_parallel_calls=tf.data.AUTOTUNE) 45 | file = file.apply(tf.data.experimental.dense_to_ragged_batch(np.prod(self.bs), drop_remainder=True)) 46 | file = file.prefetch(10) 47 | 48 | for file_idx, data in enumerate(file): 49 | data = jax.tree_map(lambda x: x.numpy(), data) 50 | data = self.map_fn(data) 51 | 52 | if not self.file_idx_init and file_idx <= self.file_idx: 53 | if file_idx % 1000 == 0: 54 | print(f"skipping to batch {self.file_idx}, currently at {file_idx}") 55 | continue 56 | self.file_idx_init = True 57 | self.file_idx = file_idx 58 | yield jax.tree_map(lambda x: x.reshape(self.bs + x.shape[1:]), data) 59 | self.used.append(i) 60 | self.file_idx = 0 61 | 62 | # this loops infinitely, use .sample_once to get an iterator for validation 63 | def get_samples(self): 64 | try: 65 | return next(self.sample_fn) 66 | except StopIteration: 67 | self.reset() 68 | return self.get_samples() 69 | 70 | def get_state(self): 71 | return { 72 | "used": self.used, 73 | "file_idx": self.file_idx 74 | } 75 | 76 | 77 | class TFRecordNewInputs(TFRecordLoader): 78 | def __init__(self, index_fname, batch_size, sample_size, restore_state=None): 79 | def tf_parse(example_proto): 80 | features = { 81 | "text": tf.io.VarLenFeature(tf.int64) 82 | } 83 | parsed_features = tf.io.parse_single_example(example_proto, features) 84 | 85 | return tf.cast(tf.sparse.to_dense(tf.sparse.reorder(parsed_features["text"])), tf.uint32) 86 | 87 | super().__init__(index_fname, batch_size, tf_parse, restore_state=restore_state) 88 | 89 | 90 | class TFRecordWIT(TFRecordLoader): 91 | def __init__(self, index_fname, batch_size, restore_state=None, text_tokens=256): 92 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 93 | self.tokenizer.pad_token = "<|endoftext|>" 94 | self.tokenizer.add_special_tokens({'sep_token': '<|sep|>', 'pad_token': '<|pad|>'}) 95 | 96 | def map_fn(example): 97 | tokenizer = self.tokenizer 98 | 99 | def decode(x): 100 | return tokenizer(["<|endoftext|>" + i.decode() for i in x])["input_ids"] 101 | 102 | texts = [ 103 | decode(example["context_page_description"]), 104 | decode(example["context_section_description"]), 105 | decode(example["caption_reference_description"]), 106 | decode(example["caption_alt_text_description"]), 107 | decode(example["caption_attribution_description"]), 108 | ] 109 | 110 | output = [] 111 | 112 | for text, dalle in zip(zip(*texts), example["dalle"]): 113 | all_text = list(itertools.chain(*text))[-text_tokens+1:] 114 | 115 | all_text += [tokenizer.pad_token_id] * ((text_tokens - 1) - len(all_text)) 116 | 117 | assert len(all_text) == text_tokens - 1 118 | 119 | all_tokens = all_text + [tokenizer.sep_token_id] + list(dalle + tokenizer.vocab_size + 1) 120 | output.append(all_tokens) 121 | 122 | return np.array(output) 123 | 124 | def tf_parse(example_proto): 125 | features = { 126 | "page_title": tf.io.FixedLenFeature([], tf.string), 127 | "section_title": tf.io.FixedLenFeature([], tf.string), 128 | "hierarchical_section_title": tf.io.FixedLenFeature([], tf.string), 129 | "caption_reference_description": tf.io.FixedLenFeature([], tf.string), 130 | "caption_attribution_description": tf.io.FixedLenFeature([], tf.string), 131 | "caption_alt_text_description": tf.io.FixedLenFeature([], tf.string), 132 | "mime_type": tf.io.FixedLenFeature([], tf.string), 133 | "context_page_description": tf.io.FixedLenFeature([], tf.string), 134 | "context_section_description": tf.io.FixedLenFeature([], tf.string), 135 | 136 | "dalle": tf.io.FixedLenFeature([1024], tf.int64), 137 | } 138 | 139 | parsed_features = tf.io.parse_single_example(example_proto, features) 140 | 141 | return parsed_features 142 | 143 | super().__init__(index_fname, batch_size, tf_parse, map_fn, restore_state=restore_state) 144 | 145 | 146 | if __name__ == "__main__": 147 | # d = TFRecordNewInputs("data/pile.val.index", (8, 32), 2048) 148 | # for idx, i in enumerate(d.sample_once()): 149 | # print(i) 150 | # break 151 | 152 | d = TFRecordWIT("data/wit_dalle.train.index", (8, 32)) 153 | for idx, i in enumerate(d.sample_once()): 154 | print(i) 155 | break 156 | 157 | print() 158 | -------------------------------------------------------------------------------- /to_hf_weights.py: -------------------------------------------------------------------------------- 1 | #### 2 | # python to_hf_weights.py --input-ckpt ./step_383500 --config ./configs/6B_roto_256.json --output-path ./gpt-j-6B --cpu --dtype fp32 3 | #### 4 | 5 | import argparse 6 | import io 7 | import multiprocessing 8 | import time 9 | import warnings 10 | import os 11 | import re 12 | from typing import Iterable, List, Union 13 | import json 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax.experimental import maps 18 | from pathy import FluidPath, Pathy 19 | import numpy as np 20 | import optax 21 | import torch 22 | 23 | from tqdm import tqdm 24 | 25 | from mesh_transformer.transformer_shard import CausalTransformer 26 | 27 | # xla: tell jax to not pre allocate all device memory 28 | # and only allocate memory as needed. 29 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 30 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 31 | 32 | DEBUG = False 33 | 34 | parser = argparse.ArgumentParser( 35 | description=( 36 | "Used to turn a sharded trained gpt-j checkpoint into pytorch hugging face format." 37 | "This script works best on a slimmed checkpoint (full checkpoints can be used but require ~100gb of ram)." 38 | "Currently, weights must be split into 8 shards for this to work." 39 | "All paths can be local or google cloud storage paths. S3 paths supported as well with `pip install pathy[s3]`." 40 | "Can be run on tpu, or on gpu with `pip install --upgrade jax==0.2.12 jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html`" 41 | ) 42 | ) 43 | parser.add_argument( 44 | "--input-ckpt", 45 | type=str, 46 | required=True, 47 | help='path to model checkpoint folder. Google storage can be used with "gs://bucket/path/step_{n}" format.', 48 | metavar="path", 49 | ) 50 | parser.add_argument( 51 | "--config", type=str, required=True, help="Config file location", metavar="path" 52 | ) 53 | parser.add_argument( 54 | "--output-path", 55 | required=True, 56 | type=str, 57 | help='Full path to save checkpoint to. Google storage can be used with "gs://bucket/path" format.', 58 | ) 59 | parser.add_argument( 60 | "--debug", 61 | action="store_true", 62 | help="Verbose printing.", 63 | ) 64 | parser.add_argument( 65 | "--cpu", 66 | action="store_true", 67 | help="Run resharding on cpu instead of searching for jax device (i.e. gpu/tpu). Will default to cpu if jax wasn't installed with `+cuda110` option", 68 | ) 69 | parser.add_argument( 70 | "--dtype", 71 | type=str, 72 | default="fp16", 73 | help="One of fp32, fp16 or bf16. Default=fp16. WARNING: Experimental. Make sure to check weights after conversion to make sure dtype information is retained.", 74 | ) 75 | 76 | 77 | def process_args( 78 | input_ckpt: Union[FluidPath, str], 79 | config: Union[FluidPath, str], 80 | output_path: Union[FluidPath, str], 81 | dtype: str = "fp16", 82 | cpu: bool = False, 83 | **kwargs, 84 | ): 85 | # validate paths and turn them into Pathy paths. 86 | input_ckpt = Pathy.fluid(str(input_ckpt)) 87 | assert input_ckpt.is_dir(), f'no such directory "{input_ckpt}"' 88 | config = Pathy.fluid(str(config)) 89 | assert config.is_file(), f'no such file "{config}"' 90 | first_shard = input_ckpt / "shard_0" 91 | assert first_shard.is_dir(), f'no shards found at "{input_ckpt}"' 92 | 93 | output_path = Pathy.fluid(str(output_path)) 94 | output_path.mkdir(exist_ok=True) 95 | 96 | # make sure dtype is valid 97 | assert dtype in {"fp16", "fp32", "bf16"} 98 | np_dtype = np.float16 99 | torch_dtype = torch.float16 100 | if dtype != "fp16": 101 | warnings.warn( 102 | "WARNING: Dtype support other than fp16 is Experimental. Make sure to check weights after conversion to make sure dtype information is retained." 103 | ) 104 | if dtype == "bf16": 105 | # np doesn't have bfloat16 so float32 is used to retain information before converting to torch. 106 | np_dtype = np.float32 107 | torch_dtype = torch.bfloat16 108 | elif dtype == "fp32": 109 | np_dtype = np.float32 110 | torch_dtype = torch.float32 111 | 112 | # tell jax to run on cpu instead of gpu/tpu 113 | if cpu: 114 | jax.config.update("jax_platform_name", "cpu") 115 | 116 | return input_ckpt, config, output_path, np_dtype, torch_dtype 117 | 118 | 119 | def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id): 120 | id_to_name = {} 121 | if getattr(pytree, "items", None): 122 | for k, v in pytree.items(): 123 | k_path = f"{path}/{k}" 124 | if is_leaf(v): 125 | id_to_name[to_id(v)] = k_path 126 | else: 127 | id_to_name = { 128 | **id_to_name, 129 | **tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path), 130 | } 131 | elif getattr(pytree, "__getitem__", None): 132 | for v in pytree: 133 | if is_leaf(v): 134 | id_to_name[to_id(v)] = path 135 | else: 136 | id_to_name = { 137 | **id_to_name, 138 | **tree_flatten_with_names(v, is_leaf=is_leaf, path=path), 139 | } 140 | else: 141 | id_to_name[to_id(pytree)] = path 142 | return id_to_name 143 | 144 | 145 | def tree_leaves_with_names(pytree, to_id=id): 146 | leaves = jax.tree_leaves(pytree) 147 | is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [ 148 | to_id(x) for x in leaves 149 | ] 150 | return tree_flatten_with_names(pytree, is_leaf) 151 | 152 | 153 | def get_tree_leaves_names_reduced(pytree) -> List[str]: 154 | 155 | leaves_ids = tree_leaves_with_names(pytree, to_id=id) 156 | leaves = jax.tree_leaves(pytree) 157 | return [leaves_ids[id(l)] for l in leaves] 158 | 159 | 160 | layer_2_hf_inner_module_id = { 161 | "linear": "attn.q_proj", 162 | "linear_1": "attn.v_proj", 163 | "linear_2": "attn.k_proj", 164 | "linear_3": "attn.out_proj", 165 | "linear_4": "mlp.fc_in", 166 | "linear_5": "mlp.fc_out", 167 | "replicated_layer_norm": "ln_1", 168 | } 169 | 170 | projection_layer_2_hf_id_start = { 171 | "linear": "lm_head", 172 | "replicated_layer_norm": "transformer.ln_f", 173 | } 174 | 175 | 176 | def leave_name_to_hf_layer_id(leaf_name: str): 177 | if not leaf_name.startswith("/params"): 178 | if leaf_name == "/step": 179 | return None 180 | else: 181 | raise NotImplementedError(f"Unknown leaf name: {leaf_name}") 182 | 183 | match = re.search( 184 | r"\/params\/causal_transformer_shard\/~\/(?P.*)\/~\/(?P.*)\/(?P.*)", 185 | leaf_name, 186 | ) 187 | 188 | assert match, f'couldn\'t match pattern against: "{leaf_name}"' 189 | 190 | layer_name = match["layer_name"] 191 | module_name = match["module_name"] 192 | wb = match["wb"] 193 | 194 | if wb in {"w", "scale"}: 195 | weight_or_bias = "weight" 196 | elif wb in {"b", "offset"}: 197 | weight_or_bias = "bias" 198 | else: 199 | raise NotImplementedError( 200 | f"unknown weight/bais type identifier \"{wb}\" at end of: '{leaf_name}'" 201 | ) 202 | 203 | # switch statement based on top level module name 204 | if module_name == "embedding_shard": 205 | hf_id = f"transformer.wte.{weight_or_bias}" 206 | 207 | elif module_name.startswith("layer"): 208 | module_index = int(module_name.split("_")[-1]) 209 | hf_inner_module_id = layer_2_hf_inner_module_id[layer_name] 210 | hf_id = f"transformer.h.{module_index}.{hf_inner_module_id}.{weight_or_bias}" 211 | 212 | elif module_name == "projection_shard": 213 | hf_id = f"{projection_layer_2_hf_id_start[layer_name]}.{weight_or_bias}" 214 | 215 | else: 216 | raise NotImplementedError( 217 | f"unknown leaf module type \"{module_name}\" in: '{leaf_name}'" 218 | ) 219 | 220 | if DEBUG: 221 | print(f"{leaf_name} \n\t -> {hf_id}") 222 | 223 | return hf_id 224 | 225 | 226 | # TODO(nijkamp): rewrite this mess 227 | def reshard(x, old_shape, do_shard_ln, do_shard_bias): 228 | # reshards using numpy arrays so as to not fill up jax memory 229 | if len(x.shape) == 1: 230 | out = np.array(x[0:1]) 231 | 232 | elif len(x.shape) == 2: 233 | if do_shard_ln: 234 | out = np.array(x[0:1]) 235 | elif do_shard_bias: 236 | out = np.reshape(np.sum(x, axis=0), old_shape) 237 | else: 238 | out = x.reshape(old_shape) 239 | 240 | elif len(x.shape) == 3: 241 | if x.shape[0] * x.shape[2] == old_shape[2]: 242 | out = np.transpose(x, (1, 0, 2)).reshape(old_shape) 243 | elif x.shape[0] * x.shape[1] == old_shape[1]: 244 | out = np.reshape(x, old_shape) 245 | else: 246 | raise NotImplementedError(f"unimplemented, {x.shape}, {old_shape}") 247 | else: 248 | raise NotImplementedError(f"unimplemented, {x}") 249 | return out 250 | 251 | 252 | def read_npz(fpath: FluidPath): 253 | # read npz file of ndarrays 254 | with fpath.open("rb") as f: 255 | buf = f.read() 256 | f_io = io.BytesIO(buf) 257 | deserialized = np.load( 258 | f_io, 259 | ) 260 | assert isinstance( 261 | deserialized, np.lib.npyio.NpzFile 262 | ), f"Not an npz file {type(deserialized)=} {f=}" 263 | # arrays are only loaded when accessed. So we need to access them before returning 264 | arrays = [] 265 | for i in deserialized: 266 | arr = deserialized[i] 267 | assert isinstance(arr, np.ndarray), f"Not a np.ndarray {type(arr)=} {f=}" 268 | arrays.append(arr) 269 | return arrays 270 | 271 | 272 | def read_file_shards( 273 | ckpt_dir: FluidPath, fname: str, shards_in: int 274 | ) -> List[List[np.ndarray]]: 275 | # read same file like "12.npz" across all shard directories 276 | with multiprocessing.pool.ThreadPool(shards_in) as p: 277 | return list( 278 | p.imap( 279 | read_npz, 280 | [ckpt_dir / f"shard_{i}" / fname for i in range(shards_in)], 281 | ) 282 | ) 283 | 284 | 285 | def lazy_read_ckpt_shards( 286 | ckpt_dir: FluidPath, shards_in: int, pieces: int = 16, reverse: bool = True 287 | ): 288 | for i in range(pieces): 289 | # iterate through files in direction of choice 290 | fname = f"{(pieces-1) - i}.npz" if reverse else f"{i}.npz" 291 | if DEBUG: 292 | print(f"reading from {fname}") 293 | file_shards = read_file_shards(ckpt_dir, fname, shards_in) 294 | 295 | # iterate over layers in file returning all shards for each 296 | file_shards = list(zip(*file_shards)) 297 | if reverse: 298 | file_shards = reversed(file_shards) 299 | yield from file_shards 300 | 301 | 302 | def unshard_leave( 303 | leave_shards: Iterable[np.ndarray], 304 | leave_name: str, 305 | old_shape: List[int], 306 | np_dtype=np.float16, 307 | ): 308 | # reshard all leave shards into single shard. 309 | 310 | # stack leave shards into single np.ndarray 311 | x = np.stack(leave_shards) 312 | # assert isinstance(x, jnp.ndarray) 313 | 314 | # As far as i can tell, this just re labels the dtype of arrays 315 | # labeled with "V2" dtype. In theory, V2 was just an alias for bfloat16 316 | # which needs to be relabeled in order for it to be understood. 317 | if x.dtype == np.dtype("V2"): 318 | x.dtype = jnp.bfloat16 319 | 320 | if DEBUG: 321 | print(f"RESHARDING: {leave_name=} {x.shape=} {old_shape=}") # type: ignore 322 | 323 | # transform sharded array to match old_shape 324 | x = reshard( 325 | x, 326 | old_shape, 327 | do_shard_bias=leave_name.endswith("embedding_shard/~/linear/b") 328 | or leave_name.endswith("linear_5/b"), 329 | do_shard_ln=leave_name.endswith("replicated_layer_norm/offset") 330 | or leave_name.endswith("replicated_layer_norm/scale"), 331 | ) 332 | assert ( 333 | x.shape == old_shape 334 | ), f"Incompatible checkpoints {x.shape} vs {old_shape} {leave_name}" 335 | return x.astype(np_dtype) 336 | 337 | 338 | def save_pytree_as_hf( 339 | pytree, 340 | input_ckpt: FluidPath, 341 | shards_in: int, 342 | output_path: FluidPath, 343 | n_layers: int = 28, 344 | np_dtype: type = np.float16, 345 | torch_dtype: torch.dtype = torch.float16, 346 | n_seq: int = 2048, 347 | ): 348 | # Loads layers and names in reverse order to avoid loading unneeded opt_state layers 349 | # that are at the front of full (i.e. not slim) models. 350 | 351 | old_leave_shapes = [old.shape for old in jax.tree_flatten(pytree)[0]] 352 | leave_names = get_tree_leaves_names_reduced(pytree) 353 | del pytree 354 | 355 | assert len(old_leave_shapes) == len( 356 | leave_names 357 | ), f"{len(old_leave_shapes)=} {len(leave_names)=}" 358 | # get generator that emits all shards of leaves from npz files in reverse order 359 | loaded_shards_in = lazy_read_ckpt_shards(input_ckpt, shards_in, reverse=True) 360 | 361 | print("Reading and transforming layers/shards. This may take a while.") 362 | 363 | hf_checkpoint = {} 364 | wte_first = None # saves first instance of a wte weight in order to combine it with the second. 365 | # Reverse iteration to grab leave_names and old leaves from the back 366 | for i in tqdm( 367 | reversed(range(len(leave_names))), 368 | desc="Reading/Transforming Layers", 369 | total=len(leave_names), 370 | ): 371 | 372 | # load next shard with correstponding leave name and old shape 373 | x = next(loaded_shards_in) 374 | leave_name = leave_names[i] 375 | old_shape = old_leave_shapes[i] 376 | hf_layer_id = leave_name_to_hf_layer_id(leave_name) 377 | 378 | # If leave is not needed in hf model (/step') 379 | if not hf_layer_id: 380 | continue 381 | 382 | x = unshard_leave(x, leave_name, old_shape, np_dtype=np_dtype) 383 | # remove first empty dimension and transpose. 384 | x = torch.tensor(x.squeeze(0), dtype=torch_dtype).T 385 | 386 | # wte embedding weights/bias need to be combined since hf model has no wte.embedding.bias 387 | if hf_layer_id.startswith("transformer.wte"): 388 | # un/re-transpose since wte weight is only leave that shouldn't be transposed 389 | x = x.T 390 | # store first weight/bias then skip saving 391 | if wte_first is None: 392 | wte_first = x 393 | continue 394 | # combine second wte bias/weight with first then move on to saving with weight name 395 | else: 396 | x = x + wte_first 397 | hf_layer_id = "transformer.wte.weight" 398 | 399 | # save params as single file with proper hf id mapped to them in save_map 400 | hf_checkpoint[hf_layer_id] = x 401 | 402 | # add attention bias layers 403 | attn_bias_weights = torch.tril(torch.ones((n_seq, n_seq), dtype=torch.bool)).view( 404 | 1, 1, n_seq, n_seq 405 | ) 406 | attn_masked_bias_weights = torch.tensor(-1e9, dtype=torch_dtype) 407 | 408 | for i in range(n_layers): 409 | hf_checkpoint[f"transformer.h.{i}.attn.bias"] = attn_bias_weights 410 | hf_checkpoint[f"transformer.h.{i}.attn.masked_bias"] = attn_masked_bias_weights 411 | 412 | torch.save(hf_checkpoint, (output_path / "pytorch_model.bin").open(mode="wb")) 413 | 414 | 415 | def save_config_to_hf_format(params: dict, torch_dtype: str, output_path: FluidPath): 416 | 417 | config = { 418 | "activation_function": "gelu_new", 419 | "architectures": ["GPTJForCausalLM"], 420 | "attn_pdrop": 0.0, 421 | "bos_token_id": 50256, 422 | "embd_pdrop": 0.0, 423 | "eos_token_id": 50256, 424 | "gradient_checkpointing": False, 425 | "initializer_range": 0.02, 426 | "layer_norm_epsilon": 1e-05, 427 | "model_type": "gptj", 428 | "n_embd": params["d_model"], 429 | "n_head": params["n_heads"], 430 | "n_layer": params["layers"], 431 | "n_positions": params["seq"], 432 | "rotary_dim": params["pe_rotary_dims"], 433 | "summary_activation": None, 434 | "summary_first_dropout": 0.1, 435 | "summary_proj_to_labels": True, 436 | "summary_type": "cls_index", 437 | "summary_use_proj": True, 438 | "transformers_version": "4.10.0.dev0", 439 | "tokenizer_class": "GPT2Tokenizer", 440 | "task_specific_params": { 441 | "text-generation": {"do_sample": True, "temperature": 1.0, "max_length": 50} 442 | }, 443 | "torch_dtype": str(torch_dtype).split(".")[-1], 444 | "use_cache": True, 445 | "vocab_size": params["n_vocab"], 446 | } 447 | 448 | with (output_path / "config.json").open("w") as f: 449 | json.dump(config, f, indent=2) 450 | 451 | 452 | def save_sharded_to_hf_format( 453 | input_ckpt: Union[FluidPath, str], 454 | params: dict, 455 | output_path: Union[FluidPath, str], 456 | cpu: bool = False, 457 | dtype: str = "fp16", 458 | ): 459 | 460 | devices = np.array([jax.devices()[0]]).reshape((1, 1)) 461 | with maps.mesh(devices, ("dp", "mp")): 462 | params_local = params.copy() 463 | params_local["cores_per_replica"] = maps.thread_resources.env.shape["mp"] 464 | network = CausalTransformer(params_local) 465 | 466 | save_pytree_as_hf( 467 | network.state, 468 | input_ckpt=input_ckpt, 469 | shards_in=params["cores_per_replica"], 470 | output_path=output_path, 471 | n_layers=params["layers"], 472 | np_dtype=np_dtype, 473 | torch_dtype=torch_dtype, 474 | n_seq=params["seq"], 475 | ) 476 | 477 | 478 | if __name__ == "__main__": 479 | args = vars(parser.parse_args()) 480 | 481 | DEBUG = args["debug"] 482 | start = time.time() 483 | 484 | input_ckpt, config, output_path, np_dtype, torch_dtype = process_args(**args) 485 | params = json.load(open(config)) 486 | params["optimizer"] = optax.scale(0) 487 | 488 | save_sharded_to_hf_format(input_ckpt, params, output_path, np_dtype, torch_dtype) 489 | save_config_to_hf_format(params, torch_dtype, output_path) 490 | print( 491 | f"HF weights created in {(time.time() - start):.0f}s \"{args['output_path']}\"" 492 | ) 493 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | import numpy as np 6 | import wandb 7 | from tqdm import tqdm 8 | 9 | from mesh_transformer.build_model import build_model 10 | from lm_eval import evaluator, tasks 11 | from tasks.eval_harness import EvalHarnessAdaptor 12 | from tfrecord_loader import TFRecordNewInputs 13 | import multiprocessing 14 | 15 | 16 | def parse_args(): 17 | # Parse command line arguments 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--tpu", type=str, help="Name of TPU to train on.") 20 | parser.add_argument("--tpu_region", type=str, help="Region of TPU to train on.") 21 | parser.add_argument("--preemptible", action="store_true") 22 | 23 | parser.add_argument("--config", type=str, default=None, help="Config file location") 24 | 25 | parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " 26 | "starts a new training run") 27 | 28 | parser.add_argument("--version", type=int, default=1, help="Choose which model version to use") 29 | 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | if __name__ == "__main__": 35 | # huggingface tokenizers gets very angry if you fork 36 | multiprocessing.set_start_method("spawn") 37 | 38 | args = parse_args() 39 | params = json.load(open(args.config)) 40 | 41 | if args.new: 42 | print(f"Starting experiment {params['name']} from scratch! " 43 | f"all data in gs://{params['bucket']}/{params['model_dir']}/ will be deleted") 44 | input("Hit enter to continue") 45 | 46 | tpu_name = args.tpu 47 | region = args.tpu_region 48 | preemptible = args.preemptible 49 | clean_start = args.new 50 | 51 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 52 | per_replica_batch = params["per_replica_batch"] 53 | tpu_size = params["tpu_size"] 54 | cores_per_replica = params["cores_per_replica"] 55 | 56 | bucket = params["bucket"] 57 | model_dir = params["model_dir"] 58 | layers = params["layers"] 59 | d_model = params["d_model"] 60 | n_heads = params["n_heads"] 61 | n_vocab = params["n_vocab"] 62 | seq = params["seq"] 63 | norm = params["norm"] 64 | 65 | val_batches = params["val_batches"] 66 | val_every = params["val_every"] 67 | ckpt_every = params["ckpt_every"] 68 | keep_every = params["keep_every"] 69 | eval_tasks = params["eval_harness_tasks"] 70 | total_steps = params["total_steps"] 71 | 72 | pe = params["pe"] 73 | assert pe in ["fixed", "rotary", "t5"] 74 | 75 | t = build_model(params, tpu_name, region, preemptible, version=args.version) 76 | 77 | try: 78 | t.save(0, bucket, model_dir, init=True, overwrite=clean_start) 79 | step = 0 80 | train_load_restore = None 81 | except Exception as e: 82 | print(f"Save failed with error {e}, trying to load instead...", e) 83 | step, aux = t.load(bucket, model_dir) 84 | train_load_restore = aux.get("train_loader", None) 85 | 86 | if train_load_restore is None: 87 | print("Failed to restore train loader state") 88 | 89 | train_dataset = TFRecordNewInputs(f"data/{params['train_set']}", 90 | batch_size=( 91 | gradient_accumulation_steps, 92 | per_replica_batch * tpu_size // cores_per_replica), 93 | sample_size=params['seq'], 94 | restore_state=train_load_restore) 95 | 96 | global_val_batch = int(per_replica_batch * tpu_size // cores_per_replica * params.get("val_batch_multiplier", 1)) 97 | 98 | val_sets = {} 99 | 100 | for k, v in params['val_set'].items(): 101 | val_sets[k] = TFRecordNewInputs(f"data/{v}", 102 | batch_size=(global_val_batch,), 103 | sample_size=seq) 104 | 105 | # use dynamic seq length unless pe is fixed 106 | adaptor = EvalHarnessAdaptor(t, 107 | seq, 108 | global_val_batch, 109 | shrink=pe != "fixed", 110 | min_seq=1024 if args.version == 2 else None) # work around suboptimal pjit layout 111 | 112 | start = time.time() 113 | t.train(train_dataset.get_samples()) 114 | print(f"Train fn compiled in {time.time() - start:.06}s") 115 | 116 | start = time.time() 117 | for val_set in val_sets.values(): 118 | t.eval(val_set.get_samples()) 119 | print(f"Eval fn compiled in {time.time() - start:.06}s") 120 | 121 | project = params.get("wandb_project", "mesh-transformer-jax") 122 | wandb.init(project=project, entity="eleutherai", name=params["name"], config=params) 123 | 124 | eval_task_dict = tasks.get_task_dict(eval_tasks) 125 | 126 | pbar = tqdm(initial=step, total=total_steps, desc="Training progress") 127 | 128 | while True: 129 | loss, last_loss = t.train(train_dataset.get_samples()) 130 | wandb.log({'train/loss': loss, 'train/last_loss': last_loss}, step) 131 | 132 | if (step % ckpt_every == 0 and step) or step == total_steps: 133 | t.save(step, bucket, model_dir, 134 | aux={"train_loader": train_dataset.get_state()}, 135 | init=False, 136 | delete_old=step % keep_every != 0) 137 | 138 | if step == total_steps: 139 | print("training completed!") 140 | exit() 141 | 142 | if step % val_every == 0: 143 | for name, val_set in val_sets.items(): 144 | val_loss = [] 145 | for i, _ in tqdm(zip(val_set.sample_once(), range(val_batches)), 146 | desc=f"validation for step {step}, set {name}", 147 | total=val_batches): 148 | val_loss.append(t.eval(i)) 149 | val_loss = np.array(val_loss).mean() 150 | print(f"validation loss for step {step}, set {name}: {val_loss}") 151 | 152 | wandb.log({f'val/loss_{name}': float(val_loss)}, step) 153 | 154 | results = evaluator.evaluate(adaptor, eval_task_dict, False, 0, None) 155 | 156 | flat_results = {} 157 | 158 | for task_name, task_res in results["results"].items(): 159 | version = results["versions"][task_name] 160 | for metric_name, metric_res in task_res.items(): 161 | flat_results[f"{task_name}-v{version}/{metric_name}"] = float(metric_res) 162 | 163 | dumped = json.dumps(results, indent=2) 164 | print(f"step {step} val results: {dumped}") 165 | wandb.log(flat_results, step) 166 | step += 1 167 | 168 | pbar.set_postfix({'loss': loss, 'last_loss': last_loss}) 169 | pbar.update() 170 | --------------------------------------------------------------------------------