├── .gitignore
├── LICENSE
├── README.md
├── assets
├── downstream.png
├── env_dump
│ ├── lscpu.txt
│ ├── nvidia_smi.txt
│ └── pip_freeze.txt
├── ft_loss.png
├── nanoT5.png
└── pt_loss.png
├── nanoT5
├── __init__.py
├── configs
│ ├── default.yaml
│ ├── local_env
│ │ └── default.yaml
│ └── task
│ │ ├── ft.yaml
│ │ ├── pt.yaml
│ │ ├── pt_12h.yaml
│ │ ├── pt_16h.yaml
│ │ ├── pt_20h.yaml
│ │ ├── pt_24h.yaml
│ │ ├── pt_4h.yaml
│ │ └── pt_8h.yaml
├── main.py
└── utils
│ ├── __init__.py
│ ├── copied_utils.py
│ ├── gen_utils.py
│ ├── logging_utils.py
│ ├── model_utils.py
│ ├── ni_dataset.py
│ ├── t5_model.py
│ └── train_utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | env.yaml
3 | logs
4 | scripts
5 | .neptune
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2022 - Piotr Nawrot. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # nanoT5 (Encoder-Decoder / Pre-training + Fine-Tuning)
2 |
3 | 
4 |
5 | [[Paper](https://arxiv.org/abs/2309.02373)] | [**TLDR**](#tldr) | [**Motivation**](#motivation) | [**Setup**](#setup) | [**Pre-training**](#pre-training) | [**Fine-tuning**](#fine-tuning) | [**Extras**](#Extras) | [**Conclusions**](#conclusions) | [**References**](#references) | [**Cite**](#cite) | [**Issues**](#issues)
6 |
7 | ## TLDR:
8 |
9 | This repository comprises the code to reproduce the pre-training of a "Large Language Model" (T5) under a limited budget (1xA100 GPU, < 24 hours) in PyTorch. We start from the randomly initialised T5-base-v1.1 (248M parameters) model, and we pre-train it on the English subset of the C4 dataset and then fine-tune it on Super-Natural Instructions (SNI) benchmark.
10 |
11 | **In ~16 hours on a single GPU, we achieve 40.7 RougeL on the SNI test set, compared to 40.9 RougeL of the original model weights available on HuggingFace Hub and pretrained on 150x more data through "a combination of model and data parallelism [...] on slices of Cloud TPU Pods", each with 1024 TPUs.**
12 |
13 | Our core contribution is not the T5 model itself, which follows the HuggingFace implementation. Instead, we optimise everything else in the training pipeline to offer you a user-friendly starting template for your NLP application/research. Most importantly, we show that it is possible to pre-train the T5 model to the top performance under a limited budget in PyTorch.
14 |
15 | ## Motivation
16 |
17 | Despite the continuously increasing size of pretrained [Transformers](https://arxiv.org/pdf/1706.03762.pdf), the research community still needs easy-to-reproduce and up-to-date baselines to test new research hypotheses fast and at a small scale.
18 |
19 | A recent effort from Andrej Karpathy, the [nanoGPT](https://github.com/karpathy/nanoGPT) repository, enables researchers to pre-train and fine-tune GPT-style (Decoder-only) language models. On the other hand, [Cramming](https://github.com/JonasGeiping/cramming) opts to find the optimal BERT-style (Encoder-only) pre-training setup for limited-compute settings.
20 |
21 | With [nanoT5](https://github.com/PiotrNawrot/nanoT5), we want to fill a gap (Community requests: [#1](https://github.com/huggingface/transformers/issues/18030) [#2](https://github.com/facebookresearch/fairseq/issues/1899) [#3](https://github.com/google-research/text-to-text-transfer-transformer/issues/172) [#4](https://discuss.huggingface.co/t/example-of-how-to-pretrain-t5/4129) [#5](https://github.com/huggingface/transformers/issues/5079)) of an accessible research template to pre-train and fine-tune T5-style (Encoder-Decoder) model. **To the best of our knowledge, it is the first attempt to reproduce T5 v1.1 pre-training in PyTorch (previously available implementations are in Jax/Flax).**
22 |
23 | ##
24 |
25 | **We created this repository for people who want to pre-train T5-style models by themselves and evaluate their performance on downstream tasks.** This could be for a variety of reasons:
26 | - You are a researcher in academia with limited compute (like me), and you came up with a promising idea based on the T5 model, so you need a pipeline to evaluate it;
27 | - You have an in-house dataset that you think is more appropriate than the original pre-training dataset (C4) for your downstream task;
28 | - You want to experiment with continued pre-training or want to build on the T5 pre-training objective.
29 |
30 | **If you don't need to pre-train the T5 model, you'd be better off downloading the weights from HuggingFace Hub. Our checkpoints are worse because we work under limited compute.**
31 |
32 | ##
33 |
34 | In this project, we expose (for research purposes) and optimise everything in the training pipeline of T5, except from the model implementation. We include the simplified implementation of T5 model (which is great for teaching & learning purposes), however, we do not optimize it with the latest techniques like [tensor or pipeline parallelism](https://arxiv.org/pdf/2104.04473.pdf), because it makes the code much more complex and is not needed at a small scale. **Most importantly, we base our code on PyTorch, since access to TPUs is limited.** Key features:
35 | - **Dataset:** Downloading and preprocessing of the C4 dataset happens in parallel with the training of the model. The C4 dataset is > 300GB, so it takes a couple of hours to download it and even longer to preprocess it. This codebase does it on the fly without any detrimental effect on the training time and performance (we haven't observed it, although it might happen with an old CPU (< 8 core) or a slow internet connection). **As a result, you can initiate pre-training of your own T5 model within minutes.**
36 | - **Model Optimizer / LR Scheduler:** The original T5 uses a memory-efficient Adafactor optimizer. [A study on pre-training T5](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models), on the other hand, reports that training does not converge with AdamW which we find strange given that AdamW relies on theoretically better approximations than Adafactor. We analysed the source of this discrepancy with several ablations. Although there are many subtle differences between Adafactor and AdamW, what ensures the Adafactor convergence is [matrix-wise LR scaling by its root mean square (RMS)](https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L595). We augmented the AdamW implementation by RMS scaling and observed that it becomes **more stable during pre-training, achieves better validation loss, and is faster**.
37 | - **Exposure and simplicity:** We try to balance the implementation of the training pipeline by keeping it customisable while retaining a sufficient level of abstraction. We use the [HuggingFace Accelerator](https://huggingface.co/docs/accelerate/index) to implement operations like Checkpoint Saving, Gradient Accumulation, Gradient Clipping, and moving tensors to the correct devices. We use [neptune.ai](https://neptune.ai) for experiment tracking and [hydra](https://hydra.cc/docs/intro/) for hyperparameters. Additionally, we expose a [simplified implementation](nanoT5/utils/t5_model.py) of the T5 model, training loop, data preprocessing, etc.
38 | - **Efficiency:** We use mixed-precision training (TF32 & BF16), PyTorch 2.0 compile, and utilise all optimisations listed in established optimisation tutorials [#1](https://huggingface.co/docs/transformers/perf_train_gpu_one) [#2](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html).
39 |
40 | ## Setup
41 |
42 | ### Environment & Hardware:
43 |
44 | ```
45 | git clone https://github.com/PiotrNawrot/nanoT5.git
46 | cd nanoT5
47 | conda create -n nanoT5 python=3.8
48 | conda activate nanoT5
49 | pip install -r requirements.txt
50 | ```
51 |
52 | The following commands result in the following [pip freeze](assets/env_dump/pip_freeze.txt) as of 18.06.2023.
53 |
54 | We also include our [lscpu](assets/env_dump/lscpu.txt) and [nvidia-smi](assets/env_dump/nvidia_smi.txt).
55 |
56 | ## Pre-training:
57 |
58 | ### Reference:
59 |
60 | The [T5 v1.1](https://arxiv.org/pdf/2002.05202.pdf) authors report **1.942** negative log-likelihood (NLL) on the held-out set of C4 after after 2^16 steps.
61 |
62 | ### Legacy Optimizer (Adafactor) & LR Schedule (Inverse-Square-Root)
63 |
64 | We follow the original experimental setup for pre-training, including [Dataset (C4)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L74), [Training Objective (Span Filling)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/copied_utils.py#L16), [Model Architecture (T5-Base)](nanoT5/utils/t5_model.py), [Optimizer (Adafactor)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L248), and [LR Schedule (Inverse-Square-Root)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L288).
65 |
66 | Our negative log-likelihood on the held-out set is **1.995**, slightly worse than the reference.
67 |
68 | ### AdamW with RMS scaling Optimizer & Cosine LR Schedule
69 |
70 | We also experiment with the AdamW optimizer (instead of the original Adafactor) as it (theoretically) should offer more stability during training. Instead of using a low-rank approximation for the second moment of the gradients, it estimates it directly by storing the moving average for each parameter in memory. However, training diverges with AdamW, similar to [this study on T5 pre-training](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models). Through several ablations, we found that [matrix-wise LR scaling by its root mean square (RMS)](https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L617) is responsible for the convergence of Adafactor. We augmented the AdamW implementation by RMS scaling and observed that [it converges, becomes more stable during pre-training, and is slightly faster](assets/pt_loss.png) (it retrieves the second moment directly from memory instead of approximating it via matrix multiplications).
71 |
72 | However, AdamW, when paired with the Inverse-Square-Root LR schedule, performs worse than Adafactor. For our final experiment, we replace ISR with Cosine LR Schedule. We achieve **1.953** negative log-likelihood on the held-out set and significantly outperform Adafactor with ISR schedule. We consider this our default pre-training config.
73 |
74 | ### Training loss of runs with different optimisers (Adafactor vs AdamW) and schedulers (ISR vs Cosine). Rest of the [hyperparameters](nanoT5/configs/default.yaml) follows the original T5 v1.1 paper.
75 |
76 | 
77 |
78 | ### Negative log-likelihood on the held-out set of C4
79 |
80 |
81 |
82 | | | **Inverse-Square-Root** | **Cosine** |
83 | | :---: | :----: | :---: |
84 | | **Adafactor** | 1.995 | 1.993 |
85 | | **AdamW** | 2.040 | **1.953** |
86 |
87 |
88 |
89 | ### Examples
90 |
91 | To reproduce any of the experiments mentioned above choose any combination of hyperparameters as follows:
92 |
93 | ```
94 | python -m nanoT5.main \
95 | optim.name={adafactor,adamwscale} \
96 | optim.lr_scheduler={legacy,cosine}
97 | ```
98 |
99 | We recommend adding `model.compile=true` flag for pre-training, if you are able to install PyTorch 2.0.
100 |
101 | Suppose you don't have access to a 80GB A100 GPU. In that case, you can increase the number of gradient accumulation steps by `optim.grad_acc=steps`, where `batch_size` has to be divisible by `steps`.
102 |
103 | The summary of the optimization process is printed every 100 steps in the following format. For instance:
104 |
105 | ```
106 | [train] Step 100 out of 65536 | Loss --> 59.881 | Grad_l2 --> 61.126 | Weights_l2 --> 7042.931 | Lr --> 0.010 | Seconds_per_step --> 1.385 |
107 | ```
108 |
109 | ### Efficiency statistics:
110 |
111 | Below we include the efficiency statistics for our pre-training experiments. We report the time it takes to pre-train the model for 1 pre-training step and the total pre-training time according to the [default config](nanoT5/configs/default.yaml). Please note that we need to increase the **optim.grad_acc steps** to fit the model in precision different from BF16.
112 |
113 |
114 |
115 | | **Mixed Precision Format** | **Torch 2.0 compile** | **Grad Acc Steps** | **Pre-training (1 step)** | **Total Pre-training time** |
116 | | :----: | :---: | :---: | :---: | :---: |
117 | | FP32 | No | 2 | ~4.10s | ~74.6h |
118 | | TF32 | No | 2 | ~1.39s | ~25.3h |
119 | | BF16 | No | 2 | ~1.30s | ~23.7h |
120 | | TF32 | Yes | 2 | ~0.95s | ~17.3h |
121 | | BF16 | Yes | 1 | ~0.56s | ~10.2h |
122 |
123 |
124 |
125 | ## Fine-tuning:
126 |
127 | To fine-tune our model, we use the popular meta-dataset called **Super Natural-Instructions (SNI)**, which aggregates datasets for many tasks. This meta-dataset was used to fine-tune many of the recent LLMs, e.g. [FlanT5](https://arxiv.org/pdf/2210.11416.pdf), [BLOOM](https://arxiv.org/pdf/2211.05100.pdf), and [Tk-Instruct](https://arxiv.org/pdf/2204.07705.pdf). While FlanT5 and BLOOM use other corpora in addition to SNI, Tk-Instruct's pipeline consists of starting from the pre-trained T5 model and fine-tuning it solely on SNI.
128 |
129 | In this repository, we reproduce the Tk-Instruct fine-tuning results and follow their pipeline to evaluate our pre-trained model.
130 |
131 | ### Download the Super-Natural Instructions data:
132 |
133 | ```
134 | git clone https://github.com/allenai/natural-instructions.git data
135 | ```
136 |
137 | ### Run fine-tuning:
138 |
139 | We strictly follow the fine-tuning [config](nanoT5/configs/task/ft.yaml) of Tk-Instruct. It remains unclear whether Tk-Instruct was initialised from a regular checkpoint (*google/t5-v1_1-base*) or the one adapted explicitly for Language Modelling with continued training (*google/t5-base-lm-adapt*). Therefore, we decided to evaluate both. Run the following command to reproduce the Tk-Instruct experiments:
140 |
141 | ```
142 | python -m nanoT5.main task=ft \
143 | model.name={google/t5-v1_1-base,google/t5-base-lm-adapt} \
144 | model.random_init={true,false} \
145 | model.checkpoint_path={"","/path/to/pytorch_model.bin"}
146 | ```
147 |
148 | Setting `model.random_init=false model.checkpoint_path=""` corresponds to downloading pre-trained weights from HuggingFace Hub.
149 |
150 | Setting `model.random_init=false model.checkpoint_path="/path/to/pytorch_model.bin"` corresponds to using the weights [**pre-trained**](#pre-training) with nanoT5.
151 |
152 | Setting `model.random_init=true model.checkpoint_path=""` corresponds to a random initialisation.
153 |
154 | ### Rouge-L on the held-out test-set across different pre-training budgets:
155 |
156 | In the figure below, we compare the performance of the model trained in this repository under different time budgets ([4](nanoT5/configs/task/pt_4h.yaml), [8](nanoT5/configs/task/pt_8h.yaml), [12](nanoT5/configs/task/pt_12h.yaml), [16](nanoT5/configs/task/pt_16h.yaml), [20](nanoT5/configs/task/pt_20h.yaml), [24](nanoT5/configs/task/pt_24h.yaml) hours) with the original T5-base-v1.1 model weights available through Huggingface Hub and its version adapted for Language Modelling (*google/t5-base-lm-adapt*). We observe that model trained in our repository for 16 hours on a single GPU is only 0.2 RougeL worse on average than the original T5-base-v1.1 model, despite being pre-trained on 150x less data (According to the [T5 paper](https://arxiv.org/pdf/1910.10683.pdf), they pre-train their models for one million steps with a batch size 2048. Our 16 hours config does 53332 steps with a batch size 256). Checkpoint explicitly adapted for Language Modelling (*google/t5-base-lm-adapt*) performs better than the original T5-base-v1.1 model and our model, however, this goes beyond the scope of this repository.
157 |
158 | 
159 |
160 | We share the model's weights after pre-training for 24 hours on [HuggingFace Hub](https://huggingface.co/pnawrot/nanoT5-base), which you can download and fine-tune on SNI using nanoT5.
161 | We also share the [fine-tuning loss curves](assets/ft_loss.png).
162 |
163 | A single Fine-tuning step takes ~0.18s, and full Fine-tuning takes ~1 hour.
164 |
165 | ## Extras:
166 |
167 | ### Things we tried and didn't work out:
168 |
169 | - **Different optimizers:** We tried the most recent optimizers like [Lion](https://arxiv.org/abs/2302.06675), [Sophia](https://github.com/Liuhong99/Sophia), however, none of them worked better than AdamW with RMS scaling.
170 | - **Positional embeddings:** We tried to replace T5's learned relative positional embeddings with [ALiBi](https://arxiv.org/pdf/2108.12409.pdf). Possible benefits include a reduction of parameters and faster training & inference. Furthermore, if ALiBi worked we could add [Flash Attention](https://github.com/HazyResearch/flash-attention), which currently supports only non-parametric bias (T5 bias is trainable). However, with ALiBi the training was less stable and it had worse pre-training loss.
171 | - **FP16 precision:** All experiments with FP16 precision diverged across different seeds
172 |
173 | ## Conclusions:
174 |
175 | We show that it is possible to successfully pre-train a "Large Language Model" (T5) under a limited budget (1xA100 GPU, < 24 hours) in PyTorch. We make our codebase, configs and training logs publicly available to enhance the accessibility of NLP research. We are keen to hear your suggestions to improve the codebase further.
176 |
177 | ### Acknowledgements:
178 |
179 | Thanks to [Edoardo Maria Ponti](https://ducdauge.github.io) for his feedback!
180 |
181 | ## References:
182 | - [T5 paper](https://arxiv.org/pdf/1910.10683.pdf)
183 | - [T5 v1.1 paper](https://arxiv.org/pdf/2002.05202.pdf)
184 | - [Super-Natural Instructions paper](https://arxiv.org/pdf/2204.07705.pdf)
185 | - [HuggingFace Flax Script](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py)
186 | - [Karpathy's nanoGPT](https://github.com/karpathy/nanoGPT)
187 | - [Instruct-GPT codebase (Super-Natural Instructions)](https://github.com/yizhongw/Tk-Instruct)
188 | - [Blog about pre-training Dutch T5 in HuggingFace](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models)
189 |
190 | ## Cite
191 |
192 | If you found the repository useful consider citing the paper about this work.
193 |
194 | ```
195 | @inproceedings{nawrot-2023-nanot5,
196 | title = {nano{T}5: Fast {\&} Simple Pre-training and Fine-tuning of {T}5 Models with Limited Resources},
197 | author = {Nawrot, Piotr},
198 | year = 2023,
199 | month = dec,
200 | booktitle = {Proceedings of the 3rd Workshop for Natural Language Processing Open Source Software (NLP-OSS 2023)},
201 | publisher = {Association for Computational Linguistics},
202 | doi = {10.18653/v1/2023.nlposs-1.11},
203 | url = {https://aclanthology.org/2023.nlposs-1.11}
204 | }
205 | ```
206 |
207 | Below you can also find an excellent work which uses nanoT5 for their experiments.
208 |
209 | ```
210 | @article{Kaddour2023NoTN,
211 | title={No Train No Gain: Revisiting Efficient Training Algorithms For Transformer-based Language Models},
212 | author={Jean Kaddour and Oscar Key and Piotr Nawrot and Pasquale Minervini and Matt J. Kusner},
213 | journal={ArXiv},
214 | year={2023},
215 | volume={abs/2307.06440},
216 | }
217 | ```
218 |
219 | ## Issues:
220 |
221 | If you have any questions, feel free to raise a Github issue or contact me directly at: piotr.nawrot@ed.ac.uk
222 |
--------------------------------------------------------------------------------
/assets/downstream.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/assets/downstream.png
--------------------------------------------------------------------------------
/assets/env_dump/lscpu.txt:
--------------------------------------------------------------------------------
1 | Architecture: x86_64
2 | CPU op-mode(s): 32-bit, 64-bit
3 | Byte Order: Little Endian
4 | CPU(s): 128
5 | On-line CPU(s) list: 0-127
6 | Thread(s) per core: 1
7 | Core(s) per socket: 64
8 | Socket(s): 2
9 | NUMA node(s): 8
10 | Vendor ID: AuthenticAMD
11 | CPU family: 25
12 | Model: 1
13 | Model name: AMD EPYC 7763 64-Core Processor
14 | Stepping: 1
15 | CPU MHz: 2445.206
16 | BogoMIPS: 4890.41
17 | Virtualization: AMD-V
18 | L1d cache: 32K
19 | L1i cache: 32K
20 | L2 cache: 512K
21 | L3 cache: 32768K
22 | NUMA node0 CPU(s): 0-15
23 | NUMA node1 CPU(s): 16-31
24 | NUMA node2 CPU(s): 32-47
25 | NUMA node3 CPU(s): 48-63
26 | NUMA node4 CPU(s): 64-79
27 | NUMA node5 CPU(s): 80-95
28 | NUMA node6 CPU(s): 96-111
29 | NUMA node7 CPU(s): 112-127
30 | Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
31 |
--------------------------------------------------------------------------------
/assets/env_dump/nvidia_smi.txt:
--------------------------------------------------------------------------------
1 | Sun Jun 18 13:21:03 2023
2 | +-----------------------------------------------------------------------------+
3 | | NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |
4 | |-------------------------------+----------------------+----------------------+
5 | | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
6 | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
7 | | | | MIG M. |
8 | |===============================+======================+======================|
9 | | 0 NVIDIA A100-SXM... On | 00000000:C1:00.0 Off | 0 |
10 | | N/A 65C P0 420W / 500W | 75277MiB / 81920MiB | 84% Default |
11 | | | | Disabled |
12 | +-------------------------------+----------------------+----------------------+
13 |
14 | +-----------------------------------------------------------------------------+
15 | | Processes: |
16 | | GPU GI CI PID Type Process name GPU Memory |
17 | | ID ID Usage |
18 | |=============================================================================|
19 | | 0 N/A N/A 50391 C python 75274MiB |
20 | +-----------------------------------------------------------------------------+
21 |
--------------------------------------------------------------------------------
/assets/env_dump/pip_freeze.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | accelerate==0.20.3
3 | aiohttp==3.8.4
4 | aiosignal==1.3.1
5 | antlr4-python3-runtime==4.9.3
6 | anyio==3.7.0
7 | argon2-cffi==21.3.0
8 | argon2-cffi-bindings==21.2.0
9 | arrow==1.2.3
10 | asttokens==2.2.1
11 | async-timeout==4.0.2
12 | attrs==23.1.0
13 | backcall==0.2.0
14 | beautifulsoup4==4.12.2
15 | bleach==6.0.0
16 | boto3==1.26.155
17 | botocore==1.29.155
18 | bravado==11.0.3
19 | bravado-core==5.17.1
20 | certifi==2023.5.7
21 | cffi==1.15.1
22 | charset-normalizer==3.1.0
23 | click==8.1.3
24 | cmake==3.26.4
25 | comm==0.1.3
26 | datasets==2.13.0
27 | debugpy==1.6.7
28 | decorator==5.1.1
29 | defusedxml==0.7.1
30 | dill==0.3.6
31 | docopt==0.6.2
32 | evaluate==0.4.0
33 | exceptiongroup==1.1.1
34 | executing==1.2.0
35 | fancycompleter==0.9.1
36 | fastjsonschema==2.17.1
37 | filelock==3.12.2
38 | fqdn==1.5.1
39 | frozenlist==1.3.3
40 | fsspec==2023.6.0
41 | future==0.18.3
42 | gitdb==4.0.10
43 | GitPython==3.1.31
44 | huggingface-hub==0.15.1
45 | hydra-core==1.3.2
46 | idna==3.4
47 | importlib-metadata==6.6.0
48 | importlib-resources==5.12.0
49 | ipykernel==6.23.2
50 | ipython==8.12.2
51 | ipython-genutils==0.2.0
52 | isoduration==20.11.0
53 | jedi==0.18.2
54 | Jinja2==3.1.2
55 | jmespath==1.0.1
56 | joblib==1.2.0
57 | jsonpointer==2.4
58 | jsonref==1.1.0
59 | jsonschema==4.17.3
60 | jupyter-events==0.6.3
61 | jupyter_client==8.2.0
62 | jupyter_core==5.3.1
63 | jupyter_server==2.6.0
64 | jupyter_server_terminals==0.4.4
65 | jupyterlab-pygments==0.2.2
66 | lit==16.0.6
67 | MarkupSafe==2.1.3
68 | matplotlib-inline==0.1.6
69 | mistune==2.0.5
70 | monotonic==1.6
71 | mpmath==1.3.0
72 | msgpack==1.0.5
73 | multidict==6.0.4
74 | multiprocess==0.70.14
75 | nbclassic==1.0.0
76 | nbclient==0.8.0
77 | nbconvert==7.5.0
78 | nbformat==5.9.0
79 | neptune==1.3.1
80 | nest-asyncio==1.5.6
81 | networkx==3.1
82 | nltk==3.8.1
83 | notebook==6.5.4
84 | notebook_shim==0.2.3
85 | numpy==1.24.3
86 | nvidia-cublas-cu11==11.10.3.66
87 | nvidia-cuda-cupti-cu11==11.7.101
88 | nvidia-cuda-nvrtc-cu11==11.7.99
89 | nvidia-cuda-runtime-cu11==11.7.99
90 | nvidia-cudnn-cu11==8.5.0.96
91 | nvidia-cufft-cu11==10.9.0.58
92 | nvidia-curand-cu11==10.2.10.91
93 | nvidia-cusolver-cu11==11.4.0.1
94 | nvidia-cusparse-cu11==11.7.4.91
95 | nvidia-nccl-cu11==2.14.3
96 | nvidia-nvtx-cu11==11.7.91
97 | oauthlib==3.2.2
98 | omegaconf==2.3.0
99 | overrides==7.3.1
100 | packaging==23.1
101 | pandas==2.0.2
102 | pandocfilters==1.5.0
103 | parso==0.8.3
104 | pdbpp==0.10.3
105 | pexpect==4.8.0
106 | pickleshare==0.7.5
107 | Pillow==9.5.0
108 | pipreqs==0.4.13
109 | pkgutil_resolve_name==1.3.10
110 | platformdirs==3.6.0
111 | prometheus-client==0.17.0
112 | prompt-toolkit==3.0.38
113 | protobuf==3.20.3
114 | psutil==5.9.5
115 | ptyprocess==0.7.0
116 | pure-eval==0.2.2
117 | pyarrow==12.0.1
118 | pycparser==2.21
119 | Pygments==2.15.1
120 | PyJWT==2.7.0
121 | pynvml==11.5.0
122 | pyrepl==0.9.0
123 | pyrsistent==0.19.3
124 | python-dateutil==2.8.2
125 | python-json-logger==2.0.7
126 | pytz==2023.3
127 | PyYAML==6.0
128 | pyzmq==25.1.0
129 | regex==2023.6.3
130 | requests==2.31.0
131 | requests-oauthlib==1.3.1
132 | responses==0.18.0
133 | rfc3339-validator==0.1.4
134 | rfc3986-validator==0.1.1
135 | rfc3987==1.3.8
136 | rouge-score==0.1.2
137 | s3transfer==0.6.1
138 | safetensors==0.3.1
139 | Send2Trash==1.8.2
140 | sentencepiece==0.1.99
141 | simplejson==3.19.1
142 | six==1.16.0
143 | smmap==5.0.0
144 | sniffio==1.3.0
145 | soupsieve==2.4.1
146 | stack-data==0.6.2
147 | swagger-spec-validator==3.0.3
148 | sympy==1.12
149 | terminado==0.17.1
150 | tinycss2==1.2.1
151 | tokenizers==0.13.3
152 | torch==2.0.1
153 | tornado==6.3.2
154 | tqdm==4.65.0
155 | traitlets==5.9.0
156 | transformers==4.30.2
157 | triton==2.0.0
158 | typing_extensions==4.6.3
159 | tzdata==2023.3
160 | uri-template==1.2.0
161 | urllib3==1.26.16
162 | wcwidth==0.2.6
163 | webcolors==1.13
164 | webencodings==0.5.1
165 | websocket-client==1.6.0
166 | wmctrl==0.4
167 | xxhash==3.2.0
168 | yarg==0.1.9
169 | yarl==1.9.2
170 | zipp==3.15.0
171 |
--------------------------------------------------------------------------------
/assets/ft_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/assets/ft_loss.png
--------------------------------------------------------------------------------
/assets/nanoT5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/assets/nanoT5.png
--------------------------------------------------------------------------------
/assets/pt_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/assets/pt_loss.png
--------------------------------------------------------------------------------
/nanoT5/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/nanoT5/__init__.py
--------------------------------------------------------------------------------
/nanoT5/configs/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: pt
4 | - local_env: default
5 |
6 | # Experiment args
7 | mode: 'pt'
8 | device: gpu
9 | precision: 'bf16'
10 | eval_only: false
11 | predict_only: false
12 | seed: 2137
13 |
14 | model:
15 | klass: local_t5
16 | name: 'google/t5-v1_1-base'
17 | overwrite:
18 | dropout_rate: 0.0
19 | add_config:
20 | is_bf16: false
21 | checkpoint_path: ''
22 | random_init: true
23 | compile: false # Pytorch 2.0
24 |
25 | data:
26 | input_length: 512
27 | mlm_probability: 0.15
28 | mean_noise_span_length: 3.0
29 | num_workers: 8
30 |
31 | optim:
32 | name: adamwscale
33 | base_lr: 2e-2
34 | batch_size: 128
35 | total_steps: 65536
36 | epochs: -1 # If it's > 0 it overwrites total_steps
37 | warmup_steps: 10000
38 | lr_scheduler: cosine
39 | weight_decay: 0.0
40 | grad_clip: 1.0
41 | grad_acc: 1
42 | final_cosine: 1e-5
43 |
44 | eval:
45 | every_steps: 100000 # Eval once in the end
46 | steps: 500
47 |
48 | checkpoint:
49 | every_steps: 100000 # Save checkpoint once in the end
50 |
51 | logging:
52 | neptune: false
53 | neptune_creds:
54 | project:
55 | api_token:
56 | tags: ''
57 | every_steps: 100
58 | grad_l2: true
59 | weights_l2: true
60 |
61 | hydra:
62 | job:
63 | chdir: True
64 |
--------------------------------------------------------------------------------
/nanoT5/configs/local_env/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | hydra:
4 | run:
5 | dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${logging.neptune_creds.tags}
--------------------------------------------------------------------------------
/nanoT5/configs/task/ft.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | mode: 'ft'
4 | precision: 'no'
5 |
6 | model:
7 | klass: hf_t5
8 |
9 | data:
10 | max_seq_len: 1024
11 | max_target_len: 128
12 | max_num_instances_per_task: 100
13 | add_task_name: False
14 | add_task_definition: True
15 | num_pos_examples: 2
16 | num_neg_examples: 0
17 | add_explanation: False
18 | tk_instruct: False
19 | exec_file_path: ./nanoT5/utils/ni_dataset.py
20 | data_dir: ./data/splits/default
21 | task_dir: ./data/tasks
22 |
23 | optim:
24 | name: adamw
25 | base_lr: 5e-5
26 | batch_size: 8
27 | epochs: 2
28 | warmup_steps: 0
29 | lr_scheduler: constant
30 | weight_decay: 0.0
31 | grad_clip: 0.0
32 | grad_acc: 1
33 |
34 | eval:
35 | steps: 200
36 |
--------------------------------------------------------------------------------
/nanoT5/configs/task/pt.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
--------------------------------------------------------------------------------
/nanoT5/configs/task/pt_12h.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | optim:
4 | warmup_steps: 5000
5 | batch_size: 256
6 | total_steps: 39999
7 | grad_acc: 2
--------------------------------------------------------------------------------
/nanoT5/configs/task/pt_16h.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | optim:
4 | warmup_steps: 6666
5 | batch_size: 256
6 | total_steps: 53332
7 | grad_acc: 2
--------------------------------------------------------------------------------
/nanoT5/configs/task/pt_20h.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | optim:
4 | warmup_steps: 8332
5 | batch_size: 256
6 | total_steps: 66665
7 | grad_acc: 2
--------------------------------------------------------------------------------
/nanoT5/configs/task/pt_24h.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | optim:
4 | warmup_steps: 10000
5 | batch_size: 256
6 | total_steps: 79998
7 | grad_acc: 2
--------------------------------------------------------------------------------
/nanoT5/configs/task/pt_4h.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | optim:
4 | warmup_steps: 1666
5 | batch_size: 256
6 | total_steps: 13333
7 | grad_acc: 2
--------------------------------------------------------------------------------
/nanoT5/configs/task/pt_8h.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | optim:
4 | warmup_steps: 3332
5 | batch_size: 256
6 | total_steps: 26666
7 | grad_acc: 2
--------------------------------------------------------------------------------
/nanoT5/main.py:
--------------------------------------------------------------------------------
1 | from accelerate import Accelerator
2 | from omegaconf import open_dict
3 | import hydra
4 | import torch
5 | import time
6 |
7 | from .utils import (
8 | setup_basics,
9 | train,
10 | predict,
11 | eval,
12 | get_lr_scheduler,
13 | get_optimizer,
14 | get_tokenizer,
15 | get_model,
16 | get_dataloaders,
17 | get_config,
18 | )
19 |
20 |
21 | @hydra.main(config_path="configs", config_name="default", version_base='1.1')
22 | def main(args):
23 | accelerator = Accelerator(
24 | cpu=args.device == "cpu",
25 | mixed_precision=args.precision,
26 | )
27 | logger = setup_basics(accelerator, args)
28 | config = get_config(args)
29 | model = get_model(args, config)
30 | tokenizer = get_tokenizer(args)
31 | optimizer = get_optimizer(model, args)
32 | lr_scheduler = get_lr_scheduler(optimizer, args, logger)
33 | train_dataloader, test_dataloader = get_dataloaders(tokenizer, config, args)
34 |
35 | logger.log_args(args)
36 |
37 | (
38 | model,
39 | optimizer,
40 | lr_scheduler,
41 | train_dataloader,
42 | test_dataloader,
43 | ) = accelerator.prepare(
44 | model, optimizer, lr_scheduler, train_dataloader, test_dataloader
45 | )
46 |
47 | if args.model.compile:
48 | model = torch.compile(model)
49 |
50 | with open_dict(args):
51 | args.current_train_step = 1
52 | args.current_epoch = 1
53 | args.last_log = time.time()
54 |
55 | if args.eval_only:
56 | model.eval()
57 | with torch.no_grad():
58 | eval(model, test_dataloader, logger, args, tokenizer)
59 | elif args.predict_only:
60 | model.eval()
61 | with torch.no_grad():
62 | predict(model, test_dataloader, logger,
63 | args, tokenizer)
64 | else:
65 | train(model, train_dataloader, test_dataloader, accelerator,
66 | lr_scheduler, optimizer, logger, args, tokenizer)
67 |
68 | logger.finish()
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/nanoT5/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .gen_utils import *
2 | from .model_utils import *
3 | from .train_utils import *
4 |
--------------------------------------------------------------------------------
/nanoT5/utils/copied_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 | import numpy as np
3 | from transformers import BatchEncoding
4 | from dataclasses import dataclass
5 | from transformers import AutoTokenizer
6 | import torch
7 | import math
8 | from torch.optim import Optimizer
9 | from typing import Iterable, Tuple
10 | from torch import nn
11 | import random
12 | import string
13 |
14 |
15 | @dataclass
16 | class DataCollatorForT5MLM:
17 | """
18 | [Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py]
19 | Data collator used for T5 span-masked language modeling.
20 | It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
21 | For more information on how T5 span-masked language modeling works, one can take a look
22 | at the `official paper `__
23 | or the `official code for preprocessing `__ .
24 | Args:
25 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
26 | The tokenizer used for encoding the data.
27 | noise_density (:obj:`float`):
28 | The probability with which to (randomly) mask tokens in the input.
29 | mean_noise_span_length (:obj:`float`):
30 | The average span length of the masked tokens.
31 | input_length (:obj:`int`):
32 | The expected input length after masking.
33 | target_length (:obj:`int`):
34 | The expected target length after masking.
35 | pad_token_id: (:obj:`int`):
36 | The pad token id of the model
37 | decoder_start_token_id: (:obj:`int):
38 | The decoder start token id of the model
39 | """
40 |
41 | tokenizer: AutoTokenizer
42 | noise_density: float
43 | mean_noise_span_length: float
44 | input_length: int
45 | target_length: int
46 | pad_token_id: int
47 |
48 | def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
49 | # convert list to dict and tensorize input
50 | batch = BatchEncoding(
51 | {
52 | k: np.array([examples[i][k] for i in range(len(examples))])
53 | for k, v in examples[0].items()
54 | }
55 | )
56 |
57 | input_ids = batch["input_ids"]
58 | batch_size, expandend_input_length = input_ids.shape
59 |
60 | mask_indices = np.asarray(
61 | [
62 | self.random_spans_noise_mask(expandend_input_length)
63 | for i in range(batch_size)
64 | ]
65 | )
66 | labels_mask = ~mask_indices
67 |
68 | input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
69 | labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
70 |
71 | batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
72 | batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
73 |
74 | if batch["input_ids"].shape[-1] != self.input_length:
75 | raise ValueError(
76 | f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
77 | f" should be {self.input_length}."
78 | )
79 |
80 | if batch["labels"].shape[-1] != self.target_length:
81 | raise ValueError(
82 | f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
83 | f" {self.target_length}."
84 | )
85 |
86 | batch = {k: torch.from_numpy(v) for k, v in batch.items()}
87 | return batch
88 |
89 | def create_sentinel_ids(self, mask_indices):
90 | """
91 | Sentinel ids creation given the indices that should be masked.
92 | The start indices of each mask are replaced by the sentinel ids in increasing
93 | order. Consecutive mask indices to be deleted are replaced with `-1`.
94 | """
95 | start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
96 | start_indices[:, 0] = mask_indices[:, 0]
97 |
98 | sentinel_ids = np.where(
99 | start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
100 | )
101 | sentinel_ids = np.where(
102 | sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0
103 | )
104 | sentinel_ids -= mask_indices - start_indices
105 |
106 | return sentinel_ids
107 |
108 | def filter_input_ids(self, input_ids, sentinel_ids):
109 | """
110 | Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
111 | This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
112 | """
113 | batch_size = input_ids.shape[0]
114 |
115 | input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
116 | # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
117 | # masked tokens coming after sentinel tokens and should be removed
118 | input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
119 | input_ids = np.concatenate(
120 | [
121 | input_ids,
122 | np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32),
123 | ],
124 | axis=-1,
125 | )
126 | return input_ids
127 |
128 | def random_spans_noise_mask(self, length):
129 | """This function is copy of `random_spans_helper `__ .
130 |
131 | Noise mask consisting of random spans of noise tokens.
132 | The number of noise tokens and the number of noise spans and non-noise spans
133 | are determined deterministically as follows:
134 | num_noise_tokens = round(length * noise_density)
135 | num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
136 | Spans alternate between non-noise and noise, beginning with non-noise.
137 | Subject to the above restrictions, all masks are equally likely.
138 |
139 | Args:
140 | length: an int32 scalar (length of the incoming token sequence)
141 | noise_density: a float - approximate density of output mask
142 | mean_noise_span_length: a number
143 |
144 | Returns:
145 | a boolean tensor with shape [length]
146 | """
147 |
148 | orig_length = length
149 |
150 | num_noise_tokens = int(np.round(length * self.noise_density))
151 | # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
152 | num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
153 | num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
154 |
155 | # avoid degeneracy by ensuring positive number of noise spans
156 | num_noise_spans = max(num_noise_spans, 1)
157 | num_nonnoise_tokens = length - num_noise_tokens
158 |
159 | # pick the lengths of the noise spans and the non-noise spans
160 | def _random_segmentation(num_items, num_segments):
161 | """Partition a sequence of items randomly into non-empty segments.
162 | Args:
163 | num_items: an integer scalar > 0
164 | num_segments: an integer scalar in [1, num_items]
165 | Returns:
166 | a Tensor with shape [num_segments] containing positive integers that add
167 | up to num_items
168 | """
169 | mask_indices = np.arange(num_items - 1) < (num_segments - 1)
170 | np.random.shuffle(mask_indices)
171 | first_in_segment = np.pad(mask_indices, [[1, 0]])
172 | segment_id = np.cumsum(first_in_segment)
173 | # count length of sub segments assuming that list is sorted
174 | _, segment_length = np.unique(segment_id, return_counts=True)
175 | return segment_length
176 |
177 | noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
178 | nonnoise_span_lengths = _random_segmentation(
179 | num_nonnoise_tokens, num_noise_spans
180 | )
181 |
182 | interleaved_span_lengths = np.reshape(
183 | np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
184 | [num_noise_spans * 2],
185 | )
186 | span_starts = np.cumsum(interleaved_span_lengths)[:-1]
187 | span_start_indicator = np.zeros((length,), dtype=np.int8)
188 | span_start_indicator[span_starts] = True
189 | span_num = np.cumsum(span_start_indicator)
190 | is_noise = np.equal(span_num % 2, 1)
191 |
192 | return is_noise[:orig_length]
193 |
194 |
195 | def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
196 | """This function is copy of `random_spans_helper `__ .
197 |
198 | [Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py]
199 | Training parameters to avoid padding with random_spans_noise_mask.
200 | When training a model with random_spans_noise_mask, we would like to set the other
201 | training hyperparmeters in a way that avoids padding.
202 | This function helps us compute these hyperparameters.
203 | We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
204 | and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
205 | This function tells us the required number of tokens in the raw example (for split_tokens())
206 | as well as the length of the encoded targets. Note that this function assumes
207 | the inputs and targets will have EOS appended and includes that in the reported length.
208 |
209 | Args:
210 | inputs_length: an integer - desired length of the tokenized inputs sequence
211 | noise_density: a float
212 | mean_noise_span_length: a float
213 | Returns:
214 | tokens_length: length of original text in tokens
215 | targets_length: an integer - length in tokens of encoded targets sequence
216 | """
217 |
218 | def _tokens_length_to_inputs_length_targets_length(tokens_length):
219 | num_noise_tokens = int(round(tokens_length * noise_density))
220 | num_nonnoise_tokens = tokens_length - num_noise_tokens
221 | num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
222 | # inputs contain all nonnoise tokens, sentinels for all noise spans
223 | # and one EOS token.
224 | _input_length = num_nonnoise_tokens + num_noise_spans + 1
225 | _output_length = num_noise_tokens + num_noise_spans + 1
226 | return _input_length, _output_length
227 |
228 | tokens_length = inputs_length
229 |
230 | while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
231 | tokens_length += 1
232 |
233 | inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
234 |
235 | # minor hack to get the targets length to be equal to inputs length
236 | # which is more likely to have been set to a nice round number.
237 | if noise_density == 0.5 and targets_length > inputs_length:
238 | tokens_length -= 1
239 | targets_length -= 1
240 | return tokens_length, targets_length
241 |
242 |
243 | class AdamWScale(Optimizer):
244 | """
245 | This AdamW implementation is copied from Huggingface.
246 | We modified it with Adagrad scaling by rms of a weight tensor
247 |
248 | Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
249 | Regularization](https://arxiv.org/abs/1711.05101).
250 |
251 | Parameters:
252 | params (`Iterable[nn.parameter.Parameter]`):
253 | Iterable of parameters to optimize or dictionaries defining parameter groups.
254 | lr (`float`, *optional*, defaults to 1e-3):
255 | The learning rate to use.
256 | betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
257 | Adam's betas parameters (b1, b2).
258 | eps (`float`, *optional*, defaults to 1e-6):
259 | Adam's epsilon for numerical stability.
260 | weight_decay (`float`, *optional*, defaults to 0):
261 | Decoupled weight decay to apply.
262 | correct_bias (`bool`, *optional*, defaults to `True`):
263 | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
264 | no_deprecation_warning (`bool`, *optional*, defaults to `False`):
265 | A flag used to disable the deprecation warning (set to `True` to disable the warning).
266 | """
267 |
268 | def __init__(
269 | self,
270 | params: Iterable[nn.parameter.Parameter],
271 | lr: float = 1e-3,
272 | betas: Tuple[float, float] = (0.9, 0.999),
273 | eps: float = 1e-6,
274 | weight_decay: float = 0.0,
275 | correct_bias: bool = True,
276 | ):
277 | if lr < 0.0:
278 | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
279 | if not 0.0 <= betas[0] < 1.0:
280 | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
281 | if not 0.0 <= betas[1] < 1.0:
282 | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
283 | if not 0.0 <= eps:
284 | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
285 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
286 | super().__init__(params, defaults)
287 |
288 | @staticmethod
289 | def _rms(tensor):
290 | return tensor.norm(2) / (tensor.numel() ** 0.5)
291 |
292 | def step(self, closure=None):
293 | """
294 | Performs a single optimization step.
295 |
296 | Arguments:
297 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
298 | """
299 | loss = None
300 | if closure is not None:
301 | loss = closure()
302 |
303 | for group in self.param_groups:
304 | for p in group["params"]:
305 | if p.grad is None:
306 | continue
307 | grad = p.grad.data
308 | if grad.is_sparse:
309 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
310 |
311 | state = self.state[p]
312 | beta1, beta2 = group["betas"]
313 |
314 | # State initialization
315 | if len(state) == 0:
316 | state["step"] = 0
317 | # Exponential moving average of gradient values
318 | state["exp_avg"] = torch.zeros_like(p.data)
319 | # Exponential moving average of squared gradient values
320 | state["exp_avg_sq"] = torch.zeros_like(p.data)
321 |
322 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
323 |
324 | state["step"] += 1
325 |
326 | # Decay the first and second moment running average coefficient
327 | # In-place operations to update the averages at the same time
328 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
329 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
330 | denom = exp_avg_sq.sqrt().add_(group["eps"])
331 |
332 | step_size = group["lr"]
333 | if group["correct_bias"]: # No bias correction for Bert
334 | bias_correction1 = 1.0 - beta1 ** state["step"]
335 | bias_correction2 = 1.0 - beta2 ** state["step"]
336 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
337 |
338 | # /Adapt Step from Adafactor
339 | step_size = step_size * max(1e-3, self._rms(p.data))
340 | # /Adapt Step from Adafactor
341 |
342 | p.data.addcdiv_(exp_avg, denom, value=-step_size)
343 |
344 | # Just adding the square of the weights to the loss function is *not*
345 | # the correct way of using L2 regularization/weight decay with Adam,
346 | # since that will interact with the m and v parameters in strange ways.
347 | #
348 | # Instead we want to decay the weights in a manner that doesn't interact
349 | # with the m/v parameters. This is equivalent to adding the square
350 | # of the weights to the loss with plain (non-momentum) SGD.
351 | # Add weight decay at the end (fixed version)
352 | if group["weight_decay"] > 0.0:
353 | p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"]))
354 |
355 | return loss
356 |
357 |
358 | def tokenize_function(examples, tokenizer, in_length):
359 | tokenizer_out = tokenizer(
360 | text=examples["text"],
361 | return_attention_mask=False,
362 | )
363 |
364 | input_ids = tokenizer_out["input_ids"]
365 |
366 | concatenated_ids = np.concatenate(input_ids)
367 |
368 | total_length = concatenated_ids.shape[0]
369 | total_length = (total_length // in_length) * in_length
370 |
371 | concatenated_ids = concatenated_ids[:total_length].reshape(-1, in_length)
372 | result = {"input_ids": concatenated_ids}
373 |
374 | return result
375 |
376 |
377 | from transformers.data.data_collator import *
378 | @dataclass
379 | class DataCollatorForNI:
380 | tokenizer: PreTrainedTokenizerBase
381 | padding: Union[bool, str, PaddingStrategy] = True
382 | max_source_length: Optional[int] = None
383 | max_target_length: Optional[int] = None
384 | pad_to_multiple_of: Optional[int] = None
385 | label_pad_token_id: int = -100
386 | return_tensors: str = "pt"
387 | add_task_name: bool = False
388 | add_task_definition: bool = True
389 | num_pos_examples: int = 0
390 | num_neg_examples: int = 0
391 | add_explanation: bool = False
392 | tk_instruct: bool = False
393 | text_only: bool = False
394 |
395 | def __call__(self, batch, return_tensors=None):
396 |
397 | if return_tensors is None:
398 | return_tensors = self.return_tensors
399 |
400 | sources = []
401 | for instance in batch:
402 | if self.tk_instruct:
403 | all_valid_encodings = [
404 | # instruction only
405 | {
406 | "add_task_name": False,
407 | "add_task_definition": True,
408 | "num_pos_examples": 0,
409 | "num_neg_examples": 0,
410 | "add_explanation": False,
411 | },
412 | # example only
413 | {
414 | "add_task_name": False,
415 | "add_task_definition": False,
416 | "num_pos_examples": 2,
417 | "num_neg_examples": 0,
418 | "add_explanation": False,
419 | },
420 | # instruction + pos examples
421 | {
422 | "add_task_name": False,
423 | "add_task_definition": True,
424 | "num_pos_examples": 2,
425 | "num_neg_examples": 0,
426 | "add_explanation": False,
427 | },
428 | # instruction + pos examples + neg examples
429 | {
430 | "add_task_name": False,
431 | "add_task_definition": True,
432 | "num_pos_examples": 2,
433 | "num_neg_examples": 2,
434 | "add_explanation": False,
435 | },
436 | # instruction + pos (w. explanation)
437 | {
438 | "add_task_name": False,
439 | "add_task_definition": True,
440 | "num_pos_examples": 2,
441 | "num_neg_examples": 0,
442 | "add_explanation": True,
443 | },
444 | ]
445 | encoding_schema = random.choice(all_valid_encodings)
446 | add_task_name = encoding_schema["add_task_name"]
447 | add_task_definition = encoding_schema["add_task_definition"]
448 | num_pos_examples = encoding_schema["num_pos_examples"]
449 | num_neg_examples = encoding_schema["num_neg_examples"]
450 | add_explanation = encoding_schema["add_explanation"]
451 | else:
452 | add_task_name = self.add_task_name
453 | add_task_definition = self.add_task_definition
454 | num_pos_examples = self.num_pos_examples
455 | num_neg_examples = self.num_neg_examples
456 | add_explanation = self.add_explanation
457 |
458 | task_input = ""
459 | # add the input first.
460 | task_input += "Now complete the following example -\n"
461 | task_input += f"Input: {instance['Instance']['input'].strip()}"
462 | if not task_input[-1] in string.punctuation:
463 | task_input += "."
464 | task_input += "\n"
465 | task_input += "Output: "
466 |
467 | task_name = ""
468 | if add_task_name:
469 | task_name += instance["Task"] + ". "
470 |
471 | definition = ""
472 | if add_task_definition:
473 | if isinstance(instance["Definition"], list):
474 | definition = (
475 | "Definition: " + instance["Definition"][0].strip()
476 | )
477 | else:
478 | definition = "Definition: " + instance["Definition"].strip()
479 | if not definition[-1] in string.punctuation:
480 | definition += "."
481 | definition += "\n\n"
482 |
483 | # try to add positive examples.
484 | pos_examples = []
485 | for idx, pos_example in enumerate(
486 | instance["Positive Examples"][:num_pos_examples]
487 | ):
488 | pos_example_str = f" Positive Example {idx+1} -\n"
489 | pos_example_str += f"Input: {pos_example['input'].strip()}"
490 | if not pos_example_str[-1] in string.punctuation:
491 | pos_example_str += "."
492 | pos_example_str += "\n"
493 | pos_example_str += f" Output: {pos_example['output'].strip()}"
494 | if not pos_example_str[-1] in string.punctuation:
495 | pos_example_str += "."
496 | pos_example_str += "\n"
497 | if add_explanation and "explanation" in pos_example:
498 | pos_example_str += (
499 | f" Explanation: {pos_example['explanation'].strip()}"
500 | )
501 | if not pos_example_str[-1] in string.punctuation:
502 | pos_example_str += "."
503 | pos_example_str += "\n"
504 | pos_example_str += "\n"
505 | if (
506 | len(
507 | self.tokenizer(
508 | definition
509 | + " ".join(pos_examples)
510 | + pos_example_str
511 | + task_input
512 | )["input_ids"]
513 | )
514 | <= self.max_source_length
515 | ):
516 | pos_examples.append(pos_example_str)
517 | else:
518 | break
519 |
520 | # try to add negative examples.
521 | neg_examples = []
522 | for idx, neg_example in enumerate(
523 | instance["Negative Examples"][:num_neg_examples]
524 | ):
525 | neg_example_str = f" Negative Example {idx+1} -\n"
526 | neg_example_str += f"Input: {neg_example['input'].strip()}"
527 | if not neg_example_str[-1] in string.punctuation:
528 | neg_example_str += "."
529 | neg_example_str += "\n"
530 | neg_example_str += f" Output: {neg_example['output'].strip()}"
531 | if not neg_example_str[-1] in string.punctuation:
532 | neg_example_str += "."
533 | neg_example_str += "\n"
534 | if add_explanation and "explanation" in neg_example:
535 | neg_example_str += (
536 | f" Explanation: {neg_example['explanation'].strip()}"
537 | )
538 | if not neg_example_str[-1] in string.punctuation:
539 | neg_example_str += "."
540 | neg_example_str += "\n"
541 | neg_example_str += "\n"
542 | if (
543 | len(
544 | self.tokenizer(
545 | definition
546 | + " ".join(pos_examples)
547 | + " ".join(neg_examples)
548 | + neg_example_str
549 | + task_input
550 | )["input_ids"]
551 | )
552 | <= self.max_source_length
553 | ):
554 | neg_examples.append(neg_example_str)
555 | else:
556 | break
557 |
558 | source = (
559 | task_name
560 | + definition
561 | + "".join(pos_examples)
562 | + "".join(neg_examples)
563 | + task_input
564 | )
565 | tokenized_source = self.tokenizer(source)["input_ids"]
566 | if len(tokenized_source) <= self.max_source_length:
567 | sources.append(source)
568 | else:
569 | sources.append(
570 | self.tokenizer.decode(
571 | tokenized_source[: self.max_source_length],
572 | skip_special_tokens=True,
573 | )
574 | )
575 |
576 | if self.text_only:
577 | model_inputs = {"inputs": sources}
578 | else:
579 | model_inputs = self.tokenizer(
580 | sources,
581 | max_length=self.max_source_length,
582 | padding=self.padding,
583 | return_tensors=self.return_tensors,
584 | truncation=True,
585 | pad_to_multiple_of=self.pad_to_multiple_of,
586 | )
587 |
588 | if "output" in batch[0]["Instance"] and batch[0]["Instance"]["output"]:
589 | # Randomly select one reference if multiple are provided.
590 | labels = [random.choice(ex["Instance"]["output"]) for ex in batch]
591 | if self.text_only:
592 | model_inputs["labels"] = labels
593 | else:
594 | labels = self.tokenizer(
595 | labels,
596 | max_length=self.max_target_length,
597 | padding=self.padding,
598 | return_tensors=self.return_tensors,
599 | truncation=True,
600 | pad_to_multiple_of=self.pad_to_multiple_of,
601 | )
602 | label_mask = labels["attention_mask"].bool()
603 | model_inputs["labels"] = labels["input_ids"].masked_fill(
604 | ~label_mask, self.label_pad_token_id
605 | )
606 | else:
607 | model_inputs["labels"] = None
608 |
609 | return model_inputs
610 |
--------------------------------------------------------------------------------
/nanoT5/utils/gen_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | from accelerate.utils import set_seed
5 | from omegaconf import open_dict
6 | from .logging_utils import Logger
7 | from hydra.utils import to_absolute_path
8 |
9 |
10 | def check_args_and_env(args):
11 | assert args.optim.batch_size % args.optim.grad_acc == 0
12 |
13 | # Train log must happen before eval log
14 | assert args.eval.every_steps % args.logging.every_steps == 0
15 |
16 | if args.device == 'gpu':
17 | assert torch.cuda.is_available(), 'We use GPU to train/eval the model'
18 |
19 | assert not (args.eval_only and args.predict_only)
20 |
21 | if args.predict_only:
22 | assert args.mode == 'ft'
23 |
24 |
25 | def opti_flags(args):
26 | # This lines reduce training step by 2.4x
27 | torch.backends.cuda.matmul.allow_tf32 = True
28 | torch.backends.cudnn.allow_tf32 = True
29 |
30 | if args.precision == 'bf16' and args.device == 'gpu' and args.model.klass == 'local_t5':
31 | args.model.add_config.is_bf16 = True
32 |
33 |
34 | def update_args_with_env_info(args):
35 | with open_dict(args):
36 | slurm_id = os.getenv('SLURM_JOB_ID')
37 |
38 | if slurm_id is not None:
39 | args.slurm_id = slurm_id
40 | else:
41 | args.slurm_id = 'none'
42 |
43 | args.working_dir = os.getcwd()
44 |
45 |
46 | def update_paths(args):
47 | if args.mode == 'ft':
48 | args.data.exec_file_path = to_absolute_path(args.data.exec_file_path)
49 | args.data.data_dir = to_absolute_path(args.data.data_dir)
50 | args.data.task_dir = to_absolute_path(args.data.task_dir)
51 |
52 |
53 | def setup_basics(accelerator, args):
54 | check_args_and_env(args)
55 | update_args_with_env_info(args)
56 | update_paths(args)
57 | opti_flags(args)
58 |
59 | if args.seed is not None:
60 | set_seed(args.seed)
61 |
62 | logger = Logger(args=args, accelerator=accelerator)
63 |
64 | return logger
65 |
--------------------------------------------------------------------------------
/nanoT5/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | from accelerate.logging import get_logger
4 | from omegaconf import OmegaConf, open_dict
5 | import logging
6 | import datasets
7 | import transformers
8 | import neptune
9 | import os
10 |
11 |
12 | class Averager:
13 | def __init__(self, weight: float = 1):
14 | self.weight = weight
15 | self.reset()
16 |
17 | def reset(self):
18 | self.total = defaultdict(float)
19 | self.counter = defaultdict(float)
20 |
21 | def update(self, stats):
22 | for key, value in stats.items():
23 | self.total[key] = self.total[key] * self.weight + value * self.weight
24 | self.counter[key] = self.counter[key] * self.weight + self.weight
25 |
26 | def average(self):
27 | averaged_stats = {
28 | key: tot / self.counter[key] for key, tot in self.total.items()
29 | }
30 | self.reset()
31 |
32 | return averaged_stats
33 |
34 |
35 | class Logger:
36 | def __init__(self, args, accelerator):
37 | self.logger = get_logger('Main')
38 |
39 | # Make one log on every process with the configuration for debugging.
40 | logging.basicConfig(
41 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
42 | datefmt="%m/%d/%Y %H:%M:%S",
43 | level=logging.INFO,
44 | )
45 | self.logger.info(accelerator.state, main_process_only=False)
46 | self.logger.info(f'Working directory is {os.getcwd()}')
47 |
48 | if accelerator.is_local_main_process:
49 | datasets.utils.logging.set_verbosity_warning()
50 | transformers.utils.logging.set_verbosity_info()
51 | else:
52 | datasets.utils.logging.set_verbosity_error()
53 | transformers.utils.logging.set_verbosity_error()
54 |
55 | self.setup_neptune(args)
56 |
57 | def setup_neptune(self, args):
58 | if args.logging.neptune:
59 | neptune_logger = neptune.init_run(
60 | project=args.logging.neptune_creds.project,
61 | api_token=args.logging.neptune_creds.api_token,
62 | tags=[str(item) for item in args.logging.neptune_creds.tags.split(",")],
63 | )
64 | else:
65 | neptune_logger = None
66 |
67 | self.neptune_logger = neptune_logger
68 |
69 | with open_dict(args):
70 | if neptune_logger is not None:
71 | args.neptune_id = neptune_logger["sys/id"].fetch()
72 |
73 | def log_args(self, args):
74 | if self.neptune_logger is not None:
75 | logging_args = OmegaConf.to_container(args, resolve=True)
76 | self.neptune_logger['args'] = logging_args
77 |
78 | def log_stats(self, stats, step, args, prefix=''):
79 | if self.neptune_logger is not None:
80 | for k, v in stats.items():
81 | self.neptune_logger[f'{prefix}{k}'].log(v, step=step)
82 |
83 | msg_start = f'[{prefix[:-1]}] Step {step} out of {args.optim.total_steps}' + ' | '
84 | dict_msg = ' | '.join([f'{k.capitalize()} --> {v:.3f}' for k, v in stats.items()]) + ' | '
85 |
86 | msg = msg_start + dict_msg
87 |
88 | self.log_message(msg)
89 |
90 | def log_message(self, msg):
91 | self.logger.info(msg)
92 |
93 | def finish(self):
94 | if self.neptune_logger is not None:
95 | self.neptune_logger.stop()
96 |
--------------------------------------------------------------------------------
/nanoT5/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datasets
3 | from torch.utils.data import DataLoader
4 | from omegaconf import open_dict
5 | from datasets.iterable_dataset import IterableDataset
6 | from transformers import (
7 | AutoTokenizer,
8 | T5ForConditionalGeneration,
9 | AutoConfig,
10 | )
11 |
12 | from .copied_utils import (
13 | compute_input_and_target_lengths,
14 | DataCollatorForT5MLM,
15 | tokenize_function,
16 | DataCollatorForNI,
17 | )
18 | from .t5_model import MyT5
19 |
20 |
21 | def get_model(args, config):
22 | klass = {
23 | 'hf_t5': T5ForConditionalGeneration,
24 | 'local_t5': MyT5,
25 | }[args.model.klass]
26 |
27 | if args.model.checkpoint_path:
28 | model = klass(config)
29 | model.load_state_dict(torch.load(args.model.checkpoint_path))
30 | elif args.model.random_init:
31 | model = klass(config)
32 | else:
33 | assert klass == T5ForConditionalGeneration, 'To load HFs weights you need to use HF model'
34 | model = klass.from_pretrained(
35 | args.model.name,
36 | config=config,
37 | )
38 |
39 | with open_dict(args):
40 | args.n_all_param = sum([p.nelement() for p in model.parameters()])
41 |
42 | return model
43 |
44 |
45 | def get_config(args):
46 | config = AutoConfig.from_pretrained(
47 | args.model.name,
48 | )
49 |
50 | if hasattr(args.model, 'overwrite'):
51 | for k, v in args.model.overwrite.items():
52 | assert hasattr(config, k), f'config does not have attribute {k}'
53 | setattr(config, k, v)
54 |
55 | if hasattr(args.model, 'add_config'):
56 | for k, v in args.model.add_config.items():
57 | assert not hasattr(config, k), f'config already has attribute {k}'
58 | setattr(config, k, v)
59 |
60 | return config
61 |
62 |
63 | def get_tokenizer(args):
64 | tokenizer = AutoTokenizer.from_pretrained(
65 | args.model.name,
66 | use_fast=True
67 | )
68 | tokenizer.model_max_length = int(1e9)
69 |
70 | return tokenizer
71 |
72 |
73 | def load_dataset_splits(args):
74 | if args.mode == 'pt':
75 | dataset = datasets.load_dataset(
76 | 'c4',
77 | 'en',
78 | streaming=True,
79 | )
80 |
81 | dataset = dataset.remove_columns(
82 | ['timestamp', 'url']
83 | )
84 |
85 | dataset_splits = {
86 | 'train': dataset['train'],
87 | 'test': dataset['validation'],
88 | }
89 |
90 | assert (
91 | dataset['train'].n_shards == 1024
92 | ), "We want to have many shards for efficient processing with num_workes in PyTorch dataloader"
93 | elif args.mode == 'ft':
94 | dataset_splits = datasets.load_dataset(
95 | args.data.exec_file_path,
96 | data_dir=args.data.data_dir,
97 | task_dir=args.data.task_dir,
98 | max_num_instances_per_task=args.data.max_num_instances_per_task,
99 | max_num_instances_per_eval_task=args.data.max_num_instances_per_task
100 | )
101 | else:
102 | raise NotImplementedError
103 |
104 | return dataset_splits
105 |
106 |
107 | def process_dataset(dataset_splits, args, tokenizer):
108 | if args.mode == 'pt':
109 | final_datasets = {}
110 |
111 | for split, dataset_split in dataset_splits.items():
112 |
113 | # We increase the input_length, because instead of masking tokens T5 replaces
114 | # masked spans with a single token, therefore to avoid padding we need to have
115 | # longer sequences at the start, before masking
116 | before_mask_input_length, target_length = compute_input_and_target_lengths(
117 | inputs_length=args.data.input_length,
118 | noise_density=args.data.mlm_probability,
119 | mean_noise_span_length=args.data.mean_noise_span_length,
120 | )
121 |
122 | with open_dict(args):
123 | args.data.before_mask_input_length = before_mask_input_length
124 | args.data.target_length = target_length
125 |
126 | dataset_split = dataset_split.map(
127 | tokenize_function,
128 | batched=True,
129 | fn_kwargs={
130 | 'tokenizer': tokenizer,
131 | 'in_length': before_mask_input_length,
132 | },
133 | remove_columns=['text'],
134 | )
135 |
136 | dataset_split = dataset_split.shuffle(buffer_size=10_000, seed=args.seed)
137 | final_datasets[split] = dataset_split
138 | elif args.mode == 'ft':
139 | final_datasets = dataset_splits
140 | else:
141 | raise NotImplementedError
142 |
143 | return final_datasets
144 |
145 |
146 | def get_data_collator(tokenizer, config, args):
147 | if args.mode == 'pt':
148 | data_collator = DataCollatorForT5MLM(
149 | tokenizer=tokenizer,
150 | noise_density=args.data.mlm_probability,
151 | mean_noise_span_length=args.data.mean_noise_span_length,
152 | input_length=args.data.input_length,
153 | target_length=args.data.target_length,
154 | pad_token_id=config.pad_token_id,
155 | )
156 | elif args.mode == 'ft':
157 | data_collator = DataCollatorForNI(
158 | tokenizer,
159 | padding="longest",
160 | max_source_length=args.data.max_seq_len,
161 | max_target_length=args.data.max_target_len,
162 | label_pad_token_id=-100,
163 | pad_to_multiple_of=8,
164 | add_task_name=args.data.add_task_name,
165 | add_task_definition=args.data.add_task_definition,
166 | num_pos_examples=args.data.num_pos_examples,
167 | num_neg_examples=args.data.num_neg_examples,
168 | add_explanation=args.data.add_explanation,
169 | tk_instruct=args.data.tk_instruct
170 | )
171 | else:
172 | raise NotImplementedError
173 |
174 | return data_collator
175 |
176 |
177 | def get_dataloaders(tokenizer, config, args):
178 | dataset_splits = load_dataset_splits(args)
179 | dataset = process_dataset(dataset_splits=dataset_splits, args=args, tokenizer=tokenizer)
180 | data_collator = get_data_collator(tokenizer=tokenizer, config=config,
181 | args=args)
182 |
183 | is_iterable = isinstance(dataset['train'], IterableDataset)
184 |
185 | dataloaders = {}
186 |
187 | for split in ['train', 'test']:
188 | batch_size = args.optim.batch_size // args.optim.grad_acc
189 |
190 | shuffle = (split == 'train') and not is_iterable
191 |
192 | if args.mode == 'ft' and split == 'train':
193 | assert shuffle is True
194 | else:
195 | assert shuffle is False
196 |
197 | dataloaders[split] = DataLoader(
198 | dataset[split],
199 | shuffle=shuffle,
200 | collate_fn=data_collator,
201 | batch_size=batch_size,
202 | num_workers=args.data.num_workers,
203 | pin_memory=True,
204 | drop_last=False,
205 | )
206 |
207 | # Add & Check args about data loaders
208 | with open_dict(args):
209 | if not is_iterable:
210 | args.data.train_batches = len(dataloaders['train'])
211 | args.data.test_batches = len(dataloaders['test'])
212 |
213 | if args.optim.epochs > 0:
214 | assert not is_iterable
215 | args.optim.total_steps = (len(dataloaders['train']) // args.optim.grad_acc) * args.optim.epochs
216 |
217 | args.eval.corrected_steps = args.eval.steps
218 |
219 | return dataloaders['train'], dataloaders['test']
220 |
221 |
222 | def get_optimizer(model, args):
223 | no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"]
224 |
225 | optimizer_grouped_parameters = [
226 | {
227 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
228 | "weight_decay": args.optim.weight_decay,
229 | },
230 | {
231 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
232 | "weight_decay": 0.0,
233 | },
234 | ]
235 |
236 | if args.optim.name == 'adamw':
237 | from transformers import AdamW
238 | optimizer = AdamW(
239 | optimizer_grouped_parameters,
240 | lr=args.optim.base_lr,
241 | )
242 | elif args.optim.name == 'adamwscale':
243 | from .copied_utils import AdamWScale
244 | optimizer = AdamWScale(
245 | optimizer_grouped_parameters,
246 | lr=args.optim.base_lr,
247 | )
248 | elif args.optim.name == 'adafactor':
249 | from transformers import Adafactor
250 | optimizer = Adafactor(
251 | optimizer_grouped_parameters,
252 | lr=args.optim.base_lr,
253 | relative_step=False,
254 | )
255 | else:
256 | raise NotImplementedError
257 |
258 | return optimizer
259 |
260 |
261 | def get_lr_scheduler(optimizer, args, logger):
262 | if args.optim.lr_scheduler == 'cosine':
263 | from torch.optim.lr_scheduler import (
264 | SequentialLR,
265 | LinearLR,
266 | CosineAnnealingLR,
267 | )
268 |
269 | scheduler1 = LinearLR(
270 | optimizer,
271 | start_factor=0.5,
272 | end_factor=1,
273 | total_iters=args.optim.warmup_steps,
274 | last_epoch=-1,
275 | )
276 |
277 | scheduler2 = CosineAnnealingLR(
278 | optimizer,
279 | T_max=args.optim.total_steps - args.optim.warmup_steps,
280 | eta_min=args.optim.final_cosine,
281 | )
282 |
283 | lr_scheduler = SequentialLR(
284 | optimizer,
285 | schedulers=[scheduler1, scheduler2],
286 | milestones=[args.optim.warmup_steps]
287 | )
288 | elif args.optim.lr_scheduler == 'legacy':
289 | import math
290 | from torch.optim.lr_scheduler import (
291 | SequentialLR,
292 | LinearLR,
293 | LambdaLR,
294 | )
295 |
296 | msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr"
297 | logger.log_message(msg)
298 |
299 | num_steps_optimizer1 = math.ceil(args.optim.total_steps * 0.9)
300 | iters_left_for_optimizer2 = args.optim.total_steps - num_steps_optimizer1
301 |
302 | scheduler1 = LambdaLR(
303 | optimizer,
304 | lambda step: min(
305 | 1e-2, 1.0 / math.sqrt(step)
306 | ) / args.optim.base_lr if step else 1e-2 / args.optim.base_lr
307 | )
308 |
309 | scheduler2 = LinearLR(
310 | optimizer,
311 | start_factor=(
312 | min(1e-2, 1.0 / math.sqrt(num_steps_optimizer1)) / args.optim.base_lr
313 | ),
314 | end_factor=0,
315 | total_iters=iters_left_for_optimizer2,
316 | last_epoch=-1,
317 | )
318 |
319 | lr_scheduler = SequentialLR(
320 | optimizer,
321 | schedulers=[scheduler1, scheduler2],
322 | milestones=[num_steps_optimizer1]
323 | )
324 | elif args.optim.lr_scheduler == 'constant':
325 | from transformers import get_scheduler
326 | lr_scheduler = get_scheduler(
327 | name=args.optim.lr_scheduler,
328 | optimizer=optimizer,
329 | )
330 | else:
331 | raise NotImplementedError
332 |
333 | return lr_scheduler
334 |
--------------------------------------------------------------------------------
/nanoT5/utils/ni_dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Natural Instruction V2 Dataset."""
18 |
19 |
20 | import json
21 | import os
22 | import random
23 | import datasets
24 |
25 | logger = datasets.logging.get_logger(__name__)
26 |
27 | _CITATION = """
28 | @article{wang2022benchmarking,
29 | title={Benchmarking Generalization via In-Context Instructions on 1,600+ Language Tasks},
30 | author={Wang, Yizhong and Mishra, Swaroop and Alipoormolabashi, Pegah and Kordi, Yeganeh and others},
31 | journal={arXiv preprint arXiv:2204.07705},
32 | year={2022}
33 | }
34 | """
35 |
36 | _DESCRIPTION = """
37 | Natural-Instructions v2 is a benchmark of 1,600+ diverse language tasks and their expert-written instructions.
38 | It covers 70+ distinct task types, such as tagging, in-filling, and rewriting.
39 | These tasks are collected with contributions of NLP practitioners in the community and
40 | through an iterative peer review process to ensure their quality.
41 | """
42 |
43 | _URL = "https://instructions.apps.allenai.org/"
44 |
45 | class NIConfig(datasets.BuilderConfig):
46 | def __init__(self, *args, task_dir=None, max_num_instances_per_task=None, max_num_instances_per_eval_task=None, **kwargs):
47 | super().__init__(*args, **kwargs)
48 | self.task_dir: str = task_dir
49 | self.max_num_instances_per_task: int = max_num_instances_per_task
50 | self.max_num_instances_per_eval_task: int = max_num_instances_per_eval_task
51 |
52 |
53 | class NaturalInstructions(datasets.GeneratorBasedBuilder):
54 | """NaturalInstructions Dataset."""
55 |
56 | VERSION = datasets.Version("2.0.0")
57 | BUILDER_CONFIG_CLASS = NIConfig
58 | BUILDER_CONFIGS = [
59 | NIConfig(name="default", description="Default config for NaturalInstructions")
60 | ]
61 | DEFAULT_CONFIG_NAME = "default"
62 |
63 | def _info(self):
64 | return datasets.DatasetInfo(
65 | description=_DESCRIPTION,
66 | features=datasets.Features(
67 | {
68 | "id": datasets.Value("string"),
69 | "Task": datasets.Value("string"),
70 | "Contributors": datasets.Value("string"),
71 | "Source": [datasets.Value("string")],
72 | "URL": [datasets.Value("string")],
73 | "Categories": [datasets.Value("string")],
74 | "Reasoning": [datasets.Value("string")],
75 | "Definition": [datasets.Value("string")],
76 | "Positive Examples": [{
77 | "input": datasets.Value("string"),
78 | "output": datasets.Value("string"),
79 | "explanation": datasets.Value("string")
80 | }],
81 | "Negative Examples": [{
82 | "input": datasets.Value("string"),
83 | "output": datasets.Value("string"),
84 | "explanation": datasets.Value("string")
85 | }],
86 | "Input_language": [datasets.Value("string")],
87 | "Output_language": [datasets.Value("string")],
88 | "Instruction_language": [datasets.Value("string")],
89 | "Domains": [datasets.Value("string")],
90 | # "Instances": [{
91 | # "input": datasets.Value("string"),
92 | # "output": [datasets.Value("string")]
93 | # }],
94 | "Instance": {
95 | "id": datasets.Value("string"),
96 | "input": datasets.Value("string"),
97 | "output": [datasets.Value("string")]
98 | },
99 | "Instance License": [datasets.Value("string")]
100 | }
101 | ),
102 | supervised_keys=None,
103 | homepage="https://github.com/allenai/natural-instructions",
104 | citation=_CITATION,
105 | )
106 |
107 | def _split_generators(self, dl_manager):
108 | """Returns SplitGenerators."""
109 | if self.config.data_dir is None or self.config.task_dir is None:
110 | dl_path = dl_manager.download_and_extract(_URL)
111 | self.config.data_dir = self.config.data_dir or os.path.join(dl_path, "splits")
112 | self.config.task_dir = self.config.task_dir or os.path.join(dl_path, "tasks")
113 |
114 | split_dir = self.config.data_dir
115 | task_dir = self.config.task_dir
116 |
117 | return [
118 | datasets.SplitGenerator(
119 | name=datasets.Split.TRAIN,
120 | gen_kwargs={
121 | "path": os.path.join(split_dir, "train_tasks.txt"),
122 | "task_dir": task_dir,
123 | "max_num_instances_per_task": self.config.max_num_instances_per_task,
124 | "subset": "train"
125 | }),
126 | # datasets.SplitGenerator(
127 | # name=datasets.Split.VALIDATION,
128 | # gen_kwargs={
129 | # "path": os.path.join(split_dir, "dev_tasks.txt"),
130 | # "task_dir": task_dir,
131 | # "max_num_instances_per_task": self.config.max_num_instances_per_eval_task,
132 | # "subset": "dev"
133 | # }),
134 | datasets.SplitGenerator(
135 | name=datasets.Split.TEST,
136 | gen_kwargs={
137 | "path": os.path.join(split_dir, "test_tasks.txt"),
138 | "task_dir": task_dir,
139 | "max_num_instances_per_task": self.config.max_num_instances_per_eval_task,
140 | "subset": "test"
141 | }),
142 | ]
143 |
144 | def _generate_examples(self, path=None, task_dir=None, max_num_instances_per_task=None, subset=None):
145 | """Yields examples."""
146 | logger.info(f"Generating tasks from = {path}")
147 | with open(path, encoding="utf-8") as split_f:
148 | for line in split_f:
149 | task_name = line.strip()
150 | task_path = os.path.join(task_dir, task_name + ".json")
151 | with open(task_path, encoding="utf-8") as task_f:
152 | s = task_f.read()
153 | task_data = json.loads(s)
154 | task_data["Task"] = task_name
155 | if "Instruction Source" in task_data:
156 | task_data.pop("Instruction Source")
157 | all_instances = task_data.pop("Instances")
158 | if subset == "test":
159 | # for testing tasks, 100 instances are selected for efficient evaluation and they are label-balanced.
160 | # we put them in the first for reproducibility.
161 | # so, we use them here
162 | instances = all_instances[:100]
163 | else:
164 | instances = all_instances
165 | if max_num_instances_per_task is not None and max_num_instances_per_task >= 0:
166 | random.shuffle(instances)
167 | instances = instances[:max_num_instances_per_task]
168 | for idx, instance in enumerate(instances):
169 | example = task_data.copy()
170 | example["id"] = instance["id"]
171 | example["Instance"] = instance
172 | yield f"{task_name}_{idx}", example
173 |
174 |
--------------------------------------------------------------------------------
/nanoT5/utils/t5_model.py:
--------------------------------------------------------------------------------
1 | # From: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
2 |
3 | import copy
4 | import math
5 | from typing import Optional
6 | from dataclasses import dataclass
7 |
8 | import torch
9 | from torch import nn
10 | from torch.nn import CrossEntropyLoss
11 |
12 | from transformers.modeling_utils import ModuleUtilsMixin
13 | from transformers.modeling_outputs import ModelOutput
14 | from transformers.models.t5.configuration_t5 import T5Config
15 | from transformers.models.t5.modeling_t5 import (
16 | T5LayerNorm,
17 | T5DenseGatedActDense,
18 | )
19 |
20 |
21 | @dataclass
22 | class EncoderOutput(ModelOutput):
23 | hidden_states: torch.FloatTensor = None
24 | attention_mask: torch.FloatTensor = None
25 |
26 |
27 | @dataclass
28 | class Seq2SeqLMOutput(ModelOutput):
29 | loss: torch.FloatTensor = None
30 | logits: torch.FloatTensor = None
31 | encoder_outputs: EncoderOutput = None
32 |
33 |
34 | class T5LayerFF(nn.Module):
35 | def __init__(self, config: T5Config):
36 | super().__init__()
37 | assert config.is_gated_act
38 | self.DenseReluDense = T5DenseGatedActDense(config)
39 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
40 | self.dropout = nn.Dropout(config.dropout_rate)
41 |
42 | def forward(self, hidden_states):
43 | forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states)
44 | forwarded_states = self.DenseReluDense(forwarded_states)
45 | hidden_states = hidden_states + self.dropout(forwarded_states)
46 | return hidden_states
47 |
48 |
49 | class T5Attention(nn.Module):
50 | def __init__(self, config: T5Config, has_relative_attention_bias=False):
51 | super().__init__()
52 | self.is_decoder = config.is_decoder
53 | self.has_relative_attention_bias = has_relative_attention_bias
54 | self.relative_attention_num_buckets = config.relative_attention_num_buckets
55 | self.relative_attention_max_distance = config.relative_attention_max_distance
56 | self.d_model = config.d_model
57 | self.key_value_proj_dim = config.d_kv
58 | self.n_heads = config.num_heads
59 | self.dropout = config.dropout_rate
60 | self.inner_dim = self.n_heads * self.key_value_proj_dim
61 |
62 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
63 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
64 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
65 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
66 |
67 | if self.has_relative_attention_bias:
68 | self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
69 |
70 | @staticmethod
71 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
72 | """
73 | Adapted from Mesh Tensorflow:
74 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
75 |
76 | Translate relative position to a bucket number for relative attention. The relative position is defined as
77 | memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
78 | position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
79 | small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
80 | positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
81 | This should allow for more graceful generalization to longer sequences than the model has been trained on
82 |
83 | Args:
84 | relative_position: an int32 Tensor
85 | bidirectional: a boolean - whether the attention is bidirectional
86 | num_buckets: an integer
87 | max_distance: an integer
88 |
89 | Returns:
90 | a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
91 | """
92 | relative_buckets = 0
93 | if bidirectional:
94 | num_buckets //= 2
95 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
96 | relative_position = torch.abs(relative_position)
97 | else:
98 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
99 | # now relative_position is in the range [0, inf)
100 |
101 | # half of the buckets are for exact increments in positions
102 | max_exact = num_buckets // 2
103 | is_small = relative_position < max_exact
104 |
105 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
106 | relative_position_if_large = max_exact + (
107 | torch.log(relative_position.float() / max_exact)
108 | / math.log(max_distance / max_exact)
109 | * (num_buckets - max_exact)
110 | ).to(torch.long)
111 | relative_position_if_large = torch.min(
112 | relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
113 | )
114 |
115 | relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
116 | return relative_buckets
117 |
118 | def compute_bias(self, query_length, key_length, device=None):
119 | """Compute binned relative position bias"""
120 | if device is None:
121 | device = self.relative_attention_bias.weight.device
122 | context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
123 | memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
124 | relative_position = memory_position - context_position # shape (query_length, key_length)
125 | relative_position_bucket = self._relative_position_bucket(
126 | relative_position, # shape (query_length, key_length)
127 | bidirectional=(not self.is_decoder),
128 | num_buckets=self.relative_attention_num_buckets,
129 | max_distance=self.relative_attention_max_distance,
130 | )
131 | values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
132 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
133 | return values
134 |
135 | def forward(
136 | self,
137 | hidden_states,
138 | mask=None,
139 | key_value_states=None,
140 | position_bias=None,
141 | ):
142 | """
143 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
144 | """
145 | # Input is (batch_size, seq_length, dim)
146 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
147 | batch_size, seq_length = hidden_states.shape[:2]
148 | real_seq_length = seq_length
149 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
150 |
151 | def shape(states):
152 | """projection"""
153 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
154 |
155 | def unshape(states):
156 | """reshape"""
157 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
158 |
159 | query_states = self.q(hidden_states)
160 | if key_value_states is None:
161 | key_states, value_states = self.k(hidden_states), self.v(hidden_states)
162 | else:
163 | key_states, value_states = self.k(key_value_states), self.v(key_value_states)
164 | query_states, key_states, value_states = shape(query_states), shape(key_states), shape(value_states)
165 |
166 | scores = torch.matmul(
167 | query_states, key_states.transpose(3, 2)
168 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
169 |
170 | if position_bias is None:
171 | if not self.has_relative_attention_bias:
172 | position_bias = torch.zeros(
173 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
174 | )
175 | else:
176 | position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
177 |
178 | if mask is not None:
179 | # Masking happens here, masked elements in the mask have the value of -inf
180 | position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
181 |
182 | position_bias_masked = position_bias
183 |
184 | scores += position_bias_masked
185 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
186 | scores
187 | ) # (batch_size, n_heads, seq_length, key_length)
188 | attn_weights = nn.functional.dropout(
189 | attn_weights, p=self.dropout, training=self.training
190 | ) # (batch_size, n_heads, seq_length, key_length)
191 |
192 | attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
193 | attn_output = self.o(attn_output)
194 |
195 | return (attn_output, position_bias)
196 |
197 |
198 | class T5LayerSelfAttention(nn.Module):
199 | def __init__(self, config, has_relative_attention_bias=False):
200 | super().__init__()
201 | self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
202 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
203 | self.dropout = nn.Dropout(config.dropout_rate)
204 |
205 | def forward(
206 | self,
207 | hidden_states,
208 | attention_mask=None,
209 | position_bias=None,
210 | ):
211 | normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states)
212 | attention_output = self.SelfAttention(
213 | normed_hidden_states,
214 | mask=attention_mask,
215 | position_bias=position_bias,
216 | )
217 | hidden_states = hidden_states + self.dropout(attention_output[0])
218 | outputs = (hidden_states,) + attention_output[1:]
219 | return outputs
220 |
221 |
222 | class T5LayerCrossAttention(nn.Module):
223 | def __init__(self, config):
224 | super().__init__()
225 | self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
226 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
227 | self.dropout = nn.Dropout(config.dropout_rate)
228 |
229 | def forward(
230 | self,
231 | hidden_states,
232 | key_value_states,
233 | attention_mask=None,
234 | position_bias=None,
235 | ):
236 | normed_hidden_states = self.layer_norm(hidden_states)
237 | attention_output = self.EncDecAttention(
238 | normed_hidden_states,
239 | mask=attention_mask,
240 | key_value_states=key_value_states,
241 | position_bias=position_bias,
242 | )
243 | layer_output = hidden_states + self.dropout(attention_output[0])
244 | outputs = (layer_output,) + attention_output[1:]
245 | return outputs
246 |
247 |
248 | class T5Block(nn.Module):
249 | def __init__(self, config, has_relative_attention_bias=False):
250 | super().__init__()
251 | self.is_decoder = config.is_decoder
252 | self.layer = nn.ModuleList()
253 | self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
254 | if self.is_decoder:
255 | self.layer.append(T5LayerCrossAttention(config))
256 |
257 | self.layer.append(T5LayerFF(config))
258 |
259 | def forward(
260 | self,
261 | hidden_states,
262 | attention_mask=None,
263 | position_bias=None,
264 | encoder_hidden_states=None,
265 | encoder_attention_mask=None,
266 | encoder_decoder_position_bias=None,
267 | ):
268 | self_attention_outputs = self.layer[0](
269 | hidden_states,
270 | attention_mask=attention_mask,
271 | position_bias=position_bias,
272 | )
273 | hidden_states = self_attention_outputs[0]
274 | attention_outputs = self_attention_outputs[1:] # Relative position weights
275 |
276 | if self.is_decoder and encoder_hidden_states is not None:
277 | cross_attention_outputs = self.layer[1](
278 | hidden_states,
279 | key_value_states=encoder_hidden_states,
280 | attention_mask=encoder_attention_mask,
281 | position_bias=encoder_decoder_position_bias,
282 | )
283 | hidden_states = cross_attention_outputs[0]
284 |
285 | # Keep relative position weights
286 | attention_outputs = attention_outputs + cross_attention_outputs[1:]
287 |
288 | # Apply Feed Forward layer
289 | hidden_states = self.layer[-1](hidden_states)
290 |
291 | outputs = (hidden_states,)
292 | outputs = outputs + attention_outputs
293 |
294 | return outputs # hidden-states, (self-attention position bias), (cross-attention position bias)
295 |
296 |
297 | class T5Stack(nn.Module, ModuleUtilsMixin):
298 | def __init__(self, config, embed_tokens):
299 | super().__init__()
300 | assert embed_tokens is not None
301 |
302 | self.config = config
303 | self.embed_tokens = embed_tokens
304 | self.is_decoder = config.is_decoder
305 |
306 | self.block = nn.ModuleList(
307 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
308 | )
309 |
310 | self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
311 | self.dropout = nn.Dropout(config.dropout_rate)
312 |
313 | def forward(
314 | self,
315 | input_ids=None,
316 | attention_mask=None,
317 | encoder_hidden_states=None,
318 | encoder_attention_mask=None,
319 | ) -> EncoderOutput:
320 | input_shape = input_ids.size()
321 | batch_size, seq_length = input_shape
322 |
323 | inputs_embeds = self.embed_tokens(input_ids)
324 |
325 | if hasattr(self.config, 'is_bf16') and self.config.is_bf16:
326 | inputs_embeds = inputs_embeds.to(torch.bfloat16)
327 |
328 | # Masking
329 | if attention_mask is None:
330 | attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device)
331 |
332 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
333 | encoder_seq_length = encoder_hidden_states.shape[1]
334 | encoder_attention_mask = torch.ones(
335 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
336 | )
337 |
338 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
339 | # ourselves in which case we just need to make it broadcastable to all heads.
340 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
341 |
342 | # If a 2D or 3D attention mask is provided for the cross-attention
343 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
344 | if self.is_decoder and encoder_hidden_states is not None:
345 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
346 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
347 | if encoder_attention_mask is None:
348 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
349 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
350 | else:
351 | encoder_extended_attention_mask = None
352 |
353 | position_bias = None
354 | encoder_decoder_position_bias = None
355 |
356 | hidden_states = self.dropout(inputs_embeds)
357 |
358 | for _, layer_module in enumerate(self.block):
359 | layer_outputs = layer_module(
360 | hidden_states,
361 | attention_mask=extended_attention_mask,
362 | position_bias=position_bias,
363 | encoder_hidden_states=encoder_hidden_states,
364 | encoder_attention_mask=encoder_extended_attention_mask,
365 | encoder_decoder_position_bias=encoder_decoder_position_bias,
366 | )
367 | hidden_states = layer_outputs[0]
368 |
369 | # We share the position biases between the layers - the first layer store them
370 | position_bias = layer_outputs[1]
371 | if self.is_decoder and encoder_hidden_states is not None:
372 | encoder_decoder_position_bias = layer_outputs[2]
373 |
374 | hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states)
375 | hidden_states = self.dropout(hidden_states)
376 |
377 | return EncoderOutput(
378 | hidden_states=hidden_states,
379 | attention_mask=attention_mask,
380 | )
381 |
382 |
383 | class MyT5(nn.Module):
384 | def __init__(self, config: T5Config):
385 | super().__init__()
386 | config.is_encoder_decoder = False
387 | assert not config.tie_word_embeddings
388 |
389 | self.config = config
390 | self.model_dim = config.d_model
391 | self.shared = nn.Embedding(config.vocab_size, config.d_model)
392 |
393 | encoder_config = copy.deepcopy(config)
394 | encoder_config.is_decoder = False
395 | self.encoder = T5Stack(encoder_config, self.shared)
396 |
397 | decoder_config = copy.deepcopy(config)
398 | decoder_config.is_decoder = True
399 | decoder_config.num_layers = config.num_decoder_layers
400 | self.decoder = T5Stack(decoder_config, self.shared)
401 |
402 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
403 | self.generation_config = None
404 |
405 | self.apply(self._init_weights)
406 |
407 | def generate(
408 | self,
409 | input_ids: Optional[torch.LongTensor] = None,
410 | attention_mask: Optional[torch.FloatTensor] = None,
411 | max_length = None,
412 | **kwargs,
413 | ) -> torch.LongTensor:
414 | """
415 | input_ids: B x L_encoder, int64
416 | attention_mask: B x L_encoder, int64
417 | 1 for tokens to attend to, 0 for tokens to ignore
418 |
419 | Generation:
420 | Starts with 0, ends with 1, padding is 0
421 |
422 | # For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s
423 | """
424 | B, _ = input_ids.size()
425 | labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device)
426 | encoder_outputs = None
427 |
428 | for _ in range(max_length):
429 | out = self.forward(
430 | input_ids=input_ids,
431 | attention_mask=attention_mask,
432 | decoder_input_ids=labels,
433 | encoder_outputs=encoder_outputs,
434 | )
435 | encoder_outputs = out.encoder_outputs
436 | top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1)
437 | labels = torch.cat([labels, top_labels], dim=-1)
438 |
439 | if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B:
440 | break
441 |
442 | labels[:, -1] = 1
443 |
444 | # Mask out the padding, i.e., all positions after the first 1 with 0
445 | B, L = labels.size()
446 | mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1)
447 | labels = labels.masked_fill(~mask, 0)
448 |
449 | return labels
450 |
451 | def forward(
452 | self,
453 | input_ids: Optional[torch.LongTensor] = None,
454 | attention_mask: Optional[torch.FloatTensor] = None,
455 | decoder_input_ids: Optional[torch.LongTensor] = None,
456 | decoder_attention_mask: Optional[torch.BoolTensor] = None,
457 | labels: Optional[torch.LongTensor] = None,
458 | encoder_outputs = None,
459 | ) -> Seq2SeqLMOutput:
460 | """
461 | input_ids: B x L_encoder, int64
462 | attention_mask: B x L_encoder, int64
463 | 1 for tokens to attend to, 0 for tokens to ignore
464 | labels: B x L_decoder, int64
465 | """
466 | if encoder_outputs is None:
467 | encoder_outputs = self.encoder(
468 | input_ids=input_ids,
469 | attention_mask=attention_mask,
470 | )
471 |
472 | hidden_states = encoder_outputs.hidden_states
473 |
474 | if labels is not None and decoder_input_ids is None:
475 | decoder_input_ids = self._shift_right(labels)
476 |
477 | decoder_outputs = self.decoder(
478 | input_ids=decoder_input_ids,
479 | attention_mask=decoder_attention_mask,
480 | encoder_hidden_states=hidden_states,
481 | encoder_attention_mask=attention_mask,
482 | )
483 |
484 | sequence_output = decoder_outputs[0]
485 | lm_logits = self.lm_head(sequence_output)
486 |
487 | loss = None
488 | if labels is not None:
489 | loss_fct = CrossEntropyLoss(ignore_index=-100)
490 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
491 |
492 | return Seq2SeqLMOutput(
493 | loss=loss,
494 | logits=lm_logits,
495 | encoder_outputs=encoder_outputs,
496 | )
497 |
498 | def _init_weights(self, module):
499 | factor = self.config.initializer_factor # Used for testing weights initialization
500 | if isinstance(module, T5LayerNorm):
501 | module.weight.data.fill_(factor * 1.0)
502 | elif isinstance(module, (MyT5)):
503 | module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
504 | if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
505 | module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
506 | elif isinstance(module, T5DenseGatedActDense):
507 | d_ff, d_model = module.wi_0.weight.data.size()
508 | module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
509 | module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
510 | module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
511 | elif isinstance(module, T5Attention):
512 | d_model = self.config.d_model
513 | key_value_proj_dim = self.config.d_kv
514 | n_heads = self.config.num_heads
515 | module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
516 | module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
517 | module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
518 | module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
519 | if hasattr(module, "relative_attention_bias"):
520 | module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
521 |
522 | def _shift_right(self, input_ids):
523 | decoder_start_token_id = self.config.decoder_start_token_id
524 | pad_token_id = self.config.pad_token_id
525 |
526 | assert decoder_start_token_id is not None and pad_token_id is not None
527 | shifted_input_ids = input_ids.new_zeros(input_ids.shape)
528 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
529 | shifted_input_ids[..., 0] = decoder_start_token_id
530 |
531 | # replace possible -100 values in labels by `pad_token_id`
532 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
533 |
534 | return shifted_input_ids
--------------------------------------------------------------------------------
/nanoT5/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import evaluate
4 | from .logging_utils import Averager
5 | from datasets.iterable_dataset import IterableDataset
6 |
7 |
8 | def maybe_save_checkpoint(accelerator, args):
9 | if (
10 | args.current_train_step > args.optim.total_steps
11 | or args.current_train_step % args.checkpoint.every_steps == 0
12 | ):
13 | output_dir = f'checkpoint-{args.mode}-{args.current_train_step}'
14 | accelerator.save_state(output_dir=output_dir)
15 |
16 |
17 | def maybe_eval_predict(model, dataloader, logger, args, tokenizer):
18 | if (
19 | args.current_train_step > args.optim.total_steps
20 | or args.current_train_step % args.eval.every_steps == 0
21 | ):
22 | model.eval()
23 |
24 | with torch.no_grad():
25 | eval(model, dataloader, logger, args, tokenizer)
26 |
27 | if args.mode == 'ft':
28 | predict(
29 | model, dataloader, logger, args, tokenizer
30 | )
31 |
32 | args.last_log = time.time()
33 | model.train()
34 |
35 |
36 | def maybe_logging(averager, args, model, optimizer, logger):
37 | if args.current_train_step % args.logging.every_steps == 0:
38 | stats = extra_stats(args, model, optimizer)
39 |
40 | averager.update(stats)
41 | averaged_stats = averager.average()
42 |
43 | logger.log_stats(
44 | stats=averaged_stats,
45 | step=args.current_train_step,
46 | args=args,
47 | prefix='train/'
48 | )
49 |
50 | args.last_log = time.time()
51 |
52 |
53 | def maybe_grad_clip_and_grad_calc(accelerator, model, args):
54 | if args.optim.grad_clip > 0:
55 | grad_l2 = accelerator.clip_grad_norm_(
56 | parameters=model.parameters(),
57 | max_norm=args.optim.grad_clip,
58 | norm_type=2,
59 | )
60 | else:
61 | grad_l2 = None
62 |
63 | if args.logging.grad_l2:
64 | if grad_l2 is None:
65 | grad_l2 = (
66 | sum(p.grad.detach().data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
67 | )
68 |
69 | return {'grad_l2': grad_l2}
70 | else:
71 | return {}
72 |
73 |
74 | def extra_stats(args, model, optimizer):
75 | stats = {}
76 |
77 | if args.logging.weights_l2:
78 | weights_l2 = sum(p.detach().norm(2).item() ** 2 for p in model.parameters()) ** 0.5
79 | stats['weights_l2'] = weights_l2
80 |
81 | stats['lr'] = optimizer.param_groups[0]['lr']
82 | stats['seconds_per_step'] = (time.time() - args.last_log) / args.logging.every_steps
83 |
84 | return stats
85 |
86 |
87 | def forward(model, batch, calc_acc=False):
88 | outputs = model(**batch)
89 | loss = outputs.loss
90 |
91 | stats = {}
92 | stats['loss'] = loss.detach().float().item()
93 |
94 | if calc_acc:
95 | correct = (outputs.logits.argmax(-1) == batch["labels"]).sum().item()
96 | accuracy = correct / batch["labels"].numel()
97 | stats['accuracy'] = accuracy
98 |
99 | return loss, stats
100 |
101 |
102 | def eval(model, dataloader, logger, args, tokenizer):
103 | args.last_log = time.time()
104 | averager = Averager()
105 |
106 | for batch_id, batch in enumerate(dataloader, start=1):
107 | if batch_id == args.eval.corrected_steps * args.optim.grad_acc:
108 | break
109 |
110 | _, stats = forward(model, batch, calc_acc=True)
111 | averager.update(stats)
112 |
113 | averager.update({'time': time.time() - args.last_log})
114 | averaged_stats = averager.average()
115 |
116 | logger.log_stats(
117 | stats=averaged_stats,
118 | step=args.current_train_step,
119 | args=args,
120 | prefix='eval/'
121 | )
122 |
123 |
124 | def predict(model, dataloader, logger, args, tokenizer):
125 | args.last_log = time.time()
126 | metric = evaluate.load('rouge')
127 | samples_seen = 0
128 |
129 | def decode(preds):
130 | preds[preds == -100] = tokenizer.pad_token_id
131 | preds = tokenizer.batch_decode(
132 | preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
133 | )
134 | preds = [pred.strip() for pred in preds]
135 | return preds
136 |
137 | for step, batch in enumerate(dataloader):
138 | predictions = model.generate(
139 | input_ids=batch['input_ids'],
140 | attention_mask=batch['attention_mask'],
141 | max_length=args.data.max_target_len,
142 | generation_config=model.generation_config,
143 | )
144 | predictions = decode(predictions)
145 | references = decode(batch["labels"])
146 |
147 | # If we are in a multiprocess environment, the last batch has duplicates
148 | if step == len(dataloader) - 1:
149 | predictions = predictions[: len(dataloader.dataset) - samples_seen]
150 | references = references[: len(dataloader.dataset) - samples_seen]
151 | else:
152 | samples_seen += len(references)
153 |
154 | metric.add_batch(
155 | predictions=predictions,
156 | references=references,
157 | )
158 |
159 | eval_metric = metric.compute(use_stemmer=True, use_aggregator=False)
160 | rougeL = sum(eval_metric["rougeL"]) * 100 / len(eval_metric["rougeL"])
161 |
162 | logger.log_stats(
163 | stats={
164 | "rougeL": rougeL,
165 | "time": time.time() - args.last_log,
166 | },
167 | step=args.current_train_step,
168 | args=args,
169 | prefix="test/",
170 | )
171 |
172 |
173 | def train(model, train_dataloader, test_dataloader, accelerator, lr_scheduler,
174 | optimizer, logger, args, tokenizer):
175 | model.train()
176 |
177 | train_averager = Averager()
178 |
179 | while args.current_train_step <= args.optim.total_steps:
180 | if isinstance(train_dataloader.dataset, IterableDataset):
181 | train_dataloader.dataset.set_epoch(args.current_epoch)
182 |
183 | # In case there is a remainder from previous epoch, we need to reset the optimizer
184 | optimizer.zero_grad(set_to_none=True)
185 |
186 | for batch_id, batch in enumerate(train_dataloader, start=1):
187 | if args.current_train_step > args.optim.total_steps:
188 | break
189 |
190 | loss, stats = forward(model, batch)
191 | accelerator.backward(loss / args.optim.grad_acc)
192 | train_averager.update(stats)
193 |
194 | if batch_id % args.optim.grad_acc == 0:
195 | stats = maybe_grad_clip_and_grad_calc(accelerator, model, args)
196 | train_averager.update(stats)
197 |
198 | optimizer.step()
199 | lr_scheduler.step()
200 | optimizer.zero_grad(set_to_none=True)
201 |
202 | maybe_logging(train_averager, args, model, optimizer, logger)
203 | maybe_eval_predict(model, test_dataloader, logger, args, tokenizer)
204 | maybe_save_checkpoint(accelerator, args)
205 |
206 | args.current_train_step += 1
207 |
208 | args.current_epoch += 1
209 |
210 | maybe_eval_predict(model, test_dataloader, logger, args, tokenizer)
211 | maybe_save_checkpoint(accelerator, args)
212 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | datasets >= 1.8.0
3 | sentencepiece != 0.1.92
4 | transformers
5 | neptune
6 | pdbpp
7 | notebook
8 | protobuf==3.20.*
9 | pyyaml
10 | pynvml
11 | hydra-core
12 | evaluate
13 | nltk
14 | absl-py
15 | rouge_score
16 | torch>=1.13.1,<=2.0.1
--------------------------------------------------------------------------------