├── .gitignore ├── LICENSE ├── README.md ├── assets ├── downstream.png ├── env_dump │ ├── lscpu.txt │ ├── nvidia_smi.txt │ └── pip_freeze.txt ├── ft_loss.png ├── nanoT5.png └── pt_loss.png ├── nanoT5 ├── __init__.py ├── configs │ ├── default.yaml │ ├── local_env │ │ └── default.yaml │ └── task │ │ ├── ft.yaml │ │ ├── pt.yaml │ │ ├── pt_12h.yaml │ │ ├── pt_16h.yaml │ │ ├── pt_20h.yaml │ │ ├── pt_24h.yaml │ │ ├── pt_4h.yaml │ │ └── pt_8h.yaml ├── main.py └── utils │ ├── __init__.py │ ├── copied_utils.py │ ├── gen_utils.py │ ├── logging_utils.py │ ├── model_utils.py │ ├── ni_dataset.py │ ├── t5_model.py │ └── train_utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | env.yaml 3 | logs 4 | scripts 5 | .neptune -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 - Piotr Nawrot. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nanoT5 (Encoder-Decoder / Pre-training + Fine-Tuning) 2 | 3 | ![nanoT5](assets/nanoT5.png) 4 | 5 | [[Paper](https://arxiv.org/abs/2309.02373)] | [**TLDR**](#tldr) | [**Motivation**](#motivation) | [**Setup**](#setup) | [**Pre-training**](#pre-training) | [**Fine-tuning**](#fine-tuning) | [**Extras**](#Extras) | [**Conclusions**](#conclusions) | [**References**](#references) | [**Cite**](#cite) | [**Issues**](#issues) 6 | 7 | ## TLDR: 8 | 9 | This repository comprises the code to reproduce the pre-training of a "Large Language Model" (T5) under a limited budget (1xA100 GPU, < 24 hours) in PyTorch. We start from the randomly initialised T5-base-v1.1 (248M parameters) model, and we pre-train it on the English subset of the C4 dataset and then fine-tune it on Super-Natural Instructions (SNI) benchmark. 10 | 11 | **In ~16 hours on a single GPU, we achieve 40.7 RougeL on the SNI test set, compared to 40.9 RougeL of the original model weights available on HuggingFace Hub and pretrained on 150x more data through "a combination of model and data parallelism [...] on slices of Cloud TPU Pods", each with 1024 TPUs.** 12 | 13 | Our core contribution is not the T5 model itself, which follows the HuggingFace implementation. Instead, we optimise everything else in the training pipeline to offer you a user-friendly starting template for your NLP application/research. Most importantly, we show that it is possible to pre-train the T5 model to the top performance under a limited budget in PyTorch. 14 | 15 | ## Motivation 16 | 17 | Despite the continuously increasing size of pretrained [Transformers](https://arxiv.org/pdf/1706.03762.pdf), the research community still needs easy-to-reproduce and up-to-date baselines to test new research hypotheses fast and at a small scale. 18 | 19 | A recent effort from Andrej Karpathy, the [nanoGPT](https://github.com/karpathy/nanoGPT) repository, enables researchers to pre-train and fine-tune GPT-style (Decoder-only) language models. On the other hand, [Cramming](https://github.com/JonasGeiping/cramming) opts to find the optimal BERT-style (Encoder-only) pre-training setup for limited-compute settings. 20 | 21 | With [nanoT5](https://github.com/PiotrNawrot/nanoT5), we want to fill a gap (Community requests: [#1](https://github.com/huggingface/transformers/issues/18030) [#2](https://github.com/facebookresearch/fairseq/issues/1899) [#3](https://github.com/google-research/text-to-text-transfer-transformer/issues/172) [#4](https://discuss.huggingface.co/t/example-of-how-to-pretrain-t5/4129) [#5](https://github.com/huggingface/transformers/issues/5079)) of an accessible research template to pre-train and fine-tune T5-style (Encoder-Decoder) model. **To the best of our knowledge, it is the first attempt to reproduce T5 v1.1 pre-training in PyTorch (previously available implementations are in Jax/Flax).** 22 | 23 | ## 24 | 25 | **We created this repository for people who want to pre-train T5-style models by themselves and evaluate their performance on downstream tasks.** This could be for a variety of reasons: 26 | - You are a researcher in academia with limited compute (like me), and you came up with a promising idea based on the T5 model, so you need a pipeline to evaluate it; 27 | - You have an in-house dataset that you think is more appropriate than the original pre-training dataset (C4) for your downstream task; 28 | - You want to experiment with continued pre-training or want to build on the T5 pre-training objective. 29 | 30 | **If you don't need to pre-train the T5 model, you'd be better off downloading the weights from HuggingFace Hub. Our checkpoints are worse because we work under limited compute.** 31 | 32 | ## 33 | 34 | In this project, we expose (for research purposes) and optimise everything in the training pipeline of T5, except from the model implementation. We include the simplified implementation of T5 model (which is great for teaching & learning purposes), however, we do not optimize it with the latest techniques like [tensor or pipeline parallelism](https://arxiv.org/pdf/2104.04473.pdf), because it makes the code much more complex and is not needed at a small scale. **Most importantly, we base our code on PyTorch, since access to TPUs is limited.** Key features: 35 | - **Dataset:** Downloading and preprocessing of the C4 dataset happens in parallel with the training of the model. The C4 dataset is > 300GB, so it takes a couple of hours to download it and even longer to preprocess it. This codebase does it on the fly without any detrimental effect on the training time and performance (we haven't observed it, although it might happen with an old CPU (< 8 core) or a slow internet connection). **As a result, you can initiate pre-training of your own T5 model within minutes.** 36 | - **Model Optimizer / LR Scheduler:** The original T5 uses a memory-efficient Adafactor optimizer. [A study on pre-training T5](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models), on the other hand, reports that training does not converge with AdamW which we find strange given that AdamW relies on theoretically better approximations than Adafactor. We analysed the source of this discrepancy with several ablations. Although there are many subtle differences between Adafactor and AdamW, what ensures the Adafactor convergence is [matrix-wise LR scaling by its root mean square (RMS)](https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L595). We augmented the AdamW implementation by RMS scaling and observed that it becomes **more stable during pre-training, achieves better validation loss, and is faster**. 37 | - **Exposure and simplicity:** We try to balance the implementation of the training pipeline by keeping it customisable while retaining a sufficient level of abstraction. We use the [HuggingFace Accelerator](https://huggingface.co/docs/accelerate/index) to implement operations like Checkpoint Saving, Gradient Accumulation, Gradient Clipping, and moving tensors to the correct devices. We use [neptune.ai](https://neptune.ai) for experiment tracking and [hydra](https://hydra.cc/docs/intro/) for hyperparameters. Additionally, we expose a [simplified implementation](nanoT5/utils/t5_model.py) of the T5 model, training loop, data preprocessing, etc. 38 | - **Efficiency:** We use mixed-precision training (TF32 & BF16), PyTorch 2.0 compile, and utilise all optimisations listed in established optimisation tutorials [#1](https://huggingface.co/docs/transformers/perf_train_gpu_one) [#2](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html). 39 | 40 | ## Setup 41 | 42 | ### Environment & Hardware: 43 | 44 | ``` 45 | git clone https://github.com/PiotrNawrot/nanoT5.git 46 | cd nanoT5 47 | conda create -n nanoT5 python=3.8 48 | conda activate nanoT5 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | The following commands result in the following [pip freeze](assets/env_dump/pip_freeze.txt) as of 18.06.2023. 53 | 54 | We also include our [lscpu](assets/env_dump/lscpu.txt) and [nvidia-smi](assets/env_dump/nvidia_smi.txt). 55 | 56 | ## Pre-training: 57 | 58 | ### Reference: 59 | 60 | The [T5 v1.1](https://arxiv.org/pdf/2002.05202.pdf) authors report **1.942** negative log-likelihood (NLL) on the held-out set of C4 after after 2^16 steps. 61 | 62 | ### Legacy Optimizer (Adafactor) & LR Schedule (Inverse-Square-Root) 63 | 64 | We follow the original experimental setup for pre-training, including [Dataset (C4)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L74), [Training Objective (Span Filling)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/copied_utils.py#L16), [Model Architecture (T5-Base)](nanoT5/utils/t5_model.py), [Optimizer (Adafactor)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L248), and [LR Schedule (Inverse-Square-Root)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L288). 65 | 66 | Our negative log-likelihood on the held-out set is **1.995**, slightly worse than the reference. 67 | 68 | ### AdamW with RMS scaling Optimizer & Cosine LR Schedule 69 | 70 | We also experiment with the AdamW optimizer (instead of the original Adafactor) as it (theoretically) should offer more stability during training. Instead of using a low-rank approximation for the second moment of the gradients, it estimates it directly by storing the moving average for each parameter in memory. However, training diverges with AdamW, similar to [this study on T5 pre-training](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models). Through several ablations, we found that [matrix-wise LR scaling by its root mean square (RMS)](https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L617) is responsible for the convergence of Adafactor. We augmented the AdamW implementation by RMS scaling and observed that [it converges, becomes more stable during pre-training, and is slightly faster](assets/pt_loss.png) (it retrieves the second moment directly from memory instead of approximating it via matrix multiplications). 71 | 72 | However, AdamW, when paired with the Inverse-Square-Root LR schedule, performs worse than Adafactor. For our final experiment, we replace ISR with Cosine LR Schedule. We achieve **1.953** negative log-likelihood on the held-out set and significantly outperform Adafactor with ISR schedule. We consider this our default pre-training config. 73 | 74 | ### Training loss of runs with different optimisers (Adafactor vs AdamW) and schedulers (ISR vs Cosine). Rest of the [hyperparameters](nanoT5/configs/default.yaml) follows the original T5 v1.1 paper. 75 | 76 | ![pt_loss](assets/pt_loss.png) 77 | 78 | ### Negative log-likelihood on the held-out set of C4 79 | 80 |
81 | 82 | | | **Inverse-Square-Root** | **Cosine** | 83 | | :---: | :----: | :---: | 84 | | **Adafactor** | 1.995 | 1.993 | 85 | | **AdamW** | 2.040 | **1.953** | 86 | 87 |
88 | 89 | ### Examples 90 | 91 | To reproduce any of the experiments mentioned above choose any combination of hyperparameters as follows: 92 | 93 | ``` 94 | python -m nanoT5.main \ 95 | optim.name={adafactor,adamwscale} \ 96 | optim.lr_scheduler={legacy,cosine} 97 | ``` 98 | 99 | We recommend adding `model.compile=true` flag for pre-training, if you are able to install PyTorch 2.0. 100 | 101 | Suppose you don't have access to a 80GB A100 GPU. In that case, you can increase the number of gradient accumulation steps by `optim.grad_acc=steps`, where `batch_size` has to be divisible by `steps`. 102 | 103 | The summary of the optimization process is printed every 100 steps in the following format. For instance: 104 | 105 | ``` 106 | [train] Step 100 out of 65536 | Loss --> 59.881 | Grad_l2 --> 61.126 | Weights_l2 --> 7042.931 | Lr --> 0.010 | Seconds_per_step --> 1.385 | 107 | ``` 108 | 109 | ### Efficiency statistics: 110 | 111 | Below we include the efficiency statistics for our pre-training experiments. We report the time it takes to pre-train the model for 1 pre-training step and the total pre-training time according to the [default config](nanoT5/configs/default.yaml). Please note that we need to increase the **optim.grad_acc steps** to fit the model in precision different from BF16. 112 | 113 |
114 | 115 | | **Mixed Precision Format** | **Torch 2.0 compile** | **Grad Acc Steps** | **Pre-training (1 step)** | **Total Pre-training time** | 116 | | :----: | :---: | :---: | :---: | :---: | 117 | | FP32 | No | 2 | ~4.10s | ~74.6h | 118 | | TF32 | No | 2 | ~1.39s | ~25.3h | 119 | | BF16 | No | 2 | ~1.30s | ~23.7h | 120 | | TF32 | Yes | 2 | ~0.95s | ~17.3h | 121 | | BF16 | Yes | 1 | ~0.56s | ~10.2h | 122 | 123 |
124 | 125 | ## Fine-tuning: 126 | 127 | To fine-tune our model, we use the popular meta-dataset called **Super Natural-Instructions (SNI)**, which aggregates datasets for many tasks. This meta-dataset was used to fine-tune many of the recent LLMs, e.g. [FlanT5](https://arxiv.org/pdf/2210.11416.pdf), [BLOOM](https://arxiv.org/pdf/2211.05100.pdf), and [Tk-Instruct](https://arxiv.org/pdf/2204.07705.pdf). While FlanT5 and BLOOM use other corpora in addition to SNI, Tk-Instruct's pipeline consists of starting from the pre-trained T5 model and fine-tuning it solely on SNI. 128 | 129 | In this repository, we reproduce the Tk-Instruct fine-tuning results and follow their pipeline to evaluate our pre-trained model. 130 | 131 | ### Download the Super-Natural Instructions data: 132 | 133 | ``` 134 | git clone https://github.com/allenai/natural-instructions.git data 135 | ``` 136 | 137 | ### Run fine-tuning: 138 | 139 | We strictly follow the fine-tuning [config](nanoT5/configs/task/ft.yaml) of Tk-Instruct. It remains unclear whether Tk-Instruct was initialised from a regular checkpoint (*google/t5-v1_1-base*) or the one adapted explicitly for Language Modelling with continued training (*google/t5-base-lm-adapt*). Therefore, we decided to evaluate both. Run the following command to reproduce the Tk-Instruct experiments: 140 | 141 | ``` 142 | python -m nanoT5.main task=ft \ 143 | model.name={google/t5-v1_1-base,google/t5-base-lm-adapt} \ 144 | model.random_init={true,false} \ 145 | model.checkpoint_path={"","/path/to/pytorch_model.bin"} 146 | ``` 147 | 148 | Setting `model.random_init=false model.checkpoint_path=""` corresponds to downloading pre-trained weights from HuggingFace Hub. 149 | 150 | Setting `model.random_init=false model.checkpoint_path="/path/to/pytorch_model.bin"` corresponds to using the weights [**pre-trained**](#pre-training) with nanoT5. 151 | 152 | Setting `model.random_init=true model.checkpoint_path=""` corresponds to a random initialisation. 153 | 154 | ### Rouge-L on the held-out test-set across different pre-training budgets: 155 | 156 | In the figure below, we compare the performance of the model trained in this repository under different time budgets ([4](nanoT5/configs/task/pt_4h.yaml), [8](nanoT5/configs/task/pt_8h.yaml), [12](nanoT5/configs/task/pt_12h.yaml), [16](nanoT5/configs/task/pt_16h.yaml), [20](nanoT5/configs/task/pt_20h.yaml), [24](nanoT5/configs/task/pt_24h.yaml) hours) with the original T5-base-v1.1 model weights available through Huggingface Hub and its version adapted for Language Modelling (*google/t5-base-lm-adapt*). We observe that model trained in our repository for 16 hours on a single GPU is only 0.2 RougeL worse on average than the original T5-base-v1.1 model, despite being pre-trained on 150x less data (According to the [T5 paper](https://arxiv.org/pdf/1910.10683.pdf), they pre-train their models for one million steps with a batch size 2048. Our 16 hours config does 53332 steps with a batch size 256). Checkpoint explicitly adapted for Language Modelling (*google/t5-base-lm-adapt*) performs better than the original T5-base-v1.1 model and our model, however, this goes beyond the scope of this repository. 157 | 158 | ![ft_rougeL](assets/downstream.png) 159 | 160 | We share the model's weights after pre-training for 24 hours on [HuggingFace Hub](https://huggingface.co/pnawrot/nanoT5-base), which you can download and fine-tune on SNI using nanoT5. 161 | We also share the [fine-tuning loss curves](assets/ft_loss.png). 162 | 163 | A single Fine-tuning step takes ~0.18s, and full Fine-tuning takes ~1 hour. 164 | 165 | ## Extras: 166 | 167 | ### Things we tried and didn't work out: 168 | 169 | - **Different optimizers:** We tried the most recent optimizers like [Lion](https://arxiv.org/abs/2302.06675), [Sophia](https://github.com/Liuhong99/Sophia), however, none of them worked better than AdamW with RMS scaling. 170 | - **Positional embeddings:** We tried to replace T5's learned relative positional embeddings with [ALiBi](https://arxiv.org/pdf/2108.12409.pdf). Possible benefits include a reduction of parameters and faster training & inference. Furthermore, if ALiBi worked we could add [Flash Attention](https://github.com/HazyResearch/flash-attention), which currently supports only non-parametric bias (T5 bias is trainable). However, with ALiBi the training was less stable and it had worse pre-training loss. 171 | - **FP16 precision:** All experiments with FP16 precision diverged across different seeds 172 | 173 | ## Conclusions: 174 | 175 | We show that it is possible to successfully pre-train a "Large Language Model" (T5) under a limited budget (1xA100 GPU, < 24 hours) in PyTorch. We make our codebase, configs and training logs publicly available to enhance the accessibility of NLP research. We are keen to hear your suggestions to improve the codebase further. 176 | 177 | ### Acknowledgements: 178 | 179 | Thanks to [Edoardo Maria Ponti](https://ducdauge.github.io) for his feedback! 180 | 181 | ## References: 182 | - [T5 paper](https://arxiv.org/pdf/1910.10683.pdf) 183 | - [T5 v1.1 paper](https://arxiv.org/pdf/2002.05202.pdf) 184 | - [Super-Natural Instructions paper](https://arxiv.org/pdf/2204.07705.pdf) 185 | - [HuggingFace Flax Script](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py) 186 | - [Karpathy's nanoGPT](https://github.com/karpathy/nanoGPT) 187 | - [Instruct-GPT codebase (Super-Natural Instructions)](https://github.com/yizhongw/Tk-Instruct) 188 | - [Blog about pre-training Dutch T5 in HuggingFace](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models) 189 | 190 | ## Cite 191 | 192 | If you found the repository useful consider citing the paper about this work. 193 | 194 | ``` 195 | @inproceedings{nawrot-2023-nanot5, 196 | title = {nano{T}5: Fast {\&} Simple Pre-training and Fine-tuning of {T}5 Models with Limited Resources}, 197 | author = {Nawrot, Piotr}, 198 | year = 2023, 199 | month = dec, 200 | booktitle = {Proceedings of the 3rd Workshop for Natural Language Processing Open Source Software (NLP-OSS 2023)}, 201 | publisher = {Association for Computational Linguistics}, 202 | doi = {10.18653/v1/2023.nlposs-1.11}, 203 | url = {https://aclanthology.org/2023.nlposs-1.11} 204 | } 205 | ``` 206 | 207 | Below you can also find an excellent work which uses nanoT5 for their experiments. 208 | 209 | ``` 210 | @article{Kaddour2023NoTN, 211 | title={No Train No Gain: Revisiting Efficient Training Algorithms For Transformer-based Language Models}, 212 | author={Jean Kaddour and Oscar Key and Piotr Nawrot and Pasquale Minervini and Matt J. Kusner}, 213 | journal={ArXiv}, 214 | year={2023}, 215 | volume={abs/2307.06440}, 216 | } 217 | ``` 218 | 219 | ## Issues: 220 | 221 | If you have any questions, feel free to raise a Github issue or contact me directly at: piotr.nawrot@ed.ac.uk 222 | -------------------------------------------------------------------------------- /assets/downstream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/assets/downstream.png -------------------------------------------------------------------------------- /assets/env_dump/lscpu.txt: -------------------------------------------------------------------------------- 1 | Architecture: x86_64 2 | CPU op-mode(s): 32-bit, 64-bit 3 | Byte Order: Little Endian 4 | CPU(s): 128 5 | On-line CPU(s) list: 0-127 6 | Thread(s) per core: 1 7 | Core(s) per socket: 64 8 | Socket(s): 2 9 | NUMA node(s): 8 10 | Vendor ID: AuthenticAMD 11 | CPU family: 25 12 | Model: 1 13 | Model name: AMD EPYC 7763 64-Core Processor 14 | Stepping: 1 15 | CPU MHz: 2445.206 16 | BogoMIPS: 4890.41 17 | Virtualization: AMD-V 18 | L1d cache: 32K 19 | L1i cache: 32K 20 | L2 cache: 512K 21 | L3 cache: 32768K 22 | NUMA node0 CPU(s): 0-15 23 | NUMA node1 CPU(s): 16-31 24 | NUMA node2 CPU(s): 32-47 25 | NUMA node3 CPU(s): 48-63 26 | NUMA node4 CPU(s): 64-79 27 | NUMA node5 CPU(s): 80-95 28 | NUMA node6 CPU(s): 96-111 29 | NUMA node7 CPU(s): 112-127 30 | Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca 31 | -------------------------------------------------------------------------------- /assets/env_dump/nvidia_smi.txt: -------------------------------------------------------------------------------- 1 | Sun Jun 18 13:21:03 2023 2 | +-----------------------------------------------------------------------------+ 3 | | NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 | 4 | |-------------------------------+----------------------+----------------------+ 5 | | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | 6 | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | 7 | | | | MIG M. | 8 | |===============================+======================+======================| 9 | | 0 NVIDIA A100-SXM... On | 00000000:C1:00.0 Off | 0 | 10 | | N/A 65C P0 420W / 500W | 75277MiB / 81920MiB | 84% Default | 11 | | | | Disabled | 12 | +-------------------------------+----------------------+----------------------+ 13 | 14 | +-----------------------------------------------------------------------------+ 15 | | Processes: | 16 | | GPU GI CI PID Type Process name GPU Memory | 17 | | ID ID Usage | 18 | |=============================================================================| 19 | | 0 N/A N/A 50391 C python 75274MiB | 20 | +-----------------------------------------------------------------------------+ 21 | -------------------------------------------------------------------------------- /assets/env_dump/pip_freeze.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.20.3 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | antlr4-python3-runtime==4.9.3 6 | anyio==3.7.0 7 | argon2-cffi==21.3.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.2.3 10 | asttokens==2.2.1 11 | async-timeout==4.0.2 12 | attrs==23.1.0 13 | backcall==0.2.0 14 | beautifulsoup4==4.12.2 15 | bleach==6.0.0 16 | boto3==1.26.155 17 | botocore==1.29.155 18 | bravado==11.0.3 19 | bravado-core==5.17.1 20 | certifi==2023.5.7 21 | cffi==1.15.1 22 | charset-normalizer==3.1.0 23 | click==8.1.3 24 | cmake==3.26.4 25 | comm==0.1.3 26 | datasets==2.13.0 27 | debugpy==1.6.7 28 | decorator==5.1.1 29 | defusedxml==0.7.1 30 | dill==0.3.6 31 | docopt==0.6.2 32 | evaluate==0.4.0 33 | exceptiongroup==1.1.1 34 | executing==1.2.0 35 | fancycompleter==0.9.1 36 | fastjsonschema==2.17.1 37 | filelock==3.12.2 38 | fqdn==1.5.1 39 | frozenlist==1.3.3 40 | fsspec==2023.6.0 41 | future==0.18.3 42 | gitdb==4.0.10 43 | GitPython==3.1.31 44 | huggingface-hub==0.15.1 45 | hydra-core==1.3.2 46 | idna==3.4 47 | importlib-metadata==6.6.0 48 | importlib-resources==5.12.0 49 | ipykernel==6.23.2 50 | ipython==8.12.2 51 | ipython-genutils==0.2.0 52 | isoduration==20.11.0 53 | jedi==0.18.2 54 | Jinja2==3.1.2 55 | jmespath==1.0.1 56 | joblib==1.2.0 57 | jsonpointer==2.4 58 | jsonref==1.1.0 59 | jsonschema==4.17.3 60 | jupyter-events==0.6.3 61 | jupyter_client==8.2.0 62 | jupyter_core==5.3.1 63 | jupyter_server==2.6.0 64 | jupyter_server_terminals==0.4.4 65 | jupyterlab-pygments==0.2.2 66 | lit==16.0.6 67 | MarkupSafe==2.1.3 68 | matplotlib-inline==0.1.6 69 | mistune==2.0.5 70 | monotonic==1.6 71 | mpmath==1.3.0 72 | msgpack==1.0.5 73 | multidict==6.0.4 74 | multiprocess==0.70.14 75 | nbclassic==1.0.0 76 | nbclient==0.8.0 77 | nbconvert==7.5.0 78 | nbformat==5.9.0 79 | neptune==1.3.1 80 | nest-asyncio==1.5.6 81 | networkx==3.1 82 | nltk==3.8.1 83 | notebook==6.5.4 84 | notebook_shim==0.2.3 85 | numpy==1.24.3 86 | nvidia-cublas-cu11==11.10.3.66 87 | nvidia-cuda-cupti-cu11==11.7.101 88 | nvidia-cuda-nvrtc-cu11==11.7.99 89 | nvidia-cuda-runtime-cu11==11.7.99 90 | nvidia-cudnn-cu11==8.5.0.96 91 | nvidia-cufft-cu11==10.9.0.58 92 | nvidia-curand-cu11==10.2.10.91 93 | nvidia-cusolver-cu11==11.4.0.1 94 | nvidia-cusparse-cu11==11.7.4.91 95 | nvidia-nccl-cu11==2.14.3 96 | nvidia-nvtx-cu11==11.7.91 97 | oauthlib==3.2.2 98 | omegaconf==2.3.0 99 | overrides==7.3.1 100 | packaging==23.1 101 | pandas==2.0.2 102 | pandocfilters==1.5.0 103 | parso==0.8.3 104 | pdbpp==0.10.3 105 | pexpect==4.8.0 106 | pickleshare==0.7.5 107 | Pillow==9.5.0 108 | pipreqs==0.4.13 109 | pkgutil_resolve_name==1.3.10 110 | platformdirs==3.6.0 111 | prometheus-client==0.17.0 112 | prompt-toolkit==3.0.38 113 | protobuf==3.20.3 114 | psutil==5.9.5 115 | ptyprocess==0.7.0 116 | pure-eval==0.2.2 117 | pyarrow==12.0.1 118 | pycparser==2.21 119 | Pygments==2.15.1 120 | PyJWT==2.7.0 121 | pynvml==11.5.0 122 | pyrepl==0.9.0 123 | pyrsistent==0.19.3 124 | python-dateutil==2.8.2 125 | python-json-logger==2.0.7 126 | pytz==2023.3 127 | PyYAML==6.0 128 | pyzmq==25.1.0 129 | regex==2023.6.3 130 | requests==2.31.0 131 | requests-oauthlib==1.3.1 132 | responses==0.18.0 133 | rfc3339-validator==0.1.4 134 | rfc3986-validator==0.1.1 135 | rfc3987==1.3.8 136 | rouge-score==0.1.2 137 | s3transfer==0.6.1 138 | safetensors==0.3.1 139 | Send2Trash==1.8.2 140 | sentencepiece==0.1.99 141 | simplejson==3.19.1 142 | six==1.16.0 143 | smmap==5.0.0 144 | sniffio==1.3.0 145 | soupsieve==2.4.1 146 | stack-data==0.6.2 147 | swagger-spec-validator==3.0.3 148 | sympy==1.12 149 | terminado==0.17.1 150 | tinycss2==1.2.1 151 | tokenizers==0.13.3 152 | torch==2.0.1 153 | tornado==6.3.2 154 | tqdm==4.65.0 155 | traitlets==5.9.0 156 | transformers==4.30.2 157 | triton==2.0.0 158 | typing_extensions==4.6.3 159 | tzdata==2023.3 160 | uri-template==1.2.0 161 | urllib3==1.26.16 162 | wcwidth==0.2.6 163 | webcolors==1.13 164 | webencodings==0.5.1 165 | websocket-client==1.6.0 166 | wmctrl==0.4 167 | xxhash==3.2.0 168 | yarg==0.1.9 169 | yarl==1.9.2 170 | zipp==3.15.0 171 | -------------------------------------------------------------------------------- /assets/ft_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/assets/ft_loss.png -------------------------------------------------------------------------------- /assets/nanoT5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/assets/nanoT5.png -------------------------------------------------------------------------------- /assets/pt_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/assets/pt_loss.png -------------------------------------------------------------------------------- /nanoT5/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/nanoT5/1375b389d33ab4f34754a9fca62e4cfa1dd52379/nanoT5/__init__.py -------------------------------------------------------------------------------- /nanoT5/configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: pt 4 | - local_env: default 5 | 6 | # Experiment args 7 | mode: 'pt' 8 | device: gpu 9 | precision: 'bf16' 10 | eval_only: false 11 | predict_only: false 12 | seed: 2137 13 | 14 | model: 15 | klass: local_t5 16 | name: 'google/t5-v1_1-base' 17 | overwrite: 18 | dropout_rate: 0.0 19 | add_config: 20 | is_bf16: false 21 | checkpoint_path: '' 22 | random_init: true 23 | compile: false # Pytorch 2.0 24 | 25 | data: 26 | input_length: 512 27 | mlm_probability: 0.15 28 | mean_noise_span_length: 3.0 29 | num_workers: 8 30 | 31 | optim: 32 | name: adamwscale 33 | base_lr: 2e-2 34 | batch_size: 128 35 | total_steps: 65536 36 | epochs: -1 # If it's > 0 it overwrites total_steps 37 | warmup_steps: 10000 38 | lr_scheduler: cosine 39 | weight_decay: 0.0 40 | grad_clip: 1.0 41 | grad_acc: 1 42 | final_cosine: 1e-5 43 | 44 | eval: 45 | every_steps: 100000 # Eval once in the end 46 | steps: 500 47 | 48 | checkpoint: 49 | every_steps: 100000 # Save checkpoint once in the end 50 | 51 | logging: 52 | neptune: false 53 | neptune_creds: 54 | project: 55 | api_token: 56 | tags: '' 57 | every_steps: 100 58 | grad_l2: true 59 | weights_l2: true 60 | 61 | hydra: 62 | job: 63 | chdir: True 64 | -------------------------------------------------------------------------------- /nanoT5/configs/local_env/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${logging.neptune_creds.tags} -------------------------------------------------------------------------------- /nanoT5/configs/task/ft.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | mode: 'ft' 4 | precision: 'no' 5 | 6 | model: 7 | klass: hf_t5 8 | 9 | data: 10 | max_seq_len: 1024 11 | max_target_len: 128 12 | max_num_instances_per_task: 100 13 | add_task_name: False 14 | add_task_definition: True 15 | num_pos_examples: 2 16 | num_neg_examples: 0 17 | add_explanation: False 18 | tk_instruct: False 19 | exec_file_path: ./nanoT5/utils/ni_dataset.py 20 | data_dir: ./data/splits/default 21 | task_dir: ./data/tasks 22 | 23 | optim: 24 | name: adamw 25 | base_lr: 5e-5 26 | batch_size: 8 27 | epochs: 2 28 | warmup_steps: 0 29 | lr_scheduler: constant 30 | weight_decay: 0.0 31 | grad_clip: 0.0 32 | grad_acc: 1 33 | 34 | eval: 35 | steps: 200 36 | -------------------------------------------------------------------------------- /nanoT5/configs/task/pt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | -------------------------------------------------------------------------------- /nanoT5/configs/task/pt_12h.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | optim: 4 | warmup_steps: 5000 5 | batch_size: 256 6 | total_steps: 39999 7 | grad_acc: 2 -------------------------------------------------------------------------------- /nanoT5/configs/task/pt_16h.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | optim: 4 | warmup_steps: 6666 5 | batch_size: 256 6 | total_steps: 53332 7 | grad_acc: 2 -------------------------------------------------------------------------------- /nanoT5/configs/task/pt_20h.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | optim: 4 | warmup_steps: 8332 5 | batch_size: 256 6 | total_steps: 66665 7 | grad_acc: 2 -------------------------------------------------------------------------------- /nanoT5/configs/task/pt_24h.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | optim: 4 | warmup_steps: 10000 5 | batch_size: 256 6 | total_steps: 79998 7 | grad_acc: 2 -------------------------------------------------------------------------------- /nanoT5/configs/task/pt_4h.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | optim: 4 | warmup_steps: 1666 5 | batch_size: 256 6 | total_steps: 13333 7 | grad_acc: 2 -------------------------------------------------------------------------------- /nanoT5/configs/task/pt_8h.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | optim: 4 | warmup_steps: 3332 5 | batch_size: 256 6 | total_steps: 26666 7 | grad_acc: 2 -------------------------------------------------------------------------------- /nanoT5/main.py: -------------------------------------------------------------------------------- 1 | from accelerate import Accelerator 2 | from omegaconf import open_dict 3 | import hydra 4 | import torch 5 | import time 6 | 7 | from .utils import ( 8 | setup_basics, 9 | train, 10 | predict, 11 | eval, 12 | get_lr_scheduler, 13 | get_optimizer, 14 | get_tokenizer, 15 | get_model, 16 | get_dataloaders, 17 | get_config, 18 | ) 19 | 20 | 21 | @hydra.main(config_path="configs", config_name="default", version_base='1.1') 22 | def main(args): 23 | accelerator = Accelerator( 24 | cpu=args.device == "cpu", 25 | mixed_precision=args.precision, 26 | ) 27 | logger = setup_basics(accelerator, args) 28 | config = get_config(args) 29 | model = get_model(args, config) 30 | tokenizer = get_tokenizer(args) 31 | optimizer = get_optimizer(model, args) 32 | lr_scheduler = get_lr_scheduler(optimizer, args, logger) 33 | train_dataloader, test_dataloader = get_dataloaders(tokenizer, config, args) 34 | 35 | logger.log_args(args) 36 | 37 | ( 38 | model, 39 | optimizer, 40 | lr_scheduler, 41 | train_dataloader, 42 | test_dataloader, 43 | ) = accelerator.prepare( 44 | model, optimizer, lr_scheduler, train_dataloader, test_dataloader 45 | ) 46 | 47 | if args.model.compile: 48 | model = torch.compile(model) 49 | 50 | with open_dict(args): 51 | args.current_train_step = 1 52 | args.current_epoch = 1 53 | args.last_log = time.time() 54 | 55 | if args.eval_only: 56 | model.eval() 57 | with torch.no_grad(): 58 | eval(model, test_dataloader, logger, args, tokenizer) 59 | elif args.predict_only: 60 | model.eval() 61 | with torch.no_grad(): 62 | predict(model, test_dataloader, logger, 63 | args, tokenizer) 64 | else: 65 | train(model, train_dataloader, test_dataloader, accelerator, 66 | lr_scheduler, optimizer, logger, args, tokenizer) 67 | 68 | logger.finish() 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /nanoT5/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .gen_utils import * 2 | from .model_utils import * 3 | from .train_utils import * 4 | -------------------------------------------------------------------------------- /nanoT5/utils/copied_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import numpy as np 3 | from transformers import BatchEncoding 4 | from dataclasses import dataclass 5 | from transformers import AutoTokenizer 6 | import torch 7 | import math 8 | from torch.optim import Optimizer 9 | from typing import Iterable, Tuple 10 | from torch import nn 11 | import random 12 | import string 13 | 14 | 15 | @dataclass 16 | class DataCollatorForT5MLM: 17 | """ 18 | [Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py] 19 | Data collator used for T5 span-masked language modeling. 20 | It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length. 21 | For more information on how T5 span-masked language modeling works, one can take a look 22 | at the `official paper `__ 23 | or the `official code for preprocessing `__ . 24 | Args: 25 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 26 | The tokenizer used for encoding the data. 27 | noise_density (:obj:`float`): 28 | The probability with which to (randomly) mask tokens in the input. 29 | mean_noise_span_length (:obj:`float`): 30 | The average span length of the masked tokens. 31 | input_length (:obj:`int`): 32 | The expected input length after masking. 33 | target_length (:obj:`int`): 34 | The expected target length after masking. 35 | pad_token_id: (:obj:`int`): 36 | The pad token id of the model 37 | decoder_start_token_id: (:obj:`int): 38 | The decoder start token id of the model 39 | """ 40 | 41 | tokenizer: AutoTokenizer 42 | noise_density: float 43 | mean_noise_span_length: float 44 | input_length: int 45 | target_length: int 46 | pad_token_id: int 47 | 48 | def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding: 49 | # convert list to dict and tensorize input 50 | batch = BatchEncoding( 51 | { 52 | k: np.array([examples[i][k] for i in range(len(examples))]) 53 | for k, v in examples[0].items() 54 | } 55 | ) 56 | 57 | input_ids = batch["input_ids"] 58 | batch_size, expandend_input_length = input_ids.shape 59 | 60 | mask_indices = np.asarray( 61 | [ 62 | self.random_spans_noise_mask(expandend_input_length) 63 | for i in range(batch_size) 64 | ] 65 | ) 66 | labels_mask = ~mask_indices 67 | 68 | input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8)) 69 | labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8)) 70 | 71 | batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel) 72 | batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel) 73 | 74 | if batch["input_ids"].shape[-1] != self.input_length: 75 | raise ValueError( 76 | f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but" 77 | f" should be {self.input_length}." 78 | ) 79 | 80 | if batch["labels"].shape[-1] != self.target_length: 81 | raise ValueError( 82 | f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be" 83 | f" {self.target_length}." 84 | ) 85 | 86 | batch = {k: torch.from_numpy(v) for k, v in batch.items()} 87 | return batch 88 | 89 | def create_sentinel_ids(self, mask_indices): 90 | """ 91 | Sentinel ids creation given the indices that should be masked. 92 | The start indices of each mask are replaced by the sentinel ids in increasing 93 | order. Consecutive mask indices to be deleted are replaced with `-1`. 94 | """ 95 | start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices 96 | start_indices[:, 0] = mask_indices[:, 0] 97 | 98 | sentinel_ids = np.where( 99 | start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices 100 | ) 101 | sentinel_ids = np.where( 102 | sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0 103 | ) 104 | sentinel_ids -= mask_indices - start_indices 105 | 106 | return sentinel_ids 107 | 108 | def filter_input_ids(self, input_ids, sentinel_ids): 109 | """ 110 | Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. 111 | This will reduce the sequence length from `expanded_inputs_length` to `input_length`. 112 | """ 113 | batch_size = input_ids.shape[0] 114 | 115 | input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) 116 | # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are 117 | # masked tokens coming after sentinel tokens and should be removed 118 | input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) 119 | input_ids = np.concatenate( 120 | [ 121 | input_ids, 122 | np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32), 123 | ], 124 | axis=-1, 125 | ) 126 | return input_ids 127 | 128 | def random_spans_noise_mask(self, length): 129 | """This function is copy of `random_spans_helper `__ . 130 | 131 | Noise mask consisting of random spans of noise tokens. 132 | The number of noise tokens and the number of noise spans and non-noise spans 133 | are determined deterministically as follows: 134 | num_noise_tokens = round(length * noise_density) 135 | num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) 136 | Spans alternate between non-noise and noise, beginning with non-noise. 137 | Subject to the above restrictions, all masks are equally likely. 138 | 139 | Args: 140 | length: an int32 scalar (length of the incoming token sequence) 141 | noise_density: a float - approximate density of output mask 142 | mean_noise_span_length: a number 143 | 144 | Returns: 145 | a boolean tensor with shape [length] 146 | """ 147 | 148 | orig_length = length 149 | 150 | num_noise_tokens = int(np.round(length * self.noise_density)) 151 | # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. 152 | num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) 153 | num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) 154 | 155 | # avoid degeneracy by ensuring positive number of noise spans 156 | num_noise_spans = max(num_noise_spans, 1) 157 | num_nonnoise_tokens = length - num_noise_tokens 158 | 159 | # pick the lengths of the noise spans and the non-noise spans 160 | def _random_segmentation(num_items, num_segments): 161 | """Partition a sequence of items randomly into non-empty segments. 162 | Args: 163 | num_items: an integer scalar > 0 164 | num_segments: an integer scalar in [1, num_items] 165 | Returns: 166 | a Tensor with shape [num_segments] containing positive integers that add 167 | up to num_items 168 | """ 169 | mask_indices = np.arange(num_items - 1) < (num_segments - 1) 170 | np.random.shuffle(mask_indices) 171 | first_in_segment = np.pad(mask_indices, [[1, 0]]) 172 | segment_id = np.cumsum(first_in_segment) 173 | # count length of sub segments assuming that list is sorted 174 | _, segment_length = np.unique(segment_id, return_counts=True) 175 | return segment_length 176 | 177 | noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) 178 | nonnoise_span_lengths = _random_segmentation( 179 | num_nonnoise_tokens, num_noise_spans 180 | ) 181 | 182 | interleaved_span_lengths = np.reshape( 183 | np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), 184 | [num_noise_spans * 2], 185 | ) 186 | span_starts = np.cumsum(interleaved_span_lengths)[:-1] 187 | span_start_indicator = np.zeros((length,), dtype=np.int8) 188 | span_start_indicator[span_starts] = True 189 | span_num = np.cumsum(span_start_indicator) 190 | is_noise = np.equal(span_num % 2, 1) 191 | 192 | return is_noise[:orig_length] 193 | 194 | 195 | def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length): 196 | """This function is copy of `random_spans_helper `__ . 197 | 198 | [Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py] 199 | Training parameters to avoid padding with random_spans_noise_mask. 200 | When training a model with random_spans_noise_mask, we would like to set the other 201 | training hyperparmeters in a way that avoids padding. 202 | This function helps us compute these hyperparameters. 203 | We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, 204 | and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. 205 | This function tells us the required number of tokens in the raw example (for split_tokens()) 206 | as well as the length of the encoded targets. Note that this function assumes 207 | the inputs and targets will have EOS appended and includes that in the reported length. 208 | 209 | Args: 210 | inputs_length: an integer - desired length of the tokenized inputs sequence 211 | noise_density: a float 212 | mean_noise_span_length: a float 213 | Returns: 214 | tokens_length: length of original text in tokens 215 | targets_length: an integer - length in tokens of encoded targets sequence 216 | """ 217 | 218 | def _tokens_length_to_inputs_length_targets_length(tokens_length): 219 | num_noise_tokens = int(round(tokens_length * noise_density)) 220 | num_nonnoise_tokens = tokens_length - num_noise_tokens 221 | num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) 222 | # inputs contain all nonnoise tokens, sentinels for all noise spans 223 | # and one EOS token. 224 | _input_length = num_nonnoise_tokens + num_noise_spans + 1 225 | _output_length = num_noise_tokens + num_noise_spans + 1 226 | return _input_length, _output_length 227 | 228 | tokens_length = inputs_length 229 | 230 | while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length: 231 | tokens_length += 1 232 | 233 | inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length) 234 | 235 | # minor hack to get the targets length to be equal to inputs length 236 | # which is more likely to have been set to a nice round number. 237 | if noise_density == 0.5 and targets_length > inputs_length: 238 | tokens_length -= 1 239 | targets_length -= 1 240 | return tokens_length, targets_length 241 | 242 | 243 | class AdamWScale(Optimizer): 244 | """ 245 | This AdamW implementation is copied from Huggingface. 246 | We modified it with Adagrad scaling by rms of a weight tensor 247 | 248 | Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay 249 | Regularization](https://arxiv.org/abs/1711.05101). 250 | 251 | Parameters: 252 | params (`Iterable[nn.parameter.Parameter]`): 253 | Iterable of parameters to optimize or dictionaries defining parameter groups. 254 | lr (`float`, *optional*, defaults to 1e-3): 255 | The learning rate to use. 256 | betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)): 257 | Adam's betas parameters (b1, b2). 258 | eps (`float`, *optional*, defaults to 1e-6): 259 | Adam's epsilon for numerical stability. 260 | weight_decay (`float`, *optional*, defaults to 0): 261 | Decoupled weight decay to apply. 262 | correct_bias (`bool`, *optional*, defaults to `True`): 263 | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). 264 | no_deprecation_warning (`bool`, *optional*, defaults to `False`): 265 | A flag used to disable the deprecation warning (set to `True` to disable the warning). 266 | """ 267 | 268 | def __init__( 269 | self, 270 | params: Iterable[nn.parameter.Parameter], 271 | lr: float = 1e-3, 272 | betas: Tuple[float, float] = (0.9, 0.999), 273 | eps: float = 1e-6, 274 | weight_decay: float = 0.0, 275 | correct_bias: bool = True, 276 | ): 277 | if lr < 0.0: 278 | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") 279 | if not 0.0 <= betas[0] < 1.0: 280 | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") 281 | if not 0.0 <= betas[1] < 1.0: 282 | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") 283 | if not 0.0 <= eps: 284 | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") 285 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 286 | super().__init__(params, defaults) 287 | 288 | @staticmethod 289 | def _rms(tensor): 290 | return tensor.norm(2) / (tensor.numel() ** 0.5) 291 | 292 | def step(self, closure=None): 293 | """ 294 | Performs a single optimization step. 295 | 296 | Arguments: 297 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. 298 | """ 299 | loss = None 300 | if closure is not None: 301 | loss = closure() 302 | 303 | for group in self.param_groups: 304 | for p in group["params"]: 305 | if p.grad is None: 306 | continue 307 | grad = p.grad.data 308 | if grad.is_sparse: 309 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 310 | 311 | state = self.state[p] 312 | beta1, beta2 = group["betas"] 313 | 314 | # State initialization 315 | if len(state) == 0: 316 | state["step"] = 0 317 | # Exponential moving average of gradient values 318 | state["exp_avg"] = torch.zeros_like(p.data) 319 | # Exponential moving average of squared gradient values 320 | state["exp_avg_sq"] = torch.zeros_like(p.data) 321 | 322 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 323 | 324 | state["step"] += 1 325 | 326 | # Decay the first and second moment running average coefficient 327 | # In-place operations to update the averages at the same time 328 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) 329 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 330 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 331 | 332 | step_size = group["lr"] 333 | if group["correct_bias"]: # No bias correction for Bert 334 | bias_correction1 = 1.0 - beta1 ** state["step"] 335 | bias_correction2 = 1.0 - beta2 ** state["step"] 336 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 337 | 338 | # /Adapt Step from Adafactor 339 | step_size = step_size * max(1e-3, self._rms(p.data)) 340 | # /Adapt Step from Adafactor 341 | 342 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 343 | 344 | # Just adding the square of the weights to the loss function is *not* 345 | # the correct way of using L2 regularization/weight decay with Adam, 346 | # since that will interact with the m and v parameters in strange ways. 347 | # 348 | # Instead we want to decay the weights in a manner that doesn't interact 349 | # with the m/v parameters. This is equivalent to adding the square 350 | # of the weights to the loss with plain (non-momentum) SGD. 351 | # Add weight decay at the end (fixed version) 352 | if group["weight_decay"] > 0.0: 353 | p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"])) 354 | 355 | return loss 356 | 357 | 358 | def tokenize_function(examples, tokenizer, in_length): 359 | tokenizer_out = tokenizer( 360 | text=examples["text"], 361 | return_attention_mask=False, 362 | ) 363 | 364 | input_ids = tokenizer_out["input_ids"] 365 | 366 | concatenated_ids = np.concatenate(input_ids) 367 | 368 | total_length = concatenated_ids.shape[0] 369 | total_length = (total_length // in_length) * in_length 370 | 371 | concatenated_ids = concatenated_ids[:total_length].reshape(-1, in_length) 372 | result = {"input_ids": concatenated_ids} 373 | 374 | return result 375 | 376 | 377 | from transformers.data.data_collator import * 378 | @dataclass 379 | class DataCollatorForNI: 380 | tokenizer: PreTrainedTokenizerBase 381 | padding: Union[bool, str, PaddingStrategy] = True 382 | max_source_length: Optional[int] = None 383 | max_target_length: Optional[int] = None 384 | pad_to_multiple_of: Optional[int] = None 385 | label_pad_token_id: int = -100 386 | return_tensors: str = "pt" 387 | add_task_name: bool = False 388 | add_task_definition: bool = True 389 | num_pos_examples: int = 0 390 | num_neg_examples: int = 0 391 | add_explanation: bool = False 392 | tk_instruct: bool = False 393 | text_only: bool = False 394 | 395 | def __call__(self, batch, return_tensors=None): 396 | 397 | if return_tensors is None: 398 | return_tensors = self.return_tensors 399 | 400 | sources = [] 401 | for instance in batch: 402 | if self.tk_instruct: 403 | all_valid_encodings = [ 404 | # instruction only 405 | { 406 | "add_task_name": False, 407 | "add_task_definition": True, 408 | "num_pos_examples": 0, 409 | "num_neg_examples": 0, 410 | "add_explanation": False, 411 | }, 412 | # example only 413 | { 414 | "add_task_name": False, 415 | "add_task_definition": False, 416 | "num_pos_examples": 2, 417 | "num_neg_examples": 0, 418 | "add_explanation": False, 419 | }, 420 | # instruction + pos examples 421 | { 422 | "add_task_name": False, 423 | "add_task_definition": True, 424 | "num_pos_examples": 2, 425 | "num_neg_examples": 0, 426 | "add_explanation": False, 427 | }, 428 | # instruction + pos examples + neg examples 429 | { 430 | "add_task_name": False, 431 | "add_task_definition": True, 432 | "num_pos_examples": 2, 433 | "num_neg_examples": 2, 434 | "add_explanation": False, 435 | }, 436 | # instruction + pos (w. explanation) 437 | { 438 | "add_task_name": False, 439 | "add_task_definition": True, 440 | "num_pos_examples": 2, 441 | "num_neg_examples": 0, 442 | "add_explanation": True, 443 | }, 444 | ] 445 | encoding_schema = random.choice(all_valid_encodings) 446 | add_task_name = encoding_schema["add_task_name"] 447 | add_task_definition = encoding_schema["add_task_definition"] 448 | num_pos_examples = encoding_schema["num_pos_examples"] 449 | num_neg_examples = encoding_schema["num_neg_examples"] 450 | add_explanation = encoding_schema["add_explanation"] 451 | else: 452 | add_task_name = self.add_task_name 453 | add_task_definition = self.add_task_definition 454 | num_pos_examples = self.num_pos_examples 455 | num_neg_examples = self.num_neg_examples 456 | add_explanation = self.add_explanation 457 | 458 | task_input = "" 459 | # add the input first. 460 | task_input += "Now complete the following example -\n" 461 | task_input += f"Input: {instance['Instance']['input'].strip()}" 462 | if not task_input[-1] in string.punctuation: 463 | task_input += "." 464 | task_input += "\n" 465 | task_input += "Output: " 466 | 467 | task_name = "" 468 | if add_task_name: 469 | task_name += instance["Task"] + ". " 470 | 471 | definition = "" 472 | if add_task_definition: 473 | if isinstance(instance["Definition"], list): 474 | definition = ( 475 | "Definition: " + instance["Definition"][0].strip() 476 | ) 477 | else: 478 | definition = "Definition: " + instance["Definition"].strip() 479 | if not definition[-1] in string.punctuation: 480 | definition += "." 481 | definition += "\n\n" 482 | 483 | # try to add positive examples. 484 | pos_examples = [] 485 | for idx, pos_example in enumerate( 486 | instance["Positive Examples"][:num_pos_examples] 487 | ): 488 | pos_example_str = f" Positive Example {idx+1} -\n" 489 | pos_example_str += f"Input: {pos_example['input'].strip()}" 490 | if not pos_example_str[-1] in string.punctuation: 491 | pos_example_str += "." 492 | pos_example_str += "\n" 493 | pos_example_str += f" Output: {pos_example['output'].strip()}" 494 | if not pos_example_str[-1] in string.punctuation: 495 | pos_example_str += "." 496 | pos_example_str += "\n" 497 | if add_explanation and "explanation" in pos_example: 498 | pos_example_str += ( 499 | f" Explanation: {pos_example['explanation'].strip()}" 500 | ) 501 | if not pos_example_str[-1] in string.punctuation: 502 | pos_example_str += "." 503 | pos_example_str += "\n" 504 | pos_example_str += "\n" 505 | if ( 506 | len( 507 | self.tokenizer( 508 | definition 509 | + " ".join(pos_examples) 510 | + pos_example_str 511 | + task_input 512 | )["input_ids"] 513 | ) 514 | <= self.max_source_length 515 | ): 516 | pos_examples.append(pos_example_str) 517 | else: 518 | break 519 | 520 | # try to add negative examples. 521 | neg_examples = [] 522 | for idx, neg_example in enumerate( 523 | instance["Negative Examples"][:num_neg_examples] 524 | ): 525 | neg_example_str = f" Negative Example {idx+1} -\n" 526 | neg_example_str += f"Input: {neg_example['input'].strip()}" 527 | if not neg_example_str[-1] in string.punctuation: 528 | neg_example_str += "." 529 | neg_example_str += "\n" 530 | neg_example_str += f" Output: {neg_example['output'].strip()}" 531 | if not neg_example_str[-1] in string.punctuation: 532 | neg_example_str += "." 533 | neg_example_str += "\n" 534 | if add_explanation and "explanation" in neg_example: 535 | neg_example_str += ( 536 | f" Explanation: {neg_example['explanation'].strip()}" 537 | ) 538 | if not neg_example_str[-1] in string.punctuation: 539 | neg_example_str += "." 540 | neg_example_str += "\n" 541 | neg_example_str += "\n" 542 | if ( 543 | len( 544 | self.tokenizer( 545 | definition 546 | + " ".join(pos_examples) 547 | + " ".join(neg_examples) 548 | + neg_example_str 549 | + task_input 550 | )["input_ids"] 551 | ) 552 | <= self.max_source_length 553 | ): 554 | neg_examples.append(neg_example_str) 555 | else: 556 | break 557 | 558 | source = ( 559 | task_name 560 | + definition 561 | + "".join(pos_examples) 562 | + "".join(neg_examples) 563 | + task_input 564 | ) 565 | tokenized_source = self.tokenizer(source)["input_ids"] 566 | if len(tokenized_source) <= self.max_source_length: 567 | sources.append(source) 568 | else: 569 | sources.append( 570 | self.tokenizer.decode( 571 | tokenized_source[: self.max_source_length], 572 | skip_special_tokens=True, 573 | ) 574 | ) 575 | 576 | if self.text_only: 577 | model_inputs = {"inputs": sources} 578 | else: 579 | model_inputs = self.tokenizer( 580 | sources, 581 | max_length=self.max_source_length, 582 | padding=self.padding, 583 | return_tensors=self.return_tensors, 584 | truncation=True, 585 | pad_to_multiple_of=self.pad_to_multiple_of, 586 | ) 587 | 588 | if "output" in batch[0]["Instance"] and batch[0]["Instance"]["output"]: 589 | # Randomly select one reference if multiple are provided. 590 | labels = [random.choice(ex["Instance"]["output"]) for ex in batch] 591 | if self.text_only: 592 | model_inputs["labels"] = labels 593 | else: 594 | labels = self.tokenizer( 595 | labels, 596 | max_length=self.max_target_length, 597 | padding=self.padding, 598 | return_tensors=self.return_tensors, 599 | truncation=True, 600 | pad_to_multiple_of=self.pad_to_multiple_of, 601 | ) 602 | label_mask = labels["attention_mask"].bool() 603 | model_inputs["labels"] = labels["input_ids"].masked_fill( 604 | ~label_mask, self.label_pad_token_id 605 | ) 606 | else: 607 | model_inputs["labels"] = None 608 | 609 | return model_inputs 610 | -------------------------------------------------------------------------------- /nanoT5/utils/gen_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from accelerate.utils import set_seed 5 | from omegaconf import open_dict 6 | from .logging_utils import Logger 7 | from hydra.utils import to_absolute_path 8 | 9 | 10 | def check_args_and_env(args): 11 | assert args.optim.batch_size % args.optim.grad_acc == 0 12 | 13 | # Train log must happen before eval log 14 | assert args.eval.every_steps % args.logging.every_steps == 0 15 | 16 | if args.device == 'gpu': 17 | assert torch.cuda.is_available(), 'We use GPU to train/eval the model' 18 | 19 | assert not (args.eval_only and args.predict_only) 20 | 21 | if args.predict_only: 22 | assert args.mode == 'ft' 23 | 24 | 25 | def opti_flags(args): 26 | # This lines reduce training step by 2.4x 27 | torch.backends.cuda.matmul.allow_tf32 = True 28 | torch.backends.cudnn.allow_tf32 = True 29 | 30 | if args.precision == 'bf16' and args.device == 'gpu' and args.model.klass == 'local_t5': 31 | args.model.add_config.is_bf16 = True 32 | 33 | 34 | def update_args_with_env_info(args): 35 | with open_dict(args): 36 | slurm_id = os.getenv('SLURM_JOB_ID') 37 | 38 | if slurm_id is not None: 39 | args.slurm_id = slurm_id 40 | else: 41 | args.slurm_id = 'none' 42 | 43 | args.working_dir = os.getcwd() 44 | 45 | 46 | def update_paths(args): 47 | if args.mode == 'ft': 48 | args.data.exec_file_path = to_absolute_path(args.data.exec_file_path) 49 | args.data.data_dir = to_absolute_path(args.data.data_dir) 50 | args.data.task_dir = to_absolute_path(args.data.task_dir) 51 | 52 | 53 | def setup_basics(accelerator, args): 54 | check_args_and_env(args) 55 | update_args_with_env_info(args) 56 | update_paths(args) 57 | opti_flags(args) 58 | 59 | if args.seed is not None: 60 | set_seed(args.seed) 61 | 62 | logger = Logger(args=args, accelerator=accelerator) 63 | 64 | return logger 65 | -------------------------------------------------------------------------------- /nanoT5/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from accelerate.logging import get_logger 4 | from omegaconf import OmegaConf, open_dict 5 | import logging 6 | import datasets 7 | import transformers 8 | import neptune 9 | import os 10 | 11 | 12 | class Averager: 13 | def __init__(self, weight: float = 1): 14 | self.weight = weight 15 | self.reset() 16 | 17 | def reset(self): 18 | self.total = defaultdict(float) 19 | self.counter = defaultdict(float) 20 | 21 | def update(self, stats): 22 | for key, value in stats.items(): 23 | self.total[key] = self.total[key] * self.weight + value * self.weight 24 | self.counter[key] = self.counter[key] * self.weight + self.weight 25 | 26 | def average(self): 27 | averaged_stats = { 28 | key: tot / self.counter[key] for key, tot in self.total.items() 29 | } 30 | self.reset() 31 | 32 | return averaged_stats 33 | 34 | 35 | class Logger: 36 | def __init__(self, args, accelerator): 37 | self.logger = get_logger('Main') 38 | 39 | # Make one log on every process with the configuration for debugging. 40 | logging.basicConfig( 41 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 42 | datefmt="%m/%d/%Y %H:%M:%S", 43 | level=logging.INFO, 44 | ) 45 | self.logger.info(accelerator.state, main_process_only=False) 46 | self.logger.info(f'Working directory is {os.getcwd()}') 47 | 48 | if accelerator.is_local_main_process: 49 | datasets.utils.logging.set_verbosity_warning() 50 | transformers.utils.logging.set_verbosity_info() 51 | else: 52 | datasets.utils.logging.set_verbosity_error() 53 | transformers.utils.logging.set_verbosity_error() 54 | 55 | self.setup_neptune(args) 56 | 57 | def setup_neptune(self, args): 58 | if args.logging.neptune: 59 | neptune_logger = neptune.init_run( 60 | project=args.logging.neptune_creds.project, 61 | api_token=args.logging.neptune_creds.api_token, 62 | tags=[str(item) for item in args.logging.neptune_creds.tags.split(",")], 63 | ) 64 | else: 65 | neptune_logger = None 66 | 67 | self.neptune_logger = neptune_logger 68 | 69 | with open_dict(args): 70 | if neptune_logger is not None: 71 | args.neptune_id = neptune_logger["sys/id"].fetch() 72 | 73 | def log_args(self, args): 74 | if self.neptune_logger is not None: 75 | logging_args = OmegaConf.to_container(args, resolve=True) 76 | self.neptune_logger['args'] = logging_args 77 | 78 | def log_stats(self, stats, step, args, prefix=''): 79 | if self.neptune_logger is not None: 80 | for k, v in stats.items(): 81 | self.neptune_logger[f'{prefix}{k}'].log(v, step=step) 82 | 83 | msg_start = f'[{prefix[:-1]}] Step {step} out of {args.optim.total_steps}' + ' | ' 84 | dict_msg = ' | '.join([f'{k.capitalize()} --> {v:.3f}' for k, v in stats.items()]) + ' | ' 85 | 86 | msg = msg_start + dict_msg 87 | 88 | self.log_message(msg) 89 | 90 | def log_message(self, msg): 91 | self.logger.info(msg) 92 | 93 | def finish(self): 94 | if self.neptune_logger is not None: 95 | self.neptune_logger.stop() 96 | -------------------------------------------------------------------------------- /nanoT5/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | from torch.utils.data import DataLoader 4 | from omegaconf import open_dict 5 | from datasets.iterable_dataset import IterableDataset 6 | from transformers import ( 7 | AutoTokenizer, 8 | T5ForConditionalGeneration, 9 | AutoConfig, 10 | ) 11 | 12 | from .copied_utils import ( 13 | compute_input_and_target_lengths, 14 | DataCollatorForT5MLM, 15 | tokenize_function, 16 | DataCollatorForNI, 17 | ) 18 | from .t5_model import MyT5 19 | 20 | 21 | def get_model(args, config): 22 | klass = { 23 | 'hf_t5': T5ForConditionalGeneration, 24 | 'local_t5': MyT5, 25 | }[args.model.klass] 26 | 27 | if args.model.checkpoint_path: 28 | model = klass(config) 29 | model.load_state_dict(torch.load(args.model.checkpoint_path)) 30 | elif args.model.random_init: 31 | model = klass(config) 32 | else: 33 | assert klass == T5ForConditionalGeneration, 'To load HFs weights you need to use HF model' 34 | model = klass.from_pretrained( 35 | args.model.name, 36 | config=config, 37 | ) 38 | 39 | with open_dict(args): 40 | args.n_all_param = sum([p.nelement() for p in model.parameters()]) 41 | 42 | return model 43 | 44 | 45 | def get_config(args): 46 | config = AutoConfig.from_pretrained( 47 | args.model.name, 48 | ) 49 | 50 | if hasattr(args.model, 'overwrite'): 51 | for k, v in args.model.overwrite.items(): 52 | assert hasattr(config, k), f'config does not have attribute {k}' 53 | setattr(config, k, v) 54 | 55 | if hasattr(args.model, 'add_config'): 56 | for k, v in args.model.add_config.items(): 57 | assert not hasattr(config, k), f'config already has attribute {k}' 58 | setattr(config, k, v) 59 | 60 | return config 61 | 62 | 63 | def get_tokenizer(args): 64 | tokenizer = AutoTokenizer.from_pretrained( 65 | args.model.name, 66 | use_fast=True 67 | ) 68 | tokenizer.model_max_length = int(1e9) 69 | 70 | return tokenizer 71 | 72 | 73 | def load_dataset_splits(args): 74 | if args.mode == 'pt': 75 | dataset = datasets.load_dataset( 76 | 'c4', 77 | 'en', 78 | streaming=True, 79 | ) 80 | 81 | dataset = dataset.remove_columns( 82 | ['timestamp', 'url'] 83 | ) 84 | 85 | dataset_splits = { 86 | 'train': dataset['train'], 87 | 'test': dataset['validation'], 88 | } 89 | 90 | assert ( 91 | dataset['train'].n_shards == 1024 92 | ), "We want to have many shards for efficient processing with num_workes in PyTorch dataloader" 93 | elif args.mode == 'ft': 94 | dataset_splits = datasets.load_dataset( 95 | args.data.exec_file_path, 96 | data_dir=args.data.data_dir, 97 | task_dir=args.data.task_dir, 98 | max_num_instances_per_task=args.data.max_num_instances_per_task, 99 | max_num_instances_per_eval_task=args.data.max_num_instances_per_task 100 | ) 101 | else: 102 | raise NotImplementedError 103 | 104 | return dataset_splits 105 | 106 | 107 | def process_dataset(dataset_splits, args, tokenizer): 108 | if args.mode == 'pt': 109 | final_datasets = {} 110 | 111 | for split, dataset_split in dataset_splits.items(): 112 | 113 | # We increase the input_length, because instead of masking tokens T5 replaces 114 | # masked spans with a single token, therefore to avoid padding we need to have 115 | # longer sequences at the start, before masking 116 | before_mask_input_length, target_length = compute_input_and_target_lengths( 117 | inputs_length=args.data.input_length, 118 | noise_density=args.data.mlm_probability, 119 | mean_noise_span_length=args.data.mean_noise_span_length, 120 | ) 121 | 122 | with open_dict(args): 123 | args.data.before_mask_input_length = before_mask_input_length 124 | args.data.target_length = target_length 125 | 126 | dataset_split = dataset_split.map( 127 | tokenize_function, 128 | batched=True, 129 | fn_kwargs={ 130 | 'tokenizer': tokenizer, 131 | 'in_length': before_mask_input_length, 132 | }, 133 | remove_columns=['text'], 134 | ) 135 | 136 | dataset_split = dataset_split.shuffle(buffer_size=10_000, seed=args.seed) 137 | final_datasets[split] = dataset_split 138 | elif args.mode == 'ft': 139 | final_datasets = dataset_splits 140 | else: 141 | raise NotImplementedError 142 | 143 | return final_datasets 144 | 145 | 146 | def get_data_collator(tokenizer, config, args): 147 | if args.mode == 'pt': 148 | data_collator = DataCollatorForT5MLM( 149 | tokenizer=tokenizer, 150 | noise_density=args.data.mlm_probability, 151 | mean_noise_span_length=args.data.mean_noise_span_length, 152 | input_length=args.data.input_length, 153 | target_length=args.data.target_length, 154 | pad_token_id=config.pad_token_id, 155 | ) 156 | elif args.mode == 'ft': 157 | data_collator = DataCollatorForNI( 158 | tokenizer, 159 | padding="longest", 160 | max_source_length=args.data.max_seq_len, 161 | max_target_length=args.data.max_target_len, 162 | label_pad_token_id=-100, 163 | pad_to_multiple_of=8, 164 | add_task_name=args.data.add_task_name, 165 | add_task_definition=args.data.add_task_definition, 166 | num_pos_examples=args.data.num_pos_examples, 167 | num_neg_examples=args.data.num_neg_examples, 168 | add_explanation=args.data.add_explanation, 169 | tk_instruct=args.data.tk_instruct 170 | ) 171 | else: 172 | raise NotImplementedError 173 | 174 | return data_collator 175 | 176 | 177 | def get_dataloaders(tokenizer, config, args): 178 | dataset_splits = load_dataset_splits(args) 179 | dataset = process_dataset(dataset_splits=dataset_splits, args=args, tokenizer=tokenizer) 180 | data_collator = get_data_collator(tokenizer=tokenizer, config=config, 181 | args=args) 182 | 183 | is_iterable = isinstance(dataset['train'], IterableDataset) 184 | 185 | dataloaders = {} 186 | 187 | for split in ['train', 'test']: 188 | batch_size = args.optim.batch_size // args.optim.grad_acc 189 | 190 | shuffle = (split == 'train') and not is_iterable 191 | 192 | if args.mode == 'ft' and split == 'train': 193 | assert shuffle is True 194 | else: 195 | assert shuffle is False 196 | 197 | dataloaders[split] = DataLoader( 198 | dataset[split], 199 | shuffle=shuffle, 200 | collate_fn=data_collator, 201 | batch_size=batch_size, 202 | num_workers=args.data.num_workers, 203 | pin_memory=True, 204 | drop_last=False, 205 | ) 206 | 207 | # Add & Check args about data loaders 208 | with open_dict(args): 209 | if not is_iterable: 210 | args.data.train_batches = len(dataloaders['train']) 211 | args.data.test_batches = len(dataloaders['test']) 212 | 213 | if args.optim.epochs > 0: 214 | assert not is_iterable 215 | args.optim.total_steps = (len(dataloaders['train']) // args.optim.grad_acc) * args.optim.epochs 216 | 217 | args.eval.corrected_steps = args.eval.steps 218 | 219 | return dataloaders['train'], dataloaders['test'] 220 | 221 | 222 | def get_optimizer(model, args): 223 | no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"] 224 | 225 | optimizer_grouped_parameters = [ 226 | { 227 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 228 | "weight_decay": args.optim.weight_decay, 229 | }, 230 | { 231 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 232 | "weight_decay": 0.0, 233 | }, 234 | ] 235 | 236 | if args.optim.name == 'adamw': 237 | from transformers import AdamW 238 | optimizer = AdamW( 239 | optimizer_grouped_parameters, 240 | lr=args.optim.base_lr, 241 | ) 242 | elif args.optim.name == 'adamwscale': 243 | from .copied_utils import AdamWScale 244 | optimizer = AdamWScale( 245 | optimizer_grouped_parameters, 246 | lr=args.optim.base_lr, 247 | ) 248 | elif args.optim.name == 'adafactor': 249 | from transformers import Adafactor 250 | optimizer = Adafactor( 251 | optimizer_grouped_parameters, 252 | lr=args.optim.base_lr, 253 | relative_step=False, 254 | ) 255 | else: 256 | raise NotImplementedError 257 | 258 | return optimizer 259 | 260 | 261 | def get_lr_scheduler(optimizer, args, logger): 262 | if args.optim.lr_scheduler == 'cosine': 263 | from torch.optim.lr_scheduler import ( 264 | SequentialLR, 265 | LinearLR, 266 | CosineAnnealingLR, 267 | ) 268 | 269 | scheduler1 = LinearLR( 270 | optimizer, 271 | start_factor=0.5, 272 | end_factor=1, 273 | total_iters=args.optim.warmup_steps, 274 | last_epoch=-1, 275 | ) 276 | 277 | scheduler2 = CosineAnnealingLR( 278 | optimizer, 279 | T_max=args.optim.total_steps - args.optim.warmup_steps, 280 | eta_min=args.optim.final_cosine, 281 | ) 282 | 283 | lr_scheduler = SequentialLR( 284 | optimizer, 285 | schedulers=[scheduler1, scheduler2], 286 | milestones=[args.optim.warmup_steps] 287 | ) 288 | elif args.optim.lr_scheduler == 'legacy': 289 | import math 290 | from torch.optim.lr_scheduler import ( 291 | SequentialLR, 292 | LinearLR, 293 | LambdaLR, 294 | ) 295 | 296 | msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr" 297 | logger.log_message(msg) 298 | 299 | num_steps_optimizer1 = math.ceil(args.optim.total_steps * 0.9) 300 | iters_left_for_optimizer2 = args.optim.total_steps - num_steps_optimizer1 301 | 302 | scheduler1 = LambdaLR( 303 | optimizer, 304 | lambda step: min( 305 | 1e-2, 1.0 / math.sqrt(step) 306 | ) / args.optim.base_lr if step else 1e-2 / args.optim.base_lr 307 | ) 308 | 309 | scheduler2 = LinearLR( 310 | optimizer, 311 | start_factor=( 312 | min(1e-2, 1.0 / math.sqrt(num_steps_optimizer1)) / args.optim.base_lr 313 | ), 314 | end_factor=0, 315 | total_iters=iters_left_for_optimizer2, 316 | last_epoch=-1, 317 | ) 318 | 319 | lr_scheduler = SequentialLR( 320 | optimizer, 321 | schedulers=[scheduler1, scheduler2], 322 | milestones=[num_steps_optimizer1] 323 | ) 324 | elif args.optim.lr_scheduler == 'constant': 325 | from transformers import get_scheduler 326 | lr_scheduler = get_scheduler( 327 | name=args.optim.lr_scheduler, 328 | optimizer=optimizer, 329 | ) 330 | else: 331 | raise NotImplementedError 332 | 333 | return lr_scheduler 334 | -------------------------------------------------------------------------------- /nanoT5/utils/ni_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Natural Instruction V2 Dataset.""" 18 | 19 | 20 | import json 21 | import os 22 | import random 23 | import datasets 24 | 25 | logger = datasets.logging.get_logger(__name__) 26 | 27 | _CITATION = """ 28 | @article{wang2022benchmarking, 29 | title={Benchmarking Generalization via In-Context Instructions on 1,600+ Language Tasks}, 30 | author={Wang, Yizhong and Mishra, Swaroop and Alipoormolabashi, Pegah and Kordi, Yeganeh and others}, 31 | journal={arXiv preprint arXiv:2204.07705}, 32 | year={2022} 33 | } 34 | """ 35 | 36 | _DESCRIPTION = """ 37 | Natural-Instructions v2 is a benchmark of 1,600+ diverse language tasks and their expert-written instructions. 38 | It covers 70+ distinct task types, such as tagging, in-filling, and rewriting. 39 | These tasks are collected with contributions of NLP practitioners in the community and 40 | through an iterative peer review process to ensure their quality. 41 | """ 42 | 43 | _URL = "https://instructions.apps.allenai.org/" 44 | 45 | class NIConfig(datasets.BuilderConfig): 46 | def __init__(self, *args, task_dir=None, max_num_instances_per_task=None, max_num_instances_per_eval_task=None, **kwargs): 47 | super().__init__(*args, **kwargs) 48 | self.task_dir: str = task_dir 49 | self.max_num_instances_per_task: int = max_num_instances_per_task 50 | self.max_num_instances_per_eval_task: int = max_num_instances_per_eval_task 51 | 52 | 53 | class NaturalInstructions(datasets.GeneratorBasedBuilder): 54 | """NaturalInstructions Dataset.""" 55 | 56 | VERSION = datasets.Version("2.0.0") 57 | BUILDER_CONFIG_CLASS = NIConfig 58 | BUILDER_CONFIGS = [ 59 | NIConfig(name="default", description="Default config for NaturalInstructions") 60 | ] 61 | DEFAULT_CONFIG_NAME = "default" 62 | 63 | def _info(self): 64 | return datasets.DatasetInfo( 65 | description=_DESCRIPTION, 66 | features=datasets.Features( 67 | { 68 | "id": datasets.Value("string"), 69 | "Task": datasets.Value("string"), 70 | "Contributors": datasets.Value("string"), 71 | "Source": [datasets.Value("string")], 72 | "URL": [datasets.Value("string")], 73 | "Categories": [datasets.Value("string")], 74 | "Reasoning": [datasets.Value("string")], 75 | "Definition": [datasets.Value("string")], 76 | "Positive Examples": [{ 77 | "input": datasets.Value("string"), 78 | "output": datasets.Value("string"), 79 | "explanation": datasets.Value("string") 80 | }], 81 | "Negative Examples": [{ 82 | "input": datasets.Value("string"), 83 | "output": datasets.Value("string"), 84 | "explanation": datasets.Value("string") 85 | }], 86 | "Input_language": [datasets.Value("string")], 87 | "Output_language": [datasets.Value("string")], 88 | "Instruction_language": [datasets.Value("string")], 89 | "Domains": [datasets.Value("string")], 90 | # "Instances": [{ 91 | # "input": datasets.Value("string"), 92 | # "output": [datasets.Value("string")] 93 | # }], 94 | "Instance": { 95 | "id": datasets.Value("string"), 96 | "input": datasets.Value("string"), 97 | "output": [datasets.Value("string")] 98 | }, 99 | "Instance License": [datasets.Value("string")] 100 | } 101 | ), 102 | supervised_keys=None, 103 | homepage="https://github.com/allenai/natural-instructions", 104 | citation=_CITATION, 105 | ) 106 | 107 | def _split_generators(self, dl_manager): 108 | """Returns SplitGenerators.""" 109 | if self.config.data_dir is None or self.config.task_dir is None: 110 | dl_path = dl_manager.download_and_extract(_URL) 111 | self.config.data_dir = self.config.data_dir or os.path.join(dl_path, "splits") 112 | self.config.task_dir = self.config.task_dir or os.path.join(dl_path, "tasks") 113 | 114 | split_dir = self.config.data_dir 115 | task_dir = self.config.task_dir 116 | 117 | return [ 118 | datasets.SplitGenerator( 119 | name=datasets.Split.TRAIN, 120 | gen_kwargs={ 121 | "path": os.path.join(split_dir, "train_tasks.txt"), 122 | "task_dir": task_dir, 123 | "max_num_instances_per_task": self.config.max_num_instances_per_task, 124 | "subset": "train" 125 | }), 126 | # datasets.SplitGenerator( 127 | # name=datasets.Split.VALIDATION, 128 | # gen_kwargs={ 129 | # "path": os.path.join(split_dir, "dev_tasks.txt"), 130 | # "task_dir": task_dir, 131 | # "max_num_instances_per_task": self.config.max_num_instances_per_eval_task, 132 | # "subset": "dev" 133 | # }), 134 | datasets.SplitGenerator( 135 | name=datasets.Split.TEST, 136 | gen_kwargs={ 137 | "path": os.path.join(split_dir, "test_tasks.txt"), 138 | "task_dir": task_dir, 139 | "max_num_instances_per_task": self.config.max_num_instances_per_eval_task, 140 | "subset": "test" 141 | }), 142 | ] 143 | 144 | def _generate_examples(self, path=None, task_dir=None, max_num_instances_per_task=None, subset=None): 145 | """Yields examples.""" 146 | logger.info(f"Generating tasks from = {path}") 147 | with open(path, encoding="utf-8") as split_f: 148 | for line in split_f: 149 | task_name = line.strip() 150 | task_path = os.path.join(task_dir, task_name + ".json") 151 | with open(task_path, encoding="utf-8") as task_f: 152 | s = task_f.read() 153 | task_data = json.loads(s) 154 | task_data["Task"] = task_name 155 | if "Instruction Source" in task_data: 156 | task_data.pop("Instruction Source") 157 | all_instances = task_data.pop("Instances") 158 | if subset == "test": 159 | # for testing tasks, 100 instances are selected for efficient evaluation and they are label-balanced. 160 | # we put them in the first for reproducibility. 161 | # so, we use them here 162 | instances = all_instances[:100] 163 | else: 164 | instances = all_instances 165 | if max_num_instances_per_task is not None and max_num_instances_per_task >= 0: 166 | random.shuffle(instances) 167 | instances = instances[:max_num_instances_per_task] 168 | for idx, instance in enumerate(instances): 169 | example = task_data.copy() 170 | example["id"] = instance["id"] 171 | example["Instance"] = instance 172 | yield f"{task_name}_{idx}", example 173 | 174 | -------------------------------------------------------------------------------- /nanoT5/utils/t5_model.py: -------------------------------------------------------------------------------- 1 | # From: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py 2 | 3 | import copy 4 | import math 5 | from typing import Optional 6 | from dataclasses import dataclass 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import CrossEntropyLoss 11 | 12 | from transformers.modeling_utils import ModuleUtilsMixin 13 | from transformers.modeling_outputs import ModelOutput 14 | from transformers.models.t5.configuration_t5 import T5Config 15 | from transformers.models.t5.modeling_t5 import ( 16 | T5LayerNorm, 17 | T5DenseGatedActDense, 18 | ) 19 | 20 | 21 | @dataclass 22 | class EncoderOutput(ModelOutput): 23 | hidden_states: torch.FloatTensor = None 24 | attention_mask: torch.FloatTensor = None 25 | 26 | 27 | @dataclass 28 | class Seq2SeqLMOutput(ModelOutput): 29 | loss: torch.FloatTensor = None 30 | logits: torch.FloatTensor = None 31 | encoder_outputs: EncoderOutput = None 32 | 33 | 34 | class T5LayerFF(nn.Module): 35 | def __init__(self, config: T5Config): 36 | super().__init__() 37 | assert config.is_gated_act 38 | self.DenseReluDense = T5DenseGatedActDense(config) 39 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 40 | self.dropout = nn.Dropout(config.dropout_rate) 41 | 42 | def forward(self, hidden_states): 43 | forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states) 44 | forwarded_states = self.DenseReluDense(forwarded_states) 45 | hidden_states = hidden_states + self.dropout(forwarded_states) 46 | return hidden_states 47 | 48 | 49 | class T5Attention(nn.Module): 50 | def __init__(self, config: T5Config, has_relative_attention_bias=False): 51 | super().__init__() 52 | self.is_decoder = config.is_decoder 53 | self.has_relative_attention_bias = has_relative_attention_bias 54 | self.relative_attention_num_buckets = config.relative_attention_num_buckets 55 | self.relative_attention_max_distance = config.relative_attention_max_distance 56 | self.d_model = config.d_model 57 | self.key_value_proj_dim = config.d_kv 58 | self.n_heads = config.num_heads 59 | self.dropout = config.dropout_rate 60 | self.inner_dim = self.n_heads * self.key_value_proj_dim 61 | 62 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) 63 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) 64 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) 65 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) 66 | 67 | if self.has_relative_attention_bias: 68 | self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) 69 | 70 | @staticmethod 71 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 72 | """ 73 | Adapted from Mesh Tensorflow: 74 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 75 | 76 | Translate relative position to a bucket number for relative attention. The relative position is defined as 77 | memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to 78 | position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for 79 | small absolute relative_position and larger buckets for larger absolute relative_positions. All relative 80 | positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. 81 | This should allow for more graceful generalization to longer sequences than the model has been trained on 82 | 83 | Args: 84 | relative_position: an int32 Tensor 85 | bidirectional: a boolean - whether the attention is bidirectional 86 | num_buckets: an integer 87 | max_distance: an integer 88 | 89 | Returns: 90 | a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) 91 | """ 92 | relative_buckets = 0 93 | if bidirectional: 94 | num_buckets //= 2 95 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets 96 | relative_position = torch.abs(relative_position) 97 | else: 98 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) 99 | # now relative_position is in the range [0, inf) 100 | 101 | # half of the buckets are for exact increments in positions 102 | max_exact = num_buckets // 2 103 | is_small = relative_position < max_exact 104 | 105 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 106 | relative_position_if_large = max_exact + ( 107 | torch.log(relative_position.float() / max_exact) 108 | / math.log(max_distance / max_exact) 109 | * (num_buckets - max_exact) 110 | ).to(torch.long) 111 | relative_position_if_large = torch.min( 112 | relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) 113 | ) 114 | 115 | relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) 116 | return relative_buckets 117 | 118 | def compute_bias(self, query_length, key_length, device=None): 119 | """Compute binned relative position bias""" 120 | if device is None: 121 | device = self.relative_attention_bias.weight.device 122 | context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] 123 | memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] 124 | relative_position = memory_position - context_position # shape (query_length, key_length) 125 | relative_position_bucket = self._relative_position_bucket( 126 | relative_position, # shape (query_length, key_length) 127 | bidirectional=(not self.is_decoder), 128 | num_buckets=self.relative_attention_num_buckets, 129 | max_distance=self.relative_attention_max_distance, 130 | ) 131 | values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) 132 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) 133 | return values 134 | 135 | def forward( 136 | self, 137 | hidden_states, 138 | mask=None, 139 | key_value_states=None, 140 | position_bias=None, 141 | ): 142 | """ 143 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). 144 | """ 145 | # Input is (batch_size, seq_length, dim) 146 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) 147 | batch_size, seq_length = hidden_states.shape[:2] 148 | real_seq_length = seq_length 149 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] 150 | 151 | def shape(states): 152 | """projection""" 153 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 154 | 155 | def unshape(states): 156 | """reshape""" 157 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) 158 | 159 | query_states = self.q(hidden_states) 160 | if key_value_states is None: 161 | key_states, value_states = self.k(hidden_states), self.v(hidden_states) 162 | else: 163 | key_states, value_states = self.k(key_value_states), self.v(key_value_states) 164 | query_states, key_states, value_states = shape(query_states), shape(key_states), shape(value_states) 165 | 166 | scores = torch.matmul( 167 | query_states, key_states.transpose(3, 2) 168 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 169 | 170 | if position_bias is None: 171 | if not self.has_relative_attention_bias: 172 | position_bias = torch.zeros( 173 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype 174 | ) 175 | else: 176 | position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) 177 | 178 | if mask is not None: 179 | # Masking happens here, masked elements in the mask have the value of -inf 180 | position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) 181 | 182 | position_bias_masked = position_bias 183 | 184 | scores += position_bias_masked 185 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( 186 | scores 187 | ) # (batch_size, n_heads, seq_length, key_length) 188 | attn_weights = nn.functional.dropout( 189 | attn_weights, p=self.dropout, training=self.training 190 | ) # (batch_size, n_heads, seq_length, key_length) 191 | 192 | attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) 193 | attn_output = self.o(attn_output) 194 | 195 | return (attn_output, position_bias) 196 | 197 | 198 | class T5LayerSelfAttention(nn.Module): 199 | def __init__(self, config, has_relative_attention_bias=False): 200 | super().__init__() 201 | self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) 202 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 203 | self.dropout = nn.Dropout(config.dropout_rate) 204 | 205 | def forward( 206 | self, 207 | hidden_states, 208 | attention_mask=None, 209 | position_bias=None, 210 | ): 211 | normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states) 212 | attention_output = self.SelfAttention( 213 | normed_hidden_states, 214 | mask=attention_mask, 215 | position_bias=position_bias, 216 | ) 217 | hidden_states = hidden_states + self.dropout(attention_output[0]) 218 | outputs = (hidden_states,) + attention_output[1:] 219 | return outputs 220 | 221 | 222 | class T5LayerCrossAttention(nn.Module): 223 | def __init__(self, config): 224 | super().__init__() 225 | self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) 226 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 227 | self.dropout = nn.Dropout(config.dropout_rate) 228 | 229 | def forward( 230 | self, 231 | hidden_states, 232 | key_value_states, 233 | attention_mask=None, 234 | position_bias=None, 235 | ): 236 | normed_hidden_states = self.layer_norm(hidden_states) 237 | attention_output = self.EncDecAttention( 238 | normed_hidden_states, 239 | mask=attention_mask, 240 | key_value_states=key_value_states, 241 | position_bias=position_bias, 242 | ) 243 | layer_output = hidden_states + self.dropout(attention_output[0]) 244 | outputs = (layer_output,) + attention_output[1:] 245 | return outputs 246 | 247 | 248 | class T5Block(nn.Module): 249 | def __init__(self, config, has_relative_attention_bias=False): 250 | super().__init__() 251 | self.is_decoder = config.is_decoder 252 | self.layer = nn.ModuleList() 253 | self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) 254 | if self.is_decoder: 255 | self.layer.append(T5LayerCrossAttention(config)) 256 | 257 | self.layer.append(T5LayerFF(config)) 258 | 259 | def forward( 260 | self, 261 | hidden_states, 262 | attention_mask=None, 263 | position_bias=None, 264 | encoder_hidden_states=None, 265 | encoder_attention_mask=None, 266 | encoder_decoder_position_bias=None, 267 | ): 268 | self_attention_outputs = self.layer[0]( 269 | hidden_states, 270 | attention_mask=attention_mask, 271 | position_bias=position_bias, 272 | ) 273 | hidden_states = self_attention_outputs[0] 274 | attention_outputs = self_attention_outputs[1:] # Relative position weights 275 | 276 | if self.is_decoder and encoder_hidden_states is not None: 277 | cross_attention_outputs = self.layer[1]( 278 | hidden_states, 279 | key_value_states=encoder_hidden_states, 280 | attention_mask=encoder_attention_mask, 281 | position_bias=encoder_decoder_position_bias, 282 | ) 283 | hidden_states = cross_attention_outputs[0] 284 | 285 | # Keep relative position weights 286 | attention_outputs = attention_outputs + cross_attention_outputs[1:] 287 | 288 | # Apply Feed Forward layer 289 | hidden_states = self.layer[-1](hidden_states) 290 | 291 | outputs = (hidden_states,) 292 | outputs = outputs + attention_outputs 293 | 294 | return outputs # hidden-states, (self-attention position bias), (cross-attention position bias) 295 | 296 | 297 | class T5Stack(nn.Module, ModuleUtilsMixin): 298 | def __init__(self, config, embed_tokens): 299 | super().__init__() 300 | assert embed_tokens is not None 301 | 302 | self.config = config 303 | self.embed_tokens = embed_tokens 304 | self.is_decoder = config.is_decoder 305 | 306 | self.block = nn.ModuleList( 307 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] 308 | ) 309 | 310 | self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 311 | self.dropout = nn.Dropout(config.dropout_rate) 312 | 313 | def forward( 314 | self, 315 | input_ids=None, 316 | attention_mask=None, 317 | encoder_hidden_states=None, 318 | encoder_attention_mask=None, 319 | ) -> EncoderOutput: 320 | input_shape = input_ids.size() 321 | batch_size, seq_length = input_shape 322 | 323 | inputs_embeds = self.embed_tokens(input_ids) 324 | 325 | if hasattr(self.config, 'is_bf16') and self.config.is_bf16: 326 | inputs_embeds = inputs_embeds.to(torch.bfloat16) 327 | 328 | # Masking 329 | if attention_mask is None: 330 | attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device) 331 | 332 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 333 | encoder_seq_length = encoder_hidden_states.shape[1] 334 | encoder_attention_mask = torch.ones( 335 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 336 | ) 337 | 338 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 339 | # ourselves in which case we just need to make it broadcastable to all heads. 340 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) 341 | 342 | # If a 2D or 3D attention mask is provided for the cross-attention 343 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 344 | if self.is_decoder and encoder_hidden_states is not None: 345 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 346 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 347 | if encoder_attention_mask is None: 348 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) 349 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 350 | else: 351 | encoder_extended_attention_mask = None 352 | 353 | position_bias = None 354 | encoder_decoder_position_bias = None 355 | 356 | hidden_states = self.dropout(inputs_embeds) 357 | 358 | for _, layer_module in enumerate(self.block): 359 | layer_outputs = layer_module( 360 | hidden_states, 361 | attention_mask=extended_attention_mask, 362 | position_bias=position_bias, 363 | encoder_hidden_states=encoder_hidden_states, 364 | encoder_attention_mask=encoder_extended_attention_mask, 365 | encoder_decoder_position_bias=encoder_decoder_position_bias, 366 | ) 367 | hidden_states = layer_outputs[0] 368 | 369 | # We share the position biases between the layers - the first layer store them 370 | position_bias = layer_outputs[1] 371 | if self.is_decoder and encoder_hidden_states is not None: 372 | encoder_decoder_position_bias = layer_outputs[2] 373 | 374 | hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states) 375 | hidden_states = self.dropout(hidden_states) 376 | 377 | return EncoderOutput( 378 | hidden_states=hidden_states, 379 | attention_mask=attention_mask, 380 | ) 381 | 382 | 383 | class MyT5(nn.Module): 384 | def __init__(self, config: T5Config): 385 | super().__init__() 386 | config.is_encoder_decoder = False 387 | assert not config.tie_word_embeddings 388 | 389 | self.config = config 390 | self.model_dim = config.d_model 391 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 392 | 393 | encoder_config = copy.deepcopy(config) 394 | encoder_config.is_decoder = False 395 | self.encoder = T5Stack(encoder_config, self.shared) 396 | 397 | decoder_config = copy.deepcopy(config) 398 | decoder_config.is_decoder = True 399 | decoder_config.num_layers = config.num_decoder_layers 400 | self.decoder = T5Stack(decoder_config, self.shared) 401 | 402 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 403 | self.generation_config = None 404 | 405 | self.apply(self._init_weights) 406 | 407 | def generate( 408 | self, 409 | input_ids: Optional[torch.LongTensor] = None, 410 | attention_mask: Optional[torch.FloatTensor] = None, 411 | max_length = None, 412 | **kwargs, 413 | ) -> torch.LongTensor: 414 | """ 415 | input_ids: B x L_encoder, int64 416 | attention_mask: B x L_encoder, int64 417 | 1 for tokens to attend to, 0 for tokens to ignore 418 | 419 | Generation: 420 | Starts with 0, ends with 1, padding is 0 421 | 422 | # For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s 423 | """ 424 | B, _ = input_ids.size() 425 | labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device) 426 | encoder_outputs = None 427 | 428 | for _ in range(max_length): 429 | out = self.forward( 430 | input_ids=input_ids, 431 | attention_mask=attention_mask, 432 | decoder_input_ids=labels, 433 | encoder_outputs=encoder_outputs, 434 | ) 435 | encoder_outputs = out.encoder_outputs 436 | top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1) 437 | labels = torch.cat([labels, top_labels], dim=-1) 438 | 439 | if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B: 440 | break 441 | 442 | labels[:, -1] = 1 443 | 444 | # Mask out the padding, i.e., all positions after the first 1 with 0 445 | B, L = labels.size() 446 | mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1) 447 | labels = labels.masked_fill(~mask, 0) 448 | 449 | return labels 450 | 451 | def forward( 452 | self, 453 | input_ids: Optional[torch.LongTensor] = None, 454 | attention_mask: Optional[torch.FloatTensor] = None, 455 | decoder_input_ids: Optional[torch.LongTensor] = None, 456 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 457 | labels: Optional[torch.LongTensor] = None, 458 | encoder_outputs = None, 459 | ) -> Seq2SeqLMOutput: 460 | """ 461 | input_ids: B x L_encoder, int64 462 | attention_mask: B x L_encoder, int64 463 | 1 for tokens to attend to, 0 for tokens to ignore 464 | labels: B x L_decoder, int64 465 | """ 466 | if encoder_outputs is None: 467 | encoder_outputs = self.encoder( 468 | input_ids=input_ids, 469 | attention_mask=attention_mask, 470 | ) 471 | 472 | hidden_states = encoder_outputs.hidden_states 473 | 474 | if labels is not None and decoder_input_ids is None: 475 | decoder_input_ids = self._shift_right(labels) 476 | 477 | decoder_outputs = self.decoder( 478 | input_ids=decoder_input_ids, 479 | attention_mask=decoder_attention_mask, 480 | encoder_hidden_states=hidden_states, 481 | encoder_attention_mask=attention_mask, 482 | ) 483 | 484 | sequence_output = decoder_outputs[0] 485 | lm_logits = self.lm_head(sequence_output) 486 | 487 | loss = None 488 | if labels is not None: 489 | loss_fct = CrossEntropyLoss(ignore_index=-100) 490 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 491 | 492 | return Seq2SeqLMOutput( 493 | loss=loss, 494 | logits=lm_logits, 495 | encoder_outputs=encoder_outputs, 496 | ) 497 | 498 | def _init_weights(self, module): 499 | factor = self.config.initializer_factor # Used for testing weights initialization 500 | if isinstance(module, T5LayerNorm): 501 | module.weight.data.fill_(factor * 1.0) 502 | elif isinstance(module, (MyT5)): 503 | module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) 504 | if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: 505 | module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) 506 | elif isinstance(module, T5DenseGatedActDense): 507 | d_ff, d_model = module.wi_0.weight.data.size() 508 | module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) 509 | module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) 510 | module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) 511 | elif isinstance(module, T5Attention): 512 | d_model = self.config.d_model 513 | key_value_proj_dim = self.config.d_kv 514 | n_heads = self.config.num_heads 515 | module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) 516 | module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) 517 | module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) 518 | module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) 519 | if hasattr(module, "relative_attention_bias"): 520 | module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) 521 | 522 | def _shift_right(self, input_ids): 523 | decoder_start_token_id = self.config.decoder_start_token_id 524 | pad_token_id = self.config.pad_token_id 525 | 526 | assert decoder_start_token_id is not None and pad_token_id is not None 527 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 528 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 529 | shifted_input_ids[..., 0] = decoder_start_token_id 530 | 531 | # replace possible -100 values in labels by `pad_token_id` 532 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 533 | 534 | return shifted_input_ids -------------------------------------------------------------------------------- /nanoT5/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import evaluate 4 | from .logging_utils import Averager 5 | from datasets.iterable_dataset import IterableDataset 6 | 7 | 8 | def maybe_save_checkpoint(accelerator, args): 9 | if ( 10 | args.current_train_step > args.optim.total_steps 11 | or args.current_train_step % args.checkpoint.every_steps == 0 12 | ): 13 | output_dir = f'checkpoint-{args.mode}-{args.current_train_step}' 14 | accelerator.save_state(output_dir=output_dir) 15 | 16 | 17 | def maybe_eval_predict(model, dataloader, logger, args, tokenizer): 18 | if ( 19 | args.current_train_step > args.optim.total_steps 20 | or args.current_train_step % args.eval.every_steps == 0 21 | ): 22 | model.eval() 23 | 24 | with torch.no_grad(): 25 | eval(model, dataloader, logger, args, tokenizer) 26 | 27 | if args.mode == 'ft': 28 | predict( 29 | model, dataloader, logger, args, tokenizer 30 | ) 31 | 32 | args.last_log = time.time() 33 | model.train() 34 | 35 | 36 | def maybe_logging(averager, args, model, optimizer, logger): 37 | if args.current_train_step % args.logging.every_steps == 0: 38 | stats = extra_stats(args, model, optimizer) 39 | 40 | averager.update(stats) 41 | averaged_stats = averager.average() 42 | 43 | logger.log_stats( 44 | stats=averaged_stats, 45 | step=args.current_train_step, 46 | args=args, 47 | prefix='train/' 48 | ) 49 | 50 | args.last_log = time.time() 51 | 52 | 53 | def maybe_grad_clip_and_grad_calc(accelerator, model, args): 54 | if args.optim.grad_clip > 0: 55 | grad_l2 = accelerator.clip_grad_norm_( 56 | parameters=model.parameters(), 57 | max_norm=args.optim.grad_clip, 58 | norm_type=2, 59 | ) 60 | else: 61 | grad_l2 = None 62 | 63 | if args.logging.grad_l2: 64 | if grad_l2 is None: 65 | grad_l2 = ( 66 | sum(p.grad.detach().data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 67 | ) 68 | 69 | return {'grad_l2': grad_l2} 70 | else: 71 | return {} 72 | 73 | 74 | def extra_stats(args, model, optimizer): 75 | stats = {} 76 | 77 | if args.logging.weights_l2: 78 | weights_l2 = sum(p.detach().norm(2).item() ** 2 for p in model.parameters()) ** 0.5 79 | stats['weights_l2'] = weights_l2 80 | 81 | stats['lr'] = optimizer.param_groups[0]['lr'] 82 | stats['seconds_per_step'] = (time.time() - args.last_log) / args.logging.every_steps 83 | 84 | return stats 85 | 86 | 87 | def forward(model, batch, calc_acc=False): 88 | outputs = model(**batch) 89 | loss = outputs.loss 90 | 91 | stats = {} 92 | stats['loss'] = loss.detach().float().item() 93 | 94 | if calc_acc: 95 | correct = (outputs.logits.argmax(-1) == batch["labels"]).sum().item() 96 | accuracy = correct / batch["labels"].numel() 97 | stats['accuracy'] = accuracy 98 | 99 | return loss, stats 100 | 101 | 102 | def eval(model, dataloader, logger, args, tokenizer): 103 | args.last_log = time.time() 104 | averager = Averager() 105 | 106 | for batch_id, batch in enumerate(dataloader, start=1): 107 | if batch_id == args.eval.corrected_steps * args.optim.grad_acc: 108 | break 109 | 110 | _, stats = forward(model, batch, calc_acc=True) 111 | averager.update(stats) 112 | 113 | averager.update({'time': time.time() - args.last_log}) 114 | averaged_stats = averager.average() 115 | 116 | logger.log_stats( 117 | stats=averaged_stats, 118 | step=args.current_train_step, 119 | args=args, 120 | prefix='eval/' 121 | ) 122 | 123 | 124 | def predict(model, dataloader, logger, args, tokenizer): 125 | args.last_log = time.time() 126 | metric = evaluate.load('rouge') 127 | samples_seen = 0 128 | 129 | def decode(preds): 130 | preds[preds == -100] = tokenizer.pad_token_id 131 | preds = tokenizer.batch_decode( 132 | preds, skip_special_tokens=True, clean_up_tokenization_spaces=True 133 | ) 134 | preds = [pred.strip() for pred in preds] 135 | return preds 136 | 137 | for step, batch in enumerate(dataloader): 138 | predictions = model.generate( 139 | input_ids=batch['input_ids'], 140 | attention_mask=batch['attention_mask'], 141 | max_length=args.data.max_target_len, 142 | generation_config=model.generation_config, 143 | ) 144 | predictions = decode(predictions) 145 | references = decode(batch["labels"]) 146 | 147 | # If we are in a multiprocess environment, the last batch has duplicates 148 | if step == len(dataloader) - 1: 149 | predictions = predictions[: len(dataloader.dataset) - samples_seen] 150 | references = references[: len(dataloader.dataset) - samples_seen] 151 | else: 152 | samples_seen += len(references) 153 | 154 | metric.add_batch( 155 | predictions=predictions, 156 | references=references, 157 | ) 158 | 159 | eval_metric = metric.compute(use_stemmer=True, use_aggregator=False) 160 | rougeL = sum(eval_metric["rougeL"]) * 100 / len(eval_metric["rougeL"]) 161 | 162 | logger.log_stats( 163 | stats={ 164 | "rougeL": rougeL, 165 | "time": time.time() - args.last_log, 166 | }, 167 | step=args.current_train_step, 168 | args=args, 169 | prefix="test/", 170 | ) 171 | 172 | 173 | def train(model, train_dataloader, test_dataloader, accelerator, lr_scheduler, 174 | optimizer, logger, args, tokenizer): 175 | model.train() 176 | 177 | train_averager = Averager() 178 | 179 | while args.current_train_step <= args.optim.total_steps: 180 | if isinstance(train_dataloader.dataset, IterableDataset): 181 | train_dataloader.dataset.set_epoch(args.current_epoch) 182 | 183 | # In case there is a remainder from previous epoch, we need to reset the optimizer 184 | optimizer.zero_grad(set_to_none=True) 185 | 186 | for batch_id, batch in enumerate(train_dataloader, start=1): 187 | if args.current_train_step > args.optim.total_steps: 188 | break 189 | 190 | loss, stats = forward(model, batch) 191 | accelerator.backward(loss / args.optim.grad_acc) 192 | train_averager.update(stats) 193 | 194 | if batch_id % args.optim.grad_acc == 0: 195 | stats = maybe_grad_clip_and_grad_calc(accelerator, model, args) 196 | train_averager.update(stats) 197 | 198 | optimizer.step() 199 | lr_scheduler.step() 200 | optimizer.zero_grad(set_to_none=True) 201 | 202 | maybe_logging(train_averager, args, model, optimizer, logger) 203 | maybe_eval_predict(model, test_dataloader, logger, args, tokenizer) 204 | maybe_save_checkpoint(accelerator, args) 205 | 206 | args.current_train_step += 1 207 | 208 | args.current_epoch += 1 209 | 210 | maybe_eval_predict(model, test_dataloader, logger, args, tokenizer) 211 | maybe_save_checkpoint(accelerator, args) 212 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | datasets >= 1.8.0 3 | sentencepiece != 0.1.92 4 | transformers 5 | neptune 6 | pdbpp 7 | notebook 8 | protobuf==3.20.* 9 | pyyaml 10 | pynvml 11 | hydra-core 12 | evaluate 13 | nltk 14 | absl-py 15 | rouge_score 16 | torch>=1.13.1,<=2.0.1 --------------------------------------------------------------------------------