├── .github └── workflows │ └── lint_python.yml ├── .gitignore ├── CITATION.bib ├── LICENSE.txt ├── README.md ├── arithmetic └── gen_sums.py ├── benchmarks.md ├── colab_demo.ipynb ├── commonsenseqa ├── dev_rand_split.jsonl ├── prompts.txt ├── prompts_answer_key.txt ├── prompts_direct.txt ├── prompts_direct_answer_key.txt ├── test_rand_split_no_answers.jsonl └── train_rand_split.jsonl ├── configs ├── iterative_base │ └── base.json └── qa_base.json ├── create_finetune_tfrecords.py ├── data └── qa.val.index ├── device_inference.py ├── device_serve.py ├── device_train.py ├── docker ├── .env ├── Dockerfile ├── README.md ├── compose-proxy.yaml ├── docker-compose.yaml ├── main.py ├── nginx_proxyvm.conf ├── ops.py ├── payloads.py └── start.sh ├── eval_harness.py ├── gsm ├── dev_rand_split.jsonl ├── prompts.txt ├── prompts_answer_key.txt ├── prompts_direct_answer_key.txt └── train_rand_split.jsonl ├── howto_finetune.md ├── iteration_train.py ├── mesh_transformer ├── TPU_cluster.py ├── __init__.py ├── build_model.py ├── checkpoint.py ├── layers.py ├── sampling.py ├── train_actor.py ├── transformer_shard.py └── util.py ├── ray_tpu.py ├── requirements.txt ├── resharding_example.py ├── scripts ├── create_serve_tpu.sh ├── deploy_server.sh ├── init_ray.sh ├── init_ray_v2.sh └── init_serve.sh ├── setup.py ├── slim_model.py ├── tasks ├── __init__.py ├── eval_harness.py └── util.py ├── tfrecord_loader.py ├── to_hf_weights.py └── train.py /.github/workflows/lint_python.yml: -------------------------------------------------------------------------------- 1 | name: lint_python 2 | on: [pull_request, push] 3 | jobs: 4 | lint_python: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v2 8 | - uses: actions/setup-python@v2 9 | - run: pip install bandit black codespell flake8 isort mypy pytest pyupgrade safety 10 | - run: bandit --recursive --skip B101 . || true # B101 is assert statements 11 | - run: black --check . || true 12 | - run: codespell 13 | - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 14 | - run: flake8 . --count --exit-zero --max-complexity=19 --max-line-length=245 --show-source --statistics 15 | - run: isort --check-only --profile black . || true 16 | - run: pip install -r requirements.txt 17 | - run: mypy --ignore-missing-imports . || true 18 | - run: pytest . || true 19 | - run: pytest --doctest-modules . || true 20 | - run: shopt -s globstar && pyupgrade --py36-plus **/*.py || true 21 | - run: safety check 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pprof 3 | /ckpt/ 4 | wandb 5 | data/* 6 | !data/qa.val.index 7 | configs/* 8 | !configs/iterative_base/base.json 9 | !configs/iterative_base 10 | 11 | commonsenseqa/* 12 | !commonsenseqa/*.py 13 | !commonsenseqa/prompts*.txt 14 | !commonsenseqa/*.jsonl 15 | 16 | gsm/* 17 | !gsm/*.py 18 | !gsm/*.jsonl 19 | !gsm/prompts*.txt 20 | 21 | arithmetic/* 22 | !arithmetic/gen_sums.py 23 | iterative*.txt 24 | .vscode/settings.json 25 | !configs/qa_base.json 26 | result_logs 27 | result_logs/* 28 | !configs/qa_base.json 29 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{ 2 | zelikman2022star, 3 | title={{ST}aR: Bootstrapping Reasoning With Reasoning}, 4 | author={Eric Zelikman and Yuhuai Wu and Jesse Mu and Noah Goodman}, 5 | booktitle={Advances in Neural Information Processing Systems}, 6 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, 7 | year={2022}, 8 | url={https://openreview.net/forum?id=_3ELRdg2sgI} 9 | } 10 | 11 | @misc{mesh-transformer-jax, 12 | author = {Wang, Ben}, 13 | title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}}, 14 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, 15 | year = 2021, 16 | month = May 17 | } 18 | 19 | @misc{gpt-j, 20 | author = {Wang, Ben and Komatsuzaki, Aran}, 21 | title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}}, 22 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, 23 | year = 2021, 24 | month = May 25 | } 26 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | https://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2021 Ben Wang 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | https://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | 1. [STaR](#star) 3 | 2. [Mesh Transformer JAX](#mesh-transformer-jax) 4 | 1. [Updates](#updates) 5 | 3. [Pretrained Models](#pretrained-models) 6 | 1. [GPT-J-6B](#gpt-j-6b) 7 | 1. [Links](#links) 8 | 2. [Acknowledgments](#acknowledgments) 9 | 3. [License](#license) 10 | 4. [Model Details](#model-details) 11 | 5. [Zero-Shot Evaluations](#zero-shot-evaluations) 12 | 4. [Architecture and Usage](#architecture-and-usage) 13 | 1. [Fine-tuning](#fine-tuning) 14 | 2. [JAX Dependency](#jax-dependency) 15 | 5. [TODO](#todo) 16 | 17 | # STaR 18 | Code for [STaR: Bootstrapping Reasoning With Reasoning (NeurIPS 2022)](https://openreview.net/forum?id=_3ELRdg2sgI). This library is built on top of [mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax) and incorporates masked training from [this repo](https://github.com/albertqjiang/mesh-transformer-jax). In order to run it, launch `iteration_train.py` with any desired arguments. The README is left mostly unchanged, as `iteration_train` largely wraps around `device_train.py`, `device_inference.py`, and `create_finetune_tfrecords.py`. 19 | 20 | # Mesh Transformer JAX 21 | 22 | A haiku library using the `xmap`/`pjit` operators in JAX for model parallelism of transformers. 23 | 24 | The parallelism scheme is similar to the [original Megatron-LM](https://arxiv.org/abs/1909.08053), which is efficient 25 | on TPUs due to the high speed 2d mesh network. There is also an experimental model version which implements [ZeRo style 26 | sharding](https://arxiv.org/abs/1910.02054). 27 | 28 | This library is designed for scalability up to approximately 40B parameters on TPUv3s, beyond which different 29 | parallelism strategies should be used. See other implementations such as 30 | [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) or [DeepSpeed](https://github.com/microsoft/DeepSpeed) for that. 31 | 32 | One future direction for research is integrating this codebase with 33 | [swarm-jax](https://github.com/kingoflolz/swarm-jax), to achieve further scalability with pipeline parallelism. 34 | 35 | ## Updates 36 | 37 | **12-07-21**: Added [guide to fine tuning](howto_finetune.md) 38 | 39 | # Pretrained Models 40 | 41 | ## GPT-J-6B 42 | 43 | A 6 billion parameter, autoregressive text generation model trained on [The Pile](https://pile.eleuther.ai/). 44 | 45 | ### Links 46 | 47 | [Slim weights (bf16 weights only, for inference, 9GB)](https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd) 48 | 49 | [Full weights (including optimizer params, 61GB)](https://the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd) 50 | 51 | [Colab demo](http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb) 52 | 53 | [Web demo](https://6b.eleuther.ai/) 54 | 55 | [Aran's blog post](https://arankomatsuzaki.wordpress.com/2021/06/04/gpt-j/) 56 | 57 | ### Acknowledgments 58 | 59 | This project would not have been possible without compute generously provided by the 60 | [TPU Research Cloud](https://sites.research.google/trc/) with assistance from [EleutherAI](https://eleuther.ai/). 61 | 62 | Thanks to the Cloud TPU team at Google for providing early access to the Cloud TPU VM alpha 63 | ([now publicly available!](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms)) 64 | 65 | Thanks to everyone who have helped out one way or another (listed alphabetically): 66 | - [Aran Komatsuzaki](https://twitter.com/arankomatsuzaki) for advice with experiment design and writing the blog posts. 67 | - [James Bradbury](https://twitter.com/jekbradbury) for valuable assistance with debugging JAX issues. 68 | - [Janko Prester](https://github.com/jprester) for creating the web demo frontend. 69 | - [Laurence Golding](https://github.com/researcher2) for adding some features to the web demo. 70 | - [Leo Gao](https://twitter.com/nabla_theta) for running zero shot evaluations for the baseline models for the table. 71 | 72 | ### License 73 | The weights of GPT-J-6B are licensed under version 2.0 of the Apache License. 74 | 75 | ### Model Details 76 | 77 | | Hyperparameter | Value | 78 | |-------------------|--------| 79 | | n_parameters | 6,053,381,344 | 80 | | n_layers | 28* | 81 | | d_model | 4,096 | 82 | | d_ff | 16,384 | 83 | | n_heads | 16 | 84 | | d_head | 256 | 85 | | n_ctx | 2,048 | 86 | | n_vocab | 50,257 (same tokenizer as GPT-2/3) | 87 | | position encoding | [Rotary position encodings (RoPE)](https://arxiv.org/abs/2104.09864) | 88 | | RoPE dimensions | [64](https://github.com/kingoflolz/mesh-transformer-jax/blob/f2aa66e0925de6593dcbb70e72399b97b4130482/mesh_transformer/layers.py#L223) | 89 | 90 | `*` each layer consists of one feedforward block and one self attention block 91 | 92 | The model consists of 28 layers with a model dimension of 4096, and a feedforward dimension of 16384. The model 93 | dimension is split into 16 heads, each with a dimension of 256. Rotary position encodings (RoPE) was applied to 64 94 | dimensions of each head. The model is trained with a tokenization vocabulary of 50257, using the same set of BPEs as 95 | GPT-2/GPT-3. 96 | 97 | ### Zero-Shot Evaluations 98 | 99 | Models roughly sorted by performance, or by FLOPs if not available. 100 | 101 | | Model | Weights | Training FLOPs | LAMBADA PPL ↓ | LAMBADA Acc ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Dataset Size (GB) | 102 | |-----------------|---------|----------------|--- |--- |--- |--- |--- |-------------------| 103 | | Chance | ✔ | 0 | ~a lot | ~0% | 50% | 25% | 25% | 0 | 104 | | GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- | 105 | | GPT-2-1.5B | ✔ | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 | 106 | | GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 825 | 107 | | Megatron-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 | 108 | | GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 825 | 109 | | GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 | 110 | | GPT-3-Babbage‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- | 111 | | Megatron-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 | 112 | | GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 | 113 | | Megatron-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 | 114 | | **GPT-J-6B**‡ | ✔ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 825 | 115 | | GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 | 116 | | GPT-3-Curie‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- | 117 | | GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 | 118 | | GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 | 119 | | GPT-3-Davinci‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- | 120 | | Gopher 230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 1344 | 121 | | MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- | 122 | 123 | `*` represents evaluation numbers reported by their respective authors, all other numbers are provided by 124 | running the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/) either with the released 125 | weights or with API access. Due to subtle implementation differences as well as different zero shot task framing, these 126 | might not be directly comparable. See [this blog post](https://www.eleuther.ai/research-log/gpt3-model-sizes/) for more 127 | details. 128 | 129 | `†` The Megatron-11B model provides no comparable metrics, and several implementations using the released weights do not 130 | reproduce the generation quality and evaluations. (see [1](https://github.com/huggingface/transformers/pull/10301) 131 | [2](https://github.com/pytorch/fairseq/issues/2358) [3](https://github.com/pytorch/fairseq/issues/2719)) 132 | Thus, evaluation was not attempted. 133 | 134 | `‡` These models have been trained with data which contains possible test set contamination. The OpenAI GPT-3 models 135 | failed to deduplicate training data for certain test sets, while the GPT-Neo models as well as this one is 136 | trained on The Pile, which has not been deduplicated against any test sets. 137 | 138 | # Architecture and Usage 139 | 140 | Most scripts in this repository are designed to be run on TPUs, which under the 141 | [TPU-VM architecture](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) are virtual machines 142 | which can run arbitrary code. Most scripts are designed to spin up a TPU, SSH into it to set up the dependencies 143 | and copy code over from the local directory, and then start a [Ray](https://github.com/ray-project/ray.git) worker 144 | which can accept RPC calls. 145 | 146 | The TPUVMs handles running model training steps and evaluation, checkpoint save and loading, while the driver python 147 | program handles data loading and general orchestration (such as when to save checkpoints etc). 148 | 149 | This means that most scripts (`train.py`, `eval_harness.py` etc) expect to be running on a GCE virtual machine in the 150 | same region as the TPUs, to minimize RPC latency and data transfer cost. Other scripts 151 | (usually ones which don't take a `--tpu` argument, such as `device_sample.py`, `device_serve.py` or `device_train.py`) 152 | expect to be run directly on a TPUVM. The device_* scripts **only work on a v3-8** and not on larger pods. 153 | 154 | Furthermore, there is an example (`resharding_example.py`) of how to convert the provided checkpoints (which have 8 155 | shards in the case of GPT-J-6B) down to a smaller number, such as for when running on GPU(s). 156 | 157 | ### Fine-tuning 158 | 159 | To fine-tune the model, run `device_train.py` on a TPU VM. Using a TPU v3-8, you can fine-tune at a rate of ~5000 160 | tokens/second, which should be sufficient for small-to-medium-size datasets. 161 | 162 | Please read the [step by step guide](howto_finetune.md) for thorough fine-tuning instructions. 163 | 164 | ### JAX Dependency 165 | 166 | Note this library has some specific requirements for JAX version. Specifically, to use the v1 models (including 167 | GPT-J 6B), `jax==0.2.12` is required. This in turn depends on `jaxlib==0.1.68`. **If this is not done, you will get 168 | cryptic xmap errors** 169 | 170 | However, to use the v2 model code (no publicly released weights), the newest JAX version can be used. 171 | # Citation 172 | 173 | To cite this repository: 174 | ``` 175 | @inproceedings{ 176 | zelikman2022star, 177 | title={{ST}aR: Bootstrapping Reasoning With Reasoning}, 178 | author={Eric Zelikman and Yuhuai Wu and Jesse Mu and Noah Goodman}, 179 | booktitle={Advances in Neural Information Processing Systems}, 180 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, 181 | year={2022}, 182 | url={https://openreview.net/forum?id=_3ELRdg2sgI} 183 | } 184 | ``` 185 | 186 | To cite the base repository: 187 | ``` 188 | @misc{mesh-transformer-jax, 189 | author = {Wang, Ben}, 190 | title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}}, 191 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, 192 | year = 2021, 193 | month = May 194 | } 195 | ``` 196 | 197 | To cite the weights of GPT-J-6B: 198 | ``` 199 | @misc{gpt-j, 200 | author = {Wang, Ben and Komatsuzaki, Aran}, 201 | title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}}, 202 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, 203 | year = 2021, 204 | month = May 205 | } 206 | ``` 207 | 208 | If you use this repository or any of the pretrained weights to do something cool, we would love to hear about it. 209 | Feel free to open a github issue or reach out over email (in profile). 210 | 211 | # TODO 212 | - [x] disentangle heads and shards 213 | - [x] test/benchmark on TPU 214 | - [x] implement gradient checkpointing 215 | - [x] fix initialization 216 | - [x] mixed precision 217 | - [x] deal with preemptible TPUs 218 | - [x] test and validate generation 219 | - [x] shard activations instead of replicating for memory efficiency (in v2) 220 | - [x] support ZeRO style sharding (in v2) 221 | -------------------------------------------------------------------------------- /arithmetic/gen_sums.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | 4 | def to_list(num): 5 | list_str = list(str(num)) 6 | return list_str 7 | 8 | def to_str(num): 9 | return " ".join(to_list(num)) 10 | 11 | # n_samples = 10000 12 | n_samples = 2000 13 | context_examples = 50 14 | examples_per_prompt = 1 15 | # examples_per_prompt = 10 16 | 17 | val_start = 6 18 | max_digits = 10 19 | # include_scrachpad = True 20 | include_scrachpad = False 21 | # fixed_examples = True 22 | fixed_examples = False 23 | randomized_digits = False 24 | 25 | def gen_examples(n_examples, digits, randomized_digits=False, min_digit=1, max_digit=8): 26 | complete_str = [] 27 | for _ in range(n_examples): 28 | if randomized_digits: 29 | digits = random.randrange(min_digit + 1, max_digit) 30 | first_number = random.randrange(int(10 ** digits), int(10 ** (digits + 1))) 31 | second_number = random.randrange(int(10 ** digits), int(10 ** (digits + 1))) 32 | input_sum = f'{to_str(first_number)} + {to_str(second_number)}' 33 | resultant_str = f'Input:\n{input_sum}\nTarget:\n' 34 | if include_scrachpad: 35 | scratch_pad = f'\n{input_sum} , C: 0\n' 36 | carry = 0 37 | running_sum = '' 38 | initial = True 39 | for first_digit, second_digit in reversed(list(zip( 40 | to_list(first_number), to_list(second_number) 41 | ))): 42 | dig_sum = int(first_digit) + int(second_digit) + carry 43 | if not initial: 44 | scratch_pad += f'{first_digit} + {second_digit} , {running_sum}C: {carry}\n' 45 | carry = int(dig_sum >= 10) 46 | running_sum = f'{dig_sum % 10} {running_sum}' 47 | initial = False 48 | scratch_pad += f', {running_sum}C: {carry}\n' 49 | scratch_pad += f'{carry} {running_sum}'.strip() + '\n' 50 | scratch_pad += '\n' 51 | resultant_str += scratch_pad 52 | resultant_str += to_str(first_number + second_number) 53 | resultant_str += '\n\n' 54 | complete_str.append(resultant_str) 55 | return complete_str 56 | 57 | 58 | if fixed_examples and randomized_digits: 59 | few_shot_str = gen_examples(examples_per_prompt - 1, None, randomized_digits, 1, val_start - 1) 60 | for digits in range(max_digits): 61 | folder_name = 'val' if digits + 1 >= val_start else 'train' 62 | if include_scrachpad: 63 | folder_name += '_scratch' 64 | if fixed_examples and not randomized_digits: 65 | few_shot_str = gen_examples(context_examples, digits, randomized_digits) 66 | else: 67 | folder_name += '_direct' 68 | if not os.path.isdir(folder_name): 69 | os.mkdir(folder_name) 70 | complete_str = '' 71 | for _ in range(n_samples): 72 | n_gen_examples = 1 if fixed_examples else examples_per_prompt 73 | if fixed_examples: 74 | complete_str += "".join(random.sample(few_shot_str, examples_per_prompt - 1)) 75 | complete_str += "".join(gen_examples(n_gen_examples, digits)) 76 | complete_str += '<|endoftext|>' 77 | 78 | with open(f'{folder_name}/{digits + 1}.txt', 'w') as f: 79 | f.write(complete_str) 80 | 81 | if digits + 2 == val_start or digits + 1 == max_digits: 82 | os.system(f'python3 ../create_finetune_tfrecords.py ./{folder_name}/ arithmetic_{folder_name}') 83 | -------------------------------------------------------------------------------- /benchmarks.md: -------------------------------------------------------------------------------- 1 | Note that everything on this page is quite outdated, and are only roughly accurate when considering new features 2 | such as RoPE 3 | 4 | # Benchmarks (v3-8) 5 | 6 | (see `tpuv38_example.py`): 7 | 8 | ## ~2.7B model 9 | ``` 10 | Initialized in 121.842s 11 | Total parameters: 2722382080 12 | Compiled in 49.0534s 13 | it: 0, loss: 20.311113357543945 14 | 15 | it: 90, loss: 3.987450361251831 16 | 100 steps in 109.385s 17 | effective flops (not including attn): 2.4466e+14 18 | ``` 19 | 20 | ## ~4.8B model 21 | ``` 22 | Initialized in 101.016s 23 | Total parameters: 4836720896 24 | Compiled in 52.7404s 25 | it: 0, loss: 4.632925987243652 26 | 27 | it: 40, loss: 3.2406811714172363 28 | 50 steps in 102.559s 29 | effective flops (not including attn): 2.31803e+14 30 | ``` 31 | 32 | ## 10B model 33 | ``` 34 | Initialized in 152.762s 35 | Total parameters: 10073579776 36 | Compiled in 92.6539s 37 | it: 0, loss: 5.3125 38 | 39 | it: 40, loss: 3.65625 40 | 50 steps in 100.235s 41 | effective flops (not including attn): 2.46988e+14 42 | ``` 43 | 44 | # Benchmarks (v3-32) 45 | 46 | (see `eval.py`): 47 | 48 | ## 6B model 49 | ``` 50 | "layers": 28, 51 | "d_model": 4096, 52 | "n_heads": 16, 53 | "n_vocab": 50400, 54 | 55 | "seq": 2048, 56 | "cores_per_replica": 8, 57 | "per_replica_batch": 1, 58 | "gradient_accumulation_steps": 8, 59 | 60 | params: 6053381856 61 | 32 iters done in 107.935s, avg 3.37298s, 0.296473/s 62 | effective flops (not including attn): 7.05692e+14 63 | MXU flops: 1.04523e+15 64 | ``` 65 | 66 | ## Note that the below models do not currently work 67 | They require a larger degree of model parallelism than is currently implemented, but benchmark numbers should be 68 | reasonably representative. 69 | 70 | ## 13B model 71 | ``` 72 | "layers": 28, 73 | "d_model": 6144, 74 | "n_heads": 32, 75 | "n_vocab": 50400, 76 | 77 | "seq": 2048, 78 | "cores_per_replica": 16, 79 | "per_replica_batch": 1, 80 | "gradient_accumulation_steps": 16, 81 | 82 | params: 13312183008 83 | 32 iters done in 250.86s, avg 7.83937s, 0.127561/s 84 | effective flops (not including attn): 6.67727e+14 85 | MXU flops: 9.80066e+14 86 | ``` 87 | 88 | ## 23B model 89 | ``` 90 | "layers": 28, 91 | "d_model": 8192, 92 | "n_heads": 32, 93 | "n_vocab": 50400, 94 | 95 | "seq": 2048, 96 | "cores_per_replica": 32, 97 | "per_replica_batch": 1, 98 | "gradient_accumulation_steps": 32, 99 | 100 | params: 23398107360 101 | 16 iters done in 221.33s, avg 13.8331s, 0.0722902/s 102 | effective flops (not including attn): 6.65107e+14 103 | MXU flops: 9.88548e+14 104 | ``` -------------------------------------------------------------------------------- /commonsenseqa/prompts.txt: -------------------------------------------------------------------------------- 1 | Q: What do people use to absorb extra ink from a fountain pen? 2 | Answer Choices: 3 | (a) shirt pocket 4 | (b) calligrapher's hand 5 | (c) inkwell 6 | (d) desk drawer 7 | (e) blotter 8 | A: The answer must be used to absorb extra ink. Blotters are designed to absorb liquids. Therefore, the answer is blotter (e). 9 | 10 | Q: What home entertainment equipment requires cable? 11 | Answer Choices: 12 | (a) radio shack 13 | (b) substation 14 | (c) television 15 | (d) cabinet 16 | (e) desk 17 | A: The answer must require cable. Cable is used to provide satellite channels to televisions. Therefore, the answer is television (c). 18 | 19 | Q: The fox walked from the city into the forest, what was it looking for? 20 | Answer Choices: 21 | (a) pretty flowers 22 | (b) hen house 23 | (c) natural habitat 24 | (d) storybook 25 | (e) dense forest 26 | A: The answer must be a reason for a fox to go into the forest. The forest is a fox's natural habitat. Therefore, the answer is natural habitat (c). 27 | 28 | Q: Sammy wanted to go to where the people were. Where might he go? 29 | Answer Choices: 30 | (a) populated areas 31 | (b) race track 32 | (c) desert 33 | (d) apartment 34 | (e) roadblock 35 | A: The answer must be a place with many people. Populated areas, by definition, have a lot of people. Therefore, the answer is populated areas (a). 36 | 37 | Q: Where do you put your grapes just before checking out? 38 | Answer Choices: 39 | (a) mouth 40 | (b) grocery cart 41 | (c) super market 42 | (d) fruit basket 43 | (e) fruit market 44 | A: The answer should be the place where grocery items are placed before checking out. Of the above choices, grocery cart makes the most sense for holding grocery items. Therefore, the answer is grocery cart (b). 45 | 46 | Q: Google Maps and other highway and street GPS services have replaced what? 47 | Answer Choices: 48 | (a) united states 49 | (b) mexico 50 | (c) countryside 51 | (d) atlas 52 | (e) oceans 53 | A: The answer must be something that used to do what Google Maps and GPS services do, which is give directions. Atlases were also used to give directions. Therefore, the answer is atlas (d). 54 | 55 | Q: Before getting a divorce, what did the wife feel who was doing all the work? 56 | Answer Choices: 57 | (a) harder 58 | (b) anguish 59 | (c) bitterness 60 | (d) tears 61 | (e) sadness 62 | A: The answer should be a feeling which would cause someone who was doing all the work to get divorced. If someone feels bitter towards their spouse, they are likely to want a divorce. Therefore, the answer is bitterness (c). -------------------------------------------------------------------------------- /commonsenseqa/prompts_answer_key.txt: -------------------------------------------------------------------------------- 1 | Q: What do people use to absorb extra ink from a fountain pen? 2 | Answer Choices: 3 | (a) shirt pocket 4 | (b) calligrapher's hand 5 | (c) inkwell 6 | (d) desk drawer 7 | (e) blotter (CORRECT) 8 | A: The answer must be used to absorb extra ink. Blotters are designed to absorb liquids. Therefore, the answer is blotter (e). 9 | 10 | Q: What home entertainment equipment requires cable? 11 | Answer Choices: 12 | (a) radio shack 13 | (b) substation 14 | (c) television (CORRECT) 15 | (d) cabinet 16 | (e) desk 17 | A: The answer must require cable. Cable is used to provide satellite channels to televisions. Therefore, the answer is television (c). 18 | 19 | Q: The fox walked from the city into the forest, what was it looking for? 20 | Answer Choices: 21 | (a) pretty flowers 22 | (b) hen house 23 | (c) natural habitat (CORRECT) 24 | (d) storybook 25 | (e) dense forest 26 | A: The answer must be a reason for a fox to go into the forest. The forest is a fox's natural habitat. Therefore, the answer is natural habitat (c). 27 | 28 | Q: Sammy wanted to go to where the people were. Where might he go? 29 | Answer Choices: 30 | (a) populated areas (CORRECT) 31 | (b) race track 32 | (c) desert 33 | (d) apartment 34 | (e) roadblock 35 | A: The answer must be a place with many people. Populated areas, by definition, have a lot of people. Therefore, the answer is populated areas (a). 36 | 37 | Q: Where do you put your grapes just before checking out? 38 | Answer Choices: 39 | (a) mouth 40 | (b) grocery cart (CORRECT) 41 | (c) super market 42 | (d) fruit basket 43 | (e) fruit market 44 | A: The answer should be the place where grocery items are placed before checking out. Of the above choices, grocery cart makes the most sense for holding grocery items. Therefore, the answer is grocery cart (b). 45 | 46 | Q: Google Maps and other highway and street GPS services have replaced what? 47 | Answer Choices: 48 | (a) united states 49 | (b) mexico 50 | (c) countryside 51 | (d) atlas (CORRECT) 52 | (e) oceans 53 | A: The answer must be something that used to do what Google Maps and GPS services do, which is give directions. Atlases were also used to give directions. Therefore, the answer is atlas (d). 54 | 55 | Q: Before getting a divorce, what did the wife feel who was doing all the work? 56 | Answer Choices: 57 | (a) harder 58 | (b) anguish 59 | (c) bitterness (CORRECT) 60 | (d) tears 61 | (e) sadness 62 | A: The answer should be a feeling which would cause someone who was doing all the work to get divorced. If someone feels bitter towards their spouse, they are likely to want a divorce. Therefore, the answer is bitterness (c). -------------------------------------------------------------------------------- /commonsenseqa/prompts_direct.txt: -------------------------------------------------------------------------------- 1 | Q: What do people use to absorb extra ink from a fountain pen? 2 | Answer Choices: 3 | (a) shirt pocket 4 | (b) calligrapher's hand 5 | (c) inkwell 6 | (d) desk drawer 7 | (e) blotter 8 | A: (e). 9 | 10 | Q: What home entertainment equipment requires cable? 11 | Answer Choices: 12 | (a) radio shack 13 | (b) substation 14 | (c) television 15 | (d) cabinet 16 | (e) desk 17 | A: (c). 18 | 19 | Q: The fox walked from the city into the forest, what was it looking for? 20 | Answer Choices: 21 | (a) pretty flowers 22 | (b) hen house 23 | (c) natural habitat 24 | (d) storybook 25 | (e) dense forest 26 | A: (c). 27 | 28 | Q: Sammy wanted to go to where the people were. Where might he go? 29 | Answer Choices: 30 | (a) populated areas 31 | (b) race track 32 | (c) desert 33 | (d) apartment 34 | (e) roadblock 35 | A: (a). 36 | 37 | Q: Where do you put your grapes just before checking out? 38 | Answer Choices: 39 | (a) mouth 40 | (b) grocery cart 41 | (c) super market 42 | (d) fruit basket 43 | (e) fruit market 44 | A: (b). 45 | 46 | Q: Google Maps and other highway and street GPS services have replaced what? 47 | Answer Choices: 48 | (a) united states 49 | (b) mexico 50 | (c) countryside 51 | (d) atlas 52 | (e) oceans 53 | A: (d). 54 | 55 | Q: Before getting a divorce, what did the wife feel who was doing all the work? 56 | Answer Choices: 57 | (a) harder 58 | (b) anguish 59 | (c) bitterness 60 | (d) tears 61 | (e) sadness 62 | A: (c). -------------------------------------------------------------------------------- /commonsenseqa/prompts_direct_answer_key.txt: -------------------------------------------------------------------------------- 1 | Q: What do people use to absorb extra ink from a fountain pen? 2 | Answer Choices: 3 | (a) shirt pocket 4 | (b) calligrapher's hand 5 | (c) inkwell 6 | (d) desk drawer 7 | (e) blotter (CORRECT) 8 | A: (e). 9 | 10 | Q: What home entertainment equipment requires cable? 11 | Answer Choices: 12 | (a) radio shack 13 | (b) substation 14 | (c) television (CORRECT) 15 | (d) cabinet 16 | (e) desk 17 | A: (c). 18 | 19 | Q: The fox walked from the city into the forest, what was it looking for? 20 | Answer Choices: 21 | (a) pretty flowers 22 | (b) hen house 23 | (c) natural habitat (CORRECT) 24 | (d) storybook 25 | (e) dense forest 26 | A: (c). 27 | 28 | Q: Sammy wanted to go to where the people were. Where might he go? 29 | Answer Choices: 30 | (a) populated areas (CORRECT) 31 | (b) race track 32 | (c) desert 33 | (d) apartment 34 | (e) roadblock 35 | A: (a). 36 | 37 | Q: Where do you put your grapes just before checking out? 38 | Answer Choices: 39 | (a) mouth 40 | (b) grocery cart (CORRECT) 41 | (c) super market 42 | (d) fruit basket 43 | (e) fruit market 44 | A: (b). 45 | 46 | Q: Google Maps and other highway and street GPS services have replaced what? 47 | Answer Choices: 48 | (a) united states 49 | (b) mexico 50 | (c) countryside 51 | (d) atlas (CORRECT) 52 | (e) oceans 53 | A: (d). 54 | 55 | Q: Before getting a divorce, what did the wife feel who was doing all the work? 56 | Answer Choices: 57 | (a) harder 58 | (b) anguish 59 | (c) bitterness (CORRECT) 60 | (d) tears 61 | (e) sadness 62 | A: (c). -------------------------------------------------------------------------------- /configs/iterative_base/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": 28, 3 | "d_model": 4096, 4 | "n_heads": 16, 5 | "n_vocab": 50400, 6 | "norm": "layernorm", 7 | "pe": "rotary", 8 | "pe_rotary_dims": 64, 9 | "seq": 2048, 10 | "cores_per_replica": 8, 11 | "per_replica_batch": 1, 12 | "gradient_accumulation_steps": 2, 13 | "warmup_steps": 100, 14 | "anneal_steps": 300000, 15 | "lr": 1e-06, 16 | "end_lr": 1e-06, 17 | "weight_decay": 0.0, 18 | "total_steps": 383500, 19 | "tpu_size": 8, 20 | "bucket": "checkpoint-bucket", 21 | "model_dir": "full_qa_4", 22 | "train_set": "qa_train_4.index", 23 | "val_set": { 24 | "index": "qa.val.index" 25 | }, 26 | "eval_harness_tasks": [ 27 | "lambada", 28 | "piqa", 29 | "hellaswag", 30 | "winogrande", 31 | "mathqa", 32 | "pubmedqa" 33 | ], 34 | "val_batches": 100, 35 | "val_every": 20, 36 | "ckpt_every": 20, 37 | "keep_every": 10000, 38 | "name": "full_6", 39 | "wandb_project": "full_6", 40 | "comment": "", 41 | "target_save_folder": "commonsenseqa/iterative_full/iterative_full_0", 42 | "target_save": "commonsenseqa/iterative_negative_fast/iterative_negative_fast_0/iterative_negative_fast_0.txt" 43 | } 44 | -------------------------------------------------------------------------------- /configs/qa_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": 28, 3 | "d_model": 4096, 4 | "n_heads": 16, 5 | "n_vocab": 50400, 6 | "norm": "layernorm", 7 | "pe": "rotary", 8 | "pe_rotary_dims": 64, 9 | "seq": 1536, 10 | "cores_per_replica": 8, 11 | "per_replica_batch": 1, 12 | "gradient_accumulation_steps": 8, 13 | "warmup_steps": 100, 14 | "anneal_steps": 300000, 15 | "lr": 1e-06, 16 | "end_lr": 1e-06, 17 | "weight_decay": 0.0, 18 | "total_steps": 383500, 19 | "tpu_size": 8, 20 | "p_rationalization": 1.0, 21 | "bucket": "checkpoint-bucket", 22 | "model_dir": "full_qa_4", 23 | "train_set": "qa_train_4.index", 24 | "val_set": { 25 | "index": "qa.val.index" 26 | }, 27 | "eval_harness_tasks": [ 28 | "lambada", 29 | "piqa", 30 | "hellaswag", 31 | "winogrande", 32 | "mathqa", 33 | "pubmedqa" 34 | ], 35 | "val_batches": 100, 36 | "val_every": 10000, 37 | "ckpt_every": 10000, 38 | "keep_every": 10000, 39 | "name": "slow_grow_full_epoch_0", 40 | "wandb_project": "full_6", 41 | "comment": "", 42 | "target_save_folder": "commonsenseqa/iterative_full/iterative_full_0", 43 | "target_save": "commonsenseqa/slow_grow_full_epoch/slow_grow_full_epoch_0/slow_grow_full_epoch_0.txt" 44 | } 45 | -------------------------------------------------------------------------------- /create_finetune_tfrecords.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import random 5 | 6 | from pathlib import Path 7 | from typing import List 8 | 9 | import ftfy 10 | import tensorflow as tf 11 | from lm_dataformat import Reader 12 | from transformers import GPT2TokenizerFast 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description=""" 18 | Converts a text dataset into the training data format expected by the model. 19 | 20 | Adapted from the script create_tfrecords.py in the gpt-neo repo. 21 | 22 | - Your text dataset: 23 | - can be provided as .txt files, or as an archive (.tar.gz, .xz, jsonl.zst). 24 | - can be one file or multiple 25 | - using a single large file may use too much memory and crash - if this occurs, split the file up into a few files 26 | - the model's end-of-text separator is added between the contents of each file 27 | - if the string '<|endoftext|>' appears inside a file, it is treated as the model's end-of-text separator (not the actual string '<|endoftext|>') 28 | - this behavior can be disabled with --treat-eot-as-text 29 | 30 | This script creates a single .tfrecords file as output 31 | - Why: the model's data loader ignores "trailing" data (< 1 batch) at the end of a .tfrecords file 32 | - this causes data loss if you have many .tfrecords files 33 | - This is probably not appropriate for very large datasets 34 | """, formatter_class=argparse.RawTextHelpFormatter) 35 | parser.add_argument( 36 | "input_path", 37 | type=str, 38 | help="Path to an input file, or a directory that contains the input files.", 39 | ) 40 | parser.add_argument("name", type=str, 41 | help="Name of output file will be {name}_{seqnum}.tfrecords, where seqnum is total sequence count") 42 | parser.add_argument("--output-dir", type=str, default="", help="Output directory (default: current directory)") 43 | 44 | cleaning_args = parser.add_argument_group('data cleaning arguments') 45 | 46 | cleaning_args.add_argument("--normalize-with-ftfy", action="store_true", help="Normalize text with ftfy") 47 | cleaning_args.add_argument("--normalize-with-wikitext-detokenize", 48 | action="store_true", help="Use wikitext detokenizer") 49 | minu_help = "Exclude repetitive documents made up of < MIN_UNIQUE_TOKENS unique tokens. These can produce large gradients." 50 | minu_help += " Set <= 0 to disable. If enabled, 200 is a good default value. (Default: 0)" 51 | cleaning_args.add_argument("--min-unique-tokens", type=int, default=0, 52 | help=minu_help) 53 | 54 | shuffle_pack_args = parser.add_argument_group('data shuffling/packing arguments') 55 | repack_ep_help = "Repeat the data N_REPACK_EPOCHS times, shuffled differently in each repetition. Recommended for multi-epoch training (set this to your intended number of epochs)." 56 | shuffle_pack_args.add_argument("--n-repack-epochs", 57 | type=int, default=1, 58 | help=repack_ep_help 59 | ) 60 | shuffle_pack_args.add_argument("--seed", type=int, default=10, 61 | help="random seed for shuffling data (default: 10)") 62 | shuffle_pack_args.add_argument("--preserve-data-order", 63 | default=False, action="store_true", 64 | help="Disables shuffling, so the input and output data have the same order.") 65 | 66 | misc_args = parser.add_argument_group('miscellaneous arguments') 67 | misc_args.add_argument("--verbose", 68 | default=False, action="store_true", 69 | help="Prints extra information, such as the text removed by --min-unique-tokens") 70 | 71 | args = parser.parse_args() 72 | 73 | # convert input_path to pathy 74 | args.input_path = Path(args.input_path) 75 | 76 | return args 77 | 78 | 79 | def get_files(input_path: Path) -> List[str]: 80 | supported_file_types = ["jsonl.zst", ".txt", ".xz", ".tar.gz"] 81 | if input_path.is_dir(): 82 | # get all files with supported file types 83 | files = [list(Path(input_path).glob(f"*{ft}")) for ft in supported_file_types] 84 | # flatten list 85 | files = [f for sublist in files for f in sublist] 86 | assert files, f"No files with supported types found in directory: {input_path}" 87 | elif input_path.is_file(): 88 | assert any( 89 | str(input_path).endswith(f_type) for f_type in supported_file_types 90 | ), f"Input file type must be one of: {supported_file_types}" 91 | files = [input_path] 92 | else: 93 | raise FileNotFoundError(f"No such file or directory: {input_path=}") 94 | 95 | return [str(f) for f in files] 96 | 97 | 98 | def wikitext_detokenizer(string): 99 | # contractions 100 | string = string.replace("s '", "s'") 101 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) 102 | # number separators 103 | string = string.replace(" @-@ ", "-") 104 | string = string.replace(" @,@ ", ",") 105 | string = string.replace(" @.@ ", ".") 106 | # punctuation 107 | string = string.replace(" : ", ": ") 108 | string = string.replace(" ; ", "; ") 109 | string = string.replace(" . ", ". ") 110 | string = string.replace(" ! ", "! ") 111 | string = string.replace(" ? ", "? ") 112 | string = string.replace(" , ", ", ") 113 | # double brackets 114 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) 115 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) 116 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) 117 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) 118 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) 119 | # miscellaneous 120 | string = string.replace("= = = =", "====") 121 | string = string.replace("= = =", "===") 122 | string = string.replace("= =", "==") 123 | string = string.replace(" " + chr(176) + " ", chr(176)) 124 | string = string.replace(" \n", "\n") 125 | string = string.replace("\n ", "\n") 126 | string = string.replace(" N ", " 1 ") 127 | string = string.replace(" 's", "'s") 128 | 129 | return string 130 | 131 | 132 | def _int64_feature(value): 133 | """ 134 | Returns an int64_list from a bool / enum / int / uint. 135 | """ 136 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 137 | 138 | 139 | def write_to_file(writer, data): 140 | """ 141 | writes data to tfrecord file 142 | """ 143 | feature = { 144 | "text": _int64_feature(data) 145 | } 146 | tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) 147 | writer.write(tf_example.SerializeToString()) 148 | 149 | 150 | def write_tfrecord(sequences, fp): 151 | with tf.io.TFRecordWriter(fp) as writer: 152 | for seq in sequences: 153 | write_to_file(writer, seq) 154 | 155 | 156 | def split_list(l, n): 157 | # splits list/string into n size chunks 158 | return [l[i:i + n] for i in range(0, len(l), n)] 159 | 160 | 161 | def enforce_min_unique(seqs, min_unique_tokens, enc, verbose=False): 162 | for seq in tqdm(seqs, mininterval=1, smoothing=0, desc="enforce_min_unique_tokens"): 163 | if len(set(seq)) >= min_unique_tokens: 164 | yield seq 165 | elif verbose: 166 | text = enc.decode(seq) 167 | print(f"excluding with {len(set(seq))} unique tokens:\n\n{repr(text)}\n\n") 168 | 169 | 170 | def eot_splitting_generator(string_iterable, encoder): 171 | """ 172 | Given strings, splits them internally on <|endoftext|> and yields (generally more) strings 173 | """ 174 | for doc in string_iterable: 175 | for d in doc.split(encoder.eos_token): 176 | if len(d) > 0: 177 | yield d 178 | 179 | 180 | def prep_and_tokenize_generator(string_iterable, encoder, normalize_with_ftfy, normalize_with_wikitext_detokenize): 181 | """ 182 | Given strings, does data cleaning / tokenization and yields arrays of tokens 183 | """ 184 | for doc in string_iterable: 185 | if normalize_with_ftfy: # fix text with ftfy if specified 186 | doc = ftfy.fix_text(doc, normalization='NFKC') 187 | if normalize_with_wikitext_detokenize: 188 | doc = wikitext_detokenizer(doc) 189 | tokens = encoder.encode(doc) + [encoder.eos_token_id] 190 | yield tokens 191 | 192 | 193 | def file_to_tokenized_docs_generator(file_path, encoder, args): 194 | """ 195 | Given a file path, reads the file and tokenizes the contents 196 | 197 | Yields token arrays of arbitrary, unequal length 198 | """ 199 | reader = Reader(file_path) 200 | string_iterable = reader.stream_data(threaded=False) 201 | string_iterable = eot_splitting_generator(string_iterable, encoder) 202 | 203 | token_list_gen = prep_and_tokenize_generator(string_iterable, 204 | encoder, 205 | normalize_with_ftfy=args.normalize_with_ftfy, 206 | normalize_with_wikitext_detokenize=args.normalize_with_wikitext_detokenize 207 | ) 208 | return token_list_gen 209 | 210 | 211 | def read_files_to_tokenized_docs(files, args, encoder): 212 | docs = [] 213 | 214 | if args.preserve_data_order: 215 | files = sorted(files) 216 | else: 217 | random.shuffle(files) 218 | 219 | for f in tqdm(files, mininterval=10, smoothing=0, desc="reading/tokenizing files"): 220 | docs.extend(file_to_tokenized_docs_generator(f, encoder, args)) 221 | 222 | if not args.preserve_data_order: 223 | # shuffle at individual document level 224 | random.shuffle(docs) 225 | 226 | return docs 227 | 228 | 229 | def arrays_to_sequences(token_list_iterable, sequence_length=1537): 230 | """ 231 | Given token arrays of arbitrary lengths, concats/splits them into arrays of equal length 232 | 233 | Returns equal-length token arrays, followed by a a final array of trailing tokens (which may be shorter) 234 | """ 235 | accum = [] 236 | for l in token_list_iterable: 237 | accum.extend(l) 238 | 239 | if len(accum) > sequence_length: 240 | chunks = split_list(accum, sequence_length) 241 | yield from chunks[:-1] 242 | accum = chunks[-1] 243 | 244 | if len(accum) > 0: 245 | yield accum 246 | 247 | 248 | def chunk_and_finalize(arrays, args, encoder): 249 | sequences = list(arrays_to_sequences(arrays)) 250 | 251 | full_seqs, trailing_data = sequences[:-1], sequences[-1] 252 | 253 | if args.min_unique_tokens > 0: 254 | full_seqs = list(enforce_min_unique(full_seqs, args.min_unique_tokens, encoder, args.verbose)) 255 | 256 | if not args.preserve_data_order: 257 | random.shuffle(full_seqs) 258 | 259 | return full_seqs, trailing_data 260 | 261 | 262 | def create_tfrecords(files, args): 263 | GPT2TokenizerFast.max_model_input_sizes['gpt2'] = 1e20 # disables a misleading warning 264 | encoder = GPT2TokenizerFast.from_pretrained('gpt2') 265 | 266 | random.seed(args.seed) 267 | 268 | all_sequences_across_epochs = [] 269 | 270 | docs = read_files_to_tokenized_docs(files, args, encoder) 271 | 272 | full_seqs, trailing_data = chunk_and_finalize(docs, args, encoder) 273 | 274 | all_sequences_across_epochs.extend(full_seqs) 275 | 276 | # ep 2+ 277 | for ep_ix in range(1, args.n_repack_epochs): 278 | # re-shuffle 279 | if not args.preserve_data_order: 280 | random.shuffle(docs) 281 | full_seqs, trailing_data = chunk_and_finalize(docs, args, encoder) 282 | else: 283 | # if we're preserving data order, we can still "repack" by shifting everything 284 | # with the trailing data of the last epoch at the beginning 285 | seqs_with_prefix = [trailing_data] + full_seqs 286 | full_seqs, trailing_data = chunk_and_finalize(seqs_with_prefix, args, encoder) 287 | 288 | all_sequences_across_epochs.extend(full_seqs) 289 | 290 | # final 291 | print(f"dropped {len(trailing_data)} tokens of trailing data") 292 | 293 | total_sequence_len = len(all_sequences_across_epochs) 294 | 295 | fp = os.path.join(args.output_dir, f"{args.name}.tfrecords") 296 | write_tfrecord(all_sequences_across_epochs, fp) 297 | 298 | 299 | if __name__ == "__main__": 300 | args = parse_args() 301 | 302 | if args.output_dir: 303 | os.makedirs(args.output_dir, exist_ok=True) 304 | files = get_files(args.input_path) 305 | print(f"Creating TFRecords from files: {files}") 306 | 307 | results = create_tfrecords(files, args) 308 | -------------------------------------------------------------------------------- /data/qa.val.index: -------------------------------------------------------------------------------- 1 | /home/username/STaR/commonsenseqa/qa_val.tfrecords 2 | -------------------------------------------------------------------------------- /device_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | import jax 6 | import numpy as np 7 | import optax 8 | import random 9 | import wandb 10 | 11 | from mesh_transformer import util 12 | from mesh_transformer.checkpoint import read_ckpt 13 | from mesh_transformer.sampling import nucleaus_sample 14 | from mesh_transformer.transformer_shard import CausalTransformer 15 | from smart_open import open as smart_open 16 | import transformers 17 | import tensorflow_datasets as tfds 18 | from mesh_transformer.util import clip_by_global_norm 19 | import jsonlines 20 | import pprint 21 | from tqdm import tqdm 22 | 23 | basic_open = open 24 | pp = pprint.PrettyPrinter(indent=2).pprint 25 | 26 | 27 | def eval_output(output, answers, context, example_classes, accuracy, target_save, tokenizer, show=False, direct=False, endoftext="<|endoftext|>"): 28 | successful_examples = [] 29 | enum_outputs = enumerate(output[1][0][:, :, 0]) 30 | for (idx, o), target, cur_base_context, example_class in zip(enum_outputs, answers, context, example_classes): 31 | cur_output = tokenizer.decode(o) 32 | output_numbers = cur_output.split('\n') 33 | if example_class not in accuracy: 34 | accuracy[example_class] = {'accurate': 0, 'total': 0} 35 | accuracy[example_class]['total'] += 1 36 | if len(output_numbers) == 0: 37 | continue 38 | try: 39 | if args.dataset_mode == "cqa": 40 | output_numbers = output_numbers[0] 41 | if "<|endoftext|>" in output_numbers: 42 | output_numbers = output_numbers.split("<|endoftext|>")[0] 43 | output_prediction = output_numbers[-3] 44 | elif args.dataset_mode == "gsm": 45 | output_prediction = "" 46 | for line_idx, line in enumerate(output_numbers): 47 | if "####" in line: 48 | output_numbers = "\n".join(output_numbers[:line_idx + 1]) 49 | if "<|endoftext|>" in output_numbers: 50 | output_numbers = output_numbers.split("<|endoftext|>")[0] 51 | output_prediction = output_numbers.split("####")[-1].strip() 52 | break 53 | elif args.dataset_mode == "arithmetic": 54 | if len(output_numbers) == 0: 55 | continue 56 | elif "<|endoftext|>" in output_numbers: 57 | prediction_index = output_numbers.index("<|endoftext|>") - 1 58 | elif "" in output_numbers: 59 | prediction_index = output_numbers.index("") + 1 60 | if prediction_index == len(output_numbers): 61 | continue 62 | else: 63 | if direct and len(output_numbers) > 1: 64 | prediction_index = 1 65 | else: 66 | prediction_index = 0 67 | output_prediction = output_numbers[prediction_index] 68 | 69 | if "<|endoftext|>" in output_prediction: 70 | output_prediction = output_prediction.split("<|endoftext|>")[0] 71 | 72 | correct = output_prediction.lower() == target.lower() 73 | if correct: 74 | accuracy[example_class]['accurate'] += 1 75 | with basic_open(target_save, 'a+') as new_train_f: 76 | if args.dataset_mode == "cqa" or args.dataset_mode == "gsm": 77 | new_example = cur_base_context + output_numbers + endoftext 78 | elif args.dataset_mode == "arithmetic": 79 | if args.few_shot_train: 80 | raise NotImplementedError 81 | joined_output = "\n".join(output_numbers[:prediction_index + 1]) 82 | if "<|endoftext|>" in joined_output: 83 | joined_output = joined_output.split("<|endoftext|>")[0] 84 | new_example = cur_base_context + joined_output + endoftext 85 | if show: 86 | print(new_example) 87 | print(new_example, file=new_train_f, end="") 88 | successful_examples.append(idx) 89 | except IndexError: 90 | pass 91 | return successful_examples 92 | 93 | 94 | def get_score(subcounts): 95 | if subcounts['total'] == 0: 96 | return 0 97 | return subcounts['accurate'] / subcounts['total'] 98 | 99 | def question_to_context(data_example, hint=False, dataset_mode='cqa', direct=False): 100 | if dataset_mode == 'cqa': 101 | context = f"Q: {data_example['question']['stem']}\nAnswer Choices:\n" 102 | for choice in data_example['question']['choices']: 103 | if hint and (choice['label'].lower() == data_example['answerKey'].lower()): 104 | context += f"({choice['label'].lower()}) {choice['text']} (CORRECT)\n" 105 | else: 106 | context += f"({choice['label'].lower()}) {choice['text']}\n" 107 | context += "A:" 108 | elif dataset_mode == 'gsm': 109 | context = f"Q: {data_example['question']}" 110 | if hint: 111 | chosen_hint = data_example['answer'] 112 | context += f" ({chosen_hint})" 113 | context += "\nA:" 114 | elif dataset_mode == "arithmetic": 115 | context = "" 116 | for example_split, next_example_split in zip(data_example.split('Target:')[:-1], data_example.split('Target:')[1:]): 117 | if direct and "" in example_split: 118 | context += example_split.split("")[-1] 119 | else: 120 | context += example_split 121 | context += "Target:" 122 | if hint: 123 | context += " " + next_example_split.split("\n")[-5] 124 | return context 125 | 126 | 127 | def examples_to_batch(data_examples, few_shot_prompts, seq, tokenizer, hint=False, direct=False, p_show_hint_save=0.1): 128 | batch = { 129 | "base_context": [], 130 | "initial_batch": [], 131 | "lengths": [], 132 | "padded_batch": [], 133 | "answers": [], 134 | "classes": [] 135 | } 136 | for data_class, data_example in data_examples: 137 | batch['classes'].append(data_class) 138 | # Context, without the few-shot prompt 139 | hintless_base_context = question_to_context(data_example, hint=False, dataset_mode=args.dataset_mode, direct=direct) 140 | base_context = question_to_context(data_example, hint=hint, dataset_mode=args.dataset_mode, direct=direct) 141 | if args.dataset_mode == "arithmetic": 142 | few_shot_prompts = base_context.split("\n\n")[:-1] 143 | base_context = base_context.split("\n\n")[-1] 144 | hintless_base_context = hintless_base_context.split("\n\n")[-1] 145 | if random.random() < p_show_hint_save: 146 | hintless_base_context = base_context 147 | # We always want to act as if no hint was given 148 | if args.few_shot_train: 149 | if args.dataset_mode == "arithmetic": 150 | raise NotImplementedError 151 | else: 152 | save_context = "\n\n".join(commonsense_prompts) + "\n\n" 153 | save_context += hintless_base_context 154 | batch['base_context'].append(save_context) 155 | else: 156 | batch['base_context'].append(hintless_base_context) 157 | # Input tokens 158 | if args.no_prompt: 159 | context = "" 160 | else: 161 | context = "\n\n".join(few_shot_prompts) + "\n\n" 162 | context += base_context 163 | tokens = tokenizer.encode(context) 164 | batch['initial_batch'].append(tokens) 165 | # Input lengths 166 | batch['lengths'].append(len(tokens)) 167 | # Padded tokens 168 | provided_ctx = len(tokens) 169 | pad_amount = max(seq - provided_ctx, 0) 170 | if provided_ctx > seq: 171 | tokens = tokens[-seq:] 172 | batch['padded_batch'].append(np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)) 173 | # Answer 174 | if args.dataset_mode == "arithmetic": 175 | if len(data_example.split("\n")) >= 3: 176 | target = data_example.split("\n")[-3] 177 | else: 178 | target = "invalid" 179 | elif args.dataset_mode == "cqa": 180 | target = data_example['answerKey'] 181 | elif args.dataset_mode == "gsm": 182 | target = data_example['answer'].split("#### ")[-1] 183 | batch['answers'].append(target) 184 | batch["lengths"] = np.asarray(batch["lengths"], dtype=np.uint32) 185 | batch["padded_batch"] = np.array(batch["padded_batch"]) 186 | return batch 187 | 188 | 189 | def eval_batch(examples, few_shot_prompts, seq, tok, gen_length, gen_params, accuracy, target_save, hint=False, direct=False): 190 | batch = examples_to_batch(examples, few_shot_prompts, seq, tok, hint=hint, direct=direct, p_show_hint_save=args.p_show_hint_save) 191 | output = network.generate(batch["padded_batch"], batch["lengths"], gen_length, gen_params) 192 | return eval_output( 193 | output, batch["answers"], batch["base_context"], batch["classes"], accuracy, target_save, tok, direct=direct 194 | ) 195 | 196 | 197 | def load_model(params, ckpt_path, devices, mesh_shape): 198 | network = CausalTransformer(params) 199 | start = time.time() 200 | network.state = read_ckpt(network.state, ckpt_path, devices.shape[1]) 201 | print(f"{ckpt_path} network loaded in {time.time() - start:.06}s on {jax.device_count()} devices") 202 | local_shards = max(jax.local_device_count() // mesh_shape[1], 1) 203 | del network.state["opt_state"] 204 | network.state = network.move_xmap(network.state, np.zeros(local_shards)) 205 | return network 206 | 207 | def eval_examples(data_examples, few_shot_prompts, few_shot_prompts_hint, direct=False): 208 | accurate_count = {} 209 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 210 | 211 | main_examples, hint_examples = [], [] 212 | pbar = tqdm(data_examples, smoothing=0) 213 | for data_example in pbar: 214 | main_examples.append(data_example) 215 | if len(main_examples) == args.eval_batch_size: 216 | successful_examples = eval_batch( 217 | main_examples, few_shot_prompts, seq, tokenizer, 218 | args.gen_length, gen_params, accurate_count, target_save, direct=direct 219 | ) 220 | for example_idx, example in enumerate(main_examples): 221 | if (example_idx not in successful_examples) and (random.random() < params.get('p_rationalization', 1.)): 222 | hint_examples.append(example) 223 | main_examples = [] 224 | 225 | if args.rationalize and len(hint_examples) >= args.eval_batch_size: 226 | cur_hint_examples = hint_examples[:args.eval_batch_size] 227 | cur_hint_examples = [ 228 | (hint_example_key + "_r", hint_example) for hint_example_key, hint_example in cur_hint_examples 229 | ] 230 | eval_batch( 231 | cur_hint_examples, few_shot_prompts_hint, hint_seq, tokenizer, 232 | args.gen_length, gen_params, accurate_count, target_save, hint=True, direct=direct 233 | ) 234 | hint_examples = hint_examples[args.eval_batch_size:] 235 | pbar.set_description(f"{split} " + ", ".join([ 236 | f"{cur_key}: {get_score(cur_counts):0.4f}" for cur_key, cur_counts in accurate_count.items() 237 | ])) 238 | return accurate_count 239 | 240 | def get_ckpt_path(params, ckpt_step=-1): 241 | bucket = params["bucket"] 242 | model_dir = params["model_dir"] 243 | if ckpt_step == -1: 244 | ckpt_step = params["total_steps"] 245 | return f"gs://{bucket}/" + (f"step_{ckpt_step}/" if ckpt_step > 10000 else f"{model_dir}/step_{ckpt_step}/") 246 | 247 | def set_opt(params): 248 | params["sampler"] = nucleaus_sample 249 | opt = optax.chain( 250 | optax.scale(1 / params.get("gradient_accumulation_steps", 1)), 251 | clip_by_global_norm(1), 252 | optax.scale_by_adam(), 253 | optax.additive_weight_decay(0), 254 | optax.scale(-1), 255 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)) 256 | ) 257 | params["optimizer"] = opt 258 | 259 | def get_dataset(args): 260 | if args.dataset_mode == "cqa": 261 | with jsonlines.open(f'commonsenseqa/{split}_rand_split.jsonl') as reader: 262 | dataset = [("cqa", example) for example in reader] 263 | elif args.dataset_mode == "gsm": 264 | with jsonlines.open(f'gsm/{split}_rand_split.jsonl') as reader: 265 | dataset = [("gsm", example) for example in reader] 266 | elif args.dataset_mode == "arithmetic": 267 | digit_range = list(range(1, 6)) 268 | dataset = [] 269 | for i in digit_range: 270 | with basic_open(f'arithmetic/train_scratch/{i}.txt') as f: 271 | dataset += [(str(i), example) for example in f.read().split('<|endoftext|>')] 272 | 273 | if split == "train": 274 | random.shuffle(dataset) 275 | dataset = dataset[:args.n_train_samples] 276 | return dataset 277 | 278 | def parse_args(): 279 | # Parse command line arguments 280 | parser = argparse.ArgumentParser() 281 | parser.add_argument("--config", type=str, default=None, help="Config file location") 282 | parser.add_argument('--direct', action='store_true', help="Whether to use direct prediction, sans scratchpad") 283 | parser.add_argument('--rationalize', action='store_true', help="Whether to use rationalization") 284 | parser.add_argument('--no_prompt', action='store_true', help="Whether to remove prompts during eval") 285 | parser.add_argument('--few_shot_train', action='store_true', help="Whether to remove few-shot-prompts during train") 286 | parser.add_argument('--show_hint_prompt', action='store_true', help="Whether a hint prompt will be necessary") 287 | parser.add_argument("--split", type=str, default="dev", help="Split") 288 | parser.add_argument("--dataset_mode", type=str, default="cqa", help="Which dataset to run on") 289 | parser.add_argument("--n_train_samples", type=int, default=3000, help="Number of training examples") 290 | parser.add_argument("--gen_length", type=int, default=96, help="Generation length") 291 | parser.add_argument("--eval_batch_size", type=int, default=8, help="Size of batches in eval") 292 | parser.add_argument("--p_show_hint_save", type=float, default=0.0, help="Percent of rationalization hints to save") 293 | parser.add_argument("--ckpt_step", type=int, default=-1, help="Which checkpoint to eval. -1 means the final one") 294 | parser.add_argument("--eval_seq", type=int, default=-1, help="Sequence length. -1 means the one in the param file") 295 | 296 | args = parser.parse_args() 297 | return args 298 | 299 | def transform_example(example): 300 | new_example = { 301 | "question": example["english_text"], 302 | "answer": "#### " + example["ans"] 303 | } 304 | return new_example 305 | 306 | if __name__ == "__main__": 307 | args = parse_args() 308 | print(args) 309 | split = args.split 310 | params = json.load(smart_open(args.config)) 311 | 312 | project = params.get("wandb_project", "mesh-transformer-jax") 313 | experiment_details = params["name"].split("_") 314 | wandb_name = "_".join(experiment_details[:-1]) 315 | wandb_iteration = int(experiment_details[-1]) 316 | wandb.init(project=project, name=wandb_name, config=params, resume=True) 317 | 318 | prompts_file = "prompts.txt" if not args.direct else "prompts_direct.txt" 319 | prompts_file = f"{args.dataset_mode}/{prompts_file}" 320 | if args.no_prompt: 321 | commonsense_prompts = [] 322 | else: 323 | with basic_open(prompts_file) as prompts: 324 | commonsense_prompts = prompts.read().split("\n\n") 325 | prompts_hint_file = "prompts_answer_key.txt" if not args.direct else "prompts_direct_answer_key.txt" 326 | prompts_hint_file = f"{args.dataset_mode}/{prompts_hint_file}" 327 | if args.no_prompt and not args.show_hint_prompt: 328 | commonsense_prompts_hint = [] 329 | else: 330 | with basic_open(prompts_hint_file) as prompts: 331 | commonsense_prompts_hint = prompts.read().split("\n\n") 332 | 333 | per_replica_batch = params["per_replica_batch"] 334 | cores_per_replica = params["cores_per_replica"] 335 | target_save = params["target_save"] if split != "dev" else f'{args.dataset_mode}/new_dev.txt' 336 | seq = params["seq"] if args.eval_seq == -1 else args.eval_seq 337 | hint_seq = seq 338 | set_opt(params) 339 | 340 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) 341 | devices = np.array(jax.devices()).reshape(mesh_shape) 342 | ckpt_path = get_ckpt_path(params, args.ckpt_step) 343 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 344 | network = load_model(params, ckpt_path, devices, mesh_shape) 345 | 346 | dataset = get_dataset(args) 347 | dataset_keys = set([datakey for datakey, _ in dataset]) 348 | 349 | total_batch = per_replica_batch * jax.device_count() // cores_per_replica * args.eval_batch_size 350 | gen_params = {"top_p": np.ones(total_batch) * 0.9, "temp": np.ones(total_batch) * 0.01} 351 | 352 | accurate_count = eval_examples(dataset, commonsense_prompts, commonsense_prompts_hint, direct=args.direct) 353 | for cur_key, cur_counts in accurate_count.items(): 354 | print(f"{split}, {cur_key}, {get_score(cur_counts)}") 355 | wandb.log({f"{split}_{cur_key}_accuracy": get_score(cur_counts), "iteration": wandb_iteration}) 356 | -------------------------------------------------------------------------------- /device_serve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import threading 4 | import time 5 | from queue import Queue, Empty 6 | 7 | import jax 8 | import numpy as np 9 | import optax 10 | 11 | from mesh_transformer import util 12 | from mesh_transformer.checkpoint import read_ckpt 13 | from mesh_transformer.sampling import nucleaus_sample 14 | from mesh_transformer.transformer_shard import CausalTransformer 15 | import transformers 16 | from smart_open import open 17 | 18 | from mesh_transformer.util import clip_by_global_norm 19 | 20 | from flask import Flask, request, make_response, jsonify 21 | app = Flask(__name__) 22 | 23 | requests_queue = Queue() 24 | 25 | """ 26 | curl --header "Content-Type: application/json" \ 27 | --request POST \ 28 | --data '{"context":"eleutherai", "top_p": 0.9, "temp": 0.75}' \ 29 | http://localhost:5000/complete 30 | """ 31 | 32 | 33 | def _build_cors_prelight_response(): 34 | response = make_response() 35 | response.headers.add("Access-Control-Allow-Origin", "*") 36 | response.headers.add('Access-Control-Allow-Headers', "*") 37 | response.headers.add('Access-Control-Allow-Methods', "*") 38 | return response 39 | 40 | 41 | def _corsify_actual_response(response): 42 | response.headers.add("Access-Control-Allow-Origin", "*") 43 | return response 44 | 45 | 46 | @app.route('/complete', methods=['POST', 'OPTIONS']) 47 | def complete(): 48 | if request.method == "OPTIONS": # CORS preflight 49 | return _build_cors_prelight_response() 50 | elif request.method == "POST": # The actual request following the preflight 51 | content = request.json 52 | 53 | if requests_queue.qsize() > 100: 54 | return {"error": "queue full, try again later"} 55 | 56 | response_queue = Queue() 57 | 58 | requests_queue.put(({ 59 | "context": content["context"], 60 | "top_p": float(content["top_p"]), 61 | "temp": float(content["temp"]) 62 | }, response_queue)) 63 | 64 | return _corsify_actual_response(jsonify({"completion": response_queue.get()})) 65 | else: 66 | raise RuntimeError("Weird - don't know how to handle method {}".format(request.method)) 67 | 68 | 69 | def parse_args(): 70 | # Parse command line arguments 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument("--config", type=str, default=None, help="Config file location") 73 | 74 | args = parser.parse_args() 75 | return args 76 | 77 | 78 | if __name__ == "__main__": 79 | threading.Thread(target=app.run, kwargs={"port": 5000, "host": "0.0.0.0"}).start() 80 | 81 | args = parse_args() 82 | params = json.load(open(args.config)) 83 | 84 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 85 | per_replica_batch = params["per_replica_batch"] 86 | cores_per_replica = params["cores_per_replica"] 87 | 88 | assert cores_per_replica <= 8 89 | 90 | bucket = params["bucket"] 91 | model_dir = params["model_dir"] 92 | layers = params["layers"] 93 | d_model = params["d_model"] 94 | n_heads = params["n_heads"] 95 | n_vocab = params["n_vocab"] 96 | seq = params["seq"] 97 | norm = params["norm"] 98 | 99 | params["sampler"] = nucleaus_sample 100 | opt = optax.chain( 101 | optax.scale(1 / gradient_accumulation_steps), 102 | clip_by_global_norm(1), 103 | optax.scale_by_adam(), 104 | optax.additive_weight_decay(0), 105 | optax.scale(-1), 106 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)) 107 | ) 108 | 109 | params["optimizer"] = opt 110 | 111 | start = time.time() 112 | print(f"jax devices: {jax.device_count()}") 113 | print(f"jax runtime initialized in {time.time() - start:.06}s") 114 | 115 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) 116 | devices = np.array(jax.devices()).reshape(mesh_shape) 117 | 118 | with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f: 119 | meta = json.load(f) 120 | 121 | ckpt_step = meta["checkpoints"][-1] 122 | print(f"using checkpoint {ckpt_step}") 123 | 124 | total_batch = per_replica_batch * jax.device_count() // cores_per_replica * 8 125 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 126 | network = CausalTransformer(params) 127 | 128 | start = time.time() 129 | network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1]) 130 | print(f"network loaded in {time.time() - start:.06}s") 131 | 132 | local_shards = max(jax.local_device_count() // mesh_shape[1], 1) 133 | del network.state["opt_state"] 134 | network.state = network.move_xmap(network.state, np.zeros(local_shards)) 135 | 136 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 137 | 138 | while True: 139 | all_ctx = [] 140 | all_top_p = [] 141 | all_temp = [] 142 | all_q = [] 143 | while len(all_ctx) < total_batch: 144 | try: 145 | o, q = requests_queue.get(block=False) 146 | all_ctx.append(o["context"]) 147 | all_top_p.append(o["top_p"]) 148 | all_temp.append(o["temp"]) 149 | all_q.append(q) 150 | except Empty: 151 | if len(all_ctx): 152 | break 153 | else: 154 | time.sleep(0.01) 155 | 156 | start = time.time() 157 | while len(all_ctx) < total_batch: 158 | all_ctx.append("whatever") 159 | all_top_p.append(1) 160 | all_temp.append(1) 161 | 162 | all_tokenized = [] 163 | all_length = [] 164 | for ctx in all_ctx: 165 | padded_tokens = np.zeros(seq).astype(np.uint32) 166 | length = 0 167 | 168 | try: 169 | tokens = tokenizer.encode(ctx) 170 | provided_ctx = len(tokens) 171 | pad_amount = seq - provided_ctx 172 | 173 | pad_amount = max(pad_amount, 0) 174 | 175 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)[-seq:] 176 | length = len(tokens) 177 | except: 178 | print("oops exception") 179 | 180 | all_tokenized.append(padded_tokens) 181 | all_length.append(length) 182 | 183 | output = network.generate(np.array(all_tokenized), 184 | np.array(all_length), 185 | 256, 186 | { 187 | "top_p": np.array(all_top_p), 188 | "temp": np.array(all_temp) 189 | }) 190 | 191 | for o, q in zip(output[1][0][:, :, 0], all_q): 192 | q.put(tokenizer.decode(o)) 193 | 194 | print(f"completion done in {time.time() - start:06}s") 195 | -------------------------------------------------------------------------------- /device_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | import jax 6 | import numpy as np 7 | import optax 8 | 9 | import wandb 10 | from tqdm import tqdm 11 | 12 | 13 | from mesh_transformer import util 14 | from mesh_transformer.checkpoint import read_ckpt, write_ckpt 15 | from mesh_transformer.transformer_shard import CausalTransformer 16 | from tfrecord_loader import TFRecordNewInputs 17 | from smart_open import open 18 | from google.cloud import storage 19 | from google.cloud.exceptions import NotFound 20 | 21 | from mesh_transformer.util import clip_by_global_norm, additive_weight_decay 22 | 23 | 24 | def parse_args(): 25 | # Parse command line arguments 26 | parser = argparse.ArgumentParser(description=""" 27 | To use, download the full checkpoint archive, extract and upload to a GCS bucket, and set that as --tune-model-path 28 | Modify the config file: 29 | - set `model_dir` to where the checkpoints should be written during training 30 | - set `train_set`, `val_set` to index files for your data 31 | - set `tpu_size` to 8 (if on a v3-8) 32 | - set `warmup_steps`, `anneal_steps`, `lr`, `end_lr` to the lr schedule for your finetuning run 33 | - the global step will reset to 0, keep that in mind when writing your lr schedule 34 | - set `name` to specify the name of the Weights & Biases run 35 | - set `wandb_project` to specify the Weights & Biases project to log to 36 | To prepare data in the expected data format: 37 | - use the script `create_finetune_tfrecords.py` in this repo to create data in the expected format 38 | - upload the .tfrecords files to GCS 39 | - save their GCS paths to a index file under `data/`, see existing files for examples 40 | """, 41 | formatter_class=argparse.RawTextHelpFormatter) 42 | parser.add_argument("--config", type=str, default=None, help="Config file location") 43 | parser.add_argument("--tune-model-path", type=str, default=None, help="Base model to finetune") 44 | parser.add_argument("--fresh-opt", default=False, action="store_true", help="Use a newly initialized optimizer, ignoring any optimizer state saved in the base checkpoint") 45 | 46 | args = parser.parse_args() 47 | return args 48 | 49 | 50 | def save(network, step, bucket, path, mp, aux=None, keep_n=20, delete_old=True): 51 | assert path 52 | client = storage.Client() 53 | 54 | if aux is None: 55 | aux = {} 56 | 57 | try: 58 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f: 59 | meta = json.load(f) 60 | except: 61 | # create metadata file 62 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f: 63 | json.dump({ 64 | "step": 0, 65 | "checkpoints": [], 66 | "aux": {} 67 | }, f) 68 | 69 | # do sharded checkpoint writing 70 | start = time.time() 71 | res = [] 72 | for shard_id in range(mp): 73 | write_ckpt(network.state, f"gs://{bucket}/{path}/step_{step}/", shard_id) 74 | 75 | print(f"Wrote checkpoint in {time.time() - start:.06}s") 76 | 77 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f: 78 | meta = json.load(f) 79 | 80 | meta["step"] = step 81 | meta["checkpoints"].append(step) 82 | all_aux = meta.get("aux", {}) 83 | 84 | while len(meta["checkpoints"]) > keep_n: 85 | ckpt_to_delete = meta["checkpoints"].pop(0) 86 | 87 | try: 88 | del all_aux[str(ckpt_to_delete)] 89 | except: 90 | print(f"failed to delete the aux state for {step}") 91 | 92 | if delete_old: 93 | print(f"deleting checkpoint {ckpt_to_delete}") 94 | for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"): 95 | # print(f"deleting {blob.name}") 96 | assert path in blob.name 97 | blob.delete() 98 | else: 99 | print(f"keeping checkpoint {ckpt_to_delete}") 100 | 101 | all_aux[step] = aux 102 | meta["aux"] = all_aux 103 | 104 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f: 105 | json.dump(meta, f) 106 | 107 | 108 | def get_mask_one_locations(single_sequence): 109 | separator = np.where(single_sequence == 25)[0] 110 | separator_locations = [] 111 | endoftext = np.where(single_sequence == 50256)[0] 112 | 113 | for endoftext_token_location in endoftext: 114 | locations = separator[separator <= endoftext_token_location] 115 | if len(locations) > 0: 116 | separator_locations.append(locations[-1]) 117 | else: 118 | separator_locations.append(0) 119 | separator = np.asarray(separator_locations) 120 | mask_one_locations = [(i+1, j) for i, j in zip(separator, endoftext)] 121 | return mask_one_locations 122 | 123 | 124 | def find_real_target_mask(single_sequence): 125 | mask_one_locations = get_mask_one_locations(single_sequence) 126 | 127 | mask = np.zeros(len(single_sequence)) 128 | for i, j in mask_one_locations: 129 | np.put(mask, np.arange(i, j+1), 1.) 130 | return mask 131 | 132 | 133 | def train_step(network, data): 134 | tgt = data[:, :, 1:] 135 | 136 | all_masks = [] 137 | for single_tgt in tgt: 138 | mask = find_real_target_mask(np.squeeze(single_tgt)) 139 | all_masks.append(np.expand_dims(mask, (0, 1))) 140 | all_masks = np.concatenate(all_masks, axis=0) 141 | 142 | inputs = { 143 | "obs": data[:, :, :-1], 144 | "target": tgt, 145 | "mask": all_masks 146 | } 147 | 148 | loss, last_loss, grad_norm, grad_norm_micro = network.train(inputs) 149 | 150 | return ( 151 | np.array(loss).mean(), 152 | np.array(last_loss).mean(), 153 | np.array(grad_norm).mean(), 154 | np.array(grad_norm_micro).mean(), 155 | ) 156 | 157 | 158 | def eval_step(network, data): 159 | tgt = data[:, 1:] 160 | 161 | all_masks = [] 162 | for single_tgt in tgt: 163 | mask_one_locations = get_mask_one_locations(single_tgt) 164 | mask = find_real_target_mask(np.squeeze(single_tgt)) 165 | all_masks.append(np.expand_dims(mask, 0)) 166 | all_masks = np.concatenate(all_masks, axis=0) 167 | 168 | inputs = { 169 | "obs": data[:, :-1], 170 | "target": tgt, 171 | "mask": all_masks 172 | } 173 | 174 | out = network.eval(inputs) 175 | loss = out["loss"] 176 | correct = out['correct'] 177 | 178 | correct_sequences = 0 179 | total_sequences = 0 180 | for i, j in mask_one_locations: 181 | total_sequences += 1 182 | if np.all(correct[0][i:j+1]==1): 183 | correct_sequences += 1 184 | return np.array(loss).mean(), np.array(correct).mean() / all_masks.mean(), correct_sequences, total_sequences 185 | 186 | 187 | if __name__ == "__main__": 188 | args = parse_args() 189 | params = json.load(open(args.config)) 190 | 191 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 192 | per_replica_batch = params["per_replica_batch"] 193 | cores_per_replica = params["cores_per_replica"] 194 | 195 | assert cores_per_replica <= 8 196 | 197 | bucket = params["bucket"] 198 | model_dir = params["model_dir"] 199 | layers = params["layers"] 200 | d_model = params["d_model"] 201 | n_heads = params["n_heads"] 202 | n_vocab = params["n_vocab"] 203 | seq = params["seq"] 204 | norm = params["norm"] 205 | 206 | val_batches = params["val_batches"] 207 | val_every = params["val_every"] 208 | ckpt_every = params["ckpt_every"] 209 | keep_every = params["keep_every"] 210 | eval_tasks = params["eval_harness_tasks"] 211 | total_steps = params["total_steps"] 212 | 213 | pe = params["pe"] 214 | assert pe in ["fixed", "rotary", "t5"] 215 | 216 | warmup_steps = params["warmup_steps"] 217 | anneal_steps = params["anneal_steps"] 218 | lr = params["lr"] 219 | end_lr = params["end_lr"] 220 | weight_decay = params["weight_decay"] 221 | 222 | # alpha parameter for the exponential moving averages used to compute B_simple 223 | noise_scale_alpha = params.get("noise_scale_alpha", 0.01) 224 | 225 | scheduler = util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr) 226 | 227 | opt = optax.chain( 228 | optax.scale(1 / gradient_accumulation_steps), 229 | clip_by_global_norm(1), 230 | optax.scale_by_adam(), 231 | additive_weight_decay(weight_decay), 232 | optax.scale(-1), 233 | optax.scale_by_schedule(scheduler) 234 | ) 235 | 236 | params["optimizer"] = opt 237 | 238 | start = time.time() 239 | tpu_size = jax.device_count() 240 | if tpu_size < cores_per_replica: 241 | msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})" 242 | raise ValueError(msg) 243 | print(f"jax devices: {tpu_size}") 244 | print(f"jax runtime initialized in {time.time() - start:.06}s") 245 | 246 | mesh_shape = (tpu_size // cores_per_replica, cores_per_replica) 247 | devices = np.array(jax.devices()).reshape(mesh_shape) 248 | 249 | # pick initial ckpt - based on tuning vs train from scratch 250 | 251 | step = 0 252 | initial_ckpt_state_path = None 253 | train_loader = None 254 | 255 | if args.tune_model_path: 256 | print('`--tune_model_path` passed: we are beginning a fine-tuning run') 257 | fine_tuning = True 258 | initial_ckpt_state_path = args.tune_model_path 259 | else: 260 | print('`--tune_model_path` not passed: we are continuing a fine-tuning run from a checkpoint (or we are not fine-tuning)') 261 | fine_tuning = False 262 | initial_ckpt_model_dir = model_dir 263 | initial_ckpt_path = f"gs://{bucket}/{initial_ckpt_model_dir}" 264 | meta_path = f"{initial_ckpt_path}/meta.json" 265 | 266 | try: 267 | with open(meta_path, "r") as f: 268 | meta = json.load(f) 269 | ckpt_step = meta["checkpoints"][-1] 270 | initial_ckpt_state_path = f"{initial_ckpt_path}/step_{ckpt_step}/" 271 | print(f"state will be restored from checkpoint {ckpt_step}") 272 | 273 | step = ckpt_step 274 | train_loader = meta['aux'][str(ckpt_step)].get("train_loader", None) 275 | except NotFound: 276 | # no checkpoint, start at zero 277 | print(f"No checkpoint to load at {initial_ckpt_path}. Training from scratch.") 278 | 279 | if initial_ckpt_state_path: 280 | print(f"path to load checkpoint from: {initial_ckpt_state_path}") 281 | else: 282 | print("not loading from a checkpoint") 283 | 284 | # set up datasets 285 | print("setting up datasets") 286 | 287 | train_dataset = TFRecordNewInputs(f"data/{params['train_set']}", 288 | batch_size=( 289 | gradient_accumulation_steps, 290 | per_replica_batch * tpu_size // cores_per_replica), 291 | sample_size=params['seq'], 292 | restore_state=train_loader) 293 | 294 | global_val_batch = per_replica_batch * tpu_size // cores_per_replica 295 | 296 | val_sets = {} 297 | 298 | for k, v in params["val_set"].items(): 299 | val_sets[k] = TFRecordNewInputs( 300 | f"data/{v}", batch_size=(global_val_batch,), sample_size=seq 301 | ) 302 | 303 | # tok/sec metrics 304 | sequences_per_step = gradient_accumulation_steps * (per_replica_batch * tpu_size // cores_per_replica) 305 | tokens_per_step = params['seq'] * sequences_per_step 306 | 307 | # load + run 308 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 309 | print("initializing network") 310 | network = CausalTransformer(params) 311 | 312 | if initial_ckpt_state_path: 313 | print("loading network") 314 | if fine_tuning: 315 | # get the scheduler step stored in the just-initialized optimizer 316 | # should be zero 317 | init_sched_state = network.state["opt_state"][-1] 318 | 319 | start = time.time() 320 | network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1], load_opt=(not args.fresh_opt)) 321 | 322 | if fine_tuning: 323 | # overwrite the loaded scheduler step with zeros 324 | # this makes fine-tuning use the lr schedule in 325 | network.state["opt_state"][-1] = init_sched_state 326 | 327 | print(f"network loaded in {time.time() - start:.06}s") 328 | 329 | print('compiling train fn') 330 | start = time.time() 331 | loss, last_loss, grad_norm, grad_norm_micro = train_step( 332 | network, train_dataset.get_samples() 333 | ) 334 | step += 1 335 | print(f"Train fn compiled in {time.time() - start:.06}s") 336 | 337 | print('compiling eval fn') 338 | start = time.time() 339 | for val_set in val_sets.values(): 340 | eval_step(network, val_set.get_samples()) 341 | val_set.reset() 342 | print(f"Eval fn compiled in {time.time() - start:.06}s") 343 | 344 | project = params.get("wandb_project", "mesh-transformer-jax") 345 | wandb.init(project=project, name=params["name"], config=params) 346 | 347 | G_noise_avg = None 348 | S_noise_avg = None 349 | 350 | while True: 351 | if (step > 1) and ((step % ckpt_every == 1) or (step == total_steps)): 352 | print(f"saving a checkpoint for step {step}") 353 | save(network, step, bucket, model_dir, 354 | mp=cores_per_replica, 355 | aux={"train_loader": train_dataset.get_state()}, 356 | delete_old=True, 357 | ) 358 | 359 | if step == total_steps: 360 | print("training completed!") 361 | exit() 362 | 363 | start = time.time() 364 | loss, last_loss, grad_norm, grad_norm_micro = train_step( 365 | network, train_dataset.get_samples() 366 | ) 367 | step += 1 368 | 369 | steps_per_sec = 1 / (time.time() - start) 370 | tokens_per_sec = tokens_per_step * steps_per_sec 371 | sequences_processed = sequences_per_step * step 372 | tokens_processed = tokens_per_step * step 373 | 374 | ### compute summary stats about the gradient 375 | 376 | # converts from grads-summed-over-microbatch (what `CasualTransformer.train` computes) 377 | # to grads-averaged-over-microbatch (what we want) 378 | # 379 | # (when taking gradient steps, the same conversion happens inside the optimizer 380 | # via optax.scale(1 / gradient_accumulation_steps)) 381 | grad_norm = grad_norm / gradient_accumulation_steps 382 | 383 | # compute G_noise and S_noise 384 | # from "An Empirical Model of Large-Batch Training" Appendix A.1 385 | # here, B_big = gradient_accumulation_steps, and B_small = 1 for convenience 386 | gbsmall = grad_norm_micro ** 2 387 | gbbig = grad_norm ** 2 388 | G_noise = (gradient_accumulation_steps * gbbig - gbsmall) / ( 389 | gradient_accumulation_steps - 1 390 | ) 391 | S_noise = (gbsmall - gbbig) / (1 - 1 / gradient_accumulation_steps) 392 | 393 | noise_scale_stats = { 394 | "noise/G_noise": G_noise, 395 | "noise/S_noise": S_noise, 396 | } 397 | 398 | # heuristic to avoid reporting G_noise in very early training when gradients are large 399 | # (these take a long time to wash out of the moving average that defines B_simple) 400 | use_step_in_noise_avgs = gbbig < 2 401 | 402 | if use_step_in_noise_avgs: 403 | # compute moving averages of G_noise and S_noise, for B_simple 404 | if G_noise_avg is None: 405 | G_noise_avg = G_noise 406 | else: 407 | G_noise_avg = (1 - noise_scale_alpha) * G_noise_avg + noise_scale_alpha * G_noise 408 | 409 | if S_noise_avg is None: 410 | S_noise_avg = S_noise 411 | else: 412 | S_noise_avg = (1 - noise_scale_alpha) * S_noise_avg + noise_scale_alpha * S_noise 413 | 414 | B_simple = S_noise_avg / G_noise_avg 415 | 416 | noise_scale_stats.update( 417 | { 418 | "noise/G_noise_avg": G_noise_avg, 419 | "noise/S_noise_avg": S_noise_avg, 420 | "noise/B_simple": B_simple, 421 | } 422 | ) 423 | 424 | wandb_stats = { 425 | "train/loss": loss, 426 | "train/last_loss": last_loss, 427 | "train/steps_per_sec": steps_per_sec, 428 | "train/tokens_per_sec": tokens_per_sec, 429 | "train/grad_norm": grad_norm, 430 | "train/learning_rate": float(scheduler(network.state["opt_state"][-1].count[0].item())), 431 | "sequences_processed": sequences_processed, 432 | "tokens_processed": tokens_processed, 433 | } 434 | wandb_stats.update(noise_scale_stats) 435 | 436 | wandb.log(wandb_stats, step) 437 | -------------------------------------------------------------------------------- /docker/.env: -------------------------------------------------------------------------------- 1 | DRYRUN=false 2 | DOMAIN=yourdomain.com 3 | SUBDOMAIN=gptj.api 4 | NGINX_PATH=./nginx_proxyvm.conf 5 | CERTS_PATH=/path/to/certs/ 6 | DNS_KEYS=/path/to/keys/ 7 | MODEL_DIR=/your/path/to/model/step_383500 8 | TPU_NAME=YOUR_TPU_NAME -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Have tested with a custom Ubuntu-1804 / Python 3.7 / Tensorflow 2.5.0 Base Image 2 | # Not tested with this image. 3 | FROM tensorflow/tensorflow:2.5.0 4 | RUN apt update && \ 5 | apt-get install git -y 6 | 7 | WORKDIR /app/ 8 | COPY . /app/ 9 | RUN git clone https://github.com/kingoflolz/mesh-transformer-jax && \ 10 | pip install -r mesh-transformer-jax/requirements.txt && \ 11 | pip install mesh-transformer-jax/ jax==0.2.12 && \ 12 | pip install fastapi uvicorn requests aiofiles aiohttp && \ 13 | ln -s /app/start.sh /start.sh 14 | 15 | ENV PYTHONPATH /app:/app/mesh-transformer-jax:/usr/local/bin/python3 16 | ENV PATH $PYTHONPATH:$PATH 17 | ENV TOKENIZERS_PARALLELISM=true 18 | EXPOSE 80 19 | 20 | CMD ["/start.sh"] 21 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Deploying GPT-J (TPU VM) in a Docker Container 2 | 3 | This PR enables you to run GPT-J for Inference in a Docker Container, and has run stable over the past week. 4 | 5 | The function currently requires 2 calls, the first is a POST call to /model/predict with the params 6 | 7 | ``` 8 | context: str 9 | temp: Optional[float] = 1.0 10 | top_p: Optional[float] = 0.9 11 | top_k: Optional[int] = 50 12 | length: Optional[int] = 256 13 | ``` 14 | 15 | This returns an ID (int) that is used to retrieve the prediction once the prediction is complete, rather than waiting until the prediction is done. 16 | 17 | The second function call is a POST call to /model/get_prediction with `qid=id` that was returned. 18 | 19 | ## Getting Started 20 | 21 | You will need the following: 22 | 1) TPU VM: A TPU (v2-8 works fine) 23 | 2) Proxy VM: a second VM within the same zone. This VM will only serve to proxy all requests to the TPUVM, so it does not require a lot of resources. 24 | 3) [Docker + Docker Compose](https://docs.docker.com/compose/install/) Installed on both VMs 25 | 4) TPU VM and Proxy VM on the same GCP Network (default for most) 26 | 5) Ports 80 and 443 exposed on the Proxy VM 27 | 6) Your DNS records pointing to the External IP for Proxy VM. 28 | 29 | The GPTJ Checkpoint should be downloaded locally and the volume should be set as the variable MODEL_DIR in the .env file. This mounts the volume to the container, preventing multiple re-downloads. 30 | 31 | ## Files to Modify 32 | 33 | `.env` 34 | 35 | ```bash 36 | # Modify these values in .env 37 | DRYRUN=false # Set to True to prevent SSL Lets Encrypt Domain validation when testing 38 | DOMAIN=yourdomain.com # Set to your domain 39 | SUBDOMAIN=gptj.api # Set to your subdomain. The endpoint will be gptj.api.yourdomain.com 40 | NGINX_PATH=./nginx_proxyvm.conf # Modify this conf file with the Internal IP of the TPU VM. This value will not change. 41 | CERTS_PATH=/path/to/certs/ # Set this to a path on the Proxy VM that be used to store SSL Certs. This path will automatically be created. 42 | DNS_KEYS=/path/to/keys/ # Set this to a path on the Proxy VM that be used to store DNS Keys. This path will automatically be created. 43 | MODEL_DIR=/your/path/to/model/step_383500 # Set this to the path on TPU VM that the model is stored in 44 | TPU_NAME=YOUR_TPU_NAME # Set this to your TPU Name as created 45 | ``` 46 | 47 | `nginx_proxyvm.conf` 48 | 49 | Modify the IP to your TPU VM's *Internal* IP Address 50 | 51 | 52 | ## Running the Docker Container 53 | 54 | Assuming you are in this directory as the working directory. 55 | 56 | On TPU VM: 57 | 58 | `docker-compose up -d` 59 | 60 | On Proxy VM: 61 | 62 | `docker-compose -f compose-proxy.yaml up` 63 | 64 | ## Final Notes 65 | 66 | The reason you can't directly serve from the TPU VM is because there is no control over the TPU VM's port firewalls. I go into more details about [Dockerization within TPU VMs here](https://trisongz.medium.com/accessing-your-tpus-in-docker-containers-with-tpu-vm-e944f5909dd4). 67 | 68 | The API is _very_ barebones and stripped down for simplicity. It's meant as a starting point and can be further extended for your needs. 69 | 70 | -------------------------------------------------------------------------------- /docker/compose-proxy.yaml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | services: 3 | gptjapi: 4 | image: ghcr.io/linuxserver/swag 5 | container_name: gptjapi 6 | cap_add: 7 | - NET_ADMIN 8 | environment: 9 | - PUID=1000 10 | - PGID=1000 11 | - TZ=America/Chicago 12 | - URL=${DOMAIN} 13 | - SUBDOMAINS=${SUBDOMAIN}, 14 | - VALIDATION=http 15 | - STAGING=${DRYRUN} 16 | - ONLY_SUBDOMAINS=true 17 | volumes: 18 | - ${NGINX_PATH}:/config/nginx/proxy-confs/app.subdomain.conf 19 | - ${CERTS_PATH}:/config/etc/letsencrypt/ 20 | - ${DNS_KEYS}:/config/keys/ 21 | ports: 22 | - 443:443 23 | - 80:80 24 | restart: unless-stopped 25 | networks: 26 | - production 27 | networks: 28 | production: 29 | -------------------------------------------------------------------------------- /docker/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | services: 3 | gptjserver: 4 | image: gptjserver 5 | container_name: gptjserver 6 | cap_add: 7 | - ALL 8 | environment: 9 | - TPU_NAME=${TPU_NAME} 10 | - TF_CPP_MIN_LOG_LEVEL=0 11 | - XRT_TPU_CONFIG="localservice;0;localhost:51011" 12 | - TF_XLA_FLAGS=--tf_xla_enable_xla_devices 13 | build: 14 | context: . 15 | dockerfile: Dockerfile 16 | ports: 17 | - 8080:80 18 | volumes: 19 | - ${MODEL_DIR}:/app/model 20 | - /var/run/docker.sock:/var/run/docker.sock 21 | - /usr/share/tpu/:/usr/share/tpu/ 22 | - /lib/libtpu.so:/lib/libtpu.so 23 | privileged: true 24 | restart: unless-stopped 25 | devices: 26 | - "/dev:/dev" 27 | networks: 28 | - production 29 | networks: 30 | production: 31 | -------------------------------------------------------------------------------- /docker/main.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import uvicorn 4 | import traceback 5 | import logging 6 | 7 | from fastapi import FastAPI 8 | from starlette.middleware.cors import CORSMiddleware 9 | from .payloads import CompletionPayload, CompletionResponse, QueueRequest, QueueResponse 10 | from .ops import get_gptj_model 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | app = FastAPI() 15 | app.add_middleware( 16 | CORSMiddleware, 17 | allow_origins=["*"], 18 | allow_credentials=True, 19 | allow_methods=["*"], 20 | allow_headers=["*"], 21 | ) 22 | 23 | 24 | # globals 25 | MODEL_API = None 26 | 27 | 28 | @app.on_event("startup") 29 | async def startup_event(): 30 | global MODEL_API 31 | try: 32 | MODEL_API = get_gptj_model() 33 | MODEL_API.load_model() 34 | MODEL_API.start_background() 35 | except Exception as e: 36 | logger.debug(f"Model could not be loaded: {str(e)}") 37 | traceback.print_exc() 38 | 39 | 40 | @app.post("/model/get_prediction") 41 | def get_prediction(payload: QueueRequest) -> CompletionResponse: 42 | res = MODEL_API.wait_for_queue(payload.qid) 43 | return CompletionResponse(**res) 44 | 45 | 46 | @app.post("/model/predict") 47 | def model_prediction(payload: CompletionPayload) -> QueueResponse: 48 | res = MODEL_API.add_to_queue(payload) 49 | return QueueResponse(qid=res['qid']) 50 | 51 | 52 | if __name__ == "__main__": 53 | uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) 54 | -------------------------------------------------------------------------------- /docker/nginx_proxyvm.conf: -------------------------------------------------------------------------------- 1 | server { 2 | listen 443; 3 | server_name gptj.api.yourdomain.com; 4 | include /config/nginx/ssl.conf; 5 | client_max_body_size 0; 6 | location / { 7 | include /config/nginx/proxy.conf; 8 | # The below IP should be your TPUVM Internal IP 9 | proxy_pass http://10.000.0.1:8080; 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /docker/ops.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import jax 5 | import optax 6 | import threading 7 | import numpy as np 8 | import logging 9 | from queue import Queue, Empty 10 | from jax.experimental import maps 11 | from transformers import GPT2TokenizerFast 12 | 13 | from mesh_transformer.checkpoint import read_ckpt 14 | from mesh_transformer.sampling import nucleaus_sample 15 | from mesh_transformer.transformer_shard import CausalTransformer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | # This prevents fastapi from creating multiple models in multi-worker mode 20 | # leading to OOM crashes 21 | gptj_model = None 22 | gptj_model_lock = threading.Lock() 23 | 24 | def compile_model(): 25 | global gptj_model 26 | with gptj_model_lock: 27 | if gptj_model: 28 | return 29 | gptj_model = GPTJ() 30 | 31 | def get_gptj_model(): 32 | compile_model() 33 | return gptj_model 34 | 35 | def timer(start_time=None): 36 | if not start_time: 37 | return time.time() 38 | return time.time() - start_time 39 | 40 | 41 | class GPTJ: 42 | def __init__(self): 43 | self.params = { 44 | "layers": 28, 45 | "d_model": 4096, 46 | "n_heads": 16, 47 | "n_vocab": 50400, 48 | "norm": "layernorm", 49 | "pe": "rotary", 50 | "pe_rotary_dims": 64, 51 | "seq": 2048, 52 | "cores_per_replica": 8, 53 | "per_replica_batch": 1, 54 | "sampler": nucleaus_sample, 55 | "optimizer": optax.scale(0) 56 | } 57 | self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') 58 | self.queue_ids = {} 59 | self.qidx = 0 60 | self.queue = Queue() 61 | self.network = None 62 | self.lock = threading.Lock() 63 | self._alive_time = timer() 64 | 65 | def load_model(self): 66 | if self.network: 67 | logger.info('Attempting to reload model when model is loaded. Returning') 68 | return 69 | with self.lock: 70 | logger.info('Loading Model') 71 | start = timer() 72 | logger.info(f"JAX Devices: {jax.device_count()}") 73 | logger.info(f"JAX Runtime Initialized in {timer(start):.06} secs") 74 | mesh_shape = (jax.device_count() // self.params['cores_per_replica'], self.params['cores_per_replica']) 75 | self.devices = np.array(jax.devices()).reshape(mesh_shape) 76 | self.total_batch = self.params['per_replica_batch'] * jax.device_count() // self.params['cores_per_replica'] * 8 77 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(self.devices, ('dp', 'mp'))) 78 | network = CausalTransformer(self.params) 79 | logger.info(f'Loading Checkpoint') 80 | network.state = read_ckpt(network.state, "/app/model/", self.devices.shape[1]) 81 | logger.info(f"GPTJ Network loaded in {timer(start):.06} secs. Total Batch Size: {self.total_batch}") 82 | del network.state["opt_state"] 83 | network.state = network.move_xmap(network.state, np.zeros(self.params['cores_per_replica'])) 84 | self.network = network 85 | 86 | 87 | def start_background(self): 88 | with self.lock: 89 | t = threading.Thread(target=self.background) 90 | t.start() 91 | 92 | def prepare_item(self, context, length=256): 93 | tokens = self.tokenizer.encode(context) 94 | logger.info(tokens) 95 | token_length = len(tokens) 96 | pad_amount = self.params['seq'] - token_length 97 | pad_amount = max(pad_amount, 0) 98 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)[-self.params['seq']:] 99 | return {'tokens': padded_tokens, 'length': token_length} 100 | 101 | # Single Item - Not Tested 102 | def infer(self, context, top_p=0.9, top_k=40, temp=1.0, length=256, **kwargs): 103 | item = self.prepare_item(context, length) 104 | batched_tokens = np.array([item['tokens']] * self.total_batch) 105 | batched_lengths = np.array([item['length']] * self.total_batch) 106 | start = timer() 107 | output = self.network.generate( 108 | batched_tokens, batched_lengths, length, 109 | { 110 | "top_p": np.ones(self.total_batch) * top_p, 111 | "top_k": np.ones(self.total_batch) * top_k, 112 | "temp": np.ones(self.total_batch) * temp, 113 | } 114 | ) 115 | samples = [] 116 | decoded_tokens = output[1][0] 117 | end_time = timer(start) 118 | for o in decoded_tokens[:, :, 0]: 119 | res = { 120 | 'context': context, 121 | 'completion': self.tokenizer.decode(o), 122 | 'time': end_time 123 | } 124 | samples.append(res) 125 | logger.info(f"Completion done in {end_time:06} secs") 126 | return samples 127 | 128 | def infer_batch(self, batch, **kwargs): 129 | logger.info(f'Starting Inference on Batch') 130 | batch_items = {'tokens': [], 'lengths': [], 'top_p': [], 'top_k': [], 'temp': []} 131 | max_lengths, contexts = [], [] 132 | for req in batch: 133 | req = self.to_data(req) 134 | item = self.prepare_item(req['context'], req['length']) 135 | batch_items['tokens'].append(item['tokens']) 136 | batch_items['lengths'].append(item['length']) 137 | batch_items['top_p'].append(req['top_p']) 138 | batch_items['top_k'].append(req['top_k']) 139 | batch_items['temp'].append(req['temp']) 140 | max_lengths.append(req['length']) 141 | contexts.append(req['context']) 142 | 143 | max_length = max(max_lengths) 144 | for key, vals in batch_items.items(): 145 | batch_items[key] = np.array(vals) 146 | start = timer() 147 | logger.info(f'Completed Preparing Batch') 148 | output = self.network.generate( 149 | batch_items['tokens'], batch_items['lengths'], max_length, 150 | { 151 | "top_p": batch_items['top_p'], 152 | "top_k": batch_items['top_k'], 153 | "temp": batch_items['temp'], 154 | } 155 | ) 156 | logger.info(f'Completed Generation') 157 | samples = [] 158 | end_time = timer(start) 159 | for pred, ctx in zip(output[1][0][:, :, 0], contexts): 160 | res = { 161 | 'context': ctx, 162 | 'completion': self.tokenizer.decode(pred), 163 | 'time': end_time 164 | } 165 | samples.append(res) 166 | logger.info(f"Completion done in {end_time:06} secs") 167 | return samples 168 | 169 | def add_to_queue(self, item): 170 | self.qidx += 1 171 | self.queue.put({'item': self.to_data(item), 'qidx': self.qidx}) 172 | self.queue_ids[self.qidx] = Queue() 173 | return {'qid': self.qidx} 174 | 175 | def wait_for_queue(self, qid): 176 | if not self.queue_ids.get(qid): 177 | return {'Error': 'QID not found'} 178 | return self.queue_ids[qid].get() 179 | 180 | def background(self): 181 | logger.info(f'Init Background') 182 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(self.devices, ('dp', 'mp'))) 183 | while True: 184 | batch, qids = [], [] 185 | while len(batch) <= self.total_batch: 186 | try: 187 | req = self.queue.get(block=False) 188 | logger.info(f'Got Queue Item: {req}') 189 | batch.append(req['item']) 190 | qids.append(req['qidx']) 191 | 192 | except Empty: 193 | if len(batch): 194 | break 195 | else: 196 | time.sleep(0.01) 197 | batch_size = len(batch) 198 | logger.info(f'Working on Batch: {batch_size} - {qids}') 199 | while len(batch) < self.total_batch: 200 | batch.append(self.placeholder_item) 201 | start = timer() 202 | results = self.infer_batch(batch) 203 | for res, qid in zip(results, qids): 204 | self.queue_ids[qid].put(res) 205 | logger.info(f'Completed Current Batch of {batch_size} Items in {timer(start):.2f} secs') 206 | 207 | @property 208 | def placeholder_item(self): 209 | return {'context': 'nada', 'top_p': 0.9, 'top_k': 40, 'temp': 1.0, 'length': 1} 210 | 211 | def to_data(self, item): 212 | try: 213 | return {'context': item.context, 'top_p': item.top_p, 'top_k': item.top_k, 'temp': item.temp, 'length': item.length} 214 | except: 215 | return {'context': item.get('context', ''), 'top_p': item.get('top_p', 0.9), 'top_k': item.get('top_k', 40), 'temp': item.get('temp', 1.0), 'length': item.get('length', 256)} 216 | 217 | @property 218 | def alive_time(self): 219 | return timer(self._alive_time) -------------------------------------------------------------------------------- /docker/payloads.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Dict, Optional, Any, List 3 | 4 | 5 | class ModelPayload(BaseModel): 6 | inputs: Optional[Any] = None 7 | params: Optional[Dict]= {} 8 | 9 | class CompletionPayload(BaseModel): 10 | context: str 11 | temp: Optional[float] = 1.0 12 | top_p: Optional[float] = 0.9 13 | top_k: Optional[int] = 50 14 | length: Optional[int] = 256 15 | 16 | class CompletionResponse(BaseModel): 17 | context: str 18 | completion: str 19 | time: float 20 | 21 | class QueueResponse(BaseModel): 22 | qid: int 23 | 24 | class QueueRequest(BaseModel): 25 | qid: int 26 | -------------------------------------------------------------------------------- /docker/start.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env sh 2 | set -e 3 | 4 | if [ -f /app/app/main.py ]; then 5 | DEFAULT_MODULE_NAME=app.main 6 | elif [ -f /app/main.py ]; then 7 | DEFAULT_MODULE_NAME=main 8 | fi 9 | 10 | MODULE_NAME=${MODULE_NAME:-$DEFAULT_MODULE_NAME} 11 | VARIABLE_NAME=${VARIABLE_NAME:-app} 12 | export APP_MODULE=${APP_MODULE:-"$MODULE_NAME:$VARIABLE_NAME"} 13 | 14 | HOST=${HOST:-0.0.0.0} 15 | PORT=${PORT:-80} 16 | LOG_LEVEL=${LOG_LEVEL:-info} 17 | 18 | # If there's a prestart.sh script in the /app directory or other path specified, run it before starting 19 | PRE_START_PATH=${PRE_START_PATH:-/app/prestart.sh} 20 | echo "Checking for script in $PRE_START_PATH" 21 | if [ -f $PRE_START_PATH ] ; then 22 | echo "Running script $PRE_START_PATH" 23 | . "$PRE_START_PATH" 24 | else 25 | echo "There is no script $PRE_START_PATH" 26 | fi 27 | 28 | # Start Uvicorn with live reload 29 | exec uvicorn --reload --reload-dir /app/mesh-transformer-jax --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" --access-log -------------------------------------------------------------------------------- /eval_harness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from lm_eval import evaluator, tasks 5 | 6 | from mesh_transformer.build_model import build_model 7 | from tasks import EvalHarnessAdaptor 8 | 9 | 10 | def parse_args(): 11 | # Parse command line arguments 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--tpu", type=str, help="Name of TPU to train on.") 14 | parser.add_argument("--tpu_region", type=str, help="Region of TPU to train on.") 15 | parser.add_argument("--preemptible", action="store_true") 16 | 17 | parser.add_argument("--config", type=str, default=None, help="Config file location") 18 | 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == "__main__": 24 | args = parse_args() 25 | params = json.load(open(args.config)) 26 | 27 | tpu_name = args.tpu 28 | region = args.tpu_region 29 | preemptible = args.preemptible 30 | 31 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 32 | per_replica_batch = params["per_replica_batch"] 33 | tpu_size = params["tpu_size"] 34 | cores_per_replica = params["cores_per_replica"] 35 | 36 | bucket = params["bucket"] 37 | model_dir = params["model_dir"] 38 | layers = params["layers"] 39 | d_model = params["d_model"] 40 | n_heads = params["n_heads"] 41 | n_vocab = params["n_vocab"] 42 | seq = params["seq"] 43 | norm = params["norm"] 44 | pe = params["pe"] 45 | 46 | total_batch = per_replica_batch * tpu_size // cores_per_replica * 4 47 | 48 | t = build_model(params, tpu_name, region, preemptible) 49 | adaptor = EvalHarnessAdaptor(t, seq, total_batch, shrink=pe != "fixed") 50 | 51 | step, aux = t.load(bucket, model_dir) 52 | t.move() 53 | 54 | results = evaluator.evaluate(adaptor, tasks.get_task_dict(["lambada", 55 | "piqa", 56 | "hellaswag", 57 | "winogrande", 58 | "mathqa", 59 | "pubmedqa", 60 | # "boolq", 61 | # "cb", 62 | # "copa", 63 | # "multirc", 64 | # "record", 65 | # "wic", 66 | # "wsc", 67 | ]), False, 0, None) 68 | dumped = json.dumps(results, indent=2) 69 | print(dumped) 70 | 71 | results = evaluator.evaluate(adaptor, tasks.get_task_dict(["lambada_cloze", 72 | ]), False, 15, None) 73 | 74 | dumped = json.dumps(results, indent=2) 75 | print(dumped) -------------------------------------------------------------------------------- /gsm/prompts.txt: -------------------------------------------------------------------------------- 1 | Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 2 | A: Natalia sold 48/2 = <<48/2=24>>24 clips in May. 3 | Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. 4 | #### 72 5 | 6 | Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet? 7 | A: In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50. 8 | Betty's grandparents gave her 15 * 2 = $<<15*2=30>>30. 9 | This means, Betty needs 100 - 50 - 30 - 15 = $<<100-50-30-15=5>>5 more. 10 | #### 5 11 | 12 | Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read? 13 | A: Maila read 12 x 2 = <<12*2=24>>24 pages today. 14 | So she was able to read a total of 12 + 24 = <<12+24=36>>36 pages since yesterday. 15 | There are 120 - 36 = <<120-36=84>>84 pages left to be read. 16 | Since she wants to read half of the remaining pages tomorrow, then she should read 84/2 = <<84/2=42>>42 pages. 17 | #### 42 18 | 19 | Q: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden? 20 | A: There are 80/100 * 10 = <<80/100*10=8>>8 more purple flowers than yellow flowers. 21 | So in Mark's garden, there are 10 + 8 = <<10+8=18>>18 purple flowers. 22 | Purple and yellow flowers sum up to 10 + 18 = <<10+18=28>>28 flowers. 23 | That means in Mark's garden there are 25/100 * 28 = <<25/100*28=7>>7 green flowers. 24 | So in total Mark has 28 + 7 = <<28+7=35>>35 plants in his garden. 25 | #### 35 26 | 27 | Q: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes? 28 | A: Let S be the amount Alexis paid for the shoes. 29 | She spent S + 30 + 46 + 38 + 11 + 18 = S + <<+30+46+38+11+18=143>>143. 30 | She used all but $16 of her budget, so S + 143 = 200 - 16 = 184. 31 | Thus, Alexis paid S = 184 - 143 = $<<184-143=41>>41 for the shoes. 32 | #### 41 33 | 34 | Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make? 35 | A: She works 8 hours a day for $18 per hour so she makes 8*18 = $<<8*18=144.00>>144.00 per 8-hour shift 36 | She works 10 hours a day and anything over 8 hours is eligible for overtime, so she gets 10-8 = <<10-8=2>>2 hours of overtime 37 | Overtime is calculated as time and a half so and she makes $18/hour so her overtime pay is 18*.5 = $<<18*.5=9.00>>9.00 38 | Her overtime pay is 18+9 = $<<18+9=27.00>>27.00 39 | Her base pay is $144.00 per 8-hour shift and she works 5 days and makes 5 * $144 = $<<144*5=720.00>>720.00 40 | Her overtime pay is $27.00 per hour and she works 2 hours of overtime per day and makes 27*2 = $<<27*2=54.00>>54.00 in overtime pay 41 | 2 hours of overtime pay for 5 days means she makes 54*5 = $270.00 42 | In 5 days her base pay is $720.00 and she makes $270.00 in overtime pay so she makes $720 + $270 = $<<720+270=990.00>>990.00 43 | #### 990 -------------------------------------------------------------------------------- /gsm/prompts_answer_key.txt: -------------------------------------------------------------------------------- 1 | Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? (72) 2 | A: Natalia sold 48/2 = <<48/2=24>>24 clips in May. 3 | Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. 4 | #### 72 5 | 6 | Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet? (5) 7 | A: In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50. 8 | Betty's grandparents gave her 15 * 2 = $<<15*2=30>>30. 9 | This means, Betty needs 100 - 50 - 30 - 15 = $<<100-50-30-15=5>>5 more. 10 | #### 5 11 | 12 | Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read? (42) 13 | A: Maila read 12 x 2 = <<12*2=24>>24 pages today. 14 | So she was able to read a total of 12 + 24 = <<12+24=36>>36 pages since yesterday. 15 | There are 120 - 36 = <<120-36=84>>84 pages left to be read. 16 | Since she wants to read half of the remaining pages tomorrow, then she should read 84/2 = <<84/2=42>>42 pages. 17 | #### 42 18 | 19 | Q: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden? (35) 20 | A: There are 80/100 * 10 = <<80/100*10=8>>8 more purple flowers than yellow flowers. 21 | So in Mark's garden, there are 10 + 8 = <<10+8=18>>18 purple flowers. 22 | Purple and yellow flowers sum up to 10 + 18 = <<10+18=28>>28 flowers. 23 | That means in Mark's garden there are 25/100 * 28 = <<25/100*28=7>>7 green flowers. 24 | So in total Mark has 28 + 7 = <<28+7=35>>35 plants in his garden. 25 | #### 35 26 | 27 | Q: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes? (41) 28 | A: Let S be the amount Alexis paid for the shoes. 29 | She spent S + 30 + 46 + 38 + 11 + 18 = S + <<+30+46+38+11+18=143>>143. 30 | She used all but $16 of her budget, so S + 143 = 200 - 16 = 184. 31 | Thus, Alexis paid S = 184 - 143 = $<<184-143=41>>41 for the shoes. 32 | #### 41 33 | 34 | Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make? (990) 35 | A: She works 8 hours a day for $18 per hour so she makes 8*18 = $<<8*18=144.00>>144.00 per 8-hour shift 36 | She works 10 hours a day and anything over 8 hours is eligible for overtime, so she gets 10-8 = <<10-8=2>>2 hours of overtime 37 | Overtime is calculated as time and a half so and she makes $18/hour so her overtime pay is 18*.5 = $<<18*.5=9.00>>9.00 38 | Her overtime pay is 18+9 = $<<18+9=27.00>>27.00 39 | Her base pay is $144.00 per 8-hour shift and she works 5 days and makes 5 * $144 = $<<144*5=720.00>>720.00 40 | Her overtime pay is $27.00 per hour and she works 2 hours of overtime per day and makes 27*2 = $<<27*2=54.00>>54.00 in overtime pay 41 | 2 hours of overtime pay for 5 days means she makes 54*5 = $270.00 42 | In 5 days her base pay is $720.00 and she makes $270.00 in overtime pay so she makes $720 + $270 = $<<720+270=990.00>>990.00 43 | #### 990 -------------------------------------------------------------------------------- /gsm/prompts_direct_answer_key.txt: -------------------------------------------------------------------------------- 1 | Q: What do people use to absorb extra ink from a fountain pen? blotter 2 | A: Therefore, the answer is blotter. 3 | 4 | Q: What home entertainment equipment requires cable? television 5 | A: Therefore, the answer is television. 6 | 7 | Q: The fox walked from the city into the forest, what was it looking for? natural habitat 8 | A: Therefore, the answer is natural habitat. 9 | 10 | Q: Sammy wanted to go to where the people were. Where might he go? populated areas 11 | A: Therefore, the answer is populated areas. 12 | 13 | Q: Where do you put your grapes just before checking out? grocery cart 14 | A: Therefore, the answer is grocery cart. 15 | 16 | Q: Google Maps and other highway and street GPS services have replaced what? atlas 17 | A: Therefore, the answer is atlas. 18 | 19 | Q: Before getting a divorce, what did the wife feel who was doing all the work? bitterness 20 | A: Therefore, the answer is bitterness. -------------------------------------------------------------------------------- /howto_finetune.md: -------------------------------------------------------------------------------- 1 | # How to Fine-Tune GPT-J - The Basics 2 | 3 | Before anything else, you'll likely want to apply for access to the TPU Research Cloud (TRC). Combined with a Google Cloud free trial, that should allow you to do everything here for free. Once you're in TRC, you need to create a project, then with the name of the new project fill out the form that was emailed to you. Use the script `create_finetune_tfrecords.py` to prepare your data as tfrecords; I might do a separate guide on that. Another thing you might want to do is fork the mesh-transformer-jax repo to make it easier to add and modify the config files. 4 | 5 | 0. [Install the Google Cloud SDK](https://cloud.google.com/sdk/docs/install). We'll need it later. 6 | 7 | 1. If you didn't make a project and activate TPU access through TRC yet (or if you plan on paying out of pocket), [make one now](https://console.cloud.google.com/projectcreate). 8 | 9 | 2. TPUs use Google Cloud buckets for storage, go ahead and [create one now](https://console.cloud.google.com/storage/create-bucket). Make sure it's in the region the TPU VM will be; the email from TRC will tell you which region(s) you can use free TPUs in. 10 | 11 | 3. You'll need the full pretrained weights in order to fine-tune the model. [Download those here](https://the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd). 12 | 13 | Now that you have a bucket on the cloud and the weights on your PC, you need to upload the weights to the bucket in two steps: 14 | 15 | 4. Decompress and extract `GPT-J-6B/step_383500.tar.zstd` so you're left with the uncompressed folder containing the sharded checkpoint. 16 | 17 | 5. Open the Google Cloud SDK and run the following command, replacing the path names as appropriate: `gsutil -m cp -R LOCAL_PATH_TO/step_383500 gs://YOUR-BUCKET`. If that works, the console will show the files being uploaded. *Note: Took about 12 hours for me, uploading to the Netherlands from California; hopefully you'll have a better geographic situation than I did! I also initially made the mistake of uploading the still-packed .tar. Don't do that, TPU VMs don't have enough local storage for you to unpack it. To avoid needing to re-upload, I had to unpack it in Colab.* 18 | 19 | You'll want to upload tfrecords of your data as well, you can do that here or through the web interface, but trust me when I say you don't want to upload the nearly 70GB weights through the web interface. 20 | 21 | Note that steps 6 and 7, preparing the index and config files, can be done later on by editing the base repo in the VM's text editor. It's more efficient to instead make these changes to your own fork of the repo as follows: 22 | 23 | 6. In the data folder, create a new file `foo.train.index`, replace foo with whatever you want to refer to your dataset as. For each tfrecord in your bucket that you intend to train with, add the path as a line in the index. Make `foo.val.index` and do the same for your validation dataset (if you have one). See the existing files for examples. 24 | 25 | 7. Duplicate the config file `6B_roto_256.json`, rename it to something appropriate for your project. Open it up and make these edits: 26 | - `tpu_size`: Change from `256` to `8` 27 | - `bucket`: Change to your bucket 28 | - `model_dir`: Change to the directory you'd like to save your checkpoints in 29 | - `train_set` and `val_set`: Change to the index files from the last step 30 | - `eval_harness_tasks`: Can be removed if you don't plan on using the eval harness 31 | - `val_every` & `ckpt_every` & `keep_every`: Usage should be intuitive. Don't set the `foo_every` values to 0 though or you'll get a divide by zero error. If you don't have a `val_set`, just set `val_every` to something higher than `total_steps`. 32 | - `val_batches`: This should equal the number of sequences in your val dataset. You can find this number at the end of the .tfrecords file produced by `create_finetune_tfrecords.py` 33 | - `name`: Change to a name for your model. 34 | - `warmup_steps`, `lr`, `val_batches`, etc.: see the *Learning Rate Notes* section at the end of the guide. 35 | 36 | 37 | 8. Push the changes to your GitHub repo. 38 | 39 | 9. Follow [this guide](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) up to and including the step **"Connect to your Cloud TPU VM"**. 40 | 41 | At this point you should have remote access to the TPU VM! 42 | 43 | 10. In the new VM terminal, type `git clone https://github.com/kingoflolz/mesh-transformer-jax` (or, preferably, your own fork, after pushing the config and index files) 44 | 45 | 11. Move to the new directory with `cd mesh-transformer-jax` and run `pip install -r requirements.txt`. Since the requirements.txt file doesn't pin the exact jax version required for finetuning, run `pip install jax==0.2.12` and you'll be all set. 46 | 47 | 12. Finally, run `python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/`. If everything is set up correctly this will begin the fine-tuning process. First the model has to be loaded into memory; when `loading network` displayed on the console it took about 10-15 minutes before the next step, setting up WandB for logging. Option 3 allows you to skip that if you aren't using WandB. A step 1 checkpoint will save, and the real training will start. If you have a small dataset, this will go by quickly; TPU VMs can train at a rate of ~5000 tokens/second. 48 | 49 | 13. You did it! Now don't forget any clean up steps you need to take like shutting down your TPU VM or removing unneeded data in buckets, so that you don't have any unexpected charges from Google later. 50 | 51 | ## Now what? 52 | 53 | This guide is labeled "The Basics", anything we haven't covered so far is out of scope, but go check out the rest of the repository! Try `python3 device_sample.py --config=configs/YOUR_CONFIG.json` for a basic sampling interface. Use `slim_model.py` to prepare an easier-to-deploy slim version of your new weights for inference. Experiment! 54 | 55 | ### Running with HuggingFace 56 | To use the model in HuggingFace's `transformer` library using pytorch, you'll need to transfer the weights 57 | into a format that it recognizes. This can be done using `to_hf_weights.py`. It's recommended that you use `slim_model.py` before attempting to move the weights to a pytorch/transformer format. Use `python to_hf_weights.py --help` to see usage details. 58 | 59 | *note: as of 9/1/2021, GPT-J has been merged into the `main` branch of `transformers` but has not yet been put into production. Run `pip install git+https://github.com/huggingface/transformers#transformers` to install the current `main` branch. 60 | 61 | ## Learning Rate Notes 62 | 63 | **Thanks to nostalgebraist for talking about this!** They're the one who explained this part on Discord, I'm just paraphrasing really: 64 | 65 | The first thing you want to determine is how long a training epoch will be. `gradient_accumulation_steps` is your batch size, it defaults to `16`, nostalgebraist recommends `32`. Your .tfrecord files should have a number in the file name indicating how many sequences are in the dataset. Divide that number by the batch size and the result is how many steps are in an epoch. Now we can write the schedule. 66 | 67 | `lr` is recommended to be between `1e-5` and `5e-5`, with `end_lr` set to 1/5 or 1/10 of `lr`. 68 | `weight_decay` can remain `0.1`. `total_steps` should be at least one epoch, possibly longer if you have a validation 69 | set to determine your training loss with. 70 | `warmup_steps` should be 5-10% of total, and finally `anneal_steps` should be `total_steps - warmup_steps`. 71 | (The `lr` is set to `end_lr` after `warmup_steps+anneal_steps` and then keeps training until `total_steps`, 72 | but usually you should stop after annealing is done) 73 | 74 | To illustrate: I have a small dataset that tokenized into 1147 sequences as a .tfrecord. Dividing by `gradient_accumulation_steps` set to `16`, rounding up to ensure I use all the data, equals 72 steps per epoch. I'll set `lr` to `5e-5`, `end_lr` to a fifth of that, `1e-5`; that may be too much, it's on the high end of the recommended range. I'll set `total_steps` to `72` for one epoch, since I don't have a validation set. Then I'll set `anneal_steps` to `65` and `warmup_steps` to `7`. Simple as that, but you may need to fiddle with the specifics on your own. 75 | -------------------------------------------------------------------------------- /iteration_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import shutil 5 | import argparse 6 | 7 | def record_folder(cur_iter): 8 | return f"{task}/{experiment_name}/{experiment_name}_{cur_iter}" 9 | 10 | 11 | def parse_args(): 12 | # Parse command line arguments 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--no_prompt', action='store_true', help="Whether to remove prompts during eval") 15 | parser.add_argument("--base_epochs", type=float, default=1., help="Epochs for the first iteration") 16 | parser.add_argument("--add_epochs", type=float, default=0.2, help="Epochs to add each iteration") 17 | parser.add_argument("--few_shot_train", action='store_true', help="Whether to use few shot training") 18 | parser.add_argument("--steady_grow", action='store_true', help="Whether to use a fixed number of epochs") 19 | parser.add_argument("--start_steps", type=float, default=40., help="Steps for the first iteration") 20 | parser.add_argument("--exponential_grow", action='store_true', help="Whether to use a fixed number of epochs") 21 | parser.add_argument("--add_steps", type=float, default=20., help="Steps to add each iteration") 22 | parser.add_argument("--grow_steps", type=float, default=1.2, help="Steps to add each iteration") 23 | parser.add_argument("--p_rationalization", type=float, default=1., help="Percent of wrong examples to rationalize") 24 | parser.add_argument("--p_show_hint_save", type=float, default=0., help="Percent of rationalization hints to save") 25 | parser.add_argument('--rationalize', action='store_true', help="Whether to use rationalization") 26 | 27 | parser.add_argument("--start_iter", type=int, default=1, help="Starting iteration") 28 | parser.add_argument("--n_iters", type=int, default=64, help="Upper limit on outer loop iterations") 29 | parser.add_argument("--copy_n", type=int, default=0, help="Number of files to copy each iteration") 30 | parser.add_argument("--n_train_samples", type=int, default=10000, help="Number of training examples") 31 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Batch size") 32 | 33 | parser.add_argument("--task", type=str, default="commonsenseqa", help="Whether to run on arithmetic") 34 | parser.add_argument('--direct', action='store_true', help="Whether to use direct prediction, sans scratchpad") 35 | parser.add_argument("--gen_length", type=int, default=96, help="Length of generated output") 36 | parser.add_argument("--sequence_count", type=int, default=10, help="Sequences per batch on average") 37 | parser.add_argument("--base_model_location", type=str, default="gs://checkpoint-bucket/step_383500/", help="Finetuning ckpt") 38 | parser.add_argument('--dry_run', action='store_true', help="Whether to do a quick run to visualize output") 39 | parser.add_argument('--skip_eval', action='store_true', help="Whether to skip evaluation (e.g. arithmetic)") 40 | 41 | args = parser.parse_args() 42 | return args 43 | 44 | def gen_train(): 45 | train_cmd = f"python3 device_inference.py --config={prev_config} --split=train --gen_length={args.gen_length} --p_show_hint_save={args.p_show_hint_save} " 46 | if task != "commonsenseqa": 47 | train_cmd += f" --dataset_mode={task} " 48 | if args.rationalize: 49 | train_cmd += " --rationalize " 50 | if args.few_shot_train: 51 | train_cmd += " --few_shot_train " 52 | if cur_iter > 1 and args.no_prompt: 53 | train_cmd += f" --no_prompt --eval_seq {eval_seq} " 54 | train_cmd += f" --n_train_samples={args.n_train_samples} " 55 | train_cmd += f" >> result_logs/{experiment_name}.txt" 56 | print(f"Generating training set {cur_iter} using model {cur_iter - 1}: {train_cmd}") 57 | if not args.dry_run and (cur_iter >= args.start_iter): 58 | if (cur_iter == 1) and os.path.exists(record_folder(0) + f"/{experiment_name}_0.txt"): 59 | print("First file cached") 60 | else: 61 | os.system(train_cmd) 62 | 63 | def gen_records(): 64 | gen_cmd = f'python3 create_finetune_tfrecords.py {record_folder(cur_iter - 1)} {record_folder(cur_iter - 1)}' 65 | print(f"Creating records for finetuning {cur_iter}: {gen_cmd}") 66 | if not args.dry_run and (cur_iter >= args.start_iter): 67 | os.system(gen_cmd) 68 | train_set = f"{experiment_name}/{exp_iteration}.index" 69 | with open(f"data/{train_set}", "w") as new_data_file: 70 | new_data_file.write(f"{record_folder(cur_iter - 1)}.tfrecords") 71 | return train_set 72 | 73 | def get_n_steps(): 74 | if args.steady_grow: 75 | return int(args.start_steps + args.add_steps * (cur_iter - 1)) 76 | elif args.exponential_grow: 77 | return int(args.start_steps * (args.grow_steps ** (cur_iter - 1))) 78 | else: 79 | # Count data points 80 | total_count = 0 81 | for cur_file in sorted(os.listdir(record_folder(cur_iter - 1)), key=lambda x: int(x.split('.')[0].split("_")[-1])): 82 | with open(f"{record_folder(cur_iter - 1)}/{cur_file}", encoding='utf-8') as train_file: 83 | train_file_text = train_file.read() 84 | total_count += len(train_file_text.split("\n\n")) 85 | print(len(train_file_text.split("\n\n"))) 86 | train_epochs = args.base_epochs + args.add_epochs * (cur_iter - 1) 87 | cur_steps = int(total_count * train_epochs // (args.gradient_accumulation_steps * args.sequence_count)) 88 | return cur_steps 89 | 90 | def gen_config(train_set): 91 | print(f"Creating new config file {cur_iter}") 92 | config_name = f'configs/{experiment_name}/{exp_iteration}.json' 93 | os.makedirs(record_folder(cur_iter), exist_ok=True) 94 | with open(prev_config, encoding='utf-8') as base_json_file: 95 | new_json = json.load(base_json_file) 96 | new_json["model_dir"] = f"strangeloop/{exp_iteration}" 97 | new_json["train_set"] = train_set 98 | new_json["target_save"] = record_folder(cur_iter) + f"/{exp_iteration}.txt" 99 | new_json["total_steps"] = get_n_steps() 100 | new_json["name"] = exp_iteration 101 | new_json["p_rationalization"] = args.p_rationalization 102 | new_json["gradient_accumulation_steps"] = args.gradient_accumulation_steps 103 | with open(config_name, "w", encoding='utf-8') as new_json_file: 104 | json.dump(new_json, new_json_file, indent=2) 105 | return config_name 106 | 107 | def train_model(): 108 | model_cmd = f"python3 device_train.py --config {config_name} --tune-model-path={args.base_model_location}" 109 | print(f"Train model {cur_iter}: {model_cmd}") 110 | if not args.dry_run and (cur_iter >= args.start_iter): 111 | os.system(model_cmd) 112 | 113 | def eval_model(): 114 | eval_cmd = f"python3 device_inference.py --config={config_name} --split=dev --gen_length={args.gen_length} --p_show_hint_save={args.p_show_hint_save} " 115 | if task != "commonsenseqa": 116 | eval_cmd += f" --dataset_mode={task} " 117 | if args.no_prompt: 118 | eval_cmd += f" --no_prompt --eval_seq {eval_seq} " 119 | if args.few_shot_train: 120 | eval_cmd += " --few_shot_train " 121 | eval_cmd += f" >> result_logs/{experiment_name}.txt" 122 | print(f"Eval model {cur_iter}: {eval_cmd}") 123 | if not args.dry_run and (cur_iter >= args.start_iter) and not args.skip_eval: 124 | os.system(eval_cmd) 125 | 126 | def copy_files(): 127 | all_files = sorted(os.listdir(record_folder(cur_iter - 1)), key=lambda x: int(x.split('.')[0].split("_")[-1])) 128 | relevant_files = all_files[-args.copy_n:] 129 | for cur_file in relevant_files: 130 | shutil.copy(f"{record_folder(cur_iter - 1)}/{cur_file}", record_folder(cur_iter)) 131 | 132 | def make_first_config(): 133 | with open(prev_config, encoding='utf-8') as base_json_file: 134 | new_json = json.load(base_json_file) 135 | os.makedirs(record_folder(0), exist_ok=True) 136 | new_json["target_save"] = record_folder(0) + f"/{experiment_name}_0.txt" 137 | new_json["name"] = f"{experiment_name}_0" 138 | new_json["p_rationalization"] = args.p_rationalization 139 | new_json["gradient_accumulation_steps"] = args.gradient_accumulation_steps 140 | with open(prev_config, "w", encoding='utf-8') as base_json_file: 141 | json.dump(new_json, base_json_file, indent=2) 142 | return new_json 143 | 144 | if __name__ == "__main__": 145 | args = parse_args() 146 | print(args) 147 | task = args.task 148 | experiment_name = "_".join(sys.argv[1:]) 149 | experiment_name = ''.join(ch for ch in experiment_name if ch.isalnum() or ch == "_") 150 | if args.no_prompt: 151 | eval_seq = 128 + args.gen_length 152 | os.makedirs(f"configs/{experiment_name}", exist_ok=True) 153 | shutil.copy(f"configs/qa_base.json", f"configs/{experiment_name}/base.json") 154 | prev_config = f"configs/{experiment_name}/base.json" 155 | new_json = make_first_config() 156 | 157 | os.makedirs(f'data/{experiment_name}', exist_ok=True) 158 | os.makedirs(f'{task}/{experiment_name}', exist_ok=True) 159 | os.makedirs(f'result_logs/', exist_ok=True) 160 | with open(f"result_logs/{experiment_name}.txt", "a+") as f: 161 | print("================================", file=f) 162 | print(args, file=f) 163 | for cur_iter in range(1, args.n_iters): 164 | exp_iteration = f"{experiment_name}_{cur_iter}" 165 | gen_train() # Generate the training set 166 | train_set = gen_records() # Create the tfrecords from the data 167 | config_name = gen_config(train_set) # Create the new configuration file 168 | train_model() # Train the new model 169 | eval_model() # Evaluate the new model 170 | prev_config = config_name # Prepare for next iteration 171 | if args.copy_n > 0: 172 | copy_files() 173 | -------------------------------------------------------------------------------- /mesh_transformer/TPU_cluster.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import time 4 | 5 | import ray 6 | 7 | from typing import Callable 8 | import numpy as np 9 | 10 | from mesh_transformer.train_actor import NetworkRunner 11 | from google.cloud import storage 12 | from smart_open import open 13 | from func_timeout import func_set_timeout 14 | 15 | 16 | class TPUCluster: 17 | @func_set_timeout(1200) 18 | def __init__(self, 19 | mesh_shape, 20 | node_count, 21 | model: Callable, 22 | version=1): 23 | assert ray.is_initialized() # needs a valid ray cluster to start 24 | self.nodes = [] 25 | self.node_count = node_count 26 | self.dp, self.mp = mesh_shape 27 | self.version = version 28 | 29 | start = time.time() 30 | 31 | for i in range(node_count): 32 | self.nodes.append(NetworkRunner.options(max_concurrency=2).remote(mesh_shape, model)) 33 | 34 | for n in self.nodes: 35 | n.run.remote() 36 | 37 | params = [] 38 | for n in self.nodes: 39 | params.append(n.get_params.remote()) 40 | 41 | self.param_count = ray.get(params)[0] 42 | print(f"Ray actors created in {time.time() - start:.06}s") 43 | 44 | @func_set_timeout(600) 45 | def train(self, data): 46 | data_chunks = np.array_split(data, len(self.nodes), axis=1) 47 | 48 | res = [] 49 | for n, d in zip(self.nodes, data_chunks): 50 | res.append(n.train.remote({ 51 | "obs": d[:, :, :-1], 52 | "target": d[:, :, 1:], 53 | })) 54 | 55 | res = ray.get(res) 56 | 57 | loss = [] 58 | last_loss = [] 59 | 60 | for r in res: 61 | loss.append(r[0]) 62 | last_loss.append(r[1]) 63 | 64 | return np.array(loss).mean(), np.array(last_loss).mean() 65 | 66 | @func_set_timeout(600) 67 | def eval(self, data): 68 | if isinstance(data, dict): 69 | data_chunked = [{} for _ in self.nodes] 70 | for k, v in data.items(): 71 | v_chunks = np.array_split(v, len(self.nodes), axis=0) 72 | for idx, v_chunk in enumerate(v_chunks): 73 | data_chunked[idx][k] = v_chunk 74 | 75 | res = [] 76 | for n, d in zip(self.nodes, data_chunked): 77 | res.append(n.eval.remote(d)) 78 | 79 | total = 0 80 | correct = 0 81 | last_correct = 0 82 | 83 | total_last_loss = 0 84 | mask_loss = [] 85 | each_correct = [] 86 | 87 | for input, output in zip(data_chunked, ray.get(res)): 88 | correct_and_valid = np.logical_and(output["correct"], input["eval_mask"]) 89 | 90 | correct_tokens_count = np.sum(correct_and_valid, -1) 91 | valid_tokens_count = np.sum(input["eval_mask"], -1) 92 | 93 | correct_example = np.logical_and(valid_tokens_count == correct_tokens_count, valid_tokens_count > 0) 94 | valid_example = valid_tokens_count > 0 95 | last_correct_example = correct_and_valid[:, -1] 96 | 97 | each_correct += correct_example.tolist() 98 | 99 | total += sum(valid_example) 100 | correct += sum(correct_example) 101 | last_correct += sum(last_correct_example) 102 | total_last_loss += sum(valid_example * output["last_loss"]) 103 | 104 | valid_loss = np.sum(output["all_loss"] * input["eval_mask"], -1) 105 | mask_loss += valid_loss.tolist() 106 | 107 | return { 108 | "total": total, 109 | "correct": correct, 110 | "last_correct": last_correct, 111 | "last_loss": total_last_loss, 112 | "mask_loss": np.array(mask_loss), 113 | "each_correct": np.array(each_correct) 114 | } 115 | else: 116 | data_chunks = np.array_split(data, len(self.nodes), axis=0) 117 | 118 | res = [] 119 | for n, d in zip(self.nodes, data_chunks): 120 | res.append(n.eval.remote({ 121 | "obs": d[:, :-1], 122 | "target": d[:, 1:], 123 | })) 124 | 125 | return np.array([i["loss"] for i in ray.get(res)]).mean() 126 | 127 | @func_set_timeout(600) 128 | def generate(self, context, ctx_length, gen_len): 129 | context = np.array_split(context, len(self.nodes), axis=0) 130 | ctx_length = np.array_split(ctx_length, len(self.nodes), axis=0) 131 | 132 | res = [] 133 | for n, ctx, l in zip(self.nodes, context, ctx_length): 134 | res.append(n.generate.remote(( 135 | ctx, 136 | np.ones(len(ctx), dtype=np.uint32) * l, 137 | gen_len 138 | ))) 139 | 140 | return np.concatenate([i[1][0][:, :, 0] for i in ray.get(res)], axis=0) 141 | 142 | @func_set_timeout(600) 143 | def move(self): 144 | start = time.time() 145 | res = [] 146 | for node in self.nodes: 147 | res.append(node.move_params.remote()) 148 | ray.get(res) 149 | 150 | print(f"Moved weights to TPU in {time.time() - start:.06}s") 151 | 152 | @func_set_timeout(1800) 153 | def load(self, bucket, path): 154 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f: 155 | meta = json.load(f) 156 | 157 | ckpt_step = meta["checkpoints"][-1] 158 | 159 | # do replicated checkpoint reading 160 | start = time.time() 161 | res = [] 162 | for node in self.nodes: 163 | res.append(node.load_ckpt.remote(f"gs://{bucket}/{path}/step_{ckpt_step}/")) 164 | 165 | # make sure they all read from the same checkpoint 166 | step = np.array(ray.get(res)) 167 | assert (step[0] == step).all() 168 | step = int(step[0]) 169 | 170 | print(f"Checkpoint@step{step} restored in {time.time() - start:.06}s") 171 | return step, meta["aux"][str(ckpt_step)] 172 | 173 | @func_set_timeout(600) 174 | def save(self, step, bucket, path, aux=None, init=False, overwrite=False, keep_n=3, delete_old=True): 175 | assert path 176 | client = storage.Client() 177 | 178 | if aux is None: 179 | aux = {} 180 | 181 | if init: 182 | # check existing checkpoint folder does not exist, and delete it if it does 183 | for blob in client.list_blobs(bucket, prefix=f"{path}/"): 184 | assert overwrite 185 | # print(f"deleting {blob.name}") 186 | assert path in blob.name 187 | blob.delete() 188 | 189 | # create metadata file 190 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f: 191 | json.dump({ 192 | "step": 0, 193 | "checkpoints": [], 194 | "aux": {} 195 | }, f) 196 | 197 | # do sharded checkpoint writing 198 | start = time.time() 199 | res = [] 200 | 201 | if self.version == 1: 202 | for shard_id, node in zip(range(self.mp), itertools.cycle(self.nodes)): 203 | res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}/", shard_id)) 204 | elif self.version == 2: 205 | for node in self.nodes: 206 | res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}", 0)) 207 | 208 | ray.get(res) 209 | print(f"Wrote checkpoint in {time.time() - start:.06}s") 210 | 211 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f: 212 | meta = json.load(f) 213 | 214 | meta["step"] = step 215 | meta["checkpoints"].append(step) 216 | all_aux = meta.get("aux", {}) 217 | 218 | while len(meta["checkpoints"]) > keep_n: 219 | ckpt_to_delete = meta["checkpoints"].pop(0) 220 | 221 | try: 222 | del all_aux[str(ckpt_to_delete)] 223 | except: 224 | print(f"failed to delete the aux state for {step}") 225 | 226 | if delete_old: 227 | print(f"deleting checkpoint {ckpt_to_delete}") 228 | for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"): 229 | # print(f"deleting {blob.name}") 230 | assert path in blob.name 231 | blob.delete() 232 | else: 233 | print(f"keeping checkpoint {ckpt_to_delete}") 234 | 235 | all_aux[step] = aux 236 | meta["aux"] = all_aux 237 | 238 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f: 239 | json.dump(meta, f) 240 | -------------------------------------------------------------------------------- /mesh_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezelikman/STaR/6ffd3c4743da6c7a7369834c434359d9f7e0ac38/mesh_transformer/__init__.py -------------------------------------------------------------------------------- /mesh_transformer/build_model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing 3 | 4 | import optax 5 | import ray 6 | 7 | from mesh_transformer import util 8 | from mesh_transformer.TPU_cluster import TPUCluster 9 | from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerV2 10 | from mesh_transformer.util import clip_by_global_norm, additive_weight_decay 11 | from ray_tpu import create_tpu, wait_til, get_connection, start_ray 12 | 13 | 14 | def build_model(params, tpu_name, region, preemptible, version=1): 15 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 16 | cores_per_replica = params["cores_per_replica"] 17 | tpu_size = params["tpu_size"] 18 | 19 | warmup_steps = params["warmup_steps"] 20 | anneal_steps = params["anneal_steps"] 21 | lr = params["lr"] 22 | end_lr = params["end_lr"] 23 | weight_decay = params["weight_decay"] 24 | 25 | assert tpu_size in [8, 32, 128, 256, 512] 26 | 27 | create_tpu(tpu_name, region, f"v3-{tpu_size}", preemptible) 28 | assert wait_til(tpu_name, region, {'state': 'READY', 'health': 'HEALTHY'}) 29 | 30 | conns = get_connection(tpu_name, region) 31 | 32 | assert len(conns) * 8 == tpu_size, "wrong size TPU for config" 33 | 34 | head_info = ray.init(include_dashboard=False, object_store_memory=10**9) 35 | address = head_info['redis_address'] 36 | 37 | with multiprocessing.pool.ThreadPool(processes=len(conns)) as p: 38 | p.map(functools.partial(start_ray, address=address, version=version), conns) 39 | 40 | opt = optax.chain( 41 | optax.scale(1 / gradient_accumulation_steps), 42 | clip_by_global_norm(1, use_psum=(version == 1)), 43 | optax.scale_by_adam(), 44 | additive_weight_decay(weight_decay), 45 | optax.scale(-1), 46 | optax.scale_by_schedule(util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)) 47 | ) 48 | 49 | params["optimizer"] = opt 50 | 51 | if version == 2: 52 | model_fn = functools.partial(CausalTransformerV2, params) 53 | elif version == 1: 54 | model_fn = functools.partial(CausalTransformer, params) 55 | else: 56 | raise Exception(f"Version {version} does not exist") 57 | 58 | t = TPUCluster((tpu_size // cores_per_replica, cores_per_replica), len(conns), model_fn, version=version) 59 | return t 60 | -------------------------------------------------------------------------------- /mesh_transformer/checkpoint.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import json 4 | import time 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import multiprocessing 10 | 11 | import ray 12 | from smart_open import open 13 | 14 | from mesh_transformer.util import head_print 15 | 16 | pieces = 16 # how many files to split each shard across 17 | 18 | 19 | def fix_dtype(pytree): 20 | def fix(x): 21 | if x.dtype == np.dtype('V2'): 22 | x.dtype = jnp.bfloat16 23 | return jnp.asarray(x) 24 | 25 | return jax.tree_map(fix, pytree) 26 | 27 | 28 | @functools.partial(jax.jit, backend="cpu") 29 | def index_weights(weights, idx): 30 | cpu_device = jax.devices("cpu")[0] 31 | return jax.device_put(jax.tree_map(lambda i: i[idx], weights), cpu_device) 32 | 33 | 34 | def write(x, ckpt_dir): 35 | # start = time.time() 36 | idx, i = x 37 | file_path = ckpt_dir + f"{idx}.npz" 38 | for _ in range(3): 39 | try: 40 | with open(file_path, "wb") as f: 41 | np.savez(f, *i) 42 | # cloudpickle.dump(i, f) 43 | # print(f"written {idx} in {time.time() - start:.06}s") 44 | return 45 | except: 46 | print("save failed, trying again") 47 | 48 | print("save failed 3 times, exiting") 49 | raise Exception("save failed") 50 | 51 | 52 | def split(a, n): 53 | k, m = divmod(len(a), n) 54 | return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)) 55 | 56 | 57 | def write_ckpt(pytree, dir, shard): 58 | # ckpt_dir = Path(dir) 59 | # ckpt_dir.mkdir(parents=True, exist_ok=True) 60 | 61 | flattened, structure = jax.tree_flatten(pytree) 62 | 63 | start = time.time() 64 | # cpu_flattened = jax.device_put(flattened, cpu_device) 65 | cpu_flattened = index_weights(flattened, shard) 66 | # print(f"Moved indexed in {time.time() - start:.06}s") 67 | 68 | cpu_flattened_chunked = split(cpu_flattened, pieces) 69 | 70 | # start = time.time() 71 | # cpu_float = move_weights(cpu_flattened) 72 | # print(f"changed weight types in {time.time() - start:.06}s") 73 | 74 | with multiprocessing.pool.ThreadPool(pieces) as p: 75 | write_fn = functools.partial(write, ckpt_dir=f"{dir}shard_{shard}/") 76 | 77 | start = time.time() 78 | list((p.imap_unordered(write_fn, enumerate(cpu_flattened_chunked)))) 79 | # print(f"written to gcs in {time.time() - start:.06}s") 80 | 81 | 82 | def read_shard(ckpt_dir): 83 | out = [] 84 | for idx in range(16): 85 | file_path = ckpt_dir + f"{idx}.npz" 86 | with open(file_path, "rb") as f: 87 | buf = f.read() 88 | f_io = io.BytesIO(buf) 89 | deserialized = np.load(f_io) 90 | for i in deserialized: 91 | out.append(deserialized[i]) 92 | return out 93 | 94 | 95 | def reshard(x, old_shape): 96 | if len(x.shape) == 1: 97 | # print("epoch") 98 | # print(x) 99 | out = x[0:1] 100 | 101 | elif len(x.shape) == 2: 102 | # print(f"LN/bias {x.shape}") 103 | # print(x[:, :16]) 104 | 105 | if (x[1:] == x[-1]).all(): 106 | # print("LN") 107 | if (x[1:] == 0).all() or (x[1:] == 1).all(): 108 | out = x[0:1] 109 | else: 110 | # print("shard bias") 111 | out = x[0:1] * x.shape[0] / old_shape[0] 112 | else: 113 | # print("bias") 114 | out = x.reshape(old_shape) 115 | 116 | print(out[:, :16]) 117 | 118 | elif len(x.shape) == 3: 119 | # print(f"weight {x.shape}") 120 | if x.shape[0] * x.shape[2] == old_shape[2]: 121 | # print("case 1") 122 | out = jnp.transpose(x, (1, 0, 2)).reshape(old_shape) 123 | elif x.shape[0] * x.shape[1] == old_shape[1]: 124 | # print("case 2") 125 | out = x.reshape(old_shape) 126 | else: 127 | raise Exception(f"unimplemented, {x.shape}, {old_shape}") 128 | else: 129 | raise Exception(f"unimplemented, {x}") 130 | 131 | return out 132 | 133 | 134 | def read_ckpt(pytree, dir, shards_in, shards_out=None, load_opt=True): 135 | if shards_out is None: 136 | shards_out = shards_in 137 | 138 | old_flattened, structure = jax.tree_flatten(pytree) 139 | 140 | original_opt_state = pytree["opt_state"] 141 | 142 | # TODO: figure out how to use a process pool here for more speed 143 | with multiprocessing.pool.ThreadPool(shards_in) as p: 144 | start = time.time() 145 | shards = list((p.imap(read_shard, [f"{dir}shard_{i}/" for i in range(shards_in)]))) 146 | print(f"read from disk/gcs in {time.time() - start:.06}s") 147 | 148 | def _unshard(shards, old_flattened): 149 | unsharded = [] 150 | 151 | for old, *all_shards in zip(old_flattened, *shards): 152 | x = np.stack(all_shards) 153 | # No idea why this is V2...? 154 | if x.dtype == np.dtype('V2'): 155 | x.dtype = jnp.bfloat16 156 | 157 | if shards_out != shards_in: 158 | x = reshard(x, old.shape) 159 | unsharded.append(x) 160 | 161 | assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}" 162 | return unsharded 163 | try: 164 | unsharded = _unshard(shards, old_flattened) 165 | except AssertionError: 166 | load_opt = False # no opt to load in ckpt 167 | del pytree['opt_state'] 168 | old_flattened, structure = jax.tree_flatten(pytree) 169 | unsharded = _unshard(shards, old_flattened) 170 | 171 | loaded_pytree = jax.tree_unflatten(structure, unsharded) 172 | 173 | if not load_opt: 174 | loaded_pytree['opt_state'] = original_opt_state 175 | return loaded_pytree 176 | 177 | 178 | def read_ckpt_lowmem(pytree, dir, shards_in, shards_out=None, load_opt=True): 179 | if shards_out is None: 180 | shards_out = shards_in 181 | 182 | old_flattened, structure = jax.tree_flatten(pytree) 183 | 184 | original_opt_state = pytree["opt_state"] 185 | 186 | def _unshard(): 187 | start = time.time() 188 | unsharded = [] 189 | devices = jax.devices() 190 | device_count = len(devices) 191 | device_index = 0 192 | 193 | for file_index in range(pieces): 194 | array_keys = [*np.load(f"{dir}shard_0/{file_index}.npz").keys()] 195 | for array_index in range(len(array_keys)): 196 | unstacked = [] 197 | for shard_index in range(shards_in): 198 | npz = np.load(f"{dir}shard_{shard_index}/{file_index}.npz") 199 | array = npz[array_keys[array_index]] 200 | if array.dtype == 'V2': 201 | array.dtype = jnp.bfloat16 202 | unstacked.append(array) 203 | 204 | x = jax.device_put(jnp.stack(unstacked), device=devices[device_index % device_count]) 205 | 206 | if shards_out != shards_in: 207 | x = reshard(x, old_flattened[device_index].shape) 208 | unsharded.append(x) 209 | 210 | assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}" 211 | device_index += 1 212 | 213 | print(f"read from disk/gcs in {time.time() - start:.06}s") 214 | return unsharded 215 | 216 | try: 217 | unsharded = _unshard() 218 | except AssertionError: 219 | load_opt = False # no opt to load in ckpt 220 | del pytree['opt_state'] 221 | old_flattened, structure = jax.tree_flatten(pytree) 222 | unsharded = _unshard() 223 | 224 | loaded_pytree = jax.tree_unflatten(structure, unsharded) 225 | 226 | if not load_opt: 227 | loaded_pytree['opt_state'] = original_opt_state 228 | return loaded_pytree 229 | 230 | 231 | def parallel_write(arrays, fname): 232 | # TODO: make this actually parallel 233 | with open(fname, "wb") as f: 234 | np.savez(f, *arrays) 235 | 236 | 237 | def parallel_read(old, fname, validate=True): 238 | old_vals, treedef = jax.tree_flatten(old) 239 | 240 | if "gs://" in fname: 241 | # TODO: make this actually parallel 242 | with open(fname, "rb") as f: 243 | buf = f.read() 244 | f_io = io.BytesIO(buf) 245 | loaded = np.load(f_io) 246 | else: 247 | loaded = np.load(fname, mmap_mode='r') 248 | 249 | new_vals = [] 250 | for i in loaded: 251 | new_vals.append(loaded[i]) 252 | 253 | assert len(new_vals) == len(old_vals), "Incompatible checkpoint" 254 | 255 | for o, n in zip(new_vals, old_vals): 256 | if validate: 257 | assert o.shape == n.shape, "Incompatible checkpoint" 258 | 259 | return jax.tree_unflatten(treedef, fix_dtype(new_vals)) 260 | 261 | 262 | def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id): 263 | id_to_name = {} 264 | if getattr(pytree, "items", None): 265 | for k, v in pytree.items(): 266 | k_path = f"{path}/{k}" 267 | if is_leaf(v): 268 | id_to_name[to_id(v)] = k_path 269 | else: 270 | id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path)} 271 | elif getattr(pytree, "__getitem__", None): 272 | for v in pytree: 273 | if is_leaf(v): 274 | id_to_name[to_id(v)] = path 275 | else: 276 | id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=path)} 277 | else: 278 | id_to_name[to_id(pytree)] = path 279 | return id_to_name 280 | 281 | 282 | def tree_leaves_with_names(pytree, to_id=id): 283 | leaves = jax.tree_leaves(pytree) 284 | is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [to_id(x) for x in leaves] 285 | return tree_flatten_with_names(pytree, is_leaf) 286 | 287 | 288 | def write_ckpt_v2(model_state, dir): 289 | start = time.time() 290 | if jax.host_id() == 0: 291 | param_map = tree_leaves_with_names(model_state["params"]) 292 | opt_map = tree_leaves_with_names(model_state["opt_state"]) 293 | 294 | meta = { 295 | "total_hosts": jax.host_count(), 296 | "step": int(model_state["step"]), 297 | "param_order": [param_map[id(i)] for i in jax.tree_leaves(model_state["params"])], 298 | "opt_order": [opt_map[id(i)] for i in jax.tree_leaves(model_state["opt_state"])] 299 | } 300 | 301 | print("step:", model_state["step"]) 302 | with open(dir + "/meta.json", "w") as f: 303 | json.dump(meta, f) 304 | print(f"meta written in {time.time() - start:.06}s") 305 | 306 | start = time.time() 307 | parallel_write(jax.tree_flatten(model_state["params"])[0], dir + f"/params/shard_{jax.host_id()}.npz") 308 | head_print(f"params written in {time.time() - start:.06}s") 309 | 310 | start = time.time() 311 | parallel_write(jax.tree_flatten(model_state["opt_state"])[0], dir + f"/opt_state/shard_{jax.host_id()}.npz") 312 | head_print(f"opt_state written in {time.time() - start:.06}s") 313 | 314 | 315 | def read_sharded_v2(state, dir, checkpoint_hosts, state_shard): 316 | files_per_host = checkpoint_hosts // jax.host_count() 317 | 318 | assert files_per_host >= 1, "can't restore model to larger pod than was trained on (yet)" 319 | assert jax.host_count() * files_per_host == checkpoint_hosts, "weird host count" 320 | 321 | if files_per_host == 1: 322 | head_print("using fast path of checkpoint restore (save shards == read shards)") 323 | parallel_read(state, dir + f"/shard_{jax.host_id()}.npz") 324 | 325 | @ray.remote 326 | def read_remote(old, fname): 327 | return parallel_read(old, fname, validate=False) 328 | 329 | start_idx = files_per_host * jax.host_id() 330 | 331 | skeleton = jax.tree_map(lambda x: jnp.zeros_like(x, shape=()), state) # a full pytree just to carry dtypes 332 | 333 | refs = [ 334 | read_remote.remote(skeleton, f"{dir}/shard_{i}.npz") 335 | for i in range(start_idx, start_idx + files_per_host) 336 | ] 337 | 338 | values = ray.get(refs) 339 | 340 | def all_array_equal(iterator): 341 | try: 342 | iterator = iter(iterator) 343 | first = next(iterator) 344 | return all(jnp.array_equal(first, rest) for rest in iterator) 345 | except StopIteration: 346 | return True 347 | 348 | def reshard_v2(old, shard_strategy, *new_values): 349 | rep_dim_count = shard_strategy.count(None) 350 | total_dim_count = len(shard_strategy) 351 | 352 | # head_print("old.shape", old.shape) 353 | # head_print("shard_strategy", shard_strategy) 354 | 355 | assert len(old.shape) == total_dim_count 356 | 357 | if rep_dim_count == total_dim_count: 358 | # fully replicated 359 | assert all_array_equal(new_values) 360 | return fix_dtype(new_values[0]) 361 | 362 | shard_dim = [idx for idx, dim in enumerate(shard_strategy) if dim is not None and "mp" in dim] 363 | 364 | # only support sharding in 1d for now 365 | assert len(shard_dim) == 1 366 | shard_dim = shard_dim[0] 367 | 368 | ret_val = jnp.concatenate(fix_dtype(new_values), axis=shard_dim) 369 | assert old.shape == ret_val.shape 370 | 371 | return jax.device_put(ret_val, jax.devices("cpu")[0]) 372 | 373 | # head_print("state", jax.tree_structure(state)) 374 | # head_print("state_shard", jax.tree_structure(state_shard)) 375 | # head_print("values", jax.tree_structure(values[0])) 376 | 377 | return jax.tree_multimap(reshard_v2, *([state, state_shard] + values)) 378 | 379 | 380 | def load_ckpt_v2(model_state, dir, state_shard, load_opt): 381 | start = time.time() 382 | with open(dir + "meta.json", "r") as f: 383 | meta = json.load(f) 384 | 385 | ckpt_hosts = meta["total_hosts"] 386 | 387 | head_print(f"meta loaded in {time.time() - start:.06}s") 388 | 389 | new_state = { 390 | "step": np.array([meta["step"]]), 391 | } 392 | 393 | start = time.time() 394 | new_state["params"] = read_sharded_v2(model_state["params"], 395 | dir + "params", 396 | ckpt_hosts, 397 | state_shard["params"]) 398 | head_print(f"params loaded in {time.time() - start:.06}s") 399 | 400 | if not load_opt: 401 | return new_state 402 | 403 | start = time.time() 404 | new_state["opt_state"] = read_sharded_v2(model_state["opt_state"], 405 | dir + "opt_state", 406 | ckpt_hosts, 407 | state_shard["opt_state"]) 408 | head_print(f"opt_state loaded in {time.time() - start:.06}s") 409 | 410 | return new_state 411 | -------------------------------------------------------------------------------- /mesh_transformer/sampling.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | # takes in a logit distribution, softmax and then sample 6 | def softmax_sample(key, logits, _, temp=1): 7 | return jax.random.categorical(key, logits/temp, -1).astype(jnp.uint32), None 8 | 9 | 10 | def nucleaus_filter(logits, top_p=0.9, top_k=None): 11 | sorted_logits = jnp.sort(logits)[:, ::-1] # sort descending 12 | sorted_indices = jnp.argsort(logits)[:, ::-1] 13 | cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1) 14 | 15 | if top_k is not None: 16 | # Keep only top_k tokens 17 | indices_range = jnp.arange(len(sorted_indices[0])) 18 | indices_range = jnp.stack([indices_range] * len(sorted_indices), axis=0) 19 | 20 | sorted_indices_to_remove = jnp.where(indices_range >= top_k, sorted_indices, 0) 21 | 22 | _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove) 23 | 24 | logit_mask = 1e10 * indices_to_remove 25 | 26 | logits -= logit_mask 27 | 28 | # Remove tokens with cumulative probability above a threshold 29 | sorted_indices_to_remove = cumulative_probs > top_p 30 | sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove), axis=-1)[:, :-1] 31 | 32 | _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove) 33 | 34 | logit_mask = 1e10 * indices_to_remove 35 | 36 | logits -= logit_mask 37 | 38 | return logits 39 | 40 | 41 | def nucleaus_sample(key, logits, _, top_p=0.9, temp=1, top_k=None): 42 | logits = nucleaus_filter(logits, top_p, top_k=top_k) 43 | 44 | return softmax_sample(key, logits, None, temp=temp) 45 | 46 | 47 | if __name__ == "__main__": 48 | import numpy as np 49 | logits = np.array([[-2, -1, 0, 0.8, 0, 0.1, 0.3, 0.4, 0.5, 0.6, 0.7, -3]]) 50 | print(nucleaus_filter(logits)) -------------------------------------------------------------------------------- /mesh_transformer/train_actor.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import time 3 | import numpy as np 4 | from queue import Queue 5 | 6 | from mesh_transformer.util import head_print 7 | 8 | 9 | @ray.remote(resources={"tpu": 1}) 10 | class NetworkRunner(object): 11 | def __init__(self, mesh_shape, network_builder): 12 | self.mesh_shape = mesh_shape 13 | self.network_builder = network_builder 14 | 15 | self.input_q = Queue(maxsize=1) 16 | self.output_q = Queue(maxsize=1) 17 | 18 | def run(self): 19 | print(f"jax runtime initialization starting") 20 | import jax 21 | from jax.experimental.maps import thread_resources, ResourceEnv, Mesh 22 | import haiku as hk 23 | # jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True 24 | 25 | thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ()) 26 | 27 | start = time.time() 28 | jax.devices() 29 | 30 | import warnings 31 | warnings.filterwarnings("ignore") 32 | warnings.filterwarnings("ignore", category=ResourceWarning) 33 | 34 | if jax.host_id() == 0: 35 | warnings.filterwarnings("default") 36 | 37 | head_print(f"jax devices: {jax.device_count()}") 38 | head_print(f"jax runtime initialized in {time.time() - start:.06}s") 39 | devices = np.array(jax.devices()).reshape(self.mesh_shape) 40 | 41 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 42 | start = time.time() 43 | network = self.network_builder() 44 | head_print(f"Initialized in {time.time() - start:.06}s") 45 | 46 | while True: 47 | operation, input = self.input_q.get() 48 | if operation == "train": 49 | self.output_q.put(network.train(input)) 50 | elif operation == "eval": 51 | self.output_q.put(network.eval(input)) 52 | elif operation == "generate": 53 | self.output_q.put(network.generate(*input)) 54 | elif operation == "write_ckpt": 55 | path, shard = input 56 | network.write_ckpt(path, shard) 57 | self.output_q.put(None) 58 | elif operation == "load_ckpt": 59 | network.load_ckpt(input) 60 | self.output_q.put(network.state["step"][0]) 61 | elif operation == "get_params": 62 | self.output_q.put(hk.data_structures.tree_size(network.state['params'])) 63 | elif operation == "move_params": 64 | # only needed for inference, otherwise first train step does this 65 | local_shards = max(jax.local_device_count() // self.mesh_shape[1], 1) 66 | 67 | # delete the optimizer states otherwise it OOMs for some reason 68 | # TODO: use ShardedDeviceArray or something to get around this for bigger models 69 | del network.state["opt_state"] 70 | network.state = network.move_xmap(network.state, np.zeros(local_shards)) 71 | self.output_q.put(None) 72 | else: 73 | raise Exception("Not implemented") 74 | 75 | def get_params(self): 76 | self.input_q.put(("get_params", None)) 77 | return self.output_q.get() 78 | 79 | def train(self, sample): 80 | self.input_q.put(("train", sample)) 81 | return self.output_q.get() 82 | 83 | def eval(self, sample): 84 | self.input_q.put(("eval", sample)) 85 | return self.output_q.get() 86 | 87 | def generate(self, input): 88 | self.input_q.put(("generate", input)) 89 | return self.output_q.get() 90 | 91 | def write_ckpt(self, path, shard): 92 | self.input_q.put(("write_ckpt", (path, shard))) 93 | return self.output_q.get() 94 | 95 | def load_ckpt(self, path): 96 | self.input_q.put(("load_ckpt", path)) 97 | return self.output_q.get() 98 | 99 | def move_params(self): 100 | self.input_q.put(("move_params", None)) 101 | return self.output_q.get() 102 | -------------------------------------------------------------------------------- /mesh_transformer/util.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.experimental.pjit import with_sharding_constraint 4 | from optax import AdditiveWeightDecayState, GradientTransformation, OptState 5 | 6 | 7 | # same as with_sharding_constraint but doesn't fail if run outside of pjit/mesh context 8 | def maybe_shard(x, resource): 9 | try: 10 | return with_sharding_constraint(x, resource) 11 | except ValueError as e: 12 | print(e) 13 | return x 14 | 15 | 16 | def gpt3_schedule(warmup_steps, 17 | total_steps, 18 | peak_lr, 19 | end_lr): 20 | def sch(step): 21 | warmup_pct = jnp.clip(step, 0, warmup_steps) / warmup_steps 22 | anneal_pct = jnp.clip(step - warmup_steps, 0, total_steps) / total_steps 23 | 24 | return warmup_pct * peak_lr - (peak_lr - end_lr) * (1 - jnp.cos(jnp.pi * anneal_pct)) / 2 25 | 26 | return sch 27 | 28 | 29 | def global_norm(updates, use_psum=True): 30 | pre_sqrt = sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(updates)]) 31 | if use_psum: 32 | pre_sqrt = jax.lax.psum(pre_sqrt, "shard") 33 | return jnp.sqrt(pre_sqrt) 34 | 35 | 36 | class ClipByGlobalNormState(OptState): 37 | """The `clip_by_global_norm` transformation is stateless.""" 38 | 39 | 40 | def clip_by_global_norm(max_norm, use_psum=True) -> GradientTransformation: 41 | """Clip updates using their global norm. 42 | 43 | References: 44 | [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) 45 | 46 | Args: 47 | max_norm: the maximum global norm for an update. 48 | 49 | Returns: 50 | An (init_fn, update_fn) tuple. 51 | """ 52 | 53 | def init_fn(_): 54 | return ClipByGlobalNormState() 55 | 56 | def update_fn(updates, state, params=None): 57 | del params 58 | g_norm = global_norm(updates, use_psum=use_psum) 59 | trigger = g_norm < max_norm 60 | updates = jax.tree_map( 61 | lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates) 62 | return updates, state 63 | 64 | return GradientTransformation(init_fn, update_fn) 65 | 66 | 67 | def additive_weight_decay(weight_decay: float = 0.0) -> GradientTransformation: 68 | """Add parameter scaled by `weight_decay`, to all parameters with more than one dim (i.e. exclude ln, bias etc) 69 | 70 | Args: 71 | weight_decay: a scalar weight decay rate. 72 | 73 | Returns: 74 | An (init_fn, update_fn) tuple. 75 | """ 76 | 77 | def init_fn(_): 78 | return AdditiveWeightDecayState() 79 | 80 | def update_fn(updates, state, params): 81 | updates = jax.tree_multimap(lambda g, p: g + weight_decay * p * (len(g.shape) > 1), updates, params) 82 | return updates, state 83 | 84 | return GradientTransformation(init_fn, update_fn) 85 | 86 | 87 | def to_f32(t): 88 | return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) 89 | 90 | 91 | def to_bf16(t): 92 | return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) 93 | 94 | 95 | def to_f16(t): 96 | return jax.tree_map(lambda x: x.astype(jnp.float16) if x.dtype == jnp.float32 else x, t) 97 | 98 | 99 | # identity in forward pass, psum in backward 100 | @jax.custom_vjp 101 | def f_psum(x): 102 | return x 103 | 104 | 105 | def f_psum_fwd(x): 106 | return f_psum(x), None 107 | 108 | 109 | def f_psum_bwd(_, g): 110 | return jax.lax.psum(g, "shard"), 111 | 112 | 113 | f_psum.defvjp(f_psum_fwd, f_psum_bwd) 114 | 115 | 116 | # identity in forward pass, pmean in backward 117 | @jax.custom_vjp 118 | def f_pmean(x): 119 | return x 120 | 121 | 122 | def f_pmean_fwd(x): 123 | return f_psum(x), None 124 | 125 | 126 | def f_pmean_bwd(_, g): 127 | return jax.lax.pmean(g, "shard"), 128 | 129 | 130 | f_pmean.defvjp(f_pmean_fwd, f_pmean_bwd) 131 | 132 | 133 | # psum in forward pass, identity in backward 134 | @jax.custom_vjp 135 | def g_psum(x): 136 | return jax.lax.psum(x, "shard") 137 | 138 | 139 | def g_psum_fwd(x): 140 | return g_psum(x), None 141 | 142 | 143 | def g_psum_bwd(_, g): 144 | return g, 145 | 146 | 147 | g_psum.defvjp(g_psum_fwd, g_psum_bwd) 148 | 149 | 150 | def shard_axis(x, axis_size, axis_name): 151 | # in_shape = x.shape 152 | assert x.shape[0] % axis_size == 0 153 | 154 | x = x.reshape((axis_size, -1) + x.shape[1:]) 155 | 156 | x = x[jax.lax.axis_index(axis_name)] 157 | # print("shard out", x.shape, "in", in_shape) 158 | 159 | # assert np.prod(x.shape) * axis_size == np.prod(in_shape) 160 | 161 | return x 162 | 163 | 164 | def unshard_axis(x, axis_name): 165 | # in_shape = x.shape 166 | x = jax.lax.all_gather(x, axis_name) 167 | 168 | x = x.reshape((-1, ) + x.shape[2:]) 169 | 170 | # assert x.shape[-1] == 4096 171 | # print("unshard out", x.shape, "in", in_shape) 172 | return x 173 | 174 | 175 | # print but only on the first node 176 | def head_print(*args, **kwargs): 177 | if jax.host_id() == 0: 178 | print(*args, **kwargs) 179 | 180 | 181 | if __name__ == "__main__": 182 | sch = gpt3_schedule(1_000, 20_000, 1e-4, 1e-5) 183 | 184 | for i in range(150): 185 | i = i * 200 186 | print(i, sch(i)) 187 | -------------------------------------------------------------------------------- /ray_tpu.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import subprocess 4 | import time 5 | 6 | import glob 7 | import requests 8 | from fabric import Connection 9 | 10 | 11 | @functools.lru_cache() 12 | def get_bearer(): 13 | return subprocess.check_output("gcloud auth print-access-token", shell=True).decode("utf-8").strip() 14 | 15 | 16 | @functools.lru_cache() 17 | def get_project(): 18 | return subprocess.check_output("gcloud config list --format 'value(core.project)'", shell=True).decode( 19 | "utf-8").strip() 20 | 21 | 22 | def create_tpu( 23 | name, 24 | zone, 25 | type, 26 | preemptible, 27 | ): 28 | headers = { 29 | 'Authorization': f'Bearer {get_bearer()}', 30 | 'Content-Type': 'application/json', 31 | } 32 | 33 | try: 34 | status = check_tpu(name, zone) 35 | 36 | if status["state"] not in ["CREATING", "READY"]: 37 | print("deleting TPU") 38 | delete_tpu(name, zone) 39 | 40 | while True: 41 | try: 42 | print("deleting check") 43 | print(check_tpu(name, zone)["state"]) 44 | 45 | time.sleep(1) 46 | except: 47 | break 48 | except: 49 | pass 50 | 51 | params = ( 52 | ('node_id', name), 53 | ) 54 | 55 | data = {"accelerator_type": 56 | type, 57 | "runtime_version": 58 | 'v2-alpha', 59 | "network_config": 60 | {"enable_external_ips": True}, 61 | } 62 | 63 | if preemptible: 64 | data["schedulingConfig"] = {"preemptible": True} 65 | 66 | response = requests.post(f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes', 67 | headers=headers, params=params, json=data) 68 | 69 | print(response.json()) 70 | 71 | return response.status_code == 200 72 | 73 | 74 | def check_tpu(name, zone): 75 | headers = { 76 | 'Authorization': f'Bearer {get_bearer()}', 77 | } 78 | 79 | response = requests.get( 80 | f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}', 81 | headers=headers) 82 | 83 | return response.json() 84 | 85 | 86 | def delete_tpu(name, zone): 87 | headers = { 88 | 'Authorization': f'Bearer {get_bearer()}', 89 | } 90 | 91 | response = requests.delete( 92 | f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}', 93 | headers=headers) 94 | 95 | return response.json() 96 | 97 | 98 | def wait_til(name, zone, state): 99 | while True: 100 | ret = check_tpu(name, zone) 101 | 102 | print("wait_til check") 103 | print(ret) 104 | 105 | matches = True 106 | for k, expected_v in state.items(): 107 | if k not in ret: 108 | matches = False 109 | continue 110 | if ret[k] != expected_v: 111 | matches = False 112 | 113 | if "error" in ret: 114 | return False 115 | 116 | if ret["state"] == "TERMINATED": 117 | return False 118 | 119 | if matches: 120 | return True 121 | 122 | time.sleep(1) 123 | 124 | 125 | def get_connection( 126 | name, 127 | zone, 128 | ): 129 | info = check_tpu(name, zone) 130 | outputs = [] 131 | for i in info["networkEndpoints"]: 132 | outputs.append(Connection(i["ipAddress"], 133 | connect_kwargs={ 134 | "key_filename": os.path.expanduser('~/.ssh/google_compute_engine'), })) 135 | return outputs 136 | 137 | 138 | def start_ray(conn, address, version=1): 139 | conn.sudo('rm -rf *.py') 140 | conn.sudo('rm -rf mesh_transformer') 141 | 142 | for i in glob.glob("*.py"): 143 | conn.put(i, "") 144 | 145 | conn.run("mkdir mesh_transformer -p") 146 | 147 | for i in glob.glob("mesh_transformer/*.py"): 148 | conn.put(i, "mesh_transformer/") 149 | 150 | conn.sudo('python3 setup.py install', hide=True) 151 | 152 | if version == 2: 153 | conn.put("scripts/init_ray_v2.sh", "/tmp/ray-tpu.sh") 154 | else: 155 | conn.put("scripts/init_ray.sh", "/tmp/ray-tpu.sh") 156 | conn.sudo('chmod +x /tmp/ray-tpu.sh', hide=True) 157 | conn.sudo('/tmp/ray-tpu.sh', hide=True) 158 | try: 159 | conn.run('ray stop -f', hide=True) 160 | except: 161 | pass 162 | 163 | time.sleep(1) 164 | 165 | conn.run(f"TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD={32 * 1024**3} ray start --address={address} --resources='" + '{"tpu": 1}\' --include-dashboard False', hide=True) 166 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.19.5 2 | tqdm~=4.62.0 3 | wandb>=0.11.2 4 | einops~=0.3.0 5 | requests~=2.25.1 6 | fabric~=2.6.0 7 | optax==0.0.9 8 | dm-haiku==0.0.5 9 | git+https://github.com/EleutherAI/lm-evaluation-harness/ 10 | ray[default]==1.4.1 11 | jax~=0.2.12 12 | Flask~=1.1.2 13 | cloudpickle~=1.3.0 14 | tensorflow-cpu~=2.6.0 15 | google-cloud-storage~=1.36.2 16 | transformers 17 | smart_open[gcs] 18 | func_timeout 19 | ftfy 20 | fastapi 21 | lm_dataformat 22 | pathy 23 | -------------------------------------------------------------------------------- /resharding_example.py: -------------------------------------------------------------------------------- 1 | # This was tested with an RTX 3090, peak memory usage is approximately 22.4GB during inference, and 19GB when loading the model 2 | # The following environment variables were also used: XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform 3 | 4 | import time 5 | 6 | import jax 7 | from jax.experimental import maps 8 | import numpy as np 9 | import optax 10 | import transformers 11 | 12 | from mesh_transformer.checkpoint import read_ckpt 13 | from mesh_transformer.sampling import nucleaus_sample 14 | from mesh_transformer.transformer_shard import CausalTransformer 15 | 16 | params = { 17 | "layers": 28, 18 | "d_model": 4096, 19 | "n_heads": 16, 20 | "n_vocab": 50400, 21 | "norm": "layernorm", 22 | "pe": "rotary", 23 | "pe_rotary_dims": 64, 24 | "early_cast": True, 25 | "seq": 2048, 26 | "cores_per_replica": 1, # only running on one GPU 27 | "per_replica_batch": 1, 28 | } 29 | 30 | per_replica_batch = params["per_replica_batch"] 31 | cores_per_replica = params["cores_per_replica"] 32 | seq = params["seq"] 33 | 34 | 35 | params["sampler"] = nucleaus_sample 36 | 37 | # here we "remove" the optimizer parameters from the model (as we don't need them for inference) 38 | params["optimizer"] = optax.scale(0) 39 | 40 | devices = np.array([jax.devices()[0]]).reshape((1, 1)) 41 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) 42 | 43 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 44 | 45 | network = CausalTransformer(params) 46 | 47 | start = time.time() 48 | 49 | # here we load a checkpoint which was written with 8 shards into 1 shard 50 | network.state = read_ckpt(network.state, "step_383500/", 8, shards_out=cores_per_replica) 51 | 52 | # move the state to CPU/system memory so it's not duplicated by xmap 53 | network.state = jax.device_put(network.state, jax.devices("cpu")[0]) 54 | 55 | def infer(context, top_k=40, top_p=0.9, temp=1.0, gen_len=512): 56 | tokens = tokenizer.encode(context) 57 | 58 | provided_ctx = len(tokens) 59 | pad_amount = seq - provided_ctx 60 | 61 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32) 62 | batched_tokens = np.array([padded_tokens] * per_replica_batch) 63 | length = np.ones(per_replica_batch, dtype=np.uint32) * len(tokens) 64 | 65 | start = time.time() 66 | output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp}) 67 | 68 | samples = [] 69 | decoded_tokens = output[1][0] 70 | 71 | for o in decoded_tokens[:, :, 0]: 72 | samples.append(tokenizer.decode(o)) 73 | 74 | print(f"completion done in {time.time() - start:06}s") 75 | return samples 76 | 77 | 78 | infer("EleutherAI is") 79 | -------------------------------------------------------------------------------- /scripts/create_serve_tpu.sh: -------------------------------------------------------------------------------- 1 | gcloud alpha compute tpus tpu-vm create "$1" --zone europe-west4-a --accelerator-type v3-8 --version v2-alpha 2 | sleep 120 3 | gcloud alpha compute tpus tpu-vm ssh "$1" --zone europe-west4-a --command connor@sparse:~/kindiana/mesh-transformer-jax --command="$(< scripts/deploy_server.sh)" -------------------------------------------------------------------------------- /scripts/deploy_server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | rm -r mesh-transformer-jax || true 5 | 6 | git clone https://github.com/kingoflolz/mesh-transformer-jax 7 | pip install -r mesh-transformer-jax/requirements.txt 8 | pip install mesh-transformer-jax/ jax==0.2.12 9 | 10 | pushd mesh-transformer-jax || exit 11 | screen -d -m python3 device_serve.py --config configs/6B_roto_256.json -------------------------------------------------------------------------------- /scripts/init_ray.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | sudo /usr/bin/docker-credential-gcr configure-docker 5 | 6 | sudo docker rm libtpu || true 7 | sudo docker create --name libtpu gcr.io/cloud-tpu-v2-images/libtpu:libtpu_20210518_RC00 "/bin/bash" && sudo docker cp libtpu:libtpu.so /lib 8 | 9 | # this locks the python executable down to hopefully stop if from being fiddled with... 10 | screen -d -m python -c 'import time; time.sleep(999999999)' 11 | 12 | # initializes jax and installs ray on cloud TPUs 13 | sudo pip install --upgrade jaxlib==0.1.67 jax==0.2.12 ray[default]==1.5.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout -------------------------------------------------------------------------------- /scripts/init_ray_v2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # this locks the python executable down to hopefully stop if from being fiddled with... 5 | screen -d -m python -c 'import time; time.sleep(999999999)' 6 | 7 | # initializes jax and installs ray on cloud TPUs 8 | sudo pip install "jax[tpu]>=0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 9 | sudo pip install --upgrade ray[default]==1.5.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout -------------------------------------------------------------------------------- /scripts/init_serve.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # initializes jax and installs ray on cloud TPUs 5 | sudo pip install --upgrade jaxlib jax==0.2.12 ray==1.2.0 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout transformers flask -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='mesh_transformer', 5 | version='0.0.0', 6 | packages=find_packages(include=['mesh_transformer', 'mesh_transformer.*']) 7 | ) 8 | -------------------------------------------------------------------------------- /slim_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | import jax 6 | import numpy as np 7 | import optax 8 | 9 | from mesh_transformer import util 10 | from mesh_transformer.checkpoint import read_ckpt, write_ckpt 11 | from mesh_transformer.transformer_shard import CausalTransformer 12 | from smart_open import open 13 | 14 | from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16 15 | 16 | 17 | def parse_args(): 18 | # Parse command line arguments 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--config", type=str, default=None, help="Config file location") 21 | parser.add_argument("--ckpt-step", type=int, default=-1, help="Step number of the checkpoint to convert (if not specified, converts the most recent checkpoint)") 22 | parser.add_argument("--f16", default=False, action="store_true", help="Convert to float16 (instead of bfloat16)") 23 | 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | if __name__ == "__main__": 29 | args = parse_args() 30 | params = json.load(open(args.config)) 31 | convert_fn = to_f16 if args.f16 else to_bf16 32 | 33 | cores_per_replica = params["cores_per_replica"] 34 | 35 | assert cores_per_replica <= 8 36 | 37 | bucket = params["bucket"] 38 | model_dir = params["model_dir"] 39 | 40 | params["optimizer"] = optax.chain( 41 | optax.scale(1), 42 | clip_by_global_norm(1), 43 | optax.scale_by_adam(), 44 | optax.additive_weight_decay(0), 45 | optax.scale(-1), 46 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)) 47 | ) 48 | 49 | start = time.time() 50 | print(f"jax devices: {jax.device_count()}") 51 | print(f"jax runtime initialized in {time.time() - start:.06}s") 52 | 53 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) 54 | devices = np.array(jax.devices()).reshape(mesh_shape) 55 | 56 | with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f: 57 | meta = json.load(f) 58 | 59 | if args.ckpt_step > -1: 60 | ckpt_step = args.ckpt_step 61 | else: 62 | ckpt_step = meta["checkpoints"][-1] 63 | print(f"using checkpoint {ckpt_step}") 64 | 65 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 66 | network = CausalTransformer(params) 67 | 68 | start = time.time() 69 | network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1]) 70 | print(f"network loaded in {time.time() - start:.06}s") 71 | 72 | start = time.time() 73 | del network.state["opt_state"] 74 | 75 | network.state["params"] = convert_fn(network.state["params"]) 76 | print(f"network converted in {time.time() - start:.06}s") 77 | 78 | suffix = "_slim_f16" if args.f16 else "_slim" 79 | 80 | for i in range(cores_per_replica): 81 | write_ckpt(network.state, f"gs://{bucket}/{model_dir}{suffix}/step_{ckpt_step}/", i) 82 | print(f"written shard {i}") 83 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from tasks.eval_harness import EvalHarnessAdaptor -------------------------------------------------------------------------------- /tasks/eval_harness.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import transformers 4 | from lm_eval.base import LM 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | from tasks.util import sample_batch, shrink_seq 9 | import multiprocessing 10 | import ftfy 11 | 12 | tokenizer = None 13 | 14 | 15 | def process_init(): 16 | global tokenizer 17 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 18 | tokenizer.model_max_length = int(1e30) 19 | tokenizer.pad_token = "<|endoftext|>" 20 | 21 | assert tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] 22 | 23 | 24 | def process_request(x, seq): 25 | global tokenizer 26 | 27 | ctx, cont = x 28 | 29 | ctx_tokens = tokenizer.encode("<|endoftext|>" + ftfy.fix_text(ctx, normalization="NFKC")) 30 | cont_tokens = tokenizer.encode(ftfy.fix_text(cont, normalization="NFKC")) 31 | 32 | all_tokens = ctx_tokens + cont_tokens 33 | all_tokens = np.array(all_tokens)[-seq:] # truncate sequence at seq length 34 | 35 | provided_ctx = len(all_tokens) - 1 36 | pad_amount = seq - provided_ctx 37 | 38 | return { 39 | "obs": np.pad(all_tokens[:-1], ((0, pad_amount),), constant_values=50256), 40 | "target": np.pad(all_tokens[1:], ((0, pad_amount),), constant_values=50256), 41 | "ctx_length": seq, 42 | "eval_mask": np.logical_and( 43 | np.arange(0, seq) > len(all_tokens) - len(cont_tokens) - 2, 44 | np.arange(0, seq) < len(all_tokens) - 1 45 | ), 46 | } 47 | 48 | 49 | class EvalHarnessAdaptor(LM): 50 | def greedy_until(self, requests): 51 | raise Exception("unimplemented") 52 | 53 | def loglikelihood_rolling(self, requests): 54 | raise Exception("unimplemented") 55 | 56 | def __init__(self, tpu_cluster, seq, batch, shrink, min_seq=None): 57 | super().__init__() 58 | self.tpu = tpu_cluster 59 | self.seq = seq 60 | self.batch = batch 61 | self.shrink = shrink 62 | self.min_seq = min_seq 63 | 64 | self.pool = multiprocessing.Pool(initializer=process_init) 65 | process_init() 66 | 67 | def convert_requests(self, requests): 68 | return self.pool.imap(partial(process_request, seq=self.seq), requests) 69 | 70 | def loglikelihood(self, requests): 71 | output = [] 72 | 73 | r = self.convert_requests(requests) 74 | zero_example = process_request(requests[0], self.seq) 75 | 76 | for b in tqdm(sample_batch(r, self.batch, zero_example), 77 | desc="LM eval harness", 78 | total=len(requests) // self.batch): 79 | 80 | if self.shrink: 81 | b = shrink_seq(b, min_seq=self.min_seq) 82 | 83 | out = self.tpu.eval(b) 84 | 85 | for loss, correct in zip(out["mask_loss"], out["each_correct"]): 86 | output.append((float(-loss), bool(correct))) 87 | 88 | return output 89 | -------------------------------------------------------------------------------- /tasks/util.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | 3 | import numpy as np 4 | 5 | 6 | def grouper(n, iterable, fillvalue): 7 | "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx" 8 | args = [iter(iterable)] * n 9 | return zip_longest(fillvalue=fillvalue, *args) 10 | 11 | 12 | # divide the seq length by 2 until it would truncate actual context 13 | def shrink_seq(examples, min_seq=None): 14 | length = examples["obs"].shape[-1] 15 | 16 | new_length = length // 2 17 | 18 | if min_seq is not None: 19 | if new_length < min_seq: 20 | return examples 21 | 22 | max_length = np.max(examples["eval_mask"] * np.arange(0, length)) + 1 23 | 24 | if max_length < new_length: 25 | examples["obs"] = examples["obs"][:, :new_length] 26 | examples["target"] = examples["target"][:, :new_length] 27 | examples["eval_mask"] = examples["eval_mask"][:, :new_length] 28 | 29 | return shrink_seq(examples, min_seq=min_seq) 30 | else: 31 | return examples 32 | 33 | 34 | def sample_batch(examples, bs, zero_example_shape): 35 | zero_example = { 36 | "obs": np.zeros_like(zero_example_shape["obs"]), 37 | "target": np.zeros_like(zero_example_shape["target"]), 38 | "eval_mask": np.zeros_like(zero_example_shape["eval_mask"]), 39 | "ctx_length": 0, 40 | } 41 | 42 | for batch in grouper(bs, examples, zero_example): 43 | batch_flattened = { 44 | "obs": [], 45 | "target": [], 46 | "eval_mask": [], 47 | "ctx_length": [], 48 | } 49 | 50 | for sample in batch: 51 | batch_flattened["obs"].append(sample["obs"]) 52 | batch_flattened["target"].append(sample["target"]) 53 | batch_flattened["eval_mask"].append(sample["eval_mask"]) 54 | batch_flattened["ctx_length"].append(sample["ctx_length"]) 55 | 56 | batch_flattened["obs"] = np.array(batch_flattened["obs"]) 57 | batch_flattened["target"] = np.array(batch_flattened["target"]) 58 | batch_flattened["eval_mask"] = np.array(batch_flattened["eval_mask"]) 59 | batch_flattened["ctx_length"] = np.array(batch_flattened["ctx_length"]) 60 | 61 | yield batch_flattened 62 | -------------------------------------------------------------------------------- /tfrecord_loader.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import tensorflow as tf 3 | import numpy as np 4 | from transformers import GPT2TokenizerFast 5 | import itertools 6 | 7 | 8 | class TFRecordLoader: 9 | def __init__(self, index_fname, batch_size, parse_fn, map_fn=None, restore_state=None): 10 | if restore_state is not None: 11 | self.file_idx = restore_state["file_idx"] 12 | self.file_idx_init = False 13 | self.used = restore_state["used"] 14 | else: 15 | self.file_idx = 0 16 | self.file_idx_init = True 17 | self.used = [] 18 | 19 | self.index = open(index_fname).read().splitlines() 20 | self.clean_index = list(filter(lambda x: x not in self.used, self.index)) 21 | self.bs = batch_size 22 | # self.seq = sample_size 23 | self.parse_fn = parse_fn 24 | 25 | if map_fn: 26 | self.map_fn = map_fn 27 | else: 28 | self.map_fn = lambda x: x 29 | 30 | self.sample_fn = self.sample_once() 31 | 32 | def reset(self): 33 | self.file_idx = 0 34 | self.file_idx_init = True 35 | self.used = [] 36 | 37 | self.clean_index = list(filter(lambda x: x not in self.used, self.index)) 38 | self.sample_fn = self.sample_once() 39 | 40 | def sample_once(self): 41 | for i in self.clean_index: 42 | compression = "ZLIB" if "zstd" in i else "" 43 | 44 | file = tf.data.TFRecordDataset(i, compression_type=compression).map(self.parse_fn, num_parallel_calls=tf.data.AUTOTUNE) 45 | file = file.apply(tf.data.experimental.dense_to_ragged_batch(np.prod(self.bs), drop_remainder=True)) 46 | file = file.prefetch(10) 47 | 48 | for file_idx, data in enumerate(file): 49 | data = jax.tree_map(lambda x: x.numpy(), data) 50 | data = self.map_fn(data) 51 | 52 | if not self.file_idx_init and file_idx <= self.file_idx: 53 | if file_idx % 1000 == 0: 54 | print(f"skipping to batch {self.file_idx}, currently at {file_idx}") 55 | continue 56 | self.file_idx_init = True 57 | self.file_idx = file_idx 58 | yield jax.tree_map(lambda x: x.reshape(self.bs + x.shape[1:]), data) 59 | self.used.append(i) 60 | self.file_idx = 0 61 | 62 | # this loops infinitely, use .sample_once to get an iterator for validation 63 | def get_samples(self): 64 | try: 65 | return next(self.sample_fn) 66 | except StopIteration: 67 | self.reset() 68 | return self.get_samples() 69 | 70 | def get_state(self): 71 | return { 72 | "used": self.used, 73 | "file_idx": self.file_idx 74 | } 75 | 76 | 77 | class TFRecordNewInputs(TFRecordLoader): 78 | def __init__(self, index_fname, batch_size, sample_size, restore_state=None): 79 | def tf_parse(example_proto): 80 | features = { 81 | "text": tf.io.VarLenFeature(tf.int64) 82 | } 83 | parsed_features = tf.io.parse_single_example(example_proto, features) 84 | 85 | return tf.cast(tf.sparse.to_dense(tf.sparse.reorder(parsed_features["text"])), tf.uint32) 86 | 87 | super().__init__(index_fname, batch_size, tf_parse, restore_state=restore_state) 88 | 89 | 90 | class TFRecordWIT(TFRecordLoader): 91 | def __init__(self, index_fname, batch_size, restore_state=None, text_tokens=256): 92 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 93 | self.tokenizer.pad_token = "<|endoftext|>" 94 | self.tokenizer.add_special_tokens({'sep_token': '<|sep|>', 'pad_token': '<|pad|>'}) 95 | 96 | def map_fn(example): 97 | tokenizer = self.tokenizer 98 | 99 | def decode(x): 100 | return tokenizer(["<|endoftext|>" + i.decode() for i in x])["input_ids"] 101 | 102 | texts = [ 103 | decode(example["context_page_description"]), 104 | decode(example["context_section_description"]), 105 | decode(example["caption_reference_description"]), 106 | decode(example["caption_alt_text_description"]), 107 | decode(example["caption_attribution_description"]), 108 | ] 109 | 110 | output = [] 111 | 112 | for text, dalle in zip(zip(*texts), example["dalle"]): 113 | all_text = list(itertools.chain(*text))[-text_tokens+1:] 114 | 115 | all_text += [tokenizer.pad_token_id] * ((text_tokens - 1) - len(all_text)) 116 | 117 | assert len(all_text) == text_tokens - 1 118 | 119 | all_tokens = all_text + [tokenizer.sep_token_id] + list(dalle + tokenizer.vocab_size + 1) 120 | output.append(all_tokens) 121 | 122 | return np.array(output) 123 | 124 | def tf_parse(example_proto): 125 | features = { 126 | "page_title": tf.io.FixedLenFeature([], tf.string), 127 | "section_title": tf.io.FixedLenFeature([], tf.string), 128 | "hierarchical_section_title": tf.io.FixedLenFeature([], tf.string), 129 | "caption_reference_description": tf.io.FixedLenFeature([], tf.string), 130 | "caption_attribution_description": tf.io.FixedLenFeature([], tf.string), 131 | "caption_alt_text_description": tf.io.FixedLenFeature([], tf.string), 132 | "mime_type": tf.io.FixedLenFeature([], tf.string), 133 | "context_page_description": tf.io.FixedLenFeature([], tf.string), 134 | "context_section_description": tf.io.FixedLenFeature([], tf.string), 135 | 136 | "dalle": tf.io.FixedLenFeature([1024], tf.int64), 137 | } 138 | 139 | parsed_features = tf.io.parse_single_example(example_proto, features) 140 | 141 | return parsed_features 142 | 143 | super().__init__(index_fname, batch_size, tf_parse, map_fn, restore_state=restore_state) 144 | 145 | 146 | if __name__ == "__main__": 147 | # d = TFRecordNewInputs("data/pile.val.index", (8, 32), 2048) 148 | # for idx, i in enumerate(d.sample_once()): 149 | # print(i) 150 | # break 151 | 152 | d = TFRecordWIT("data/wit_dalle.train.index", (8, 32)) 153 | for idx, i in enumerate(d.sample_once()): 154 | print(i) 155 | break 156 | 157 | print() 158 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | import numpy as np 6 | import wandb 7 | from tqdm import tqdm 8 | 9 | from mesh_transformer.build_model import build_model 10 | from lm_eval import evaluator, tasks 11 | from tasks.eval_harness import EvalHarnessAdaptor 12 | from tfrecord_loader import TFRecordNewInputs 13 | import multiprocessing 14 | 15 | 16 | def parse_args(): 17 | # Parse command line arguments 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--tpu", type=str, help="Name of TPU to train on.") 20 | parser.add_argument("--tpu_region", type=str, help="Region of TPU to train on.") 21 | parser.add_argument("--preemptible", action="store_true") 22 | 23 | parser.add_argument("--config", type=str, default=None, help="Config file location") 24 | 25 | parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " 26 | "starts a new training run") 27 | 28 | parser.add_argument("--version", type=int, default=1, help="Choose which model version to use") 29 | 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | if __name__ == "__main__": 35 | # huggingface tokenizers gets very angry if you fork 36 | multiprocessing.set_start_method("spawn") 37 | 38 | args = parse_args() 39 | params = json.load(open(args.config)) 40 | 41 | if args.new: 42 | print(f"Starting experiment {params['name']} from scratch! " 43 | f"all data in gs://{params['bucket']}/{params['model_dir']}/ will be deleted") 44 | input("Hit enter to continue") 45 | 46 | tpu_name = args.tpu 47 | region = args.tpu_region 48 | preemptible = args.preemptible 49 | clean_start = args.new 50 | 51 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1) 52 | per_replica_batch = params["per_replica_batch"] 53 | tpu_size = params["tpu_size"] 54 | cores_per_replica = params["cores_per_replica"] 55 | 56 | bucket = params["bucket"] 57 | model_dir = params["model_dir"] 58 | layers = params["layers"] 59 | d_model = params["d_model"] 60 | n_heads = params["n_heads"] 61 | n_vocab = params["n_vocab"] 62 | seq = params["seq"] 63 | norm = params["norm"] 64 | 65 | val_batches = params["val_batches"] 66 | val_every = params["val_every"] 67 | ckpt_every = params["ckpt_every"] 68 | keep_every = params["keep_every"] 69 | eval_tasks = params["eval_harness_tasks"] 70 | total_steps = params["total_steps"] 71 | 72 | pe = params["pe"] 73 | assert pe in ["fixed", "rotary", "t5"] 74 | 75 | t = build_model(params, tpu_name, region, preemptible, version=args.version) 76 | 77 | try: 78 | t.save(0, bucket, model_dir, init=True, overwrite=clean_start) 79 | step = 0 80 | train_load_restore = None 81 | except Exception as e: 82 | print(f"Save failed with error {e}, trying to load instead...", e) 83 | step, aux = t.load(bucket, model_dir) 84 | train_load_restore = aux.get("train_loader", None) 85 | 86 | if train_load_restore is None: 87 | print("Failed to restore train loader state") 88 | 89 | train_dataset = TFRecordNewInputs(f"data/{params['train_set']}", 90 | batch_size=( 91 | gradient_accumulation_steps, 92 | per_replica_batch * tpu_size // cores_per_replica), 93 | sample_size=params['seq'], 94 | restore_state=train_load_restore) 95 | 96 | global_val_batch = int(per_replica_batch * tpu_size // cores_per_replica * params.get("val_batch_multiplier", 1)) 97 | 98 | val_sets = {} 99 | 100 | for k, v in params['val_set'].items(): 101 | val_sets[k] = TFRecordNewInputs(f"data/{v}", 102 | batch_size=(global_val_batch,), 103 | sample_size=seq) 104 | 105 | # use dynamic seq length unless pe is fixed 106 | adaptor = EvalHarnessAdaptor(t, 107 | seq, 108 | global_val_batch, 109 | shrink=pe != "fixed", 110 | min_seq=1024 if args.version == 2 else None) # work around suboptimal pjit layout 111 | 112 | start = time.time() 113 | t.train(train_dataset.get_samples()) 114 | print(f"Train fn compiled in {time.time() - start:.06}s") 115 | 116 | start = time.time() 117 | for val_set in val_sets.values(): 118 | t.eval(val_set.get_samples()) 119 | print(f"Eval fn compiled in {time.time() - start:.06}s") 120 | 121 | project = params.get("wandb_project", "mesh-transformer-jax") 122 | wandb.init(project=project, entity="eleutherai", name=params["name"], config=params) 123 | 124 | eval_task_dict = tasks.get_task_dict(eval_tasks) 125 | 126 | pbar = tqdm(initial=step, total=total_steps, desc="Training progress") 127 | 128 | while True: 129 | loss, last_loss = t.train(train_dataset.get_samples()) 130 | wandb.log({'train/loss': loss, 'train/last_loss': last_loss}, step) 131 | 132 | if (step % ckpt_every == 0 and step) or step == total_steps: 133 | t.save(step, bucket, model_dir, 134 | aux={"train_loader": train_dataset.get_state()}, 135 | init=False, 136 | delete_old=step % keep_every != 0) 137 | 138 | if step == total_steps: 139 | print("training completed!") 140 | exit() 141 | 142 | if step % val_every == 0: 143 | for name, val_set in val_sets.items(): 144 | val_loss = [] 145 | for i, _ in tqdm(zip(val_set.sample_once(), range(val_batches)), 146 | desc=f"validation for step {step}, set {name}", 147 | total=val_batches): 148 | val_loss.append(t.eval(i)) 149 | val_loss = np.array(val_loss).mean() 150 | print(f"validation loss for step {step}, set {name}: {val_loss}") 151 | 152 | wandb.log({f'val/loss_{name}': float(val_loss)}, step) 153 | 154 | results = evaluator.evaluate(adaptor, eval_task_dict, False, 0, None) 155 | 156 | flat_results = {} 157 | 158 | for task_name, task_res in results["results"].items(): 159 | version = results["versions"][task_name] 160 | for metric_name, metric_res in task_res.items(): 161 | flat_results[f"{task_name}-v{version}/{metric_name}"] = float(metric_res) 162 | 163 | dumped = json.dumps(results, indent=2) 164 | print(f"step {step} val results: {dumped}") 165 | wandb.log(flat_results, step) 166 | step += 1 167 | 168 | pbar.set_postfix({'loss': loss, 'last_loss': last_loss}) 169 | pbar.update() 170 | --------------------------------------------------------------------------------