├── .github
└── workflows
│ └── lint_python.yml
├── .gitignore
├── CITATION.bib
├── LICENSE.txt
├── README.md
├── arithmetic
└── gen_sums.py
├── benchmarks.md
├── colab_demo.ipynb
├── commonsenseqa
├── dev_rand_split.jsonl
├── prompts.txt
├── prompts_answer_key.txt
├── prompts_direct.txt
├── prompts_direct_answer_key.txt
├── test_rand_split_no_answers.jsonl
└── train_rand_split.jsonl
├── configs
├── iterative_base
│ └── base.json
└── qa_base.json
├── create_finetune_tfrecords.py
├── data
└── qa.val.index
├── device_inference.py
├── device_serve.py
├── device_train.py
├── docker
├── .env
├── Dockerfile
├── README.md
├── compose-proxy.yaml
├── docker-compose.yaml
├── main.py
├── nginx_proxyvm.conf
├── ops.py
├── payloads.py
└── start.sh
├── eval_harness.py
├── gsm
├── dev_rand_split.jsonl
├── prompts.txt
├── prompts_answer_key.txt
├── prompts_direct_answer_key.txt
└── train_rand_split.jsonl
├── howto_finetune.md
├── iteration_train.py
├── mesh_transformer
├── TPU_cluster.py
├── __init__.py
├── build_model.py
├── checkpoint.py
├── layers.py
├── sampling.py
├── train_actor.py
├── transformer_shard.py
└── util.py
├── ray_tpu.py
├── requirements.txt
├── resharding_example.py
├── scripts
├── create_serve_tpu.sh
├── deploy_server.sh
├── init_ray.sh
├── init_ray_v2.sh
└── init_serve.sh
├── setup.py
├── slim_model.py
├── tasks
├── __init__.py
├── eval_harness.py
└── util.py
├── tfrecord_loader.py
├── to_hf_weights.py
└── train.py
/.github/workflows/lint_python.yml:
--------------------------------------------------------------------------------
1 | name: lint_python
2 | on: [pull_request, push]
3 | jobs:
4 | lint_python:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v2
8 | - uses: actions/setup-python@v2
9 | - run: pip install bandit black codespell flake8 isort mypy pytest pyupgrade safety
10 | - run: bandit --recursive --skip B101 . || true # B101 is assert statements
11 | - run: black --check . || true
12 | - run: codespell
13 | - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
14 | - run: flake8 . --count --exit-zero --max-complexity=19 --max-line-length=245 --show-source --statistics
15 | - run: isort --check-only --profile black . || true
16 | - run: pip install -r requirements.txt
17 | - run: mypy --ignore-missing-imports . || true
18 | - run: pytest . || true
19 | - run: pytest --doctest-modules . || true
20 | - run: shopt -s globstar && pyupgrade --py36-plus **/*.py || true
21 | - run: safety check
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.pprof
3 | /ckpt/
4 | wandb
5 | data/*
6 | !data/qa.val.index
7 | configs/*
8 | !configs/iterative_base/base.json
9 | !configs/iterative_base
10 |
11 | commonsenseqa/*
12 | !commonsenseqa/*.py
13 | !commonsenseqa/prompts*.txt
14 | !commonsenseqa/*.jsonl
15 |
16 | gsm/*
17 | !gsm/*.py
18 | !gsm/*.jsonl
19 | !gsm/prompts*.txt
20 |
21 | arithmetic/*
22 | !arithmetic/gen_sums.py
23 | iterative*.txt
24 | .vscode/settings.json
25 | !configs/qa_base.json
26 | result_logs
27 | result_logs/*
28 | !configs/qa_base.json
29 |
--------------------------------------------------------------------------------
/CITATION.bib:
--------------------------------------------------------------------------------
1 | @inproceedings{
2 | zelikman2022star,
3 | title={{ST}aR: Bootstrapping Reasoning With Reasoning},
4 | author={Eric Zelikman and Yuhuai Wu and Jesse Mu and Noah Goodman},
5 | booktitle={Advances in Neural Information Processing Systems},
6 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
7 | year={2022},
8 | url={https://openreview.net/forum?id=_3ELRdg2sgI}
9 | }
10 |
11 | @misc{mesh-transformer-jax,
12 | author = {Wang, Ben},
13 | title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
14 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
15 | year = 2021,
16 | month = May
17 | }
18 |
19 | @misc{gpt-j,
20 | author = {Wang, Ben and Komatsuzaki, Aran},
21 | title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
22 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
23 | year = 2021,
24 | month = May
25 | }
26 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | https://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | Copyright 2021 Ben Wang
179 |
180 | Licensed under the Apache License, Version 2.0 (the "License");
181 | you may not use this file except in compliance with the License.
182 | You may obtain a copy of the License at
183 |
184 | https://www.apache.org/licenses/LICENSE-2.0
185 |
186 | Unless required by applicable law or agreed to in writing, software
187 | distributed under the License is distributed on an "AS IS" BASIS,
188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189 | See the License for the specific language governing permissions and
190 | limitations under the License.
191 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Table of contents
2 | 1. [STaR](#star)
3 | 2. [Mesh Transformer JAX](#mesh-transformer-jax)
4 | 1. [Updates](#updates)
5 | 3. [Pretrained Models](#pretrained-models)
6 | 1. [GPT-J-6B](#gpt-j-6b)
7 | 1. [Links](#links)
8 | 2. [Acknowledgments](#acknowledgments)
9 | 3. [License](#license)
10 | 4. [Model Details](#model-details)
11 | 5. [Zero-Shot Evaluations](#zero-shot-evaluations)
12 | 4. [Architecture and Usage](#architecture-and-usage)
13 | 1. [Fine-tuning](#fine-tuning)
14 | 2. [JAX Dependency](#jax-dependency)
15 | 5. [TODO](#todo)
16 |
17 | # STaR
18 | Code for [STaR: Bootstrapping Reasoning With Reasoning (NeurIPS 2022)](https://openreview.net/forum?id=_3ELRdg2sgI). This library is built on top of [mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax) and incorporates masked training from [this repo](https://github.com/albertqjiang/mesh-transformer-jax). In order to run it, launch `iteration_train.py` with any desired arguments. The README is left mostly unchanged, as `iteration_train` largely wraps around `device_train.py`, `device_inference.py`, and `create_finetune_tfrecords.py`.
19 |
20 | # Mesh Transformer JAX
21 |
22 | A haiku library using the `xmap`/`pjit` operators in JAX for model parallelism of transformers.
23 |
24 | The parallelism scheme is similar to the [original Megatron-LM](https://arxiv.org/abs/1909.08053), which is efficient
25 | on TPUs due to the high speed 2d mesh network. There is also an experimental model version which implements [ZeRo style
26 | sharding](https://arxiv.org/abs/1910.02054).
27 |
28 | This library is designed for scalability up to approximately 40B parameters on TPUv3s, beyond which different
29 | parallelism strategies should be used. See other implementations such as
30 | [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) or [DeepSpeed](https://github.com/microsoft/DeepSpeed) for that.
31 |
32 | One future direction for research is integrating this codebase with
33 | [swarm-jax](https://github.com/kingoflolz/swarm-jax), to achieve further scalability with pipeline parallelism.
34 |
35 | ## Updates
36 |
37 | **12-07-21**: Added [guide to fine tuning](howto_finetune.md)
38 |
39 | # Pretrained Models
40 |
41 | ## GPT-J-6B
42 |
43 | A 6 billion parameter, autoregressive text generation model trained on [The Pile](https://pile.eleuther.ai/).
44 |
45 | ### Links
46 |
47 | [Slim weights (bf16 weights only, for inference, 9GB)](https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd)
48 |
49 | [Full weights (including optimizer params, 61GB)](https://the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd)
50 |
51 | [Colab demo](http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb)
52 |
53 | [Web demo](https://6b.eleuther.ai/)
54 |
55 | [Aran's blog post](https://arankomatsuzaki.wordpress.com/2021/06/04/gpt-j/)
56 |
57 | ### Acknowledgments
58 |
59 | This project would not have been possible without compute generously provided by the
60 | [TPU Research Cloud](https://sites.research.google/trc/) with assistance from [EleutherAI](https://eleuther.ai/).
61 |
62 | Thanks to the Cloud TPU team at Google for providing early access to the Cloud TPU VM alpha
63 | ([now publicly available!](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms))
64 |
65 | Thanks to everyone who have helped out one way or another (listed alphabetically):
66 | - [Aran Komatsuzaki](https://twitter.com/arankomatsuzaki) for advice with experiment design and writing the blog posts.
67 | - [James Bradbury](https://twitter.com/jekbradbury) for valuable assistance with debugging JAX issues.
68 | - [Janko Prester](https://github.com/jprester) for creating the web demo frontend.
69 | - [Laurence Golding](https://github.com/researcher2) for adding some features to the web demo.
70 | - [Leo Gao](https://twitter.com/nabla_theta) for running zero shot evaluations for the baseline models for the table.
71 |
72 | ### License
73 | The weights of GPT-J-6B are licensed under version 2.0 of the Apache License.
74 |
75 | ### Model Details
76 |
77 | | Hyperparameter | Value |
78 | |-------------------|--------|
79 | | n_parameters | 6,053,381,344 |
80 | | n_layers | 28* |
81 | | d_model | 4,096 |
82 | | d_ff | 16,384 |
83 | | n_heads | 16 |
84 | | d_head | 256 |
85 | | n_ctx | 2,048 |
86 | | n_vocab | 50,257 (same tokenizer as GPT-2/3) |
87 | | position encoding | [Rotary position encodings (RoPE)](https://arxiv.org/abs/2104.09864) |
88 | | RoPE dimensions | [64](https://github.com/kingoflolz/mesh-transformer-jax/blob/f2aa66e0925de6593dcbb70e72399b97b4130482/mesh_transformer/layers.py#L223) |
89 |
90 | `*` each layer consists of one feedforward block and one self attention block
91 |
92 | The model consists of 28 layers with a model dimension of 4096, and a feedforward dimension of 16384. The model
93 | dimension is split into 16 heads, each with a dimension of 256. Rotary position encodings (RoPE) was applied to 64
94 | dimensions of each head. The model is trained with a tokenization vocabulary of 50257, using the same set of BPEs as
95 | GPT-2/GPT-3.
96 |
97 | ### Zero-Shot Evaluations
98 |
99 | Models roughly sorted by performance, or by FLOPs if not available.
100 |
101 | | Model | Weights | Training FLOPs | LAMBADA PPL ↓ | LAMBADA Acc ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Dataset Size (GB) |
102 | |-----------------|---------|----------------|--- |--- |--- |--- |--- |-------------------|
103 | | Chance | ✔ | 0 | ~a lot | ~0% | 50% | 25% | 25% | 0 |
104 | | GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- |
105 | | GPT-2-1.5B | ✔ | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 |
106 | | GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 825 |
107 | | Megatron-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 |
108 | | GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 825 |
109 | | GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 |
110 | | GPT-3-Babbage‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- |
111 | | Megatron-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 |
112 | | GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 |
113 | | Megatron-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
114 | | **GPT-J-6B**‡ | ✔ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 825 |
115 | | GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 |
116 | | GPT-3-Curie‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- |
117 | | GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 |
118 | | GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 |
119 | | GPT-3-Davinci‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
120 | | Gopher 230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 1344 |
121 | | MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- |
122 |
123 | `*` represents evaluation numbers reported by their respective authors, all other numbers are provided by
124 | running the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/) either with the released
125 | weights or with API access. Due to subtle implementation differences as well as different zero shot task framing, these
126 | might not be directly comparable. See [this blog post](https://www.eleuther.ai/research-log/gpt3-model-sizes/) for more
127 | details.
128 |
129 | `†` The Megatron-11B model provides no comparable metrics, and several implementations using the released weights do not
130 | reproduce the generation quality and evaluations. (see [1](https://github.com/huggingface/transformers/pull/10301)
131 | [2](https://github.com/pytorch/fairseq/issues/2358) [3](https://github.com/pytorch/fairseq/issues/2719))
132 | Thus, evaluation was not attempted.
133 |
134 | `‡` These models have been trained with data which contains possible test set contamination. The OpenAI GPT-3 models
135 | failed to deduplicate training data for certain test sets, while the GPT-Neo models as well as this one is
136 | trained on The Pile, which has not been deduplicated against any test sets.
137 |
138 | # Architecture and Usage
139 |
140 | Most scripts in this repository are designed to be run on TPUs, which under the
141 | [TPU-VM architecture](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) are virtual machines
142 | which can run arbitrary code. Most scripts are designed to spin up a TPU, SSH into it to set up the dependencies
143 | and copy code over from the local directory, and then start a [Ray](https://github.com/ray-project/ray.git) worker
144 | which can accept RPC calls.
145 |
146 | The TPUVMs handles running model training steps and evaluation, checkpoint save and loading, while the driver python
147 | program handles data loading and general orchestration (such as when to save checkpoints etc).
148 |
149 | This means that most scripts (`train.py`, `eval_harness.py` etc) expect to be running on a GCE virtual machine in the
150 | same region as the TPUs, to minimize RPC latency and data transfer cost. Other scripts
151 | (usually ones which don't take a `--tpu` argument, such as `device_sample.py`, `device_serve.py` or `device_train.py`)
152 | expect to be run directly on a TPUVM. The device_* scripts **only work on a v3-8** and not on larger pods.
153 |
154 | Furthermore, there is an example (`resharding_example.py`) of how to convert the provided checkpoints (which have 8
155 | shards in the case of GPT-J-6B) down to a smaller number, such as for when running on GPU(s).
156 |
157 | ### Fine-tuning
158 |
159 | To fine-tune the model, run `device_train.py` on a TPU VM. Using a TPU v3-8, you can fine-tune at a rate of ~5000
160 | tokens/second, which should be sufficient for small-to-medium-size datasets.
161 |
162 | Please read the [step by step guide](howto_finetune.md) for thorough fine-tuning instructions.
163 |
164 | ### JAX Dependency
165 |
166 | Note this library has some specific requirements for JAX version. Specifically, to use the v1 models (including
167 | GPT-J 6B), `jax==0.2.12` is required. This in turn depends on `jaxlib==0.1.68`. **If this is not done, you will get
168 | cryptic xmap errors**
169 |
170 | However, to use the v2 model code (no publicly released weights), the newest JAX version can be used.
171 | # Citation
172 |
173 | To cite this repository:
174 | ```
175 | @inproceedings{
176 | zelikman2022star,
177 | title={{ST}aR: Bootstrapping Reasoning With Reasoning},
178 | author={Eric Zelikman and Yuhuai Wu and Jesse Mu and Noah Goodman},
179 | booktitle={Advances in Neural Information Processing Systems},
180 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
181 | year={2022},
182 | url={https://openreview.net/forum?id=_3ELRdg2sgI}
183 | }
184 | ```
185 |
186 | To cite the base repository:
187 | ```
188 | @misc{mesh-transformer-jax,
189 | author = {Wang, Ben},
190 | title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
191 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
192 | year = 2021,
193 | month = May
194 | }
195 | ```
196 |
197 | To cite the weights of GPT-J-6B:
198 | ```
199 | @misc{gpt-j,
200 | author = {Wang, Ben and Komatsuzaki, Aran},
201 | title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
202 | howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
203 | year = 2021,
204 | month = May
205 | }
206 | ```
207 |
208 | If you use this repository or any of the pretrained weights to do something cool, we would love to hear about it.
209 | Feel free to open a github issue or reach out over email (in profile).
210 |
211 | # TODO
212 | - [x] disentangle heads and shards
213 | - [x] test/benchmark on TPU
214 | - [x] implement gradient checkpointing
215 | - [x] fix initialization
216 | - [x] mixed precision
217 | - [x] deal with preemptible TPUs
218 | - [x] test and validate generation
219 | - [x] shard activations instead of replicating for memory efficiency (in v2)
220 | - [x] support ZeRO style sharding (in v2)
221 |
--------------------------------------------------------------------------------
/arithmetic/gen_sums.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 |
4 | def to_list(num):
5 | list_str = list(str(num))
6 | return list_str
7 |
8 | def to_str(num):
9 | return " ".join(to_list(num))
10 |
11 | # n_samples = 10000
12 | n_samples = 2000
13 | context_examples = 50
14 | examples_per_prompt = 1
15 | # examples_per_prompt = 10
16 |
17 | val_start = 6
18 | max_digits = 10
19 | # include_scrachpad = True
20 | include_scrachpad = False
21 | # fixed_examples = True
22 | fixed_examples = False
23 | randomized_digits = False
24 |
25 | def gen_examples(n_examples, digits, randomized_digits=False, min_digit=1, max_digit=8):
26 | complete_str = []
27 | for _ in range(n_examples):
28 | if randomized_digits:
29 | digits = random.randrange(min_digit + 1, max_digit)
30 | first_number = random.randrange(int(10 ** digits), int(10 ** (digits + 1)))
31 | second_number = random.randrange(int(10 ** digits), int(10 ** (digits + 1)))
32 | input_sum = f'{to_str(first_number)} + {to_str(second_number)}'
33 | resultant_str = f'Input:\n{input_sum}\nTarget:\n'
34 | if include_scrachpad:
35 | scratch_pad = f'\n{input_sum} , C: 0\n'
36 | carry = 0
37 | running_sum = ''
38 | initial = True
39 | for first_digit, second_digit in reversed(list(zip(
40 | to_list(first_number), to_list(second_number)
41 | ))):
42 | dig_sum = int(first_digit) + int(second_digit) + carry
43 | if not initial:
44 | scratch_pad += f'{first_digit} + {second_digit} , {running_sum}C: {carry}\n'
45 | carry = int(dig_sum >= 10)
46 | running_sum = f'{dig_sum % 10} {running_sum}'
47 | initial = False
48 | scratch_pad += f', {running_sum}C: {carry}\n'
49 | scratch_pad += f'{carry} {running_sum}'.strip() + '\n'
50 | scratch_pad += '\n'
51 | resultant_str += scratch_pad
52 | resultant_str += to_str(first_number + second_number)
53 | resultant_str += '\n\n'
54 | complete_str.append(resultant_str)
55 | return complete_str
56 |
57 |
58 | if fixed_examples and randomized_digits:
59 | few_shot_str = gen_examples(examples_per_prompt - 1, None, randomized_digits, 1, val_start - 1)
60 | for digits in range(max_digits):
61 | folder_name = 'val' if digits + 1 >= val_start else 'train'
62 | if include_scrachpad:
63 | folder_name += '_scratch'
64 | if fixed_examples and not randomized_digits:
65 | few_shot_str = gen_examples(context_examples, digits, randomized_digits)
66 | else:
67 | folder_name += '_direct'
68 | if not os.path.isdir(folder_name):
69 | os.mkdir(folder_name)
70 | complete_str = ''
71 | for _ in range(n_samples):
72 | n_gen_examples = 1 if fixed_examples else examples_per_prompt
73 | if fixed_examples:
74 | complete_str += "".join(random.sample(few_shot_str, examples_per_prompt - 1))
75 | complete_str += "".join(gen_examples(n_gen_examples, digits))
76 | complete_str += '<|endoftext|>'
77 |
78 | with open(f'{folder_name}/{digits + 1}.txt', 'w') as f:
79 | f.write(complete_str)
80 |
81 | if digits + 2 == val_start or digits + 1 == max_digits:
82 | os.system(f'python3 ../create_finetune_tfrecords.py ./{folder_name}/ arithmetic_{folder_name}')
83 |
--------------------------------------------------------------------------------
/benchmarks.md:
--------------------------------------------------------------------------------
1 | Note that everything on this page is quite outdated, and are only roughly accurate when considering new features
2 | such as RoPE
3 |
4 | # Benchmarks (v3-8)
5 |
6 | (see `tpuv38_example.py`):
7 |
8 | ## ~2.7B model
9 | ```
10 | Initialized in 121.842s
11 | Total parameters: 2722382080
12 | Compiled in 49.0534s
13 | it: 0, loss: 20.311113357543945
14 |
15 | it: 90, loss: 3.987450361251831
16 | 100 steps in 109.385s
17 | effective flops (not including attn): 2.4466e+14
18 | ```
19 |
20 | ## ~4.8B model
21 | ```
22 | Initialized in 101.016s
23 | Total parameters: 4836720896
24 | Compiled in 52.7404s
25 | it: 0, loss: 4.632925987243652
26 |
27 | it: 40, loss: 3.2406811714172363
28 | 50 steps in 102.559s
29 | effective flops (not including attn): 2.31803e+14
30 | ```
31 |
32 | ## 10B model
33 | ```
34 | Initialized in 152.762s
35 | Total parameters: 10073579776
36 | Compiled in 92.6539s
37 | it: 0, loss: 5.3125
38 |
39 | it: 40, loss: 3.65625
40 | 50 steps in 100.235s
41 | effective flops (not including attn): 2.46988e+14
42 | ```
43 |
44 | # Benchmarks (v3-32)
45 |
46 | (see `eval.py`):
47 |
48 | ## 6B model
49 | ```
50 | "layers": 28,
51 | "d_model": 4096,
52 | "n_heads": 16,
53 | "n_vocab": 50400,
54 |
55 | "seq": 2048,
56 | "cores_per_replica": 8,
57 | "per_replica_batch": 1,
58 | "gradient_accumulation_steps": 8,
59 |
60 | params: 6053381856
61 | 32 iters done in 107.935s, avg 3.37298s, 0.296473/s
62 | effective flops (not including attn): 7.05692e+14
63 | MXU flops: 1.04523e+15
64 | ```
65 |
66 | ## Note that the below models do not currently work
67 | They require a larger degree of model parallelism than is currently implemented, but benchmark numbers should be
68 | reasonably representative.
69 |
70 | ## 13B model
71 | ```
72 | "layers": 28,
73 | "d_model": 6144,
74 | "n_heads": 32,
75 | "n_vocab": 50400,
76 |
77 | "seq": 2048,
78 | "cores_per_replica": 16,
79 | "per_replica_batch": 1,
80 | "gradient_accumulation_steps": 16,
81 |
82 | params: 13312183008
83 | 32 iters done in 250.86s, avg 7.83937s, 0.127561/s
84 | effective flops (not including attn): 6.67727e+14
85 | MXU flops: 9.80066e+14
86 | ```
87 |
88 | ## 23B model
89 | ```
90 | "layers": 28,
91 | "d_model": 8192,
92 | "n_heads": 32,
93 | "n_vocab": 50400,
94 |
95 | "seq": 2048,
96 | "cores_per_replica": 32,
97 | "per_replica_batch": 1,
98 | "gradient_accumulation_steps": 32,
99 |
100 | params: 23398107360
101 | 16 iters done in 221.33s, avg 13.8331s, 0.0722902/s
102 | effective flops (not including attn): 6.65107e+14
103 | MXU flops: 9.88548e+14
104 | ```
--------------------------------------------------------------------------------
/commonsenseqa/prompts.txt:
--------------------------------------------------------------------------------
1 | Q: What do people use to absorb extra ink from a fountain pen?
2 | Answer Choices:
3 | (a) shirt pocket
4 | (b) calligrapher's hand
5 | (c) inkwell
6 | (d) desk drawer
7 | (e) blotter
8 | A: The answer must be used to absorb extra ink. Blotters are designed to absorb liquids. Therefore, the answer is blotter (e).
9 |
10 | Q: What home entertainment equipment requires cable?
11 | Answer Choices:
12 | (a) radio shack
13 | (b) substation
14 | (c) television
15 | (d) cabinet
16 | (e) desk
17 | A: The answer must require cable. Cable is used to provide satellite channels to televisions. Therefore, the answer is television (c).
18 |
19 | Q: The fox walked from the city into the forest, what was it looking for?
20 | Answer Choices:
21 | (a) pretty flowers
22 | (b) hen house
23 | (c) natural habitat
24 | (d) storybook
25 | (e) dense forest
26 | A: The answer must be a reason for a fox to go into the forest. The forest is a fox's natural habitat. Therefore, the answer is natural habitat (c).
27 |
28 | Q: Sammy wanted to go to where the people were. Where might he go?
29 | Answer Choices:
30 | (a) populated areas
31 | (b) race track
32 | (c) desert
33 | (d) apartment
34 | (e) roadblock
35 | A: The answer must be a place with many people. Populated areas, by definition, have a lot of people. Therefore, the answer is populated areas (a).
36 |
37 | Q: Where do you put your grapes just before checking out?
38 | Answer Choices:
39 | (a) mouth
40 | (b) grocery cart
41 | (c) super market
42 | (d) fruit basket
43 | (e) fruit market
44 | A: The answer should be the place where grocery items are placed before checking out. Of the above choices, grocery cart makes the most sense for holding grocery items. Therefore, the answer is grocery cart (b).
45 |
46 | Q: Google Maps and other highway and street GPS services have replaced what?
47 | Answer Choices:
48 | (a) united states
49 | (b) mexico
50 | (c) countryside
51 | (d) atlas
52 | (e) oceans
53 | A: The answer must be something that used to do what Google Maps and GPS services do, which is give directions. Atlases were also used to give directions. Therefore, the answer is atlas (d).
54 |
55 | Q: Before getting a divorce, what did the wife feel who was doing all the work?
56 | Answer Choices:
57 | (a) harder
58 | (b) anguish
59 | (c) bitterness
60 | (d) tears
61 | (e) sadness
62 | A: The answer should be a feeling which would cause someone who was doing all the work to get divorced. If someone feels bitter towards their spouse, they are likely to want a divorce. Therefore, the answer is bitterness (c).
--------------------------------------------------------------------------------
/commonsenseqa/prompts_answer_key.txt:
--------------------------------------------------------------------------------
1 | Q: What do people use to absorb extra ink from a fountain pen?
2 | Answer Choices:
3 | (a) shirt pocket
4 | (b) calligrapher's hand
5 | (c) inkwell
6 | (d) desk drawer
7 | (e) blotter (CORRECT)
8 | A: The answer must be used to absorb extra ink. Blotters are designed to absorb liquids. Therefore, the answer is blotter (e).
9 |
10 | Q: What home entertainment equipment requires cable?
11 | Answer Choices:
12 | (a) radio shack
13 | (b) substation
14 | (c) television (CORRECT)
15 | (d) cabinet
16 | (e) desk
17 | A: The answer must require cable. Cable is used to provide satellite channels to televisions. Therefore, the answer is television (c).
18 |
19 | Q: The fox walked from the city into the forest, what was it looking for?
20 | Answer Choices:
21 | (a) pretty flowers
22 | (b) hen house
23 | (c) natural habitat (CORRECT)
24 | (d) storybook
25 | (e) dense forest
26 | A: The answer must be a reason for a fox to go into the forest. The forest is a fox's natural habitat. Therefore, the answer is natural habitat (c).
27 |
28 | Q: Sammy wanted to go to where the people were. Where might he go?
29 | Answer Choices:
30 | (a) populated areas (CORRECT)
31 | (b) race track
32 | (c) desert
33 | (d) apartment
34 | (e) roadblock
35 | A: The answer must be a place with many people. Populated areas, by definition, have a lot of people. Therefore, the answer is populated areas (a).
36 |
37 | Q: Where do you put your grapes just before checking out?
38 | Answer Choices:
39 | (a) mouth
40 | (b) grocery cart (CORRECT)
41 | (c) super market
42 | (d) fruit basket
43 | (e) fruit market
44 | A: The answer should be the place where grocery items are placed before checking out. Of the above choices, grocery cart makes the most sense for holding grocery items. Therefore, the answer is grocery cart (b).
45 |
46 | Q: Google Maps and other highway and street GPS services have replaced what?
47 | Answer Choices:
48 | (a) united states
49 | (b) mexico
50 | (c) countryside
51 | (d) atlas (CORRECT)
52 | (e) oceans
53 | A: The answer must be something that used to do what Google Maps and GPS services do, which is give directions. Atlases were also used to give directions. Therefore, the answer is atlas (d).
54 |
55 | Q: Before getting a divorce, what did the wife feel who was doing all the work?
56 | Answer Choices:
57 | (a) harder
58 | (b) anguish
59 | (c) bitterness (CORRECT)
60 | (d) tears
61 | (e) sadness
62 | A: The answer should be a feeling which would cause someone who was doing all the work to get divorced. If someone feels bitter towards their spouse, they are likely to want a divorce. Therefore, the answer is bitterness (c).
--------------------------------------------------------------------------------
/commonsenseqa/prompts_direct.txt:
--------------------------------------------------------------------------------
1 | Q: What do people use to absorb extra ink from a fountain pen?
2 | Answer Choices:
3 | (a) shirt pocket
4 | (b) calligrapher's hand
5 | (c) inkwell
6 | (d) desk drawer
7 | (e) blotter
8 | A: (e).
9 |
10 | Q: What home entertainment equipment requires cable?
11 | Answer Choices:
12 | (a) radio shack
13 | (b) substation
14 | (c) television
15 | (d) cabinet
16 | (e) desk
17 | A: (c).
18 |
19 | Q: The fox walked from the city into the forest, what was it looking for?
20 | Answer Choices:
21 | (a) pretty flowers
22 | (b) hen house
23 | (c) natural habitat
24 | (d) storybook
25 | (e) dense forest
26 | A: (c).
27 |
28 | Q: Sammy wanted to go to where the people were. Where might he go?
29 | Answer Choices:
30 | (a) populated areas
31 | (b) race track
32 | (c) desert
33 | (d) apartment
34 | (e) roadblock
35 | A: (a).
36 |
37 | Q: Where do you put your grapes just before checking out?
38 | Answer Choices:
39 | (a) mouth
40 | (b) grocery cart
41 | (c) super market
42 | (d) fruit basket
43 | (e) fruit market
44 | A: (b).
45 |
46 | Q: Google Maps and other highway and street GPS services have replaced what?
47 | Answer Choices:
48 | (a) united states
49 | (b) mexico
50 | (c) countryside
51 | (d) atlas
52 | (e) oceans
53 | A: (d).
54 |
55 | Q: Before getting a divorce, what did the wife feel who was doing all the work?
56 | Answer Choices:
57 | (a) harder
58 | (b) anguish
59 | (c) bitterness
60 | (d) tears
61 | (e) sadness
62 | A: (c).
--------------------------------------------------------------------------------
/commonsenseqa/prompts_direct_answer_key.txt:
--------------------------------------------------------------------------------
1 | Q: What do people use to absorb extra ink from a fountain pen?
2 | Answer Choices:
3 | (a) shirt pocket
4 | (b) calligrapher's hand
5 | (c) inkwell
6 | (d) desk drawer
7 | (e) blotter (CORRECT)
8 | A: (e).
9 |
10 | Q: What home entertainment equipment requires cable?
11 | Answer Choices:
12 | (a) radio shack
13 | (b) substation
14 | (c) television (CORRECT)
15 | (d) cabinet
16 | (e) desk
17 | A: (c).
18 |
19 | Q: The fox walked from the city into the forest, what was it looking for?
20 | Answer Choices:
21 | (a) pretty flowers
22 | (b) hen house
23 | (c) natural habitat (CORRECT)
24 | (d) storybook
25 | (e) dense forest
26 | A: (c).
27 |
28 | Q: Sammy wanted to go to where the people were. Where might he go?
29 | Answer Choices:
30 | (a) populated areas (CORRECT)
31 | (b) race track
32 | (c) desert
33 | (d) apartment
34 | (e) roadblock
35 | A: (a).
36 |
37 | Q: Where do you put your grapes just before checking out?
38 | Answer Choices:
39 | (a) mouth
40 | (b) grocery cart (CORRECT)
41 | (c) super market
42 | (d) fruit basket
43 | (e) fruit market
44 | A: (b).
45 |
46 | Q: Google Maps and other highway and street GPS services have replaced what?
47 | Answer Choices:
48 | (a) united states
49 | (b) mexico
50 | (c) countryside
51 | (d) atlas (CORRECT)
52 | (e) oceans
53 | A: (d).
54 |
55 | Q: Before getting a divorce, what did the wife feel who was doing all the work?
56 | Answer Choices:
57 | (a) harder
58 | (b) anguish
59 | (c) bitterness (CORRECT)
60 | (d) tears
61 | (e) sadness
62 | A: (c).
--------------------------------------------------------------------------------
/configs/iterative_base/base.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": 28,
3 | "d_model": 4096,
4 | "n_heads": 16,
5 | "n_vocab": 50400,
6 | "norm": "layernorm",
7 | "pe": "rotary",
8 | "pe_rotary_dims": 64,
9 | "seq": 2048,
10 | "cores_per_replica": 8,
11 | "per_replica_batch": 1,
12 | "gradient_accumulation_steps": 2,
13 | "warmup_steps": 100,
14 | "anneal_steps": 300000,
15 | "lr": 1e-06,
16 | "end_lr": 1e-06,
17 | "weight_decay": 0.0,
18 | "total_steps": 383500,
19 | "tpu_size": 8,
20 | "bucket": "checkpoint-bucket",
21 | "model_dir": "full_qa_4",
22 | "train_set": "qa_train_4.index",
23 | "val_set": {
24 | "index": "qa.val.index"
25 | },
26 | "eval_harness_tasks": [
27 | "lambada",
28 | "piqa",
29 | "hellaswag",
30 | "winogrande",
31 | "mathqa",
32 | "pubmedqa"
33 | ],
34 | "val_batches": 100,
35 | "val_every": 20,
36 | "ckpt_every": 20,
37 | "keep_every": 10000,
38 | "name": "full_6",
39 | "wandb_project": "full_6",
40 | "comment": "",
41 | "target_save_folder": "commonsenseqa/iterative_full/iterative_full_0",
42 | "target_save": "commonsenseqa/iterative_negative_fast/iterative_negative_fast_0/iterative_negative_fast_0.txt"
43 | }
44 |
--------------------------------------------------------------------------------
/configs/qa_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": 28,
3 | "d_model": 4096,
4 | "n_heads": 16,
5 | "n_vocab": 50400,
6 | "norm": "layernorm",
7 | "pe": "rotary",
8 | "pe_rotary_dims": 64,
9 | "seq": 1536,
10 | "cores_per_replica": 8,
11 | "per_replica_batch": 1,
12 | "gradient_accumulation_steps": 8,
13 | "warmup_steps": 100,
14 | "anneal_steps": 300000,
15 | "lr": 1e-06,
16 | "end_lr": 1e-06,
17 | "weight_decay": 0.0,
18 | "total_steps": 383500,
19 | "tpu_size": 8,
20 | "p_rationalization": 1.0,
21 | "bucket": "checkpoint-bucket",
22 | "model_dir": "full_qa_4",
23 | "train_set": "qa_train_4.index",
24 | "val_set": {
25 | "index": "qa.val.index"
26 | },
27 | "eval_harness_tasks": [
28 | "lambada",
29 | "piqa",
30 | "hellaswag",
31 | "winogrande",
32 | "mathqa",
33 | "pubmedqa"
34 | ],
35 | "val_batches": 100,
36 | "val_every": 10000,
37 | "ckpt_every": 10000,
38 | "keep_every": 10000,
39 | "name": "slow_grow_full_epoch_0",
40 | "wandb_project": "full_6",
41 | "comment": "",
42 | "target_save_folder": "commonsenseqa/iterative_full/iterative_full_0",
43 | "target_save": "commonsenseqa/slow_grow_full_epoch/slow_grow_full_epoch_0/slow_grow_full_epoch_0.txt"
44 | }
45 |
--------------------------------------------------------------------------------
/create_finetune_tfrecords.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | import random
5 |
6 | from pathlib import Path
7 | from typing import List
8 |
9 | import ftfy
10 | import tensorflow as tf
11 | from lm_dataformat import Reader
12 | from transformers import GPT2TokenizerFast
13 | from tqdm import tqdm
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser(description="""
18 | Converts a text dataset into the training data format expected by the model.
19 |
20 | Adapted from the script create_tfrecords.py in the gpt-neo repo.
21 |
22 | - Your text dataset:
23 | - can be provided as .txt files, or as an archive (.tar.gz, .xz, jsonl.zst).
24 | - can be one file or multiple
25 | - using a single large file may use too much memory and crash - if this occurs, split the file up into a few files
26 | - the model's end-of-text separator is added between the contents of each file
27 | - if the string '<|endoftext|>' appears inside a file, it is treated as the model's end-of-text separator (not the actual string '<|endoftext|>')
28 | - this behavior can be disabled with --treat-eot-as-text
29 |
30 | This script creates a single .tfrecords file as output
31 | - Why: the model's data loader ignores "trailing" data (< 1 batch) at the end of a .tfrecords file
32 | - this causes data loss if you have many .tfrecords files
33 | - This is probably not appropriate for very large datasets
34 | """, formatter_class=argparse.RawTextHelpFormatter)
35 | parser.add_argument(
36 | "input_path",
37 | type=str,
38 | help="Path to an input file, or a directory that contains the input files.",
39 | )
40 | parser.add_argument("name", type=str,
41 | help="Name of output file will be {name}_{seqnum}.tfrecords, where seqnum is total sequence count")
42 | parser.add_argument("--output-dir", type=str, default="", help="Output directory (default: current directory)")
43 |
44 | cleaning_args = parser.add_argument_group('data cleaning arguments')
45 |
46 | cleaning_args.add_argument("--normalize-with-ftfy", action="store_true", help="Normalize text with ftfy")
47 | cleaning_args.add_argument("--normalize-with-wikitext-detokenize",
48 | action="store_true", help="Use wikitext detokenizer")
49 | minu_help = "Exclude repetitive documents made up of < MIN_UNIQUE_TOKENS unique tokens. These can produce large gradients."
50 | minu_help += " Set <= 0 to disable. If enabled, 200 is a good default value. (Default: 0)"
51 | cleaning_args.add_argument("--min-unique-tokens", type=int, default=0,
52 | help=minu_help)
53 |
54 | shuffle_pack_args = parser.add_argument_group('data shuffling/packing arguments')
55 | repack_ep_help = "Repeat the data N_REPACK_EPOCHS times, shuffled differently in each repetition. Recommended for multi-epoch training (set this to your intended number of epochs)."
56 | shuffle_pack_args.add_argument("--n-repack-epochs",
57 | type=int, default=1,
58 | help=repack_ep_help
59 | )
60 | shuffle_pack_args.add_argument("--seed", type=int, default=10,
61 | help="random seed for shuffling data (default: 10)")
62 | shuffle_pack_args.add_argument("--preserve-data-order",
63 | default=False, action="store_true",
64 | help="Disables shuffling, so the input and output data have the same order.")
65 |
66 | misc_args = parser.add_argument_group('miscellaneous arguments')
67 | misc_args.add_argument("--verbose",
68 | default=False, action="store_true",
69 | help="Prints extra information, such as the text removed by --min-unique-tokens")
70 |
71 | args = parser.parse_args()
72 |
73 | # convert input_path to pathy
74 | args.input_path = Path(args.input_path)
75 |
76 | return args
77 |
78 |
79 | def get_files(input_path: Path) -> List[str]:
80 | supported_file_types = ["jsonl.zst", ".txt", ".xz", ".tar.gz"]
81 | if input_path.is_dir():
82 | # get all files with supported file types
83 | files = [list(Path(input_path).glob(f"*{ft}")) for ft in supported_file_types]
84 | # flatten list
85 | files = [f for sublist in files for f in sublist]
86 | assert files, f"No files with supported types found in directory: {input_path}"
87 | elif input_path.is_file():
88 | assert any(
89 | str(input_path).endswith(f_type) for f_type in supported_file_types
90 | ), f"Input file type must be one of: {supported_file_types}"
91 | files = [input_path]
92 | else:
93 | raise FileNotFoundError(f"No such file or directory: {input_path=}")
94 |
95 | return [str(f) for f in files]
96 |
97 |
98 | def wikitext_detokenizer(string):
99 | # contractions
100 | string = string.replace("s '", "s'")
101 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
102 | # number separators
103 | string = string.replace(" @-@ ", "-")
104 | string = string.replace(" @,@ ", ",")
105 | string = string.replace(" @.@ ", ".")
106 | # punctuation
107 | string = string.replace(" : ", ": ")
108 | string = string.replace(" ; ", "; ")
109 | string = string.replace(" . ", ". ")
110 | string = string.replace(" ! ", "! ")
111 | string = string.replace(" ? ", "? ")
112 | string = string.replace(" , ", ", ")
113 | # double brackets
114 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
115 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
116 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
117 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
118 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
119 | # miscellaneous
120 | string = string.replace("= = = =", "====")
121 | string = string.replace("= = =", "===")
122 | string = string.replace("= =", "==")
123 | string = string.replace(" " + chr(176) + " ", chr(176))
124 | string = string.replace(" \n", "\n")
125 | string = string.replace("\n ", "\n")
126 | string = string.replace(" N ", " 1 ")
127 | string = string.replace(" 's", "'s")
128 |
129 | return string
130 |
131 |
132 | def _int64_feature(value):
133 | """
134 | Returns an int64_list from a bool / enum / int / uint.
135 | """
136 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
137 |
138 |
139 | def write_to_file(writer, data):
140 | """
141 | writes data to tfrecord file
142 | """
143 | feature = {
144 | "text": _int64_feature(data)
145 | }
146 | tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
147 | writer.write(tf_example.SerializeToString())
148 |
149 |
150 | def write_tfrecord(sequences, fp):
151 | with tf.io.TFRecordWriter(fp) as writer:
152 | for seq in sequences:
153 | write_to_file(writer, seq)
154 |
155 |
156 | def split_list(l, n):
157 | # splits list/string into n size chunks
158 | return [l[i:i + n] for i in range(0, len(l), n)]
159 |
160 |
161 | def enforce_min_unique(seqs, min_unique_tokens, enc, verbose=False):
162 | for seq in tqdm(seqs, mininterval=1, smoothing=0, desc="enforce_min_unique_tokens"):
163 | if len(set(seq)) >= min_unique_tokens:
164 | yield seq
165 | elif verbose:
166 | text = enc.decode(seq)
167 | print(f"excluding with {len(set(seq))} unique tokens:\n\n{repr(text)}\n\n")
168 |
169 |
170 | def eot_splitting_generator(string_iterable, encoder):
171 | """
172 | Given strings, splits them internally on <|endoftext|> and yields (generally more) strings
173 | """
174 | for doc in string_iterable:
175 | for d in doc.split(encoder.eos_token):
176 | if len(d) > 0:
177 | yield d
178 |
179 |
180 | def prep_and_tokenize_generator(string_iterable, encoder, normalize_with_ftfy, normalize_with_wikitext_detokenize):
181 | """
182 | Given strings, does data cleaning / tokenization and yields arrays of tokens
183 | """
184 | for doc in string_iterable:
185 | if normalize_with_ftfy: # fix text with ftfy if specified
186 | doc = ftfy.fix_text(doc, normalization='NFKC')
187 | if normalize_with_wikitext_detokenize:
188 | doc = wikitext_detokenizer(doc)
189 | tokens = encoder.encode(doc) + [encoder.eos_token_id]
190 | yield tokens
191 |
192 |
193 | def file_to_tokenized_docs_generator(file_path, encoder, args):
194 | """
195 | Given a file path, reads the file and tokenizes the contents
196 |
197 | Yields token arrays of arbitrary, unequal length
198 | """
199 | reader = Reader(file_path)
200 | string_iterable = reader.stream_data(threaded=False)
201 | string_iterable = eot_splitting_generator(string_iterable, encoder)
202 |
203 | token_list_gen = prep_and_tokenize_generator(string_iterable,
204 | encoder,
205 | normalize_with_ftfy=args.normalize_with_ftfy,
206 | normalize_with_wikitext_detokenize=args.normalize_with_wikitext_detokenize
207 | )
208 | return token_list_gen
209 |
210 |
211 | def read_files_to_tokenized_docs(files, args, encoder):
212 | docs = []
213 |
214 | if args.preserve_data_order:
215 | files = sorted(files)
216 | else:
217 | random.shuffle(files)
218 |
219 | for f in tqdm(files, mininterval=10, smoothing=0, desc="reading/tokenizing files"):
220 | docs.extend(file_to_tokenized_docs_generator(f, encoder, args))
221 |
222 | if not args.preserve_data_order:
223 | # shuffle at individual document level
224 | random.shuffle(docs)
225 |
226 | return docs
227 |
228 |
229 | def arrays_to_sequences(token_list_iterable, sequence_length=1537):
230 | """
231 | Given token arrays of arbitrary lengths, concats/splits them into arrays of equal length
232 |
233 | Returns equal-length token arrays, followed by a a final array of trailing tokens (which may be shorter)
234 | """
235 | accum = []
236 | for l in token_list_iterable:
237 | accum.extend(l)
238 |
239 | if len(accum) > sequence_length:
240 | chunks = split_list(accum, sequence_length)
241 | yield from chunks[:-1]
242 | accum = chunks[-1]
243 |
244 | if len(accum) > 0:
245 | yield accum
246 |
247 |
248 | def chunk_and_finalize(arrays, args, encoder):
249 | sequences = list(arrays_to_sequences(arrays))
250 |
251 | full_seqs, trailing_data = sequences[:-1], sequences[-1]
252 |
253 | if args.min_unique_tokens > 0:
254 | full_seqs = list(enforce_min_unique(full_seqs, args.min_unique_tokens, encoder, args.verbose))
255 |
256 | if not args.preserve_data_order:
257 | random.shuffle(full_seqs)
258 |
259 | return full_seqs, trailing_data
260 |
261 |
262 | def create_tfrecords(files, args):
263 | GPT2TokenizerFast.max_model_input_sizes['gpt2'] = 1e20 # disables a misleading warning
264 | encoder = GPT2TokenizerFast.from_pretrained('gpt2')
265 |
266 | random.seed(args.seed)
267 |
268 | all_sequences_across_epochs = []
269 |
270 | docs = read_files_to_tokenized_docs(files, args, encoder)
271 |
272 | full_seqs, trailing_data = chunk_and_finalize(docs, args, encoder)
273 |
274 | all_sequences_across_epochs.extend(full_seqs)
275 |
276 | # ep 2+
277 | for ep_ix in range(1, args.n_repack_epochs):
278 | # re-shuffle
279 | if not args.preserve_data_order:
280 | random.shuffle(docs)
281 | full_seqs, trailing_data = chunk_and_finalize(docs, args, encoder)
282 | else:
283 | # if we're preserving data order, we can still "repack" by shifting everything
284 | # with the trailing data of the last epoch at the beginning
285 | seqs_with_prefix = [trailing_data] + full_seqs
286 | full_seqs, trailing_data = chunk_and_finalize(seqs_with_prefix, args, encoder)
287 |
288 | all_sequences_across_epochs.extend(full_seqs)
289 |
290 | # final
291 | print(f"dropped {len(trailing_data)} tokens of trailing data")
292 |
293 | total_sequence_len = len(all_sequences_across_epochs)
294 |
295 | fp = os.path.join(args.output_dir, f"{args.name}.tfrecords")
296 | write_tfrecord(all_sequences_across_epochs, fp)
297 |
298 |
299 | if __name__ == "__main__":
300 | args = parse_args()
301 |
302 | if args.output_dir:
303 | os.makedirs(args.output_dir, exist_ok=True)
304 | files = get_files(args.input_path)
305 | print(f"Creating TFRecords from files: {files}")
306 |
307 | results = create_tfrecords(files, args)
308 |
--------------------------------------------------------------------------------
/data/qa.val.index:
--------------------------------------------------------------------------------
1 | /home/username/STaR/commonsenseqa/qa_val.tfrecords
2 |
--------------------------------------------------------------------------------
/device_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import time
4 |
5 | import jax
6 | import numpy as np
7 | import optax
8 | import random
9 | import wandb
10 |
11 | from mesh_transformer import util
12 | from mesh_transformer.checkpoint import read_ckpt
13 | from mesh_transformer.sampling import nucleaus_sample
14 | from mesh_transformer.transformer_shard import CausalTransformer
15 | from smart_open import open as smart_open
16 | import transformers
17 | import tensorflow_datasets as tfds
18 | from mesh_transformer.util import clip_by_global_norm
19 | import jsonlines
20 | import pprint
21 | from tqdm import tqdm
22 |
23 | basic_open = open
24 | pp = pprint.PrettyPrinter(indent=2).pprint
25 |
26 |
27 | def eval_output(output, answers, context, example_classes, accuracy, target_save, tokenizer, show=False, direct=False, endoftext="<|endoftext|>"):
28 | successful_examples = []
29 | enum_outputs = enumerate(output[1][0][:, :, 0])
30 | for (idx, o), target, cur_base_context, example_class in zip(enum_outputs, answers, context, example_classes):
31 | cur_output = tokenizer.decode(o)
32 | output_numbers = cur_output.split('\n')
33 | if example_class not in accuracy:
34 | accuracy[example_class] = {'accurate': 0, 'total': 0}
35 | accuracy[example_class]['total'] += 1
36 | if len(output_numbers) == 0:
37 | continue
38 | try:
39 | if args.dataset_mode == "cqa":
40 | output_numbers = output_numbers[0]
41 | if "<|endoftext|>" in output_numbers:
42 | output_numbers = output_numbers.split("<|endoftext|>")[0]
43 | output_prediction = output_numbers[-3]
44 | elif args.dataset_mode == "gsm":
45 | output_prediction = ""
46 | for line_idx, line in enumerate(output_numbers):
47 | if "####" in line:
48 | output_numbers = "\n".join(output_numbers[:line_idx + 1])
49 | if "<|endoftext|>" in output_numbers:
50 | output_numbers = output_numbers.split("<|endoftext|>")[0]
51 | output_prediction = output_numbers.split("####")[-1].strip()
52 | break
53 | elif args.dataset_mode == "arithmetic":
54 | if len(output_numbers) == 0:
55 | continue
56 | elif "<|endoftext|>" in output_numbers:
57 | prediction_index = output_numbers.index("<|endoftext|>") - 1
58 | elif "" in output_numbers:
59 | prediction_index = output_numbers.index("") + 1
60 | if prediction_index == len(output_numbers):
61 | continue
62 | else:
63 | if direct and len(output_numbers) > 1:
64 | prediction_index = 1
65 | else:
66 | prediction_index = 0
67 | output_prediction = output_numbers[prediction_index]
68 |
69 | if "<|endoftext|>" in output_prediction:
70 | output_prediction = output_prediction.split("<|endoftext|>")[0]
71 |
72 | correct = output_prediction.lower() == target.lower()
73 | if correct:
74 | accuracy[example_class]['accurate'] += 1
75 | with basic_open(target_save, 'a+') as new_train_f:
76 | if args.dataset_mode == "cqa" or args.dataset_mode == "gsm":
77 | new_example = cur_base_context + output_numbers + endoftext
78 | elif args.dataset_mode == "arithmetic":
79 | if args.few_shot_train:
80 | raise NotImplementedError
81 | joined_output = "\n".join(output_numbers[:prediction_index + 1])
82 | if "<|endoftext|>" in joined_output:
83 | joined_output = joined_output.split("<|endoftext|>")[0]
84 | new_example = cur_base_context + joined_output + endoftext
85 | if show:
86 | print(new_example)
87 | print(new_example, file=new_train_f, end="")
88 | successful_examples.append(idx)
89 | except IndexError:
90 | pass
91 | return successful_examples
92 |
93 |
94 | def get_score(subcounts):
95 | if subcounts['total'] == 0:
96 | return 0
97 | return subcounts['accurate'] / subcounts['total']
98 |
99 | def question_to_context(data_example, hint=False, dataset_mode='cqa', direct=False):
100 | if dataset_mode == 'cqa':
101 | context = f"Q: {data_example['question']['stem']}\nAnswer Choices:\n"
102 | for choice in data_example['question']['choices']:
103 | if hint and (choice['label'].lower() == data_example['answerKey'].lower()):
104 | context += f"({choice['label'].lower()}) {choice['text']} (CORRECT)\n"
105 | else:
106 | context += f"({choice['label'].lower()}) {choice['text']}\n"
107 | context += "A:"
108 | elif dataset_mode == 'gsm':
109 | context = f"Q: {data_example['question']}"
110 | if hint:
111 | chosen_hint = data_example['answer']
112 | context += f" ({chosen_hint})"
113 | context += "\nA:"
114 | elif dataset_mode == "arithmetic":
115 | context = ""
116 | for example_split, next_example_split in zip(data_example.split('Target:')[:-1], data_example.split('Target:')[1:]):
117 | if direct and "" in example_split:
118 | context += example_split.split("")[-1]
119 | else:
120 | context += example_split
121 | context += "Target:"
122 | if hint:
123 | context += " " + next_example_split.split("\n")[-5]
124 | return context
125 |
126 |
127 | def examples_to_batch(data_examples, few_shot_prompts, seq, tokenizer, hint=False, direct=False, p_show_hint_save=0.1):
128 | batch = {
129 | "base_context": [],
130 | "initial_batch": [],
131 | "lengths": [],
132 | "padded_batch": [],
133 | "answers": [],
134 | "classes": []
135 | }
136 | for data_class, data_example in data_examples:
137 | batch['classes'].append(data_class)
138 | # Context, without the few-shot prompt
139 | hintless_base_context = question_to_context(data_example, hint=False, dataset_mode=args.dataset_mode, direct=direct)
140 | base_context = question_to_context(data_example, hint=hint, dataset_mode=args.dataset_mode, direct=direct)
141 | if args.dataset_mode == "arithmetic":
142 | few_shot_prompts = base_context.split("\n\n")[:-1]
143 | base_context = base_context.split("\n\n")[-1]
144 | hintless_base_context = hintless_base_context.split("\n\n")[-1]
145 | if random.random() < p_show_hint_save:
146 | hintless_base_context = base_context
147 | # We always want to act as if no hint was given
148 | if args.few_shot_train:
149 | if args.dataset_mode == "arithmetic":
150 | raise NotImplementedError
151 | else:
152 | save_context = "\n\n".join(commonsense_prompts) + "\n\n"
153 | save_context += hintless_base_context
154 | batch['base_context'].append(save_context)
155 | else:
156 | batch['base_context'].append(hintless_base_context)
157 | # Input tokens
158 | if args.no_prompt:
159 | context = ""
160 | else:
161 | context = "\n\n".join(few_shot_prompts) + "\n\n"
162 | context += base_context
163 | tokens = tokenizer.encode(context)
164 | batch['initial_batch'].append(tokens)
165 | # Input lengths
166 | batch['lengths'].append(len(tokens))
167 | # Padded tokens
168 | provided_ctx = len(tokens)
169 | pad_amount = max(seq - provided_ctx, 0)
170 | if provided_ctx > seq:
171 | tokens = tokens[-seq:]
172 | batch['padded_batch'].append(np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32))
173 | # Answer
174 | if args.dataset_mode == "arithmetic":
175 | if len(data_example.split("\n")) >= 3:
176 | target = data_example.split("\n")[-3]
177 | else:
178 | target = "invalid"
179 | elif args.dataset_mode == "cqa":
180 | target = data_example['answerKey']
181 | elif args.dataset_mode == "gsm":
182 | target = data_example['answer'].split("#### ")[-1]
183 | batch['answers'].append(target)
184 | batch["lengths"] = np.asarray(batch["lengths"], dtype=np.uint32)
185 | batch["padded_batch"] = np.array(batch["padded_batch"])
186 | return batch
187 |
188 |
189 | def eval_batch(examples, few_shot_prompts, seq, tok, gen_length, gen_params, accuracy, target_save, hint=False, direct=False):
190 | batch = examples_to_batch(examples, few_shot_prompts, seq, tok, hint=hint, direct=direct, p_show_hint_save=args.p_show_hint_save)
191 | output = network.generate(batch["padded_batch"], batch["lengths"], gen_length, gen_params)
192 | return eval_output(
193 | output, batch["answers"], batch["base_context"], batch["classes"], accuracy, target_save, tok, direct=direct
194 | )
195 |
196 |
197 | def load_model(params, ckpt_path, devices, mesh_shape):
198 | network = CausalTransformer(params)
199 | start = time.time()
200 | network.state = read_ckpt(network.state, ckpt_path, devices.shape[1])
201 | print(f"{ckpt_path} network loaded in {time.time() - start:.06}s on {jax.device_count()} devices")
202 | local_shards = max(jax.local_device_count() // mesh_shape[1], 1)
203 | del network.state["opt_state"]
204 | network.state = network.move_xmap(network.state, np.zeros(local_shards))
205 | return network
206 |
207 | def eval_examples(data_examples, few_shot_prompts, few_shot_prompts_hint, direct=False):
208 | accurate_count = {}
209 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
210 |
211 | main_examples, hint_examples = [], []
212 | pbar = tqdm(data_examples, smoothing=0)
213 | for data_example in pbar:
214 | main_examples.append(data_example)
215 | if len(main_examples) == args.eval_batch_size:
216 | successful_examples = eval_batch(
217 | main_examples, few_shot_prompts, seq, tokenizer,
218 | args.gen_length, gen_params, accurate_count, target_save, direct=direct
219 | )
220 | for example_idx, example in enumerate(main_examples):
221 | if (example_idx not in successful_examples) and (random.random() < params.get('p_rationalization', 1.)):
222 | hint_examples.append(example)
223 | main_examples = []
224 |
225 | if args.rationalize and len(hint_examples) >= args.eval_batch_size:
226 | cur_hint_examples = hint_examples[:args.eval_batch_size]
227 | cur_hint_examples = [
228 | (hint_example_key + "_r", hint_example) for hint_example_key, hint_example in cur_hint_examples
229 | ]
230 | eval_batch(
231 | cur_hint_examples, few_shot_prompts_hint, hint_seq, tokenizer,
232 | args.gen_length, gen_params, accurate_count, target_save, hint=True, direct=direct
233 | )
234 | hint_examples = hint_examples[args.eval_batch_size:]
235 | pbar.set_description(f"{split} " + ", ".join([
236 | f"{cur_key}: {get_score(cur_counts):0.4f}" for cur_key, cur_counts in accurate_count.items()
237 | ]))
238 | return accurate_count
239 |
240 | def get_ckpt_path(params, ckpt_step=-1):
241 | bucket = params["bucket"]
242 | model_dir = params["model_dir"]
243 | if ckpt_step == -1:
244 | ckpt_step = params["total_steps"]
245 | return f"gs://{bucket}/" + (f"step_{ckpt_step}/" if ckpt_step > 10000 else f"{model_dir}/step_{ckpt_step}/")
246 |
247 | def set_opt(params):
248 | params["sampler"] = nucleaus_sample
249 | opt = optax.chain(
250 | optax.scale(1 / params.get("gradient_accumulation_steps", 1)),
251 | clip_by_global_norm(1),
252 | optax.scale_by_adam(),
253 | optax.additive_weight_decay(0),
254 | optax.scale(-1),
255 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
256 | )
257 | params["optimizer"] = opt
258 |
259 | def get_dataset(args):
260 | if args.dataset_mode == "cqa":
261 | with jsonlines.open(f'commonsenseqa/{split}_rand_split.jsonl') as reader:
262 | dataset = [("cqa", example) for example in reader]
263 | elif args.dataset_mode == "gsm":
264 | with jsonlines.open(f'gsm/{split}_rand_split.jsonl') as reader:
265 | dataset = [("gsm", example) for example in reader]
266 | elif args.dataset_mode == "arithmetic":
267 | digit_range = list(range(1, 6))
268 | dataset = []
269 | for i in digit_range:
270 | with basic_open(f'arithmetic/train_scratch/{i}.txt') as f:
271 | dataset += [(str(i), example) for example in f.read().split('<|endoftext|>')]
272 |
273 | if split == "train":
274 | random.shuffle(dataset)
275 | dataset = dataset[:args.n_train_samples]
276 | return dataset
277 |
278 | def parse_args():
279 | # Parse command line arguments
280 | parser = argparse.ArgumentParser()
281 | parser.add_argument("--config", type=str, default=None, help="Config file location")
282 | parser.add_argument('--direct', action='store_true', help="Whether to use direct prediction, sans scratchpad")
283 | parser.add_argument('--rationalize', action='store_true', help="Whether to use rationalization")
284 | parser.add_argument('--no_prompt', action='store_true', help="Whether to remove prompts during eval")
285 | parser.add_argument('--few_shot_train', action='store_true', help="Whether to remove few-shot-prompts during train")
286 | parser.add_argument('--show_hint_prompt', action='store_true', help="Whether a hint prompt will be necessary")
287 | parser.add_argument("--split", type=str, default="dev", help="Split")
288 | parser.add_argument("--dataset_mode", type=str, default="cqa", help="Which dataset to run on")
289 | parser.add_argument("--n_train_samples", type=int, default=3000, help="Number of training examples")
290 | parser.add_argument("--gen_length", type=int, default=96, help="Generation length")
291 | parser.add_argument("--eval_batch_size", type=int, default=8, help="Size of batches in eval")
292 | parser.add_argument("--p_show_hint_save", type=float, default=0.0, help="Percent of rationalization hints to save")
293 | parser.add_argument("--ckpt_step", type=int, default=-1, help="Which checkpoint to eval. -1 means the final one")
294 | parser.add_argument("--eval_seq", type=int, default=-1, help="Sequence length. -1 means the one in the param file")
295 |
296 | args = parser.parse_args()
297 | return args
298 |
299 | def transform_example(example):
300 | new_example = {
301 | "question": example["english_text"],
302 | "answer": "#### " + example["ans"]
303 | }
304 | return new_example
305 |
306 | if __name__ == "__main__":
307 | args = parse_args()
308 | print(args)
309 | split = args.split
310 | params = json.load(smart_open(args.config))
311 |
312 | project = params.get("wandb_project", "mesh-transformer-jax")
313 | experiment_details = params["name"].split("_")
314 | wandb_name = "_".join(experiment_details[:-1])
315 | wandb_iteration = int(experiment_details[-1])
316 | wandb.init(project=project, name=wandb_name, config=params, resume=True)
317 |
318 | prompts_file = "prompts.txt" if not args.direct else "prompts_direct.txt"
319 | prompts_file = f"{args.dataset_mode}/{prompts_file}"
320 | if args.no_prompt:
321 | commonsense_prompts = []
322 | else:
323 | with basic_open(prompts_file) as prompts:
324 | commonsense_prompts = prompts.read().split("\n\n")
325 | prompts_hint_file = "prompts_answer_key.txt" if not args.direct else "prompts_direct_answer_key.txt"
326 | prompts_hint_file = f"{args.dataset_mode}/{prompts_hint_file}"
327 | if args.no_prompt and not args.show_hint_prompt:
328 | commonsense_prompts_hint = []
329 | else:
330 | with basic_open(prompts_hint_file) as prompts:
331 | commonsense_prompts_hint = prompts.read().split("\n\n")
332 |
333 | per_replica_batch = params["per_replica_batch"]
334 | cores_per_replica = params["cores_per_replica"]
335 | target_save = params["target_save"] if split != "dev" else f'{args.dataset_mode}/new_dev.txt'
336 | seq = params["seq"] if args.eval_seq == -1 else args.eval_seq
337 | hint_seq = seq
338 | set_opt(params)
339 |
340 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
341 | devices = np.array(jax.devices()).reshape(mesh_shape)
342 | ckpt_path = get_ckpt_path(params, args.ckpt_step)
343 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
344 | network = load_model(params, ckpt_path, devices, mesh_shape)
345 |
346 | dataset = get_dataset(args)
347 | dataset_keys = set([datakey for datakey, _ in dataset])
348 |
349 | total_batch = per_replica_batch * jax.device_count() // cores_per_replica * args.eval_batch_size
350 | gen_params = {"top_p": np.ones(total_batch) * 0.9, "temp": np.ones(total_batch) * 0.01}
351 |
352 | accurate_count = eval_examples(dataset, commonsense_prompts, commonsense_prompts_hint, direct=args.direct)
353 | for cur_key, cur_counts in accurate_count.items():
354 | print(f"{split}, {cur_key}, {get_score(cur_counts)}")
355 | wandb.log({f"{split}_{cur_key}_accuracy": get_score(cur_counts), "iteration": wandb_iteration})
356 |
--------------------------------------------------------------------------------
/device_serve.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import threading
4 | import time
5 | from queue import Queue, Empty
6 |
7 | import jax
8 | import numpy as np
9 | import optax
10 |
11 | from mesh_transformer import util
12 | from mesh_transformer.checkpoint import read_ckpt
13 | from mesh_transformer.sampling import nucleaus_sample
14 | from mesh_transformer.transformer_shard import CausalTransformer
15 | import transformers
16 | from smart_open import open
17 |
18 | from mesh_transformer.util import clip_by_global_norm
19 |
20 | from flask import Flask, request, make_response, jsonify
21 | app = Flask(__name__)
22 |
23 | requests_queue = Queue()
24 |
25 | """
26 | curl --header "Content-Type: application/json" \
27 | --request POST \
28 | --data '{"context":"eleutherai", "top_p": 0.9, "temp": 0.75}' \
29 | http://localhost:5000/complete
30 | """
31 |
32 |
33 | def _build_cors_prelight_response():
34 | response = make_response()
35 | response.headers.add("Access-Control-Allow-Origin", "*")
36 | response.headers.add('Access-Control-Allow-Headers', "*")
37 | response.headers.add('Access-Control-Allow-Methods', "*")
38 | return response
39 |
40 |
41 | def _corsify_actual_response(response):
42 | response.headers.add("Access-Control-Allow-Origin", "*")
43 | return response
44 |
45 |
46 | @app.route('/complete', methods=['POST', 'OPTIONS'])
47 | def complete():
48 | if request.method == "OPTIONS": # CORS preflight
49 | return _build_cors_prelight_response()
50 | elif request.method == "POST": # The actual request following the preflight
51 | content = request.json
52 |
53 | if requests_queue.qsize() > 100:
54 | return {"error": "queue full, try again later"}
55 |
56 | response_queue = Queue()
57 |
58 | requests_queue.put(({
59 | "context": content["context"],
60 | "top_p": float(content["top_p"]),
61 | "temp": float(content["temp"])
62 | }, response_queue))
63 |
64 | return _corsify_actual_response(jsonify({"completion": response_queue.get()}))
65 | else:
66 | raise RuntimeError("Weird - don't know how to handle method {}".format(request.method))
67 |
68 |
69 | def parse_args():
70 | # Parse command line arguments
71 | parser = argparse.ArgumentParser()
72 | parser.add_argument("--config", type=str, default=None, help="Config file location")
73 |
74 | args = parser.parse_args()
75 | return args
76 |
77 |
78 | if __name__ == "__main__":
79 | threading.Thread(target=app.run, kwargs={"port": 5000, "host": "0.0.0.0"}).start()
80 |
81 | args = parse_args()
82 | params = json.load(open(args.config))
83 |
84 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
85 | per_replica_batch = params["per_replica_batch"]
86 | cores_per_replica = params["cores_per_replica"]
87 |
88 | assert cores_per_replica <= 8
89 |
90 | bucket = params["bucket"]
91 | model_dir = params["model_dir"]
92 | layers = params["layers"]
93 | d_model = params["d_model"]
94 | n_heads = params["n_heads"]
95 | n_vocab = params["n_vocab"]
96 | seq = params["seq"]
97 | norm = params["norm"]
98 |
99 | params["sampler"] = nucleaus_sample
100 | opt = optax.chain(
101 | optax.scale(1 / gradient_accumulation_steps),
102 | clip_by_global_norm(1),
103 | optax.scale_by_adam(),
104 | optax.additive_weight_decay(0),
105 | optax.scale(-1),
106 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
107 | )
108 |
109 | params["optimizer"] = opt
110 |
111 | start = time.time()
112 | print(f"jax devices: {jax.device_count()}")
113 | print(f"jax runtime initialized in {time.time() - start:.06}s")
114 |
115 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
116 | devices = np.array(jax.devices()).reshape(mesh_shape)
117 |
118 | with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
119 | meta = json.load(f)
120 |
121 | ckpt_step = meta["checkpoints"][-1]
122 | print(f"using checkpoint {ckpt_step}")
123 |
124 | total_batch = per_replica_batch * jax.device_count() // cores_per_replica * 8
125 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
126 | network = CausalTransformer(params)
127 |
128 | start = time.time()
129 | network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
130 | print(f"network loaded in {time.time() - start:.06}s")
131 |
132 | local_shards = max(jax.local_device_count() // mesh_shape[1], 1)
133 | del network.state["opt_state"]
134 | network.state = network.move_xmap(network.state, np.zeros(local_shards))
135 |
136 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
137 |
138 | while True:
139 | all_ctx = []
140 | all_top_p = []
141 | all_temp = []
142 | all_q = []
143 | while len(all_ctx) < total_batch:
144 | try:
145 | o, q = requests_queue.get(block=False)
146 | all_ctx.append(o["context"])
147 | all_top_p.append(o["top_p"])
148 | all_temp.append(o["temp"])
149 | all_q.append(q)
150 | except Empty:
151 | if len(all_ctx):
152 | break
153 | else:
154 | time.sleep(0.01)
155 |
156 | start = time.time()
157 | while len(all_ctx) < total_batch:
158 | all_ctx.append("whatever")
159 | all_top_p.append(1)
160 | all_temp.append(1)
161 |
162 | all_tokenized = []
163 | all_length = []
164 | for ctx in all_ctx:
165 | padded_tokens = np.zeros(seq).astype(np.uint32)
166 | length = 0
167 |
168 | try:
169 | tokens = tokenizer.encode(ctx)
170 | provided_ctx = len(tokens)
171 | pad_amount = seq - provided_ctx
172 |
173 | pad_amount = max(pad_amount, 0)
174 |
175 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)[-seq:]
176 | length = len(tokens)
177 | except:
178 | print("oops exception")
179 |
180 | all_tokenized.append(padded_tokens)
181 | all_length.append(length)
182 |
183 | output = network.generate(np.array(all_tokenized),
184 | np.array(all_length),
185 | 256,
186 | {
187 | "top_p": np.array(all_top_p),
188 | "temp": np.array(all_temp)
189 | })
190 |
191 | for o, q in zip(output[1][0][:, :, 0], all_q):
192 | q.put(tokenizer.decode(o))
193 |
194 | print(f"completion done in {time.time() - start:06}s")
195 |
--------------------------------------------------------------------------------
/device_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import time
4 |
5 | import jax
6 | import numpy as np
7 | import optax
8 |
9 | import wandb
10 | from tqdm import tqdm
11 |
12 |
13 | from mesh_transformer import util
14 | from mesh_transformer.checkpoint import read_ckpt, write_ckpt
15 | from mesh_transformer.transformer_shard import CausalTransformer
16 | from tfrecord_loader import TFRecordNewInputs
17 | from smart_open import open
18 | from google.cloud import storage
19 | from google.cloud.exceptions import NotFound
20 |
21 | from mesh_transformer.util import clip_by_global_norm, additive_weight_decay
22 |
23 |
24 | def parse_args():
25 | # Parse command line arguments
26 | parser = argparse.ArgumentParser(description="""
27 | To use, download the full checkpoint archive, extract and upload to a GCS bucket, and set that as --tune-model-path
28 | Modify the config file:
29 | - set `model_dir` to where the checkpoints should be written during training
30 | - set `train_set`, `val_set` to index files for your data
31 | - set `tpu_size` to 8 (if on a v3-8)
32 | - set `warmup_steps`, `anneal_steps`, `lr`, `end_lr` to the lr schedule for your finetuning run
33 | - the global step will reset to 0, keep that in mind when writing your lr schedule
34 | - set `name` to specify the name of the Weights & Biases run
35 | - set `wandb_project` to specify the Weights & Biases project to log to
36 | To prepare data in the expected data format:
37 | - use the script `create_finetune_tfrecords.py` in this repo to create data in the expected format
38 | - upload the .tfrecords files to GCS
39 | - save their GCS paths to a index file under `data/`, see existing files for examples
40 | """,
41 | formatter_class=argparse.RawTextHelpFormatter)
42 | parser.add_argument("--config", type=str, default=None, help="Config file location")
43 | parser.add_argument("--tune-model-path", type=str, default=None, help="Base model to finetune")
44 | parser.add_argument("--fresh-opt", default=False, action="store_true", help="Use a newly initialized optimizer, ignoring any optimizer state saved in the base checkpoint")
45 |
46 | args = parser.parse_args()
47 | return args
48 |
49 |
50 | def save(network, step, bucket, path, mp, aux=None, keep_n=20, delete_old=True):
51 | assert path
52 | client = storage.Client()
53 |
54 | if aux is None:
55 | aux = {}
56 |
57 | try:
58 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
59 | meta = json.load(f)
60 | except:
61 | # create metadata file
62 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
63 | json.dump({
64 | "step": 0,
65 | "checkpoints": [],
66 | "aux": {}
67 | }, f)
68 |
69 | # do sharded checkpoint writing
70 | start = time.time()
71 | res = []
72 | for shard_id in range(mp):
73 | write_ckpt(network.state, f"gs://{bucket}/{path}/step_{step}/", shard_id)
74 |
75 | print(f"Wrote checkpoint in {time.time() - start:.06}s")
76 |
77 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
78 | meta = json.load(f)
79 |
80 | meta["step"] = step
81 | meta["checkpoints"].append(step)
82 | all_aux = meta.get("aux", {})
83 |
84 | while len(meta["checkpoints"]) > keep_n:
85 | ckpt_to_delete = meta["checkpoints"].pop(0)
86 |
87 | try:
88 | del all_aux[str(ckpt_to_delete)]
89 | except:
90 | print(f"failed to delete the aux state for {step}")
91 |
92 | if delete_old:
93 | print(f"deleting checkpoint {ckpt_to_delete}")
94 | for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"):
95 | # print(f"deleting {blob.name}")
96 | assert path in blob.name
97 | blob.delete()
98 | else:
99 | print(f"keeping checkpoint {ckpt_to_delete}")
100 |
101 | all_aux[step] = aux
102 | meta["aux"] = all_aux
103 |
104 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
105 | json.dump(meta, f)
106 |
107 |
108 | def get_mask_one_locations(single_sequence):
109 | separator = np.where(single_sequence == 25)[0]
110 | separator_locations = []
111 | endoftext = np.where(single_sequence == 50256)[0]
112 |
113 | for endoftext_token_location in endoftext:
114 | locations = separator[separator <= endoftext_token_location]
115 | if len(locations) > 0:
116 | separator_locations.append(locations[-1])
117 | else:
118 | separator_locations.append(0)
119 | separator = np.asarray(separator_locations)
120 | mask_one_locations = [(i+1, j) for i, j in zip(separator, endoftext)]
121 | return mask_one_locations
122 |
123 |
124 | def find_real_target_mask(single_sequence):
125 | mask_one_locations = get_mask_one_locations(single_sequence)
126 |
127 | mask = np.zeros(len(single_sequence))
128 | for i, j in mask_one_locations:
129 | np.put(mask, np.arange(i, j+1), 1.)
130 | return mask
131 |
132 |
133 | def train_step(network, data):
134 | tgt = data[:, :, 1:]
135 |
136 | all_masks = []
137 | for single_tgt in tgt:
138 | mask = find_real_target_mask(np.squeeze(single_tgt))
139 | all_masks.append(np.expand_dims(mask, (0, 1)))
140 | all_masks = np.concatenate(all_masks, axis=0)
141 |
142 | inputs = {
143 | "obs": data[:, :, :-1],
144 | "target": tgt,
145 | "mask": all_masks
146 | }
147 |
148 | loss, last_loss, grad_norm, grad_norm_micro = network.train(inputs)
149 |
150 | return (
151 | np.array(loss).mean(),
152 | np.array(last_loss).mean(),
153 | np.array(grad_norm).mean(),
154 | np.array(grad_norm_micro).mean(),
155 | )
156 |
157 |
158 | def eval_step(network, data):
159 | tgt = data[:, 1:]
160 |
161 | all_masks = []
162 | for single_tgt in tgt:
163 | mask_one_locations = get_mask_one_locations(single_tgt)
164 | mask = find_real_target_mask(np.squeeze(single_tgt))
165 | all_masks.append(np.expand_dims(mask, 0))
166 | all_masks = np.concatenate(all_masks, axis=0)
167 |
168 | inputs = {
169 | "obs": data[:, :-1],
170 | "target": tgt,
171 | "mask": all_masks
172 | }
173 |
174 | out = network.eval(inputs)
175 | loss = out["loss"]
176 | correct = out['correct']
177 |
178 | correct_sequences = 0
179 | total_sequences = 0
180 | for i, j in mask_one_locations:
181 | total_sequences += 1
182 | if np.all(correct[0][i:j+1]==1):
183 | correct_sequences += 1
184 | return np.array(loss).mean(), np.array(correct).mean() / all_masks.mean(), correct_sequences, total_sequences
185 |
186 |
187 | if __name__ == "__main__":
188 | args = parse_args()
189 | params = json.load(open(args.config))
190 |
191 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
192 | per_replica_batch = params["per_replica_batch"]
193 | cores_per_replica = params["cores_per_replica"]
194 |
195 | assert cores_per_replica <= 8
196 |
197 | bucket = params["bucket"]
198 | model_dir = params["model_dir"]
199 | layers = params["layers"]
200 | d_model = params["d_model"]
201 | n_heads = params["n_heads"]
202 | n_vocab = params["n_vocab"]
203 | seq = params["seq"]
204 | norm = params["norm"]
205 |
206 | val_batches = params["val_batches"]
207 | val_every = params["val_every"]
208 | ckpt_every = params["ckpt_every"]
209 | keep_every = params["keep_every"]
210 | eval_tasks = params["eval_harness_tasks"]
211 | total_steps = params["total_steps"]
212 |
213 | pe = params["pe"]
214 | assert pe in ["fixed", "rotary", "t5"]
215 |
216 | warmup_steps = params["warmup_steps"]
217 | anneal_steps = params["anneal_steps"]
218 | lr = params["lr"]
219 | end_lr = params["end_lr"]
220 | weight_decay = params["weight_decay"]
221 |
222 | # alpha parameter for the exponential moving averages used to compute B_simple
223 | noise_scale_alpha = params.get("noise_scale_alpha", 0.01)
224 |
225 | scheduler = util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)
226 |
227 | opt = optax.chain(
228 | optax.scale(1 / gradient_accumulation_steps),
229 | clip_by_global_norm(1),
230 | optax.scale_by_adam(),
231 | additive_weight_decay(weight_decay),
232 | optax.scale(-1),
233 | optax.scale_by_schedule(scheduler)
234 | )
235 |
236 | params["optimizer"] = opt
237 |
238 | start = time.time()
239 | tpu_size = jax.device_count()
240 | if tpu_size < cores_per_replica:
241 | msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})"
242 | raise ValueError(msg)
243 | print(f"jax devices: {tpu_size}")
244 | print(f"jax runtime initialized in {time.time() - start:.06}s")
245 |
246 | mesh_shape = (tpu_size // cores_per_replica, cores_per_replica)
247 | devices = np.array(jax.devices()).reshape(mesh_shape)
248 |
249 | # pick initial ckpt - based on tuning vs train from scratch
250 |
251 | step = 0
252 | initial_ckpt_state_path = None
253 | train_loader = None
254 |
255 | if args.tune_model_path:
256 | print('`--tune_model_path` passed: we are beginning a fine-tuning run')
257 | fine_tuning = True
258 | initial_ckpt_state_path = args.tune_model_path
259 | else:
260 | print('`--tune_model_path` not passed: we are continuing a fine-tuning run from a checkpoint (or we are not fine-tuning)')
261 | fine_tuning = False
262 | initial_ckpt_model_dir = model_dir
263 | initial_ckpt_path = f"gs://{bucket}/{initial_ckpt_model_dir}"
264 | meta_path = f"{initial_ckpt_path}/meta.json"
265 |
266 | try:
267 | with open(meta_path, "r") as f:
268 | meta = json.load(f)
269 | ckpt_step = meta["checkpoints"][-1]
270 | initial_ckpt_state_path = f"{initial_ckpt_path}/step_{ckpt_step}/"
271 | print(f"state will be restored from checkpoint {ckpt_step}")
272 |
273 | step = ckpt_step
274 | train_loader = meta['aux'][str(ckpt_step)].get("train_loader", None)
275 | except NotFound:
276 | # no checkpoint, start at zero
277 | print(f"No checkpoint to load at {initial_ckpt_path}. Training from scratch.")
278 |
279 | if initial_ckpt_state_path:
280 | print(f"path to load checkpoint from: {initial_ckpt_state_path}")
281 | else:
282 | print("not loading from a checkpoint")
283 |
284 | # set up datasets
285 | print("setting up datasets")
286 |
287 | train_dataset = TFRecordNewInputs(f"data/{params['train_set']}",
288 | batch_size=(
289 | gradient_accumulation_steps,
290 | per_replica_batch * tpu_size // cores_per_replica),
291 | sample_size=params['seq'],
292 | restore_state=train_loader)
293 |
294 | global_val_batch = per_replica_batch * tpu_size // cores_per_replica
295 |
296 | val_sets = {}
297 |
298 | for k, v in params["val_set"].items():
299 | val_sets[k] = TFRecordNewInputs(
300 | f"data/{v}", batch_size=(global_val_batch,), sample_size=seq
301 | )
302 |
303 | # tok/sec metrics
304 | sequences_per_step = gradient_accumulation_steps * (per_replica_batch * tpu_size // cores_per_replica)
305 | tokens_per_step = params['seq'] * sequences_per_step
306 |
307 | # load + run
308 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
309 | print("initializing network")
310 | network = CausalTransformer(params)
311 |
312 | if initial_ckpt_state_path:
313 | print("loading network")
314 | if fine_tuning:
315 | # get the scheduler step stored in the just-initialized optimizer
316 | # should be zero
317 | init_sched_state = network.state["opt_state"][-1]
318 |
319 | start = time.time()
320 | network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1], load_opt=(not args.fresh_opt))
321 |
322 | if fine_tuning:
323 | # overwrite the loaded scheduler step with zeros
324 | # this makes fine-tuning use the lr schedule in
325 | network.state["opt_state"][-1] = init_sched_state
326 |
327 | print(f"network loaded in {time.time() - start:.06}s")
328 |
329 | print('compiling train fn')
330 | start = time.time()
331 | loss, last_loss, grad_norm, grad_norm_micro = train_step(
332 | network, train_dataset.get_samples()
333 | )
334 | step += 1
335 | print(f"Train fn compiled in {time.time() - start:.06}s")
336 |
337 | print('compiling eval fn')
338 | start = time.time()
339 | for val_set in val_sets.values():
340 | eval_step(network, val_set.get_samples())
341 | val_set.reset()
342 | print(f"Eval fn compiled in {time.time() - start:.06}s")
343 |
344 | project = params.get("wandb_project", "mesh-transformer-jax")
345 | wandb.init(project=project, name=params["name"], config=params)
346 |
347 | G_noise_avg = None
348 | S_noise_avg = None
349 |
350 | while True:
351 | if (step > 1) and ((step % ckpt_every == 1) or (step == total_steps)):
352 | print(f"saving a checkpoint for step {step}")
353 | save(network, step, bucket, model_dir,
354 | mp=cores_per_replica,
355 | aux={"train_loader": train_dataset.get_state()},
356 | delete_old=True,
357 | )
358 |
359 | if step == total_steps:
360 | print("training completed!")
361 | exit()
362 |
363 | start = time.time()
364 | loss, last_loss, grad_norm, grad_norm_micro = train_step(
365 | network, train_dataset.get_samples()
366 | )
367 | step += 1
368 |
369 | steps_per_sec = 1 / (time.time() - start)
370 | tokens_per_sec = tokens_per_step * steps_per_sec
371 | sequences_processed = sequences_per_step * step
372 | tokens_processed = tokens_per_step * step
373 |
374 | ### compute summary stats about the gradient
375 |
376 | # converts from grads-summed-over-microbatch (what `CasualTransformer.train` computes)
377 | # to grads-averaged-over-microbatch (what we want)
378 | #
379 | # (when taking gradient steps, the same conversion happens inside the optimizer
380 | # via optax.scale(1 / gradient_accumulation_steps))
381 | grad_norm = grad_norm / gradient_accumulation_steps
382 |
383 | # compute G_noise and S_noise
384 | # from "An Empirical Model of Large-Batch Training" Appendix A.1
385 | # here, B_big = gradient_accumulation_steps, and B_small = 1 for convenience
386 | gbsmall = grad_norm_micro ** 2
387 | gbbig = grad_norm ** 2
388 | G_noise = (gradient_accumulation_steps * gbbig - gbsmall) / (
389 | gradient_accumulation_steps - 1
390 | )
391 | S_noise = (gbsmall - gbbig) / (1 - 1 / gradient_accumulation_steps)
392 |
393 | noise_scale_stats = {
394 | "noise/G_noise": G_noise,
395 | "noise/S_noise": S_noise,
396 | }
397 |
398 | # heuristic to avoid reporting G_noise in very early training when gradients are large
399 | # (these take a long time to wash out of the moving average that defines B_simple)
400 | use_step_in_noise_avgs = gbbig < 2
401 |
402 | if use_step_in_noise_avgs:
403 | # compute moving averages of G_noise and S_noise, for B_simple
404 | if G_noise_avg is None:
405 | G_noise_avg = G_noise
406 | else:
407 | G_noise_avg = (1 - noise_scale_alpha) * G_noise_avg + noise_scale_alpha * G_noise
408 |
409 | if S_noise_avg is None:
410 | S_noise_avg = S_noise
411 | else:
412 | S_noise_avg = (1 - noise_scale_alpha) * S_noise_avg + noise_scale_alpha * S_noise
413 |
414 | B_simple = S_noise_avg / G_noise_avg
415 |
416 | noise_scale_stats.update(
417 | {
418 | "noise/G_noise_avg": G_noise_avg,
419 | "noise/S_noise_avg": S_noise_avg,
420 | "noise/B_simple": B_simple,
421 | }
422 | )
423 |
424 | wandb_stats = {
425 | "train/loss": loss,
426 | "train/last_loss": last_loss,
427 | "train/steps_per_sec": steps_per_sec,
428 | "train/tokens_per_sec": tokens_per_sec,
429 | "train/grad_norm": grad_norm,
430 | "train/learning_rate": float(scheduler(network.state["opt_state"][-1].count[0].item())),
431 | "sequences_processed": sequences_processed,
432 | "tokens_processed": tokens_processed,
433 | }
434 | wandb_stats.update(noise_scale_stats)
435 |
436 | wandb.log(wandb_stats, step)
437 |
--------------------------------------------------------------------------------
/docker/.env:
--------------------------------------------------------------------------------
1 | DRYRUN=false
2 | DOMAIN=yourdomain.com
3 | SUBDOMAIN=gptj.api
4 | NGINX_PATH=./nginx_proxyvm.conf
5 | CERTS_PATH=/path/to/certs/
6 | DNS_KEYS=/path/to/keys/
7 | MODEL_DIR=/your/path/to/model/step_383500
8 | TPU_NAME=YOUR_TPU_NAME
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Have tested with a custom Ubuntu-1804 / Python 3.7 / Tensorflow 2.5.0 Base Image
2 | # Not tested with this image.
3 | FROM tensorflow/tensorflow:2.5.0
4 | RUN apt update && \
5 | apt-get install git -y
6 |
7 | WORKDIR /app/
8 | COPY . /app/
9 | RUN git clone https://github.com/kingoflolz/mesh-transformer-jax && \
10 | pip install -r mesh-transformer-jax/requirements.txt && \
11 | pip install mesh-transformer-jax/ jax==0.2.12 && \
12 | pip install fastapi uvicorn requests aiofiles aiohttp && \
13 | ln -s /app/start.sh /start.sh
14 |
15 | ENV PYTHONPATH /app:/app/mesh-transformer-jax:/usr/local/bin/python3
16 | ENV PATH $PYTHONPATH:$PATH
17 | ENV TOKENIZERS_PARALLELISM=true
18 | EXPOSE 80
19 |
20 | CMD ["/start.sh"]
21 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Deploying GPT-J (TPU VM) in a Docker Container
2 |
3 | This PR enables you to run GPT-J for Inference in a Docker Container, and has run stable over the past week.
4 |
5 | The function currently requires 2 calls, the first is a POST call to /model/predict with the params
6 |
7 | ```
8 | context: str
9 | temp: Optional[float] = 1.0
10 | top_p: Optional[float] = 0.9
11 | top_k: Optional[int] = 50
12 | length: Optional[int] = 256
13 | ```
14 |
15 | This returns an ID (int) that is used to retrieve the prediction once the prediction is complete, rather than waiting until the prediction is done.
16 |
17 | The second function call is a POST call to /model/get_prediction with `qid=id` that was returned.
18 |
19 | ## Getting Started
20 |
21 | You will need the following:
22 | 1) TPU VM: A TPU (v2-8 works fine)
23 | 2) Proxy VM: a second VM within the same zone. This VM will only serve to proxy all requests to the TPUVM, so it does not require a lot of resources.
24 | 3) [Docker + Docker Compose](https://docs.docker.com/compose/install/) Installed on both VMs
25 | 4) TPU VM and Proxy VM on the same GCP Network (default for most)
26 | 5) Ports 80 and 443 exposed on the Proxy VM
27 | 6) Your DNS records pointing to the External IP for Proxy VM.
28 |
29 | The GPTJ Checkpoint should be downloaded locally and the volume should be set as the variable MODEL_DIR in the .env file. This mounts the volume to the container, preventing multiple re-downloads.
30 |
31 | ## Files to Modify
32 |
33 | `.env`
34 |
35 | ```bash
36 | # Modify these values in .env
37 | DRYRUN=false # Set to True to prevent SSL Lets Encrypt Domain validation when testing
38 | DOMAIN=yourdomain.com # Set to your domain
39 | SUBDOMAIN=gptj.api # Set to your subdomain. The endpoint will be gptj.api.yourdomain.com
40 | NGINX_PATH=./nginx_proxyvm.conf # Modify this conf file with the Internal IP of the TPU VM. This value will not change.
41 | CERTS_PATH=/path/to/certs/ # Set this to a path on the Proxy VM that be used to store SSL Certs. This path will automatically be created.
42 | DNS_KEYS=/path/to/keys/ # Set this to a path on the Proxy VM that be used to store DNS Keys. This path will automatically be created.
43 | MODEL_DIR=/your/path/to/model/step_383500 # Set this to the path on TPU VM that the model is stored in
44 | TPU_NAME=YOUR_TPU_NAME # Set this to your TPU Name as created
45 | ```
46 |
47 | `nginx_proxyvm.conf`
48 |
49 | Modify the IP to your TPU VM's *Internal* IP Address
50 |
51 |
52 | ## Running the Docker Container
53 |
54 | Assuming you are in this directory as the working directory.
55 |
56 | On TPU VM:
57 |
58 | `docker-compose up -d`
59 |
60 | On Proxy VM:
61 |
62 | `docker-compose -f compose-proxy.yaml up`
63 |
64 | ## Final Notes
65 |
66 | The reason you can't directly serve from the TPU VM is because there is no control over the TPU VM's port firewalls. I go into more details about [Dockerization within TPU VMs here](https://trisongz.medium.com/accessing-your-tpus-in-docker-containers-with-tpu-vm-e944f5909dd4).
67 |
68 | The API is _very_ barebones and stripped down for simplicity. It's meant as a starting point and can be further extended for your needs.
69 |
70 |
--------------------------------------------------------------------------------
/docker/compose-proxy.yaml:
--------------------------------------------------------------------------------
1 | version: "3.3"
2 | services:
3 | gptjapi:
4 | image: ghcr.io/linuxserver/swag
5 | container_name: gptjapi
6 | cap_add:
7 | - NET_ADMIN
8 | environment:
9 | - PUID=1000
10 | - PGID=1000
11 | - TZ=America/Chicago
12 | - URL=${DOMAIN}
13 | - SUBDOMAINS=${SUBDOMAIN},
14 | - VALIDATION=http
15 | - STAGING=${DRYRUN}
16 | - ONLY_SUBDOMAINS=true
17 | volumes:
18 | - ${NGINX_PATH}:/config/nginx/proxy-confs/app.subdomain.conf
19 | - ${CERTS_PATH}:/config/etc/letsencrypt/
20 | - ${DNS_KEYS}:/config/keys/
21 | ports:
22 | - 443:443
23 | - 80:80
24 | restart: unless-stopped
25 | networks:
26 | - production
27 | networks:
28 | production:
29 |
--------------------------------------------------------------------------------
/docker/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: "3.3"
2 | services:
3 | gptjserver:
4 | image: gptjserver
5 | container_name: gptjserver
6 | cap_add:
7 | - ALL
8 | environment:
9 | - TPU_NAME=${TPU_NAME}
10 | - TF_CPP_MIN_LOG_LEVEL=0
11 | - XRT_TPU_CONFIG="localservice;0;localhost:51011"
12 | - TF_XLA_FLAGS=--tf_xla_enable_xla_devices
13 | build:
14 | context: .
15 | dockerfile: Dockerfile
16 | ports:
17 | - 8080:80
18 | volumes:
19 | - ${MODEL_DIR}:/app/model
20 | - /var/run/docker.sock:/var/run/docker.sock
21 | - /usr/share/tpu/:/usr/share/tpu/
22 | - /lib/libtpu.so:/lib/libtpu.so
23 | privileged: true
24 | restart: unless-stopped
25 | devices:
26 | - "/dev:/dev"
27 | networks:
28 | - production
29 | networks:
30 | production:
31 |
--------------------------------------------------------------------------------
/docker/main.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import uvicorn
4 | import traceback
5 | import logging
6 |
7 | from fastapi import FastAPI
8 | from starlette.middleware.cors import CORSMiddleware
9 | from .payloads import CompletionPayload, CompletionResponse, QueueRequest, QueueResponse
10 | from .ops import get_gptj_model
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | app = FastAPI()
15 | app.add_middleware(
16 | CORSMiddleware,
17 | allow_origins=["*"],
18 | allow_credentials=True,
19 | allow_methods=["*"],
20 | allow_headers=["*"],
21 | )
22 |
23 |
24 | # globals
25 | MODEL_API = None
26 |
27 |
28 | @app.on_event("startup")
29 | async def startup_event():
30 | global MODEL_API
31 | try:
32 | MODEL_API = get_gptj_model()
33 | MODEL_API.load_model()
34 | MODEL_API.start_background()
35 | except Exception as e:
36 | logger.debug(f"Model could not be loaded: {str(e)}")
37 | traceback.print_exc()
38 |
39 |
40 | @app.post("/model/get_prediction")
41 | def get_prediction(payload: QueueRequest) -> CompletionResponse:
42 | res = MODEL_API.wait_for_queue(payload.qid)
43 | return CompletionResponse(**res)
44 |
45 |
46 | @app.post("/model/predict")
47 | def model_prediction(payload: CompletionPayload) -> QueueResponse:
48 | res = MODEL_API.add_to_queue(payload)
49 | return QueueResponse(qid=res['qid'])
50 |
51 |
52 | if __name__ == "__main__":
53 | uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
54 |
--------------------------------------------------------------------------------
/docker/nginx_proxyvm.conf:
--------------------------------------------------------------------------------
1 | server {
2 | listen 443;
3 | server_name gptj.api.yourdomain.com;
4 | include /config/nginx/ssl.conf;
5 | client_max_body_size 0;
6 | location / {
7 | include /config/nginx/proxy.conf;
8 | # The below IP should be your TPUVM Internal IP
9 | proxy_pass http://10.000.0.1:8080;
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/docker/ops.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import time
4 | import jax
5 | import optax
6 | import threading
7 | import numpy as np
8 | import logging
9 | from queue import Queue, Empty
10 | from jax.experimental import maps
11 | from transformers import GPT2TokenizerFast
12 |
13 | from mesh_transformer.checkpoint import read_ckpt
14 | from mesh_transformer.sampling import nucleaus_sample
15 | from mesh_transformer.transformer_shard import CausalTransformer
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | # This prevents fastapi from creating multiple models in multi-worker mode
20 | # leading to OOM crashes
21 | gptj_model = None
22 | gptj_model_lock = threading.Lock()
23 |
24 | def compile_model():
25 | global gptj_model
26 | with gptj_model_lock:
27 | if gptj_model:
28 | return
29 | gptj_model = GPTJ()
30 |
31 | def get_gptj_model():
32 | compile_model()
33 | return gptj_model
34 |
35 | def timer(start_time=None):
36 | if not start_time:
37 | return time.time()
38 | return time.time() - start_time
39 |
40 |
41 | class GPTJ:
42 | def __init__(self):
43 | self.params = {
44 | "layers": 28,
45 | "d_model": 4096,
46 | "n_heads": 16,
47 | "n_vocab": 50400,
48 | "norm": "layernorm",
49 | "pe": "rotary",
50 | "pe_rotary_dims": 64,
51 | "seq": 2048,
52 | "cores_per_replica": 8,
53 | "per_replica_batch": 1,
54 | "sampler": nucleaus_sample,
55 | "optimizer": optax.scale(0)
56 | }
57 | self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
58 | self.queue_ids = {}
59 | self.qidx = 0
60 | self.queue = Queue()
61 | self.network = None
62 | self.lock = threading.Lock()
63 | self._alive_time = timer()
64 |
65 | def load_model(self):
66 | if self.network:
67 | logger.info('Attempting to reload model when model is loaded. Returning')
68 | return
69 | with self.lock:
70 | logger.info('Loading Model')
71 | start = timer()
72 | logger.info(f"JAX Devices: {jax.device_count()}")
73 | logger.info(f"JAX Runtime Initialized in {timer(start):.06} secs")
74 | mesh_shape = (jax.device_count() // self.params['cores_per_replica'], self.params['cores_per_replica'])
75 | self.devices = np.array(jax.devices()).reshape(mesh_shape)
76 | self.total_batch = self.params['per_replica_batch'] * jax.device_count() // self.params['cores_per_replica'] * 8
77 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(self.devices, ('dp', 'mp')))
78 | network = CausalTransformer(self.params)
79 | logger.info(f'Loading Checkpoint')
80 | network.state = read_ckpt(network.state, "/app/model/", self.devices.shape[1])
81 | logger.info(f"GPTJ Network loaded in {timer(start):.06} secs. Total Batch Size: {self.total_batch}")
82 | del network.state["opt_state"]
83 | network.state = network.move_xmap(network.state, np.zeros(self.params['cores_per_replica']))
84 | self.network = network
85 |
86 |
87 | def start_background(self):
88 | with self.lock:
89 | t = threading.Thread(target=self.background)
90 | t.start()
91 |
92 | def prepare_item(self, context, length=256):
93 | tokens = self.tokenizer.encode(context)
94 | logger.info(tokens)
95 | token_length = len(tokens)
96 | pad_amount = self.params['seq'] - token_length
97 | pad_amount = max(pad_amount, 0)
98 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)[-self.params['seq']:]
99 | return {'tokens': padded_tokens, 'length': token_length}
100 |
101 | # Single Item - Not Tested
102 | def infer(self, context, top_p=0.9, top_k=40, temp=1.0, length=256, **kwargs):
103 | item = self.prepare_item(context, length)
104 | batched_tokens = np.array([item['tokens']] * self.total_batch)
105 | batched_lengths = np.array([item['length']] * self.total_batch)
106 | start = timer()
107 | output = self.network.generate(
108 | batched_tokens, batched_lengths, length,
109 | {
110 | "top_p": np.ones(self.total_batch) * top_p,
111 | "top_k": np.ones(self.total_batch) * top_k,
112 | "temp": np.ones(self.total_batch) * temp,
113 | }
114 | )
115 | samples = []
116 | decoded_tokens = output[1][0]
117 | end_time = timer(start)
118 | for o in decoded_tokens[:, :, 0]:
119 | res = {
120 | 'context': context,
121 | 'completion': self.tokenizer.decode(o),
122 | 'time': end_time
123 | }
124 | samples.append(res)
125 | logger.info(f"Completion done in {end_time:06} secs")
126 | return samples
127 |
128 | def infer_batch(self, batch, **kwargs):
129 | logger.info(f'Starting Inference on Batch')
130 | batch_items = {'tokens': [], 'lengths': [], 'top_p': [], 'top_k': [], 'temp': []}
131 | max_lengths, contexts = [], []
132 | for req in batch:
133 | req = self.to_data(req)
134 | item = self.prepare_item(req['context'], req['length'])
135 | batch_items['tokens'].append(item['tokens'])
136 | batch_items['lengths'].append(item['length'])
137 | batch_items['top_p'].append(req['top_p'])
138 | batch_items['top_k'].append(req['top_k'])
139 | batch_items['temp'].append(req['temp'])
140 | max_lengths.append(req['length'])
141 | contexts.append(req['context'])
142 |
143 | max_length = max(max_lengths)
144 | for key, vals in batch_items.items():
145 | batch_items[key] = np.array(vals)
146 | start = timer()
147 | logger.info(f'Completed Preparing Batch')
148 | output = self.network.generate(
149 | batch_items['tokens'], batch_items['lengths'], max_length,
150 | {
151 | "top_p": batch_items['top_p'],
152 | "top_k": batch_items['top_k'],
153 | "temp": batch_items['temp'],
154 | }
155 | )
156 | logger.info(f'Completed Generation')
157 | samples = []
158 | end_time = timer(start)
159 | for pred, ctx in zip(output[1][0][:, :, 0], contexts):
160 | res = {
161 | 'context': ctx,
162 | 'completion': self.tokenizer.decode(pred),
163 | 'time': end_time
164 | }
165 | samples.append(res)
166 | logger.info(f"Completion done in {end_time:06} secs")
167 | return samples
168 |
169 | def add_to_queue(self, item):
170 | self.qidx += 1
171 | self.queue.put({'item': self.to_data(item), 'qidx': self.qidx})
172 | self.queue_ids[self.qidx] = Queue()
173 | return {'qid': self.qidx}
174 |
175 | def wait_for_queue(self, qid):
176 | if not self.queue_ids.get(qid):
177 | return {'Error': 'QID not found'}
178 | return self.queue_ids[qid].get()
179 |
180 | def background(self):
181 | logger.info(f'Init Background')
182 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(self.devices, ('dp', 'mp')))
183 | while True:
184 | batch, qids = [], []
185 | while len(batch) <= self.total_batch:
186 | try:
187 | req = self.queue.get(block=False)
188 | logger.info(f'Got Queue Item: {req}')
189 | batch.append(req['item'])
190 | qids.append(req['qidx'])
191 |
192 | except Empty:
193 | if len(batch):
194 | break
195 | else:
196 | time.sleep(0.01)
197 | batch_size = len(batch)
198 | logger.info(f'Working on Batch: {batch_size} - {qids}')
199 | while len(batch) < self.total_batch:
200 | batch.append(self.placeholder_item)
201 | start = timer()
202 | results = self.infer_batch(batch)
203 | for res, qid in zip(results, qids):
204 | self.queue_ids[qid].put(res)
205 | logger.info(f'Completed Current Batch of {batch_size} Items in {timer(start):.2f} secs')
206 |
207 | @property
208 | def placeholder_item(self):
209 | return {'context': 'nada', 'top_p': 0.9, 'top_k': 40, 'temp': 1.0, 'length': 1}
210 |
211 | def to_data(self, item):
212 | try:
213 | return {'context': item.context, 'top_p': item.top_p, 'top_k': item.top_k, 'temp': item.temp, 'length': item.length}
214 | except:
215 | return {'context': item.get('context', ''), 'top_p': item.get('top_p', 0.9), 'top_k': item.get('top_k', 40), 'temp': item.get('temp', 1.0), 'length': item.get('length', 256)}
216 |
217 | @property
218 | def alive_time(self):
219 | return timer(self._alive_time)
--------------------------------------------------------------------------------
/docker/payloads.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import Dict, Optional, Any, List
3 |
4 |
5 | class ModelPayload(BaseModel):
6 | inputs: Optional[Any] = None
7 | params: Optional[Dict]= {}
8 |
9 | class CompletionPayload(BaseModel):
10 | context: str
11 | temp: Optional[float] = 1.0
12 | top_p: Optional[float] = 0.9
13 | top_k: Optional[int] = 50
14 | length: Optional[int] = 256
15 |
16 | class CompletionResponse(BaseModel):
17 | context: str
18 | completion: str
19 | time: float
20 |
21 | class QueueResponse(BaseModel):
22 | qid: int
23 |
24 | class QueueRequest(BaseModel):
25 | qid: int
26 |
--------------------------------------------------------------------------------
/docker/start.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env sh
2 | set -e
3 |
4 | if [ -f /app/app/main.py ]; then
5 | DEFAULT_MODULE_NAME=app.main
6 | elif [ -f /app/main.py ]; then
7 | DEFAULT_MODULE_NAME=main
8 | fi
9 |
10 | MODULE_NAME=${MODULE_NAME:-$DEFAULT_MODULE_NAME}
11 | VARIABLE_NAME=${VARIABLE_NAME:-app}
12 | export APP_MODULE=${APP_MODULE:-"$MODULE_NAME:$VARIABLE_NAME"}
13 |
14 | HOST=${HOST:-0.0.0.0}
15 | PORT=${PORT:-80}
16 | LOG_LEVEL=${LOG_LEVEL:-info}
17 |
18 | # If there's a prestart.sh script in the /app directory or other path specified, run it before starting
19 | PRE_START_PATH=${PRE_START_PATH:-/app/prestart.sh}
20 | echo "Checking for script in $PRE_START_PATH"
21 | if [ -f $PRE_START_PATH ] ; then
22 | echo "Running script $PRE_START_PATH"
23 | . "$PRE_START_PATH"
24 | else
25 | echo "There is no script $PRE_START_PATH"
26 | fi
27 |
28 | # Start Uvicorn with live reload
29 | exec uvicorn --reload --reload-dir /app/mesh-transformer-jax --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" --access-log
--------------------------------------------------------------------------------
/eval_harness.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | from lm_eval import evaluator, tasks
5 |
6 | from mesh_transformer.build_model import build_model
7 | from tasks import EvalHarnessAdaptor
8 |
9 |
10 | def parse_args():
11 | # Parse command line arguments
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--tpu", type=str, help="Name of TPU to train on.")
14 | parser.add_argument("--tpu_region", type=str, help="Region of TPU to train on.")
15 | parser.add_argument("--preemptible", action="store_true")
16 |
17 | parser.add_argument("--config", type=str, default=None, help="Config file location")
18 |
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | if __name__ == "__main__":
24 | args = parse_args()
25 | params = json.load(open(args.config))
26 |
27 | tpu_name = args.tpu
28 | region = args.tpu_region
29 | preemptible = args.preemptible
30 |
31 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
32 | per_replica_batch = params["per_replica_batch"]
33 | tpu_size = params["tpu_size"]
34 | cores_per_replica = params["cores_per_replica"]
35 |
36 | bucket = params["bucket"]
37 | model_dir = params["model_dir"]
38 | layers = params["layers"]
39 | d_model = params["d_model"]
40 | n_heads = params["n_heads"]
41 | n_vocab = params["n_vocab"]
42 | seq = params["seq"]
43 | norm = params["norm"]
44 | pe = params["pe"]
45 |
46 | total_batch = per_replica_batch * tpu_size // cores_per_replica * 4
47 |
48 | t = build_model(params, tpu_name, region, preemptible)
49 | adaptor = EvalHarnessAdaptor(t, seq, total_batch, shrink=pe != "fixed")
50 |
51 | step, aux = t.load(bucket, model_dir)
52 | t.move()
53 |
54 | results = evaluator.evaluate(adaptor, tasks.get_task_dict(["lambada",
55 | "piqa",
56 | "hellaswag",
57 | "winogrande",
58 | "mathqa",
59 | "pubmedqa",
60 | # "boolq",
61 | # "cb",
62 | # "copa",
63 | # "multirc",
64 | # "record",
65 | # "wic",
66 | # "wsc",
67 | ]), False, 0, None)
68 | dumped = json.dumps(results, indent=2)
69 | print(dumped)
70 |
71 | results = evaluator.evaluate(adaptor, tasks.get_task_dict(["lambada_cloze",
72 | ]), False, 15, None)
73 |
74 | dumped = json.dumps(results, indent=2)
75 | print(dumped)
--------------------------------------------------------------------------------
/gsm/prompts.txt:
--------------------------------------------------------------------------------
1 | Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
2 | A: Natalia sold 48/2 = <<48/2=24>>24 clips in May.
3 | Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
4 | #### 72
5 |
6 | Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?
7 | A: In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50.
8 | Betty's grandparents gave her 15 * 2 = $<<15*2=30>>30.
9 | This means, Betty needs 100 - 50 - 30 - 15 = $<<100-50-30-15=5>>5 more.
10 | #### 5
11 |
12 | Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?
13 | A: Maila read 12 x 2 = <<12*2=24>>24 pages today.
14 | So she was able to read a total of 12 + 24 = <<12+24=36>>36 pages since yesterday.
15 | There are 120 - 36 = <<120-36=84>>84 pages left to be read.
16 | Since she wants to read half of the remaining pages tomorrow, then she should read 84/2 = <<84/2=42>>42 pages.
17 | #### 42
18 |
19 | Q: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?
20 | A: There are 80/100 * 10 = <<80/100*10=8>>8 more purple flowers than yellow flowers.
21 | So in Mark's garden, there are 10 + 8 = <<10+8=18>>18 purple flowers.
22 | Purple and yellow flowers sum up to 10 + 18 = <<10+18=28>>28 flowers.
23 | That means in Mark's garden there are 25/100 * 28 = <<25/100*28=7>>7 green flowers.
24 | So in total Mark has 28 + 7 = <<28+7=35>>35 plants in his garden.
25 | #### 35
26 |
27 | Q: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes?
28 | A: Let S be the amount Alexis paid for the shoes.
29 | She spent S + 30 + 46 + 38 + 11 + 18 = S + <<+30+46+38+11+18=143>>143.
30 | She used all but $16 of her budget, so S + 143 = 200 - 16 = 184.
31 | Thus, Alexis paid S = 184 - 143 = $<<184-143=41>>41 for the shoes.
32 | #### 41
33 |
34 | Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?
35 | A: She works 8 hours a day for $18 per hour so she makes 8*18 = $<<8*18=144.00>>144.00 per 8-hour shift
36 | She works 10 hours a day and anything over 8 hours is eligible for overtime, so she gets 10-8 = <<10-8=2>>2 hours of overtime
37 | Overtime is calculated as time and a half so and she makes $18/hour so her overtime pay is 18*.5 = $<<18*.5=9.00>>9.00
38 | Her overtime pay is 18+9 = $<<18+9=27.00>>27.00
39 | Her base pay is $144.00 per 8-hour shift and she works 5 days and makes 5 * $144 = $<<144*5=720.00>>720.00
40 | Her overtime pay is $27.00 per hour and she works 2 hours of overtime per day and makes 27*2 = $<<27*2=54.00>>54.00 in overtime pay
41 | 2 hours of overtime pay for 5 days means she makes 54*5 = $270.00
42 | In 5 days her base pay is $720.00 and she makes $270.00 in overtime pay so she makes $720 + $270 = $<<720+270=990.00>>990.00
43 | #### 990
--------------------------------------------------------------------------------
/gsm/prompts_answer_key.txt:
--------------------------------------------------------------------------------
1 | Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? (72)
2 | A: Natalia sold 48/2 = <<48/2=24>>24 clips in May.
3 | Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
4 | #### 72
5 |
6 | Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet? (5)
7 | A: In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50.
8 | Betty's grandparents gave her 15 * 2 = $<<15*2=30>>30.
9 | This means, Betty needs 100 - 50 - 30 - 15 = $<<100-50-30-15=5>>5 more.
10 | #### 5
11 |
12 | Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read? (42)
13 | A: Maila read 12 x 2 = <<12*2=24>>24 pages today.
14 | So she was able to read a total of 12 + 24 = <<12+24=36>>36 pages since yesterday.
15 | There are 120 - 36 = <<120-36=84>>84 pages left to be read.
16 | Since she wants to read half of the remaining pages tomorrow, then she should read 84/2 = <<84/2=42>>42 pages.
17 | #### 42
18 |
19 | Q: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden? (35)
20 | A: There are 80/100 * 10 = <<80/100*10=8>>8 more purple flowers than yellow flowers.
21 | So in Mark's garden, there are 10 + 8 = <<10+8=18>>18 purple flowers.
22 | Purple and yellow flowers sum up to 10 + 18 = <<10+18=28>>28 flowers.
23 | That means in Mark's garden there are 25/100 * 28 = <<25/100*28=7>>7 green flowers.
24 | So in total Mark has 28 + 7 = <<28+7=35>>35 plants in his garden.
25 | #### 35
26 |
27 | Q: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes? (41)
28 | A: Let S be the amount Alexis paid for the shoes.
29 | She spent S + 30 + 46 + 38 + 11 + 18 = S + <<+30+46+38+11+18=143>>143.
30 | She used all but $16 of her budget, so S + 143 = 200 - 16 = 184.
31 | Thus, Alexis paid S = 184 - 143 = $<<184-143=41>>41 for the shoes.
32 | #### 41
33 |
34 | Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make? (990)
35 | A: She works 8 hours a day for $18 per hour so she makes 8*18 = $<<8*18=144.00>>144.00 per 8-hour shift
36 | She works 10 hours a day and anything over 8 hours is eligible for overtime, so she gets 10-8 = <<10-8=2>>2 hours of overtime
37 | Overtime is calculated as time and a half so and she makes $18/hour so her overtime pay is 18*.5 = $<<18*.5=9.00>>9.00
38 | Her overtime pay is 18+9 = $<<18+9=27.00>>27.00
39 | Her base pay is $144.00 per 8-hour shift and she works 5 days and makes 5 * $144 = $<<144*5=720.00>>720.00
40 | Her overtime pay is $27.00 per hour and she works 2 hours of overtime per day and makes 27*2 = $<<27*2=54.00>>54.00 in overtime pay
41 | 2 hours of overtime pay for 5 days means she makes 54*5 = $270.00
42 | In 5 days her base pay is $720.00 and she makes $270.00 in overtime pay so she makes $720 + $270 = $<<720+270=990.00>>990.00
43 | #### 990
--------------------------------------------------------------------------------
/gsm/prompts_direct_answer_key.txt:
--------------------------------------------------------------------------------
1 | Q: What do people use to absorb extra ink from a fountain pen? blotter
2 | A: Therefore, the answer is blotter.
3 |
4 | Q: What home entertainment equipment requires cable? television
5 | A: Therefore, the answer is television.
6 |
7 | Q: The fox walked from the city into the forest, what was it looking for? natural habitat
8 | A: Therefore, the answer is natural habitat.
9 |
10 | Q: Sammy wanted to go to where the people were. Where might he go? populated areas
11 | A: Therefore, the answer is populated areas.
12 |
13 | Q: Where do you put your grapes just before checking out? grocery cart
14 | A: Therefore, the answer is grocery cart.
15 |
16 | Q: Google Maps and other highway and street GPS services have replaced what? atlas
17 | A: Therefore, the answer is atlas.
18 |
19 | Q: Before getting a divorce, what did the wife feel who was doing all the work? bitterness
20 | A: Therefore, the answer is bitterness.
--------------------------------------------------------------------------------
/howto_finetune.md:
--------------------------------------------------------------------------------
1 | # How to Fine-Tune GPT-J - The Basics
2 |
3 | Before anything else, you'll likely want to apply for access to the TPU Research Cloud (TRC). Combined with a Google Cloud free trial, that should allow you to do everything here for free. Once you're in TRC, you need to create a project, then with the name of the new project fill out the form that was emailed to you. Use the script `create_finetune_tfrecords.py` to prepare your data as tfrecords; I might do a separate guide on that. Another thing you might want to do is fork the mesh-transformer-jax repo to make it easier to add and modify the config files.
4 |
5 | 0. [Install the Google Cloud SDK](https://cloud.google.com/sdk/docs/install). We'll need it later.
6 |
7 | 1. If you didn't make a project and activate TPU access through TRC yet (or if you plan on paying out of pocket), [make one now](https://console.cloud.google.com/projectcreate).
8 |
9 | 2. TPUs use Google Cloud buckets for storage, go ahead and [create one now](https://console.cloud.google.com/storage/create-bucket). Make sure it's in the region the TPU VM will be; the email from TRC will tell you which region(s) you can use free TPUs in.
10 |
11 | 3. You'll need the full pretrained weights in order to fine-tune the model. [Download those here](https://the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd).
12 |
13 | Now that you have a bucket on the cloud and the weights on your PC, you need to upload the weights to the bucket in two steps:
14 |
15 | 4. Decompress and extract `GPT-J-6B/step_383500.tar.zstd` so you're left with the uncompressed folder containing the sharded checkpoint.
16 |
17 | 5. Open the Google Cloud SDK and run the following command, replacing the path names as appropriate: `gsutil -m cp -R LOCAL_PATH_TO/step_383500 gs://YOUR-BUCKET`. If that works, the console will show the files being uploaded. *Note: Took about 12 hours for me, uploading to the Netherlands from California; hopefully you'll have a better geographic situation than I did! I also initially made the mistake of uploading the still-packed .tar. Don't do that, TPU VMs don't have enough local storage for you to unpack it. To avoid needing to re-upload, I had to unpack it in Colab.*
18 |
19 | You'll want to upload tfrecords of your data as well, you can do that here or through the web interface, but trust me when I say you don't want to upload the nearly 70GB weights through the web interface.
20 |
21 | Note that steps 6 and 7, preparing the index and config files, can be done later on by editing the base repo in the VM's text editor. It's more efficient to instead make these changes to your own fork of the repo as follows:
22 |
23 | 6. In the data folder, create a new file `foo.train.index`, replace foo with whatever you want to refer to your dataset as. For each tfrecord in your bucket that you intend to train with, add the path as a line in the index. Make `foo.val.index` and do the same for your validation dataset (if you have one). See the existing files for examples.
24 |
25 | 7. Duplicate the config file `6B_roto_256.json`, rename it to something appropriate for your project. Open it up and make these edits:
26 | - `tpu_size`: Change from `256` to `8`
27 | - `bucket`: Change to your bucket
28 | - `model_dir`: Change to the directory you'd like to save your checkpoints in
29 | - `train_set` and `val_set`: Change to the index files from the last step
30 | - `eval_harness_tasks`: Can be removed if you don't plan on using the eval harness
31 | - `val_every` & `ckpt_every` & `keep_every`: Usage should be intuitive. Don't set the `foo_every` values to 0 though or you'll get a divide by zero error. If you don't have a `val_set`, just set `val_every` to something higher than `total_steps`.
32 | - `val_batches`: This should equal the number of sequences in your val dataset. You can find this number at the end of the .tfrecords file produced by `create_finetune_tfrecords.py`
33 | - `name`: Change to a name for your model.
34 | - `warmup_steps`, `lr`, `val_batches`, etc.: see the *Learning Rate Notes* section at the end of the guide.
35 |
36 |
37 | 8. Push the changes to your GitHub repo.
38 |
39 | 9. Follow [this guide](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) up to and including the step **"Connect to your Cloud TPU VM"**.
40 |
41 | At this point you should have remote access to the TPU VM!
42 |
43 | 10. In the new VM terminal, type `git clone https://github.com/kingoflolz/mesh-transformer-jax` (or, preferably, your own fork, after pushing the config and index files)
44 |
45 | 11. Move to the new directory with `cd mesh-transformer-jax` and run `pip install -r requirements.txt`. Since the requirements.txt file doesn't pin the exact jax version required for finetuning, run `pip install jax==0.2.12` and you'll be all set.
46 |
47 | 12. Finally, run `python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/`. If everything is set up correctly this will begin the fine-tuning process. First the model has to be loaded into memory; when `loading network` displayed on the console it took about 10-15 minutes before the next step, setting up WandB for logging. Option 3 allows you to skip that if you aren't using WandB. A step 1 checkpoint will save, and the real training will start. If you have a small dataset, this will go by quickly; TPU VMs can train at a rate of ~5000 tokens/second.
48 |
49 | 13. You did it! Now don't forget any clean up steps you need to take like shutting down your TPU VM or removing unneeded data in buckets, so that you don't have any unexpected charges from Google later.
50 |
51 | ## Now what?
52 |
53 | This guide is labeled "The Basics", anything we haven't covered so far is out of scope, but go check out the rest of the repository! Try `python3 device_sample.py --config=configs/YOUR_CONFIG.json` for a basic sampling interface. Use `slim_model.py` to prepare an easier-to-deploy slim version of your new weights for inference. Experiment!
54 |
55 | ### Running with HuggingFace
56 | To use the model in HuggingFace's `transformer` library using pytorch, you'll need to transfer the weights
57 | into a format that it recognizes. This can be done using `to_hf_weights.py`. It's recommended that you use `slim_model.py` before attempting to move the weights to a pytorch/transformer format. Use `python to_hf_weights.py --help` to see usage details.
58 |
59 | *note: as of 9/1/2021, GPT-J has been merged into the `main` branch of `transformers` but has not yet been put into production. Run `pip install git+https://github.com/huggingface/transformers#transformers` to install the current `main` branch.
60 |
61 | ## Learning Rate Notes
62 |
63 | **Thanks to nostalgebraist for talking about this!** They're the one who explained this part on Discord, I'm just paraphrasing really:
64 |
65 | The first thing you want to determine is how long a training epoch will be. `gradient_accumulation_steps` is your batch size, it defaults to `16`, nostalgebraist recommends `32`. Your .tfrecord files should have a number in the file name indicating how many sequences are in the dataset. Divide that number by the batch size and the result is how many steps are in an epoch. Now we can write the schedule.
66 |
67 | `lr` is recommended to be between `1e-5` and `5e-5`, with `end_lr` set to 1/5 or 1/10 of `lr`.
68 | `weight_decay` can remain `0.1`. `total_steps` should be at least one epoch, possibly longer if you have a validation
69 | set to determine your training loss with.
70 | `warmup_steps` should be 5-10% of total, and finally `anneal_steps` should be `total_steps - warmup_steps`.
71 | (The `lr` is set to `end_lr` after `warmup_steps+anneal_steps` and then keeps training until `total_steps`,
72 | but usually you should stop after annealing is done)
73 |
74 | To illustrate: I have a small dataset that tokenized into 1147 sequences as a .tfrecord. Dividing by `gradient_accumulation_steps` set to `16`, rounding up to ensure I use all the data, equals 72 steps per epoch. I'll set `lr` to `5e-5`, `end_lr` to a fifth of that, `1e-5`; that may be too much, it's on the high end of the recommended range. I'll set `total_steps` to `72` for one epoch, since I don't have a validation set. Then I'll set `anneal_steps` to `65` and `warmup_steps` to `7`. Simple as that, but you may need to fiddle with the specifics on your own.
75 |
--------------------------------------------------------------------------------
/iteration_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import shutil
5 | import argparse
6 |
7 | def record_folder(cur_iter):
8 | return f"{task}/{experiment_name}/{experiment_name}_{cur_iter}"
9 |
10 |
11 | def parse_args():
12 | # Parse command line arguments
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--no_prompt', action='store_true', help="Whether to remove prompts during eval")
15 | parser.add_argument("--base_epochs", type=float, default=1., help="Epochs for the first iteration")
16 | parser.add_argument("--add_epochs", type=float, default=0.2, help="Epochs to add each iteration")
17 | parser.add_argument("--few_shot_train", action='store_true', help="Whether to use few shot training")
18 | parser.add_argument("--steady_grow", action='store_true', help="Whether to use a fixed number of epochs")
19 | parser.add_argument("--start_steps", type=float, default=40., help="Steps for the first iteration")
20 | parser.add_argument("--exponential_grow", action='store_true', help="Whether to use a fixed number of epochs")
21 | parser.add_argument("--add_steps", type=float, default=20., help="Steps to add each iteration")
22 | parser.add_argument("--grow_steps", type=float, default=1.2, help="Steps to add each iteration")
23 | parser.add_argument("--p_rationalization", type=float, default=1., help="Percent of wrong examples to rationalize")
24 | parser.add_argument("--p_show_hint_save", type=float, default=0., help="Percent of rationalization hints to save")
25 | parser.add_argument('--rationalize', action='store_true', help="Whether to use rationalization")
26 |
27 | parser.add_argument("--start_iter", type=int, default=1, help="Starting iteration")
28 | parser.add_argument("--n_iters", type=int, default=64, help="Upper limit on outer loop iterations")
29 | parser.add_argument("--copy_n", type=int, default=0, help="Number of files to copy each iteration")
30 | parser.add_argument("--n_train_samples", type=int, default=10000, help="Number of training examples")
31 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Batch size")
32 |
33 | parser.add_argument("--task", type=str, default="commonsenseqa", help="Whether to run on arithmetic")
34 | parser.add_argument('--direct', action='store_true', help="Whether to use direct prediction, sans scratchpad")
35 | parser.add_argument("--gen_length", type=int, default=96, help="Length of generated output")
36 | parser.add_argument("--sequence_count", type=int, default=10, help="Sequences per batch on average")
37 | parser.add_argument("--base_model_location", type=str, default="gs://checkpoint-bucket/step_383500/", help="Finetuning ckpt")
38 | parser.add_argument('--dry_run', action='store_true', help="Whether to do a quick run to visualize output")
39 | parser.add_argument('--skip_eval', action='store_true', help="Whether to skip evaluation (e.g. arithmetic)")
40 |
41 | args = parser.parse_args()
42 | return args
43 |
44 | def gen_train():
45 | train_cmd = f"python3 device_inference.py --config={prev_config} --split=train --gen_length={args.gen_length} --p_show_hint_save={args.p_show_hint_save} "
46 | if task != "commonsenseqa":
47 | train_cmd += f" --dataset_mode={task} "
48 | if args.rationalize:
49 | train_cmd += " --rationalize "
50 | if args.few_shot_train:
51 | train_cmd += " --few_shot_train "
52 | if cur_iter > 1 and args.no_prompt:
53 | train_cmd += f" --no_prompt --eval_seq {eval_seq} "
54 | train_cmd += f" --n_train_samples={args.n_train_samples} "
55 | train_cmd += f" >> result_logs/{experiment_name}.txt"
56 | print(f"Generating training set {cur_iter} using model {cur_iter - 1}: {train_cmd}")
57 | if not args.dry_run and (cur_iter >= args.start_iter):
58 | if (cur_iter == 1) and os.path.exists(record_folder(0) + f"/{experiment_name}_0.txt"):
59 | print("First file cached")
60 | else:
61 | os.system(train_cmd)
62 |
63 | def gen_records():
64 | gen_cmd = f'python3 create_finetune_tfrecords.py {record_folder(cur_iter - 1)} {record_folder(cur_iter - 1)}'
65 | print(f"Creating records for finetuning {cur_iter}: {gen_cmd}")
66 | if not args.dry_run and (cur_iter >= args.start_iter):
67 | os.system(gen_cmd)
68 | train_set = f"{experiment_name}/{exp_iteration}.index"
69 | with open(f"data/{train_set}", "w") as new_data_file:
70 | new_data_file.write(f"{record_folder(cur_iter - 1)}.tfrecords")
71 | return train_set
72 |
73 | def get_n_steps():
74 | if args.steady_grow:
75 | return int(args.start_steps + args.add_steps * (cur_iter - 1))
76 | elif args.exponential_grow:
77 | return int(args.start_steps * (args.grow_steps ** (cur_iter - 1)))
78 | else:
79 | # Count data points
80 | total_count = 0
81 | for cur_file in sorted(os.listdir(record_folder(cur_iter - 1)), key=lambda x: int(x.split('.')[0].split("_")[-1])):
82 | with open(f"{record_folder(cur_iter - 1)}/{cur_file}", encoding='utf-8') as train_file:
83 | train_file_text = train_file.read()
84 | total_count += len(train_file_text.split("\n\n"))
85 | print(len(train_file_text.split("\n\n")))
86 | train_epochs = args.base_epochs + args.add_epochs * (cur_iter - 1)
87 | cur_steps = int(total_count * train_epochs // (args.gradient_accumulation_steps * args.sequence_count))
88 | return cur_steps
89 |
90 | def gen_config(train_set):
91 | print(f"Creating new config file {cur_iter}")
92 | config_name = f'configs/{experiment_name}/{exp_iteration}.json'
93 | os.makedirs(record_folder(cur_iter), exist_ok=True)
94 | with open(prev_config, encoding='utf-8') as base_json_file:
95 | new_json = json.load(base_json_file)
96 | new_json["model_dir"] = f"strangeloop/{exp_iteration}"
97 | new_json["train_set"] = train_set
98 | new_json["target_save"] = record_folder(cur_iter) + f"/{exp_iteration}.txt"
99 | new_json["total_steps"] = get_n_steps()
100 | new_json["name"] = exp_iteration
101 | new_json["p_rationalization"] = args.p_rationalization
102 | new_json["gradient_accumulation_steps"] = args.gradient_accumulation_steps
103 | with open(config_name, "w", encoding='utf-8') as new_json_file:
104 | json.dump(new_json, new_json_file, indent=2)
105 | return config_name
106 |
107 | def train_model():
108 | model_cmd = f"python3 device_train.py --config {config_name} --tune-model-path={args.base_model_location}"
109 | print(f"Train model {cur_iter}: {model_cmd}")
110 | if not args.dry_run and (cur_iter >= args.start_iter):
111 | os.system(model_cmd)
112 |
113 | def eval_model():
114 | eval_cmd = f"python3 device_inference.py --config={config_name} --split=dev --gen_length={args.gen_length} --p_show_hint_save={args.p_show_hint_save} "
115 | if task != "commonsenseqa":
116 | eval_cmd += f" --dataset_mode={task} "
117 | if args.no_prompt:
118 | eval_cmd += f" --no_prompt --eval_seq {eval_seq} "
119 | if args.few_shot_train:
120 | eval_cmd += " --few_shot_train "
121 | eval_cmd += f" >> result_logs/{experiment_name}.txt"
122 | print(f"Eval model {cur_iter}: {eval_cmd}")
123 | if not args.dry_run and (cur_iter >= args.start_iter) and not args.skip_eval:
124 | os.system(eval_cmd)
125 |
126 | def copy_files():
127 | all_files = sorted(os.listdir(record_folder(cur_iter - 1)), key=lambda x: int(x.split('.')[0].split("_")[-1]))
128 | relevant_files = all_files[-args.copy_n:]
129 | for cur_file in relevant_files:
130 | shutil.copy(f"{record_folder(cur_iter - 1)}/{cur_file}", record_folder(cur_iter))
131 |
132 | def make_first_config():
133 | with open(prev_config, encoding='utf-8') as base_json_file:
134 | new_json = json.load(base_json_file)
135 | os.makedirs(record_folder(0), exist_ok=True)
136 | new_json["target_save"] = record_folder(0) + f"/{experiment_name}_0.txt"
137 | new_json["name"] = f"{experiment_name}_0"
138 | new_json["p_rationalization"] = args.p_rationalization
139 | new_json["gradient_accumulation_steps"] = args.gradient_accumulation_steps
140 | with open(prev_config, "w", encoding='utf-8') as base_json_file:
141 | json.dump(new_json, base_json_file, indent=2)
142 | return new_json
143 |
144 | if __name__ == "__main__":
145 | args = parse_args()
146 | print(args)
147 | task = args.task
148 | experiment_name = "_".join(sys.argv[1:])
149 | experiment_name = ''.join(ch for ch in experiment_name if ch.isalnum() or ch == "_")
150 | if args.no_prompt:
151 | eval_seq = 128 + args.gen_length
152 | os.makedirs(f"configs/{experiment_name}", exist_ok=True)
153 | shutil.copy(f"configs/qa_base.json", f"configs/{experiment_name}/base.json")
154 | prev_config = f"configs/{experiment_name}/base.json"
155 | new_json = make_first_config()
156 |
157 | os.makedirs(f'data/{experiment_name}', exist_ok=True)
158 | os.makedirs(f'{task}/{experiment_name}', exist_ok=True)
159 | os.makedirs(f'result_logs/', exist_ok=True)
160 | with open(f"result_logs/{experiment_name}.txt", "a+") as f:
161 | print("================================", file=f)
162 | print(args, file=f)
163 | for cur_iter in range(1, args.n_iters):
164 | exp_iteration = f"{experiment_name}_{cur_iter}"
165 | gen_train() # Generate the training set
166 | train_set = gen_records() # Create the tfrecords from the data
167 | config_name = gen_config(train_set) # Create the new configuration file
168 | train_model() # Train the new model
169 | eval_model() # Evaluate the new model
170 | prev_config = config_name # Prepare for next iteration
171 | if args.copy_n > 0:
172 | copy_files()
173 |
--------------------------------------------------------------------------------
/mesh_transformer/TPU_cluster.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import json
3 | import time
4 |
5 | import ray
6 |
7 | from typing import Callable
8 | import numpy as np
9 |
10 | from mesh_transformer.train_actor import NetworkRunner
11 | from google.cloud import storage
12 | from smart_open import open
13 | from func_timeout import func_set_timeout
14 |
15 |
16 | class TPUCluster:
17 | @func_set_timeout(1200)
18 | def __init__(self,
19 | mesh_shape,
20 | node_count,
21 | model: Callable,
22 | version=1):
23 | assert ray.is_initialized() # needs a valid ray cluster to start
24 | self.nodes = []
25 | self.node_count = node_count
26 | self.dp, self.mp = mesh_shape
27 | self.version = version
28 |
29 | start = time.time()
30 |
31 | for i in range(node_count):
32 | self.nodes.append(NetworkRunner.options(max_concurrency=2).remote(mesh_shape, model))
33 |
34 | for n in self.nodes:
35 | n.run.remote()
36 |
37 | params = []
38 | for n in self.nodes:
39 | params.append(n.get_params.remote())
40 |
41 | self.param_count = ray.get(params)[0]
42 | print(f"Ray actors created in {time.time() - start:.06}s")
43 |
44 | @func_set_timeout(600)
45 | def train(self, data):
46 | data_chunks = np.array_split(data, len(self.nodes), axis=1)
47 |
48 | res = []
49 | for n, d in zip(self.nodes, data_chunks):
50 | res.append(n.train.remote({
51 | "obs": d[:, :, :-1],
52 | "target": d[:, :, 1:],
53 | }))
54 |
55 | res = ray.get(res)
56 |
57 | loss = []
58 | last_loss = []
59 |
60 | for r in res:
61 | loss.append(r[0])
62 | last_loss.append(r[1])
63 |
64 | return np.array(loss).mean(), np.array(last_loss).mean()
65 |
66 | @func_set_timeout(600)
67 | def eval(self, data):
68 | if isinstance(data, dict):
69 | data_chunked = [{} for _ in self.nodes]
70 | for k, v in data.items():
71 | v_chunks = np.array_split(v, len(self.nodes), axis=0)
72 | for idx, v_chunk in enumerate(v_chunks):
73 | data_chunked[idx][k] = v_chunk
74 |
75 | res = []
76 | for n, d in zip(self.nodes, data_chunked):
77 | res.append(n.eval.remote(d))
78 |
79 | total = 0
80 | correct = 0
81 | last_correct = 0
82 |
83 | total_last_loss = 0
84 | mask_loss = []
85 | each_correct = []
86 |
87 | for input, output in zip(data_chunked, ray.get(res)):
88 | correct_and_valid = np.logical_and(output["correct"], input["eval_mask"])
89 |
90 | correct_tokens_count = np.sum(correct_and_valid, -1)
91 | valid_tokens_count = np.sum(input["eval_mask"], -1)
92 |
93 | correct_example = np.logical_and(valid_tokens_count == correct_tokens_count, valid_tokens_count > 0)
94 | valid_example = valid_tokens_count > 0
95 | last_correct_example = correct_and_valid[:, -1]
96 |
97 | each_correct += correct_example.tolist()
98 |
99 | total += sum(valid_example)
100 | correct += sum(correct_example)
101 | last_correct += sum(last_correct_example)
102 | total_last_loss += sum(valid_example * output["last_loss"])
103 |
104 | valid_loss = np.sum(output["all_loss"] * input["eval_mask"], -1)
105 | mask_loss += valid_loss.tolist()
106 |
107 | return {
108 | "total": total,
109 | "correct": correct,
110 | "last_correct": last_correct,
111 | "last_loss": total_last_loss,
112 | "mask_loss": np.array(mask_loss),
113 | "each_correct": np.array(each_correct)
114 | }
115 | else:
116 | data_chunks = np.array_split(data, len(self.nodes), axis=0)
117 |
118 | res = []
119 | for n, d in zip(self.nodes, data_chunks):
120 | res.append(n.eval.remote({
121 | "obs": d[:, :-1],
122 | "target": d[:, 1:],
123 | }))
124 |
125 | return np.array([i["loss"] for i in ray.get(res)]).mean()
126 |
127 | @func_set_timeout(600)
128 | def generate(self, context, ctx_length, gen_len):
129 | context = np.array_split(context, len(self.nodes), axis=0)
130 | ctx_length = np.array_split(ctx_length, len(self.nodes), axis=0)
131 |
132 | res = []
133 | for n, ctx, l in zip(self.nodes, context, ctx_length):
134 | res.append(n.generate.remote((
135 | ctx,
136 | np.ones(len(ctx), dtype=np.uint32) * l,
137 | gen_len
138 | )))
139 |
140 | return np.concatenate([i[1][0][:, :, 0] for i in ray.get(res)], axis=0)
141 |
142 | @func_set_timeout(600)
143 | def move(self):
144 | start = time.time()
145 | res = []
146 | for node in self.nodes:
147 | res.append(node.move_params.remote())
148 | ray.get(res)
149 |
150 | print(f"Moved weights to TPU in {time.time() - start:.06}s")
151 |
152 | @func_set_timeout(1800)
153 | def load(self, bucket, path):
154 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
155 | meta = json.load(f)
156 |
157 | ckpt_step = meta["checkpoints"][-1]
158 |
159 | # do replicated checkpoint reading
160 | start = time.time()
161 | res = []
162 | for node in self.nodes:
163 | res.append(node.load_ckpt.remote(f"gs://{bucket}/{path}/step_{ckpt_step}/"))
164 |
165 | # make sure they all read from the same checkpoint
166 | step = np.array(ray.get(res))
167 | assert (step[0] == step).all()
168 | step = int(step[0])
169 |
170 | print(f"Checkpoint@step{step} restored in {time.time() - start:.06}s")
171 | return step, meta["aux"][str(ckpt_step)]
172 |
173 | @func_set_timeout(600)
174 | def save(self, step, bucket, path, aux=None, init=False, overwrite=False, keep_n=3, delete_old=True):
175 | assert path
176 | client = storage.Client()
177 |
178 | if aux is None:
179 | aux = {}
180 |
181 | if init:
182 | # check existing checkpoint folder does not exist, and delete it if it does
183 | for blob in client.list_blobs(bucket, prefix=f"{path}/"):
184 | assert overwrite
185 | # print(f"deleting {blob.name}")
186 | assert path in blob.name
187 | blob.delete()
188 |
189 | # create metadata file
190 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
191 | json.dump({
192 | "step": 0,
193 | "checkpoints": [],
194 | "aux": {}
195 | }, f)
196 |
197 | # do sharded checkpoint writing
198 | start = time.time()
199 | res = []
200 |
201 | if self.version == 1:
202 | for shard_id, node in zip(range(self.mp), itertools.cycle(self.nodes)):
203 | res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}/", shard_id))
204 | elif self.version == 2:
205 | for node in self.nodes:
206 | res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}", 0))
207 |
208 | ray.get(res)
209 | print(f"Wrote checkpoint in {time.time() - start:.06}s")
210 |
211 | with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
212 | meta = json.load(f)
213 |
214 | meta["step"] = step
215 | meta["checkpoints"].append(step)
216 | all_aux = meta.get("aux", {})
217 |
218 | while len(meta["checkpoints"]) > keep_n:
219 | ckpt_to_delete = meta["checkpoints"].pop(0)
220 |
221 | try:
222 | del all_aux[str(ckpt_to_delete)]
223 | except:
224 | print(f"failed to delete the aux state for {step}")
225 |
226 | if delete_old:
227 | print(f"deleting checkpoint {ckpt_to_delete}")
228 | for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"):
229 | # print(f"deleting {blob.name}")
230 | assert path in blob.name
231 | blob.delete()
232 | else:
233 | print(f"keeping checkpoint {ckpt_to_delete}")
234 |
235 | all_aux[step] = aux
236 | meta["aux"] = all_aux
237 |
238 | with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
239 | json.dump(meta, f)
240 |
--------------------------------------------------------------------------------
/mesh_transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ezelikman/STaR/6ffd3c4743da6c7a7369834c434359d9f7e0ac38/mesh_transformer/__init__.py
--------------------------------------------------------------------------------
/mesh_transformer/build_model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import multiprocessing
3 |
4 | import optax
5 | import ray
6 |
7 | from mesh_transformer import util
8 | from mesh_transformer.TPU_cluster import TPUCluster
9 | from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerV2
10 | from mesh_transformer.util import clip_by_global_norm, additive_weight_decay
11 | from ray_tpu import create_tpu, wait_til, get_connection, start_ray
12 |
13 |
14 | def build_model(params, tpu_name, region, preemptible, version=1):
15 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
16 | cores_per_replica = params["cores_per_replica"]
17 | tpu_size = params["tpu_size"]
18 |
19 | warmup_steps = params["warmup_steps"]
20 | anneal_steps = params["anneal_steps"]
21 | lr = params["lr"]
22 | end_lr = params["end_lr"]
23 | weight_decay = params["weight_decay"]
24 |
25 | assert tpu_size in [8, 32, 128, 256, 512]
26 |
27 | create_tpu(tpu_name, region, f"v3-{tpu_size}", preemptible)
28 | assert wait_til(tpu_name, region, {'state': 'READY', 'health': 'HEALTHY'})
29 |
30 | conns = get_connection(tpu_name, region)
31 |
32 | assert len(conns) * 8 == tpu_size, "wrong size TPU for config"
33 |
34 | head_info = ray.init(include_dashboard=False, object_store_memory=10**9)
35 | address = head_info['redis_address']
36 |
37 | with multiprocessing.pool.ThreadPool(processes=len(conns)) as p:
38 | p.map(functools.partial(start_ray, address=address, version=version), conns)
39 |
40 | opt = optax.chain(
41 | optax.scale(1 / gradient_accumulation_steps),
42 | clip_by_global_norm(1, use_psum=(version == 1)),
43 | optax.scale_by_adam(),
44 | additive_weight_decay(weight_decay),
45 | optax.scale(-1),
46 | optax.scale_by_schedule(util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr))
47 | )
48 |
49 | params["optimizer"] = opt
50 |
51 | if version == 2:
52 | model_fn = functools.partial(CausalTransformerV2, params)
53 | elif version == 1:
54 | model_fn = functools.partial(CausalTransformer, params)
55 | else:
56 | raise Exception(f"Version {version} does not exist")
57 |
58 | t = TPUCluster((tpu_size // cores_per_replica, cores_per_replica), len(conns), model_fn, version=version)
59 | return t
60 |
--------------------------------------------------------------------------------
/mesh_transformer/checkpoint.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import io
3 | import json
4 | import time
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import numpy as np
9 | import multiprocessing
10 |
11 | import ray
12 | from smart_open import open
13 |
14 | from mesh_transformer.util import head_print
15 |
16 | pieces = 16 # how many files to split each shard across
17 |
18 |
19 | def fix_dtype(pytree):
20 | def fix(x):
21 | if x.dtype == np.dtype('V2'):
22 | x.dtype = jnp.bfloat16
23 | return jnp.asarray(x)
24 |
25 | return jax.tree_map(fix, pytree)
26 |
27 |
28 | @functools.partial(jax.jit, backend="cpu")
29 | def index_weights(weights, idx):
30 | cpu_device = jax.devices("cpu")[0]
31 | return jax.device_put(jax.tree_map(lambda i: i[idx], weights), cpu_device)
32 |
33 |
34 | def write(x, ckpt_dir):
35 | # start = time.time()
36 | idx, i = x
37 | file_path = ckpt_dir + f"{idx}.npz"
38 | for _ in range(3):
39 | try:
40 | with open(file_path, "wb") as f:
41 | np.savez(f, *i)
42 | # cloudpickle.dump(i, f)
43 | # print(f"written {idx} in {time.time() - start:.06}s")
44 | return
45 | except:
46 | print("save failed, trying again")
47 |
48 | print("save failed 3 times, exiting")
49 | raise Exception("save failed")
50 |
51 |
52 | def split(a, n):
53 | k, m = divmod(len(a), n)
54 | return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
55 |
56 |
57 | def write_ckpt(pytree, dir, shard):
58 | # ckpt_dir = Path(dir)
59 | # ckpt_dir.mkdir(parents=True, exist_ok=True)
60 |
61 | flattened, structure = jax.tree_flatten(pytree)
62 |
63 | start = time.time()
64 | # cpu_flattened = jax.device_put(flattened, cpu_device)
65 | cpu_flattened = index_weights(flattened, shard)
66 | # print(f"Moved indexed in {time.time() - start:.06}s")
67 |
68 | cpu_flattened_chunked = split(cpu_flattened, pieces)
69 |
70 | # start = time.time()
71 | # cpu_float = move_weights(cpu_flattened)
72 | # print(f"changed weight types in {time.time() - start:.06}s")
73 |
74 | with multiprocessing.pool.ThreadPool(pieces) as p:
75 | write_fn = functools.partial(write, ckpt_dir=f"{dir}shard_{shard}/")
76 |
77 | start = time.time()
78 | list((p.imap_unordered(write_fn, enumerate(cpu_flattened_chunked))))
79 | # print(f"written to gcs in {time.time() - start:.06}s")
80 |
81 |
82 | def read_shard(ckpt_dir):
83 | out = []
84 | for idx in range(16):
85 | file_path = ckpt_dir + f"{idx}.npz"
86 | with open(file_path, "rb") as f:
87 | buf = f.read()
88 | f_io = io.BytesIO(buf)
89 | deserialized = np.load(f_io)
90 | for i in deserialized:
91 | out.append(deserialized[i])
92 | return out
93 |
94 |
95 | def reshard(x, old_shape):
96 | if len(x.shape) == 1:
97 | # print("epoch")
98 | # print(x)
99 | out = x[0:1]
100 |
101 | elif len(x.shape) == 2:
102 | # print(f"LN/bias {x.shape}")
103 | # print(x[:, :16])
104 |
105 | if (x[1:] == x[-1]).all():
106 | # print("LN")
107 | if (x[1:] == 0).all() or (x[1:] == 1).all():
108 | out = x[0:1]
109 | else:
110 | # print("shard bias")
111 | out = x[0:1] * x.shape[0] / old_shape[0]
112 | else:
113 | # print("bias")
114 | out = x.reshape(old_shape)
115 |
116 | print(out[:, :16])
117 |
118 | elif len(x.shape) == 3:
119 | # print(f"weight {x.shape}")
120 | if x.shape[0] * x.shape[2] == old_shape[2]:
121 | # print("case 1")
122 | out = jnp.transpose(x, (1, 0, 2)).reshape(old_shape)
123 | elif x.shape[0] * x.shape[1] == old_shape[1]:
124 | # print("case 2")
125 | out = x.reshape(old_shape)
126 | else:
127 | raise Exception(f"unimplemented, {x.shape}, {old_shape}")
128 | else:
129 | raise Exception(f"unimplemented, {x}")
130 |
131 | return out
132 |
133 |
134 | def read_ckpt(pytree, dir, shards_in, shards_out=None, load_opt=True):
135 | if shards_out is None:
136 | shards_out = shards_in
137 |
138 | old_flattened, structure = jax.tree_flatten(pytree)
139 |
140 | original_opt_state = pytree["opt_state"]
141 |
142 | # TODO: figure out how to use a process pool here for more speed
143 | with multiprocessing.pool.ThreadPool(shards_in) as p:
144 | start = time.time()
145 | shards = list((p.imap(read_shard, [f"{dir}shard_{i}/" for i in range(shards_in)])))
146 | print(f"read from disk/gcs in {time.time() - start:.06}s")
147 |
148 | def _unshard(shards, old_flattened):
149 | unsharded = []
150 |
151 | for old, *all_shards in zip(old_flattened, *shards):
152 | x = np.stack(all_shards)
153 | # No idea why this is V2...?
154 | if x.dtype == np.dtype('V2'):
155 | x.dtype = jnp.bfloat16
156 |
157 | if shards_out != shards_in:
158 | x = reshard(x, old.shape)
159 | unsharded.append(x)
160 |
161 | assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}"
162 | return unsharded
163 | try:
164 | unsharded = _unshard(shards, old_flattened)
165 | except AssertionError:
166 | load_opt = False # no opt to load in ckpt
167 | del pytree['opt_state']
168 | old_flattened, structure = jax.tree_flatten(pytree)
169 | unsharded = _unshard(shards, old_flattened)
170 |
171 | loaded_pytree = jax.tree_unflatten(structure, unsharded)
172 |
173 | if not load_opt:
174 | loaded_pytree['opt_state'] = original_opt_state
175 | return loaded_pytree
176 |
177 |
178 | def read_ckpt_lowmem(pytree, dir, shards_in, shards_out=None, load_opt=True):
179 | if shards_out is None:
180 | shards_out = shards_in
181 |
182 | old_flattened, structure = jax.tree_flatten(pytree)
183 |
184 | original_opt_state = pytree["opt_state"]
185 |
186 | def _unshard():
187 | start = time.time()
188 | unsharded = []
189 | devices = jax.devices()
190 | device_count = len(devices)
191 | device_index = 0
192 |
193 | for file_index in range(pieces):
194 | array_keys = [*np.load(f"{dir}shard_0/{file_index}.npz").keys()]
195 | for array_index in range(len(array_keys)):
196 | unstacked = []
197 | for shard_index in range(shards_in):
198 | npz = np.load(f"{dir}shard_{shard_index}/{file_index}.npz")
199 | array = npz[array_keys[array_index]]
200 | if array.dtype == 'V2':
201 | array.dtype = jnp.bfloat16
202 | unstacked.append(array)
203 |
204 | x = jax.device_put(jnp.stack(unstacked), device=devices[device_index % device_count])
205 |
206 | if shards_out != shards_in:
207 | x = reshard(x, old_flattened[device_index].shape)
208 | unsharded.append(x)
209 |
210 | assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}"
211 | device_index += 1
212 |
213 | print(f"read from disk/gcs in {time.time() - start:.06}s")
214 | return unsharded
215 |
216 | try:
217 | unsharded = _unshard()
218 | except AssertionError:
219 | load_opt = False # no opt to load in ckpt
220 | del pytree['opt_state']
221 | old_flattened, structure = jax.tree_flatten(pytree)
222 | unsharded = _unshard()
223 |
224 | loaded_pytree = jax.tree_unflatten(structure, unsharded)
225 |
226 | if not load_opt:
227 | loaded_pytree['opt_state'] = original_opt_state
228 | return loaded_pytree
229 |
230 |
231 | def parallel_write(arrays, fname):
232 | # TODO: make this actually parallel
233 | with open(fname, "wb") as f:
234 | np.savez(f, *arrays)
235 |
236 |
237 | def parallel_read(old, fname, validate=True):
238 | old_vals, treedef = jax.tree_flatten(old)
239 |
240 | if "gs://" in fname:
241 | # TODO: make this actually parallel
242 | with open(fname, "rb") as f:
243 | buf = f.read()
244 | f_io = io.BytesIO(buf)
245 | loaded = np.load(f_io)
246 | else:
247 | loaded = np.load(fname, mmap_mode='r')
248 |
249 | new_vals = []
250 | for i in loaded:
251 | new_vals.append(loaded[i])
252 |
253 | assert len(new_vals) == len(old_vals), "Incompatible checkpoint"
254 |
255 | for o, n in zip(new_vals, old_vals):
256 | if validate:
257 | assert o.shape == n.shape, "Incompatible checkpoint"
258 |
259 | return jax.tree_unflatten(treedef, fix_dtype(new_vals))
260 |
261 |
262 | def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id):
263 | id_to_name = {}
264 | if getattr(pytree, "items", None):
265 | for k, v in pytree.items():
266 | k_path = f"{path}/{k}"
267 | if is_leaf(v):
268 | id_to_name[to_id(v)] = k_path
269 | else:
270 | id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path)}
271 | elif getattr(pytree, "__getitem__", None):
272 | for v in pytree:
273 | if is_leaf(v):
274 | id_to_name[to_id(v)] = path
275 | else:
276 | id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=path)}
277 | else:
278 | id_to_name[to_id(pytree)] = path
279 | return id_to_name
280 |
281 |
282 | def tree_leaves_with_names(pytree, to_id=id):
283 | leaves = jax.tree_leaves(pytree)
284 | is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [to_id(x) for x in leaves]
285 | return tree_flatten_with_names(pytree, is_leaf)
286 |
287 |
288 | def write_ckpt_v2(model_state, dir):
289 | start = time.time()
290 | if jax.host_id() == 0:
291 | param_map = tree_leaves_with_names(model_state["params"])
292 | opt_map = tree_leaves_with_names(model_state["opt_state"])
293 |
294 | meta = {
295 | "total_hosts": jax.host_count(),
296 | "step": int(model_state["step"]),
297 | "param_order": [param_map[id(i)] for i in jax.tree_leaves(model_state["params"])],
298 | "opt_order": [opt_map[id(i)] for i in jax.tree_leaves(model_state["opt_state"])]
299 | }
300 |
301 | print("step:", model_state["step"])
302 | with open(dir + "/meta.json", "w") as f:
303 | json.dump(meta, f)
304 | print(f"meta written in {time.time() - start:.06}s")
305 |
306 | start = time.time()
307 | parallel_write(jax.tree_flatten(model_state["params"])[0], dir + f"/params/shard_{jax.host_id()}.npz")
308 | head_print(f"params written in {time.time() - start:.06}s")
309 |
310 | start = time.time()
311 | parallel_write(jax.tree_flatten(model_state["opt_state"])[0], dir + f"/opt_state/shard_{jax.host_id()}.npz")
312 | head_print(f"opt_state written in {time.time() - start:.06}s")
313 |
314 |
315 | def read_sharded_v2(state, dir, checkpoint_hosts, state_shard):
316 | files_per_host = checkpoint_hosts // jax.host_count()
317 |
318 | assert files_per_host >= 1, "can't restore model to larger pod than was trained on (yet)"
319 | assert jax.host_count() * files_per_host == checkpoint_hosts, "weird host count"
320 |
321 | if files_per_host == 1:
322 | head_print("using fast path of checkpoint restore (save shards == read shards)")
323 | parallel_read(state, dir + f"/shard_{jax.host_id()}.npz")
324 |
325 | @ray.remote
326 | def read_remote(old, fname):
327 | return parallel_read(old, fname, validate=False)
328 |
329 | start_idx = files_per_host * jax.host_id()
330 |
331 | skeleton = jax.tree_map(lambda x: jnp.zeros_like(x, shape=()), state) # a full pytree just to carry dtypes
332 |
333 | refs = [
334 | read_remote.remote(skeleton, f"{dir}/shard_{i}.npz")
335 | for i in range(start_idx, start_idx + files_per_host)
336 | ]
337 |
338 | values = ray.get(refs)
339 |
340 | def all_array_equal(iterator):
341 | try:
342 | iterator = iter(iterator)
343 | first = next(iterator)
344 | return all(jnp.array_equal(first, rest) for rest in iterator)
345 | except StopIteration:
346 | return True
347 |
348 | def reshard_v2(old, shard_strategy, *new_values):
349 | rep_dim_count = shard_strategy.count(None)
350 | total_dim_count = len(shard_strategy)
351 |
352 | # head_print("old.shape", old.shape)
353 | # head_print("shard_strategy", shard_strategy)
354 |
355 | assert len(old.shape) == total_dim_count
356 |
357 | if rep_dim_count == total_dim_count:
358 | # fully replicated
359 | assert all_array_equal(new_values)
360 | return fix_dtype(new_values[0])
361 |
362 | shard_dim = [idx for idx, dim in enumerate(shard_strategy) if dim is not None and "mp" in dim]
363 |
364 | # only support sharding in 1d for now
365 | assert len(shard_dim) == 1
366 | shard_dim = shard_dim[0]
367 |
368 | ret_val = jnp.concatenate(fix_dtype(new_values), axis=shard_dim)
369 | assert old.shape == ret_val.shape
370 |
371 | return jax.device_put(ret_val, jax.devices("cpu")[0])
372 |
373 | # head_print("state", jax.tree_structure(state))
374 | # head_print("state_shard", jax.tree_structure(state_shard))
375 | # head_print("values", jax.tree_structure(values[0]))
376 |
377 | return jax.tree_multimap(reshard_v2, *([state, state_shard] + values))
378 |
379 |
380 | def load_ckpt_v2(model_state, dir, state_shard, load_opt):
381 | start = time.time()
382 | with open(dir + "meta.json", "r") as f:
383 | meta = json.load(f)
384 |
385 | ckpt_hosts = meta["total_hosts"]
386 |
387 | head_print(f"meta loaded in {time.time() - start:.06}s")
388 |
389 | new_state = {
390 | "step": np.array([meta["step"]]),
391 | }
392 |
393 | start = time.time()
394 | new_state["params"] = read_sharded_v2(model_state["params"],
395 | dir + "params",
396 | ckpt_hosts,
397 | state_shard["params"])
398 | head_print(f"params loaded in {time.time() - start:.06}s")
399 |
400 | if not load_opt:
401 | return new_state
402 |
403 | start = time.time()
404 | new_state["opt_state"] = read_sharded_v2(model_state["opt_state"],
405 | dir + "opt_state",
406 | ckpt_hosts,
407 | state_shard["opt_state"])
408 | head_print(f"opt_state loaded in {time.time() - start:.06}s")
409 |
410 | return new_state
411 |
--------------------------------------------------------------------------------
/mesh_transformer/sampling.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 |
5 | # takes in a logit distribution, softmax and then sample
6 | def softmax_sample(key, logits, _, temp=1):
7 | return jax.random.categorical(key, logits/temp, -1).astype(jnp.uint32), None
8 |
9 |
10 | def nucleaus_filter(logits, top_p=0.9, top_k=None):
11 | sorted_logits = jnp.sort(logits)[:, ::-1] # sort descending
12 | sorted_indices = jnp.argsort(logits)[:, ::-1]
13 | cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1)
14 |
15 | if top_k is not None:
16 | # Keep only top_k tokens
17 | indices_range = jnp.arange(len(sorted_indices[0]))
18 | indices_range = jnp.stack([indices_range] * len(sorted_indices), axis=0)
19 |
20 | sorted_indices_to_remove = jnp.where(indices_range >= top_k, sorted_indices, 0)
21 |
22 | _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove)
23 |
24 | logit_mask = 1e10 * indices_to_remove
25 |
26 | logits -= logit_mask
27 |
28 | # Remove tokens with cumulative probability above a threshold
29 | sorted_indices_to_remove = cumulative_probs > top_p
30 | sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove), axis=-1)[:, :-1]
31 |
32 | _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove)
33 |
34 | logit_mask = 1e10 * indices_to_remove
35 |
36 | logits -= logit_mask
37 |
38 | return logits
39 |
40 |
41 | def nucleaus_sample(key, logits, _, top_p=0.9, temp=1, top_k=None):
42 | logits = nucleaus_filter(logits, top_p, top_k=top_k)
43 |
44 | return softmax_sample(key, logits, None, temp=temp)
45 |
46 |
47 | if __name__ == "__main__":
48 | import numpy as np
49 | logits = np.array([[-2, -1, 0, 0.8, 0, 0.1, 0.3, 0.4, 0.5, 0.6, 0.7, -3]])
50 | print(nucleaus_filter(logits))
--------------------------------------------------------------------------------
/mesh_transformer/train_actor.py:
--------------------------------------------------------------------------------
1 | import ray
2 | import time
3 | import numpy as np
4 | from queue import Queue
5 |
6 | from mesh_transformer.util import head_print
7 |
8 |
9 | @ray.remote(resources={"tpu": 1})
10 | class NetworkRunner(object):
11 | def __init__(self, mesh_shape, network_builder):
12 | self.mesh_shape = mesh_shape
13 | self.network_builder = network_builder
14 |
15 | self.input_q = Queue(maxsize=1)
16 | self.output_q = Queue(maxsize=1)
17 |
18 | def run(self):
19 | print(f"jax runtime initialization starting")
20 | import jax
21 | from jax.experimental.maps import thread_resources, ResourceEnv, Mesh
22 | import haiku as hk
23 | # jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
24 |
25 | thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())
26 |
27 | start = time.time()
28 | jax.devices()
29 |
30 | import warnings
31 | warnings.filterwarnings("ignore")
32 | warnings.filterwarnings("ignore", category=ResourceWarning)
33 |
34 | if jax.host_id() == 0:
35 | warnings.filterwarnings("default")
36 |
37 | head_print(f"jax devices: {jax.device_count()}")
38 | head_print(f"jax runtime initialized in {time.time() - start:.06}s")
39 | devices = np.array(jax.devices()).reshape(self.mesh_shape)
40 |
41 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
42 | start = time.time()
43 | network = self.network_builder()
44 | head_print(f"Initialized in {time.time() - start:.06}s")
45 |
46 | while True:
47 | operation, input = self.input_q.get()
48 | if operation == "train":
49 | self.output_q.put(network.train(input))
50 | elif operation == "eval":
51 | self.output_q.put(network.eval(input))
52 | elif operation == "generate":
53 | self.output_q.put(network.generate(*input))
54 | elif operation == "write_ckpt":
55 | path, shard = input
56 | network.write_ckpt(path, shard)
57 | self.output_q.put(None)
58 | elif operation == "load_ckpt":
59 | network.load_ckpt(input)
60 | self.output_q.put(network.state["step"][0])
61 | elif operation == "get_params":
62 | self.output_q.put(hk.data_structures.tree_size(network.state['params']))
63 | elif operation == "move_params":
64 | # only needed for inference, otherwise first train step does this
65 | local_shards = max(jax.local_device_count() // self.mesh_shape[1], 1)
66 |
67 | # delete the optimizer states otherwise it OOMs for some reason
68 | # TODO: use ShardedDeviceArray or something to get around this for bigger models
69 | del network.state["opt_state"]
70 | network.state = network.move_xmap(network.state, np.zeros(local_shards))
71 | self.output_q.put(None)
72 | else:
73 | raise Exception("Not implemented")
74 |
75 | def get_params(self):
76 | self.input_q.put(("get_params", None))
77 | return self.output_q.get()
78 |
79 | def train(self, sample):
80 | self.input_q.put(("train", sample))
81 | return self.output_q.get()
82 |
83 | def eval(self, sample):
84 | self.input_q.put(("eval", sample))
85 | return self.output_q.get()
86 |
87 | def generate(self, input):
88 | self.input_q.put(("generate", input))
89 | return self.output_q.get()
90 |
91 | def write_ckpt(self, path, shard):
92 | self.input_q.put(("write_ckpt", (path, shard)))
93 | return self.output_q.get()
94 |
95 | def load_ckpt(self, path):
96 | self.input_q.put(("load_ckpt", path))
97 | return self.output_q.get()
98 |
99 | def move_params(self):
100 | self.input_q.put(("move_params", None))
101 | return self.output_q.get()
102 |
--------------------------------------------------------------------------------
/mesh_transformer/util.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.experimental.pjit import with_sharding_constraint
4 | from optax import AdditiveWeightDecayState, GradientTransformation, OptState
5 |
6 |
7 | # same as with_sharding_constraint but doesn't fail if run outside of pjit/mesh context
8 | def maybe_shard(x, resource):
9 | try:
10 | return with_sharding_constraint(x, resource)
11 | except ValueError as e:
12 | print(e)
13 | return x
14 |
15 |
16 | def gpt3_schedule(warmup_steps,
17 | total_steps,
18 | peak_lr,
19 | end_lr):
20 | def sch(step):
21 | warmup_pct = jnp.clip(step, 0, warmup_steps) / warmup_steps
22 | anneal_pct = jnp.clip(step - warmup_steps, 0, total_steps) / total_steps
23 |
24 | return warmup_pct * peak_lr - (peak_lr - end_lr) * (1 - jnp.cos(jnp.pi * anneal_pct)) / 2
25 |
26 | return sch
27 |
28 |
29 | def global_norm(updates, use_psum=True):
30 | pre_sqrt = sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(updates)])
31 | if use_psum:
32 | pre_sqrt = jax.lax.psum(pre_sqrt, "shard")
33 | return jnp.sqrt(pre_sqrt)
34 |
35 |
36 | class ClipByGlobalNormState(OptState):
37 | """The `clip_by_global_norm` transformation is stateless."""
38 |
39 |
40 | def clip_by_global_norm(max_norm, use_psum=True) -> GradientTransformation:
41 | """Clip updates using their global norm.
42 |
43 | References:
44 | [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
45 |
46 | Args:
47 | max_norm: the maximum global norm for an update.
48 |
49 | Returns:
50 | An (init_fn, update_fn) tuple.
51 | """
52 |
53 | def init_fn(_):
54 | return ClipByGlobalNormState()
55 |
56 | def update_fn(updates, state, params=None):
57 | del params
58 | g_norm = global_norm(updates, use_psum=use_psum)
59 | trigger = g_norm < max_norm
60 | updates = jax.tree_map(
61 | lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates)
62 | return updates, state
63 |
64 | return GradientTransformation(init_fn, update_fn)
65 |
66 |
67 | def additive_weight_decay(weight_decay: float = 0.0) -> GradientTransformation:
68 | """Add parameter scaled by `weight_decay`, to all parameters with more than one dim (i.e. exclude ln, bias etc)
69 |
70 | Args:
71 | weight_decay: a scalar weight decay rate.
72 |
73 | Returns:
74 | An (init_fn, update_fn) tuple.
75 | """
76 |
77 | def init_fn(_):
78 | return AdditiveWeightDecayState()
79 |
80 | def update_fn(updates, state, params):
81 | updates = jax.tree_multimap(lambda g, p: g + weight_decay * p * (len(g.shape) > 1), updates, params)
82 | return updates, state
83 |
84 | return GradientTransformation(init_fn, update_fn)
85 |
86 |
87 | def to_f32(t):
88 | return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
89 |
90 |
91 | def to_bf16(t):
92 | return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
93 |
94 |
95 | def to_f16(t):
96 | return jax.tree_map(lambda x: x.astype(jnp.float16) if x.dtype == jnp.float32 else x, t)
97 |
98 |
99 | # identity in forward pass, psum in backward
100 | @jax.custom_vjp
101 | def f_psum(x):
102 | return x
103 |
104 |
105 | def f_psum_fwd(x):
106 | return f_psum(x), None
107 |
108 |
109 | def f_psum_bwd(_, g):
110 | return jax.lax.psum(g, "shard"),
111 |
112 |
113 | f_psum.defvjp(f_psum_fwd, f_psum_bwd)
114 |
115 |
116 | # identity in forward pass, pmean in backward
117 | @jax.custom_vjp
118 | def f_pmean(x):
119 | return x
120 |
121 |
122 | def f_pmean_fwd(x):
123 | return f_psum(x), None
124 |
125 |
126 | def f_pmean_bwd(_, g):
127 | return jax.lax.pmean(g, "shard"),
128 |
129 |
130 | f_pmean.defvjp(f_pmean_fwd, f_pmean_bwd)
131 |
132 |
133 | # psum in forward pass, identity in backward
134 | @jax.custom_vjp
135 | def g_psum(x):
136 | return jax.lax.psum(x, "shard")
137 |
138 |
139 | def g_psum_fwd(x):
140 | return g_psum(x), None
141 |
142 |
143 | def g_psum_bwd(_, g):
144 | return g,
145 |
146 |
147 | g_psum.defvjp(g_psum_fwd, g_psum_bwd)
148 |
149 |
150 | def shard_axis(x, axis_size, axis_name):
151 | # in_shape = x.shape
152 | assert x.shape[0] % axis_size == 0
153 |
154 | x = x.reshape((axis_size, -1) + x.shape[1:])
155 |
156 | x = x[jax.lax.axis_index(axis_name)]
157 | # print("shard out", x.shape, "in", in_shape)
158 |
159 | # assert np.prod(x.shape) * axis_size == np.prod(in_shape)
160 |
161 | return x
162 |
163 |
164 | def unshard_axis(x, axis_name):
165 | # in_shape = x.shape
166 | x = jax.lax.all_gather(x, axis_name)
167 |
168 | x = x.reshape((-1, ) + x.shape[2:])
169 |
170 | # assert x.shape[-1] == 4096
171 | # print("unshard out", x.shape, "in", in_shape)
172 | return x
173 |
174 |
175 | # print but only on the first node
176 | def head_print(*args, **kwargs):
177 | if jax.host_id() == 0:
178 | print(*args, **kwargs)
179 |
180 |
181 | if __name__ == "__main__":
182 | sch = gpt3_schedule(1_000, 20_000, 1e-4, 1e-5)
183 |
184 | for i in range(150):
185 | i = i * 200
186 | print(i, sch(i))
187 |
--------------------------------------------------------------------------------
/ray_tpu.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | import subprocess
4 | import time
5 |
6 | import glob
7 | import requests
8 | from fabric import Connection
9 |
10 |
11 | @functools.lru_cache()
12 | def get_bearer():
13 | return subprocess.check_output("gcloud auth print-access-token", shell=True).decode("utf-8").strip()
14 |
15 |
16 | @functools.lru_cache()
17 | def get_project():
18 | return subprocess.check_output("gcloud config list --format 'value(core.project)'", shell=True).decode(
19 | "utf-8").strip()
20 |
21 |
22 | def create_tpu(
23 | name,
24 | zone,
25 | type,
26 | preemptible,
27 | ):
28 | headers = {
29 | 'Authorization': f'Bearer {get_bearer()}',
30 | 'Content-Type': 'application/json',
31 | }
32 |
33 | try:
34 | status = check_tpu(name, zone)
35 |
36 | if status["state"] not in ["CREATING", "READY"]:
37 | print("deleting TPU")
38 | delete_tpu(name, zone)
39 |
40 | while True:
41 | try:
42 | print("deleting check")
43 | print(check_tpu(name, zone)["state"])
44 |
45 | time.sleep(1)
46 | except:
47 | break
48 | except:
49 | pass
50 |
51 | params = (
52 | ('node_id', name),
53 | )
54 |
55 | data = {"accelerator_type":
56 | type,
57 | "runtime_version":
58 | 'v2-alpha',
59 | "network_config":
60 | {"enable_external_ips": True},
61 | }
62 |
63 | if preemptible:
64 | data["schedulingConfig"] = {"preemptible": True}
65 |
66 | response = requests.post(f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes',
67 | headers=headers, params=params, json=data)
68 |
69 | print(response.json())
70 |
71 | return response.status_code == 200
72 |
73 |
74 | def check_tpu(name, zone):
75 | headers = {
76 | 'Authorization': f'Bearer {get_bearer()}',
77 | }
78 |
79 | response = requests.get(
80 | f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}',
81 | headers=headers)
82 |
83 | return response.json()
84 |
85 |
86 | def delete_tpu(name, zone):
87 | headers = {
88 | 'Authorization': f'Bearer {get_bearer()}',
89 | }
90 |
91 | response = requests.delete(
92 | f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}',
93 | headers=headers)
94 |
95 | return response.json()
96 |
97 |
98 | def wait_til(name, zone, state):
99 | while True:
100 | ret = check_tpu(name, zone)
101 |
102 | print("wait_til check")
103 | print(ret)
104 |
105 | matches = True
106 | for k, expected_v in state.items():
107 | if k not in ret:
108 | matches = False
109 | continue
110 | if ret[k] != expected_v:
111 | matches = False
112 |
113 | if "error" in ret:
114 | return False
115 |
116 | if ret["state"] == "TERMINATED":
117 | return False
118 |
119 | if matches:
120 | return True
121 |
122 | time.sleep(1)
123 |
124 |
125 | def get_connection(
126 | name,
127 | zone,
128 | ):
129 | info = check_tpu(name, zone)
130 | outputs = []
131 | for i in info["networkEndpoints"]:
132 | outputs.append(Connection(i["ipAddress"],
133 | connect_kwargs={
134 | "key_filename": os.path.expanduser('~/.ssh/google_compute_engine'), }))
135 | return outputs
136 |
137 |
138 | def start_ray(conn, address, version=1):
139 | conn.sudo('rm -rf *.py')
140 | conn.sudo('rm -rf mesh_transformer')
141 |
142 | for i in glob.glob("*.py"):
143 | conn.put(i, "")
144 |
145 | conn.run("mkdir mesh_transformer -p")
146 |
147 | for i in glob.glob("mesh_transformer/*.py"):
148 | conn.put(i, "mesh_transformer/")
149 |
150 | conn.sudo('python3 setup.py install', hide=True)
151 |
152 | if version == 2:
153 | conn.put("scripts/init_ray_v2.sh", "/tmp/ray-tpu.sh")
154 | else:
155 | conn.put("scripts/init_ray.sh", "/tmp/ray-tpu.sh")
156 | conn.sudo('chmod +x /tmp/ray-tpu.sh', hide=True)
157 | conn.sudo('/tmp/ray-tpu.sh', hide=True)
158 | try:
159 | conn.run('ray stop -f', hide=True)
160 | except:
161 | pass
162 |
163 | time.sleep(1)
164 |
165 | conn.run(f"TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD={32 * 1024**3} ray start --address={address} --resources='" + '{"tpu": 1}\' --include-dashboard False', hide=True)
166 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy~=1.19.5
2 | tqdm~=4.62.0
3 | wandb>=0.11.2
4 | einops~=0.3.0
5 | requests~=2.25.1
6 | fabric~=2.6.0
7 | optax==0.0.9
8 | dm-haiku==0.0.5
9 | git+https://github.com/EleutherAI/lm-evaluation-harness/
10 | ray[default]==1.4.1
11 | jax~=0.2.12
12 | Flask~=1.1.2
13 | cloudpickle~=1.3.0
14 | tensorflow-cpu~=2.6.0
15 | google-cloud-storage~=1.36.2
16 | transformers
17 | smart_open[gcs]
18 | func_timeout
19 | ftfy
20 | fastapi
21 | lm_dataformat
22 | pathy
23 |
--------------------------------------------------------------------------------
/resharding_example.py:
--------------------------------------------------------------------------------
1 | # This was tested with an RTX 3090, peak memory usage is approximately 22.4GB during inference, and 19GB when loading the model
2 | # The following environment variables were also used: XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform
3 |
4 | import time
5 |
6 | import jax
7 | from jax.experimental import maps
8 | import numpy as np
9 | import optax
10 | import transformers
11 |
12 | from mesh_transformer.checkpoint import read_ckpt
13 | from mesh_transformer.sampling import nucleaus_sample
14 | from mesh_transformer.transformer_shard import CausalTransformer
15 |
16 | params = {
17 | "layers": 28,
18 | "d_model": 4096,
19 | "n_heads": 16,
20 | "n_vocab": 50400,
21 | "norm": "layernorm",
22 | "pe": "rotary",
23 | "pe_rotary_dims": 64,
24 | "early_cast": True,
25 | "seq": 2048,
26 | "cores_per_replica": 1, # only running on one GPU
27 | "per_replica_batch": 1,
28 | }
29 |
30 | per_replica_batch = params["per_replica_batch"]
31 | cores_per_replica = params["cores_per_replica"]
32 | seq = params["seq"]
33 |
34 |
35 | params["sampler"] = nucleaus_sample
36 |
37 | # here we "remove" the optimizer parameters from the model (as we don't need them for inference)
38 | params["optimizer"] = optax.scale(0)
39 |
40 | devices = np.array([jax.devices()[0]]).reshape((1, 1))
41 | maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
42 |
43 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
44 |
45 | network = CausalTransformer(params)
46 |
47 | start = time.time()
48 |
49 | # here we load a checkpoint which was written with 8 shards into 1 shard
50 | network.state = read_ckpt(network.state, "step_383500/", 8, shards_out=cores_per_replica)
51 |
52 | # move the state to CPU/system memory so it's not duplicated by xmap
53 | network.state = jax.device_put(network.state, jax.devices("cpu")[0])
54 |
55 | def infer(context, top_k=40, top_p=0.9, temp=1.0, gen_len=512):
56 | tokens = tokenizer.encode(context)
57 |
58 | provided_ctx = len(tokens)
59 | pad_amount = seq - provided_ctx
60 |
61 | padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
62 | batched_tokens = np.array([padded_tokens] * per_replica_batch)
63 | length = np.ones(per_replica_batch, dtype=np.uint32) * len(tokens)
64 |
65 | start = time.time()
66 | output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp})
67 |
68 | samples = []
69 | decoded_tokens = output[1][0]
70 |
71 | for o in decoded_tokens[:, :, 0]:
72 | samples.append(tokenizer.decode(o))
73 |
74 | print(f"completion done in {time.time() - start:06}s")
75 | return samples
76 |
77 |
78 | infer("EleutherAI is")
79 |
--------------------------------------------------------------------------------
/scripts/create_serve_tpu.sh:
--------------------------------------------------------------------------------
1 | gcloud alpha compute tpus tpu-vm create "$1" --zone europe-west4-a --accelerator-type v3-8 --version v2-alpha
2 | sleep 120
3 | gcloud alpha compute tpus tpu-vm ssh "$1" --zone europe-west4-a --command connor@sparse:~/kindiana/mesh-transformer-jax --command="$(< scripts/deploy_server.sh)"
--------------------------------------------------------------------------------
/scripts/deploy_server.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | rm -r mesh-transformer-jax || true
5 |
6 | git clone https://github.com/kingoflolz/mesh-transformer-jax
7 | pip install -r mesh-transformer-jax/requirements.txt
8 | pip install mesh-transformer-jax/ jax==0.2.12
9 |
10 | pushd mesh-transformer-jax || exit
11 | screen -d -m python3 device_serve.py --config configs/6B_roto_256.json
--------------------------------------------------------------------------------
/scripts/init_ray.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | sudo /usr/bin/docker-credential-gcr configure-docker
5 |
6 | sudo docker rm libtpu || true
7 | sudo docker create --name libtpu gcr.io/cloud-tpu-v2-images/libtpu:libtpu_20210518_RC00 "/bin/bash" && sudo docker cp libtpu:libtpu.so /lib
8 |
9 | # this locks the python executable down to hopefully stop if from being fiddled with...
10 | screen -d -m python -c 'import time; time.sleep(999999999)'
11 |
12 | # initializes jax and installs ray on cloud TPUs
13 | sudo pip install --upgrade jaxlib==0.1.67 jax==0.2.12 ray[default]==1.5.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout
--------------------------------------------------------------------------------
/scripts/init_ray_v2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | # this locks the python executable down to hopefully stop if from being fiddled with...
5 | screen -d -m python -c 'import time; time.sleep(999999999)'
6 |
7 | # initializes jax and installs ray on cloud TPUs
8 | sudo pip install "jax[tpu]>=0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
9 | sudo pip install --upgrade ray[default]==1.5.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout
--------------------------------------------------------------------------------
/scripts/init_serve.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | # initializes jax and installs ray on cloud TPUs
5 | sudo pip install --upgrade jaxlib jax==0.2.12 ray==1.2.0 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout transformers flask
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='mesh_transformer',
5 | version='0.0.0',
6 | packages=find_packages(include=['mesh_transformer', 'mesh_transformer.*'])
7 | )
8 |
--------------------------------------------------------------------------------
/slim_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import time
4 |
5 | import jax
6 | import numpy as np
7 | import optax
8 |
9 | from mesh_transformer import util
10 | from mesh_transformer.checkpoint import read_ckpt, write_ckpt
11 | from mesh_transformer.transformer_shard import CausalTransformer
12 | from smart_open import open
13 |
14 | from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16
15 |
16 |
17 | def parse_args():
18 | # Parse command line arguments
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--config", type=str, default=None, help="Config file location")
21 | parser.add_argument("--ckpt-step", type=int, default=-1, help="Step number of the checkpoint to convert (if not specified, converts the most recent checkpoint)")
22 | parser.add_argument("--f16", default=False, action="store_true", help="Convert to float16 (instead of bfloat16)")
23 |
24 | args = parser.parse_args()
25 | return args
26 |
27 |
28 | if __name__ == "__main__":
29 | args = parse_args()
30 | params = json.load(open(args.config))
31 | convert_fn = to_f16 if args.f16 else to_bf16
32 |
33 | cores_per_replica = params["cores_per_replica"]
34 |
35 | assert cores_per_replica <= 8
36 |
37 | bucket = params["bucket"]
38 | model_dir = params["model_dir"]
39 |
40 | params["optimizer"] = optax.chain(
41 | optax.scale(1),
42 | clip_by_global_norm(1),
43 | optax.scale_by_adam(),
44 | optax.additive_weight_decay(0),
45 | optax.scale(-1),
46 | optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
47 | )
48 |
49 | start = time.time()
50 | print(f"jax devices: {jax.device_count()}")
51 | print(f"jax runtime initialized in {time.time() - start:.06}s")
52 |
53 | mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
54 | devices = np.array(jax.devices()).reshape(mesh_shape)
55 |
56 | with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
57 | meta = json.load(f)
58 |
59 | if args.ckpt_step > -1:
60 | ckpt_step = args.ckpt_step
61 | else:
62 | ckpt_step = meta["checkpoints"][-1]
63 | print(f"using checkpoint {ckpt_step}")
64 |
65 | with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
66 | network = CausalTransformer(params)
67 |
68 | start = time.time()
69 | network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
70 | print(f"network loaded in {time.time() - start:.06}s")
71 |
72 | start = time.time()
73 | del network.state["opt_state"]
74 |
75 | network.state["params"] = convert_fn(network.state["params"])
76 | print(f"network converted in {time.time() - start:.06}s")
77 |
78 | suffix = "_slim_f16" if args.f16 else "_slim"
79 |
80 | for i in range(cores_per_replica):
81 | write_ckpt(network.state, f"gs://{bucket}/{model_dir}{suffix}/step_{ckpt_step}/", i)
82 | print(f"written shard {i}")
83 |
--------------------------------------------------------------------------------
/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from tasks.eval_harness import EvalHarnessAdaptor
--------------------------------------------------------------------------------
/tasks/eval_harness.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import transformers
4 | from lm_eval.base import LM
5 | from tqdm import tqdm
6 | import numpy as np
7 |
8 | from tasks.util import sample_batch, shrink_seq
9 | import multiprocessing
10 | import ftfy
11 |
12 | tokenizer = None
13 |
14 |
15 | def process_init():
16 | global tokenizer
17 | tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
18 | tokenizer.model_max_length = int(1e30)
19 | tokenizer.pad_token = "<|endoftext|>"
20 |
21 | assert tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
22 |
23 |
24 | def process_request(x, seq):
25 | global tokenizer
26 |
27 | ctx, cont = x
28 |
29 | ctx_tokens = tokenizer.encode("<|endoftext|>" + ftfy.fix_text(ctx, normalization="NFKC"))
30 | cont_tokens = tokenizer.encode(ftfy.fix_text(cont, normalization="NFKC"))
31 |
32 | all_tokens = ctx_tokens + cont_tokens
33 | all_tokens = np.array(all_tokens)[-seq:] # truncate sequence at seq length
34 |
35 | provided_ctx = len(all_tokens) - 1
36 | pad_amount = seq - provided_ctx
37 |
38 | return {
39 | "obs": np.pad(all_tokens[:-1], ((0, pad_amount),), constant_values=50256),
40 | "target": np.pad(all_tokens[1:], ((0, pad_amount),), constant_values=50256),
41 | "ctx_length": seq,
42 | "eval_mask": np.logical_and(
43 | np.arange(0, seq) > len(all_tokens) - len(cont_tokens) - 2,
44 | np.arange(0, seq) < len(all_tokens) - 1
45 | ),
46 | }
47 |
48 |
49 | class EvalHarnessAdaptor(LM):
50 | def greedy_until(self, requests):
51 | raise Exception("unimplemented")
52 |
53 | def loglikelihood_rolling(self, requests):
54 | raise Exception("unimplemented")
55 |
56 | def __init__(self, tpu_cluster, seq, batch, shrink, min_seq=None):
57 | super().__init__()
58 | self.tpu = tpu_cluster
59 | self.seq = seq
60 | self.batch = batch
61 | self.shrink = shrink
62 | self.min_seq = min_seq
63 |
64 | self.pool = multiprocessing.Pool(initializer=process_init)
65 | process_init()
66 |
67 | def convert_requests(self, requests):
68 | return self.pool.imap(partial(process_request, seq=self.seq), requests)
69 |
70 | def loglikelihood(self, requests):
71 | output = []
72 |
73 | r = self.convert_requests(requests)
74 | zero_example = process_request(requests[0], self.seq)
75 |
76 | for b in tqdm(sample_batch(r, self.batch, zero_example),
77 | desc="LM eval harness",
78 | total=len(requests) // self.batch):
79 |
80 | if self.shrink:
81 | b = shrink_seq(b, min_seq=self.min_seq)
82 |
83 | out = self.tpu.eval(b)
84 |
85 | for loss, correct in zip(out["mask_loss"], out["each_correct"]):
86 | output.append((float(-loss), bool(correct)))
87 |
88 | return output
89 |
--------------------------------------------------------------------------------
/tasks/util.py:
--------------------------------------------------------------------------------
1 | from itertools import zip_longest
2 |
3 | import numpy as np
4 |
5 |
6 | def grouper(n, iterable, fillvalue):
7 | "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
8 | args = [iter(iterable)] * n
9 | return zip_longest(fillvalue=fillvalue, *args)
10 |
11 |
12 | # divide the seq length by 2 until it would truncate actual context
13 | def shrink_seq(examples, min_seq=None):
14 | length = examples["obs"].shape[-1]
15 |
16 | new_length = length // 2
17 |
18 | if min_seq is not None:
19 | if new_length < min_seq:
20 | return examples
21 |
22 | max_length = np.max(examples["eval_mask"] * np.arange(0, length)) + 1
23 |
24 | if max_length < new_length:
25 | examples["obs"] = examples["obs"][:, :new_length]
26 | examples["target"] = examples["target"][:, :new_length]
27 | examples["eval_mask"] = examples["eval_mask"][:, :new_length]
28 |
29 | return shrink_seq(examples, min_seq=min_seq)
30 | else:
31 | return examples
32 |
33 |
34 | def sample_batch(examples, bs, zero_example_shape):
35 | zero_example = {
36 | "obs": np.zeros_like(zero_example_shape["obs"]),
37 | "target": np.zeros_like(zero_example_shape["target"]),
38 | "eval_mask": np.zeros_like(zero_example_shape["eval_mask"]),
39 | "ctx_length": 0,
40 | }
41 |
42 | for batch in grouper(bs, examples, zero_example):
43 | batch_flattened = {
44 | "obs": [],
45 | "target": [],
46 | "eval_mask": [],
47 | "ctx_length": [],
48 | }
49 |
50 | for sample in batch:
51 | batch_flattened["obs"].append(sample["obs"])
52 | batch_flattened["target"].append(sample["target"])
53 | batch_flattened["eval_mask"].append(sample["eval_mask"])
54 | batch_flattened["ctx_length"].append(sample["ctx_length"])
55 |
56 | batch_flattened["obs"] = np.array(batch_flattened["obs"])
57 | batch_flattened["target"] = np.array(batch_flattened["target"])
58 | batch_flattened["eval_mask"] = np.array(batch_flattened["eval_mask"])
59 | batch_flattened["ctx_length"] = np.array(batch_flattened["ctx_length"])
60 |
61 | yield batch_flattened
62 |
--------------------------------------------------------------------------------
/tfrecord_loader.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import tensorflow as tf
3 | import numpy as np
4 | from transformers import GPT2TokenizerFast
5 | import itertools
6 |
7 |
8 | class TFRecordLoader:
9 | def __init__(self, index_fname, batch_size, parse_fn, map_fn=None, restore_state=None):
10 | if restore_state is not None:
11 | self.file_idx = restore_state["file_idx"]
12 | self.file_idx_init = False
13 | self.used = restore_state["used"]
14 | else:
15 | self.file_idx = 0
16 | self.file_idx_init = True
17 | self.used = []
18 |
19 | self.index = open(index_fname).read().splitlines()
20 | self.clean_index = list(filter(lambda x: x not in self.used, self.index))
21 | self.bs = batch_size
22 | # self.seq = sample_size
23 | self.parse_fn = parse_fn
24 |
25 | if map_fn:
26 | self.map_fn = map_fn
27 | else:
28 | self.map_fn = lambda x: x
29 |
30 | self.sample_fn = self.sample_once()
31 |
32 | def reset(self):
33 | self.file_idx = 0
34 | self.file_idx_init = True
35 | self.used = []
36 |
37 | self.clean_index = list(filter(lambda x: x not in self.used, self.index))
38 | self.sample_fn = self.sample_once()
39 |
40 | def sample_once(self):
41 | for i in self.clean_index:
42 | compression = "ZLIB" if "zstd" in i else ""
43 |
44 | file = tf.data.TFRecordDataset(i, compression_type=compression).map(self.parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
45 | file = file.apply(tf.data.experimental.dense_to_ragged_batch(np.prod(self.bs), drop_remainder=True))
46 | file = file.prefetch(10)
47 |
48 | for file_idx, data in enumerate(file):
49 | data = jax.tree_map(lambda x: x.numpy(), data)
50 | data = self.map_fn(data)
51 |
52 | if not self.file_idx_init and file_idx <= self.file_idx:
53 | if file_idx % 1000 == 0:
54 | print(f"skipping to batch {self.file_idx}, currently at {file_idx}")
55 | continue
56 | self.file_idx_init = True
57 | self.file_idx = file_idx
58 | yield jax.tree_map(lambda x: x.reshape(self.bs + x.shape[1:]), data)
59 | self.used.append(i)
60 | self.file_idx = 0
61 |
62 | # this loops infinitely, use .sample_once to get an iterator for validation
63 | def get_samples(self):
64 | try:
65 | return next(self.sample_fn)
66 | except StopIteration:
67 | self.reset()
68 | return self.get_samples()
69 |
70 | def get_state(self):
71 | return {
72 | "used": self.used,
73 | "file_idx": self.file_idx
74 | }
75 |
76 |
77 | class TFRecordNewInputs(TFRecordLoader):
78 | def __init__(self, index_fname, batch_size, sample_size, restore_state=None):
79 | def tf_parse(example_proto):
80 | features = {
81 | "text": tf.io.VarLenFeature(tf.int64)
82 | }
83 | parsed_features = tf.io.parse_single_example(example_proto, features)
84 |
85 | return tf.cast(tf.sparse.to_dense(tf.sparse.reorder(parsed_features["text"])), tf.uint32)
86 |
87 | super().__init__(index_fname, batch_size, tf_parse, restore_state=restore_state)
88 |
89 |
90 | class TFRecordWIT(TFRecordLoader):
91 | def __init__(self, index_fname, batch_size, restore_state=None, text_tokens=256):
92 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
93 | self.tokenizer.pad_token = "<|endoftext|>"
94 | self.tokenizer.add_special_tokens({'sep_token': '<|sep|>', 'pad_token': '<|pad|>'})
95 |
96 | def map_fn(example):
97 | tokenizer = self.tokenizer
98 |
99 | def decode(x):
100 | return tokenizer(["<|endoftext|>" + i.decode() for i in x])["input_ids"]
101 |
102 | texts = [
103 | decode(example["context_page_description"]),
104 | decode(example["context_section_description"]),
105 | decode(example["caption_reference_description"]),
106 | decode(example["caption_alt_text_description"]),
107 | decode(example["caption_attribution_description"]),
108 | ]
109 |
110 | output = []
111 |
112 | for text, dalle in zip(zip(*texts), example["dalle"]):
113 | all_text = list(itertools.chain(*text))[-text_tokens+1:]
114 |
115 | all_text += [tokenizer.pad_token_id] * ((text_tokens - 1) - len(all_text))
116 |
117 | assert len(all_text) == text_tokens - 1
118 |
119 | all_tokens = all_text + [tokenizer.sep_token_id] + list(dalle + tokenizer.vocab_size + 1)
120 | output.append(all_tokens)
121 |
122 | return np.array(output)
123 |
124 | def tf_parse(example_proto):
125 | features = {
126 | "page_title": tf.io.FixedLenFeature([], tf.string),
127 | "section_title": tf.io.FixedLenFeature([], tf.string),
128 | "hierarchical_section_title": tf.io.FixedLenFeature([], tf.string),
129 | "caption_reference_description": tf.io.FixedLenFeature([], tf.string),
130 | "caption_attribution_description": tf.io.FixedLenFeature([], tf.string),
131 | "caption_alt_text_description": tf.io.FixedLenFeature([], tf.string),
132 | "mime_type": tf.io.FixedLenFeature([], tf.string),
133 | "context_page_description": tf.io.FixedLenFeature([], tf.string),
134 | "context_section_description": tf.io.FixedLenFeature([], tf.string),
135 |
136 | "dalle": tf.io.FixedLenFeature([1024], tf.int64),
137 | }
138 |
139 | parsed_features = tf.io.parse_single_example(example_proto, features)
140 |
141 | return parsed_features
142 |
143 | super().__init__(index_fname, batch_size, tf_parse, map_fn, restore_state=restore_state)
144 |
145 |
146 | if __name__ == "__main__":
147 | # d = TFRecordNewInputs("data/pile.val.index", (8, 32), 2048)
148 | # for idx, i in enumerate(d.sample_once()):
149 | # print(i)
150 | # break
151 |
152 | d = TFRecordWIT("data/wit_dalle.train.index", (8, 32))
153 | for idx, i in enumerate(d.sample_once()):
154 | print(i)
155 | break
156 |
157 | print()
158 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import time
4 |
5 | import numpy as np
6 | import wandb
7 | from tqdm import tqdm
8 |
9 | from mesh_transformer.build_model import build_model
10 | from lm_eval import evaluator, tasks
11 | from tasks.eval_harness import EvalHarnessAdaptor
12 | from tfrecord_loader import TFRecordNewInputs
13 | import multiprocessing
14 |
15 |
16 | def parse_args():
17 | # Parse command line arguments
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--tpu", type=str, help="Name of TPU to train on.")
20 | parser.add_argument("--tpu_region", type=str, help="Region of TPU to train on.")
21 | parser.add_argument("--preemptible", action="store_true")
22 |
23 | parser.add_argument("--config", type=str, default=None, help="Config file location")
24 |
25 | parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and "
26 | "starts a new training run")
27 |
28 | parser.add_argument("--version", type=int, default=1, help="Choose which model version to use")
29 |
30 | args = parser.parse_args()
31 | return args
32 |
33 |
34 | if __name__ == "__main__":
35 | # huggingface tokenizers gets very angry if you fork
36 | multiprocessing.set_start_method("spawn")
37 |
38 | args = parse_args()
39 | params = json.load(open(args.config))
40 |
41 | if args.new:
42 | print(f"Starting experiment {params['name']} from scratch! "
43 | f"all data in gs://{params['bucket']}/{params['model_dir']}/ will be deleted")
44 | input("Hit enter to continue")
45 |
46 | tpu_name = args.tpu
47 | region = args.tpu_region
48 | preemptible = args.preemptible
49 | clean_start = args.new
50 |
51 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
52 | per_replica_batch = params["per_replica_batch"]
53 | tpu_size = params["tpu_size"]
54 | cores_per_replica = params["cores_per_replica"]
55 |
56 | bucket = params["bucket"]
57 | model_dir = params["model_dir"]
58 | layers = params["layers"]
59 | d_model = params["d_model"]
60 | n_heads = params["n_heads"]
61 | n_vocab = params["n_vocab"]
62 | seq = params["seq"]
63 | norm = params["norm"]
64 |
65 | val_batches = params["val_batches"]
66 | val_every = params["val_every"]
67 | ckpt_every = params["ckpt_every"]
68 | keep_every = params["keep_every"]
69 | eval_tasks = params["eval_harness_tasks"]
70 | total_steps = params["total_steps"]
71 |
72 | pe = params["pe"]
73 | assert pe in ["fixed", "rotary", "t5"]
74 |
75 | t = build_model(params, tpu_name, region, preemptible, version=args.version)
76 |
77 | try:
78 | t.save(0, bucket, model_dir, init=True, overwrite=clean_start)
79 | step = 0
80 | train_load_restore = None
81 | except Exception as e:
82 | print(f"Save failed with error {e}, trying to load instead...", e)
83 | step, aux = t.load(bucket, model_dir)
84 | train_load_restore = aux.get("train_loader", None)
85 |
86 | if train_load_restore is None:
87 | print("Failed to restore train loader state")
88 |
89 | train_dataset = TFRecordNewInputs(f"data/{params['train_set']}",
90 | batch_size=(
91 | gradient_accumulation_steps,
92 | per_replica_batch * tpu_size // cores_per_replica),
93 | sample_size=params['seq'],
94 | restore_state=train_load_restore)
95 |
96 | global_val_batch = int(per_replica_batch * tpu_size // cores_per_replica * params.get("val_batch_multiplier", 1))
97 |
98 | val_sets = {}
99 |
100 | for k, v in params['val_set'].items():
101 | val_sets[k] = TFRecordNewInputs(f"data/{v}",
102 | batch_size=(global_val_batch,),
103 | sample_size=seq)
104 |
105 | # use dynamic seq length unless pe is fixed
106 | adaptor = EvalHarnessAdaptor(t,
107 | seq,
108 | global_val_batch,
109 | shrink=pe != "fixed",
110 | min_seq=1024 if args.version == 2 else None) # work around suboptimal pjit layout
111 |
112 | start = time.time()
113 | t.train(train_dataset.get_samples())
114 | print(f"Train fn compiled in {time.time() - start:.06}s")
115 |
116 | start = time.time()
117 | for val_set in val_sets.values():
118 | t.eval(val_set.get_samples())
119 | print(f"Eval fn compiled in {time.time() - start:.06}s")
120 |
121 | project = params.get("wandb_project", "mesh-transformer-jax")
122 | wandb.init(project=project, entity="eleutherai", name=params["name"], config=params)
123 |
124 | eval_task_dict = tasks.get_task_dict(eval_tasks)
125 |
126 | pbar = tqdm(initial=step, total=total_steps, desc="Training progress")
127 |
128 | while True:
129 | loss, last_loss = t.train(train_dataset.get_samples())
130 | wandb.log({'train/loss': loss, 'train/last_loss': last_loss}, step)
131 |
132 | if (step % ckpt_every == 0 and step) or step == total_steps:
133 | t.save(step, bucket, model_dir,
134 | aux={"train_loader": train_dataset.get_state()},
135 | init=False,
136 | delete_old=step % keep_every != 0)
137 |
138 | if step == total_steps:
139 | print("training completed!")
140 | exit()
141 |
142 | if step % val_every == 0:
143 | for name, val_set in val_sets.items():
144 | val_loss = []
145 | for i, _ in tqdm(zip(val_set.sample_once(), range(val_batches)),
146 | desc=f"validation for step {step}, set {name}",
147 | total=val_batches):
148 | val_loss.append(t.eval(i))
149 | val_loss = np.array(val_loss).mean()
150 | print(f"validation loss for step {step}, set {name}: {val_loss}")
151 |
152 | wandb.log({f'val/loss_{name}': float(val_loss)}, step)
153 |
154 | results = evaluator.evaluate(adaptor, eval_task_dict, False, 0, None)
155 |
156 | flat_results = {}
157 |
158 | for task_name, task_res in results["results"].items():
159 | version = results["versions"][task_name]
160 | for metric_name, metric_res in task_res.items():
161 | flat_results[f"{task_name}-v{version}/{metric_name}"] = float(metric_res)
162 |
163 | dumped = json.dumps(results, indent=2)
164 | print(f"step {step} val results: {dumped}")
165 | wandb.log(flat_results, step)
166 | step += 1
167 |
168 | pbar.set_postfix({'loss': loss, 'last_loss': last_loss})
169 | pbar.update()
170 |
--------------------------------------------------------------------------------